out_axes specification issue when you run the script with sudo, works fine without sudo
See original GitHub issueout_axes specification issue when you run the script with sudo, works fine without sudo, But i want to run the script through a systemd service, so it gives this error.
even i run it as
sudo python3 device_serve.py
it gives same error, t works fine with python3 device_serve.py
here’s some stack trace.
key shape (8, 2) in shape (1, 2048) dp 1 mp 8 read from disk/gcs in 6.40554s Traceback (most recent call last): File "simple.py", line 149, in <module> output = network.generate(batched_tokens, length, gen_length, {"top_p": np.ones(total_batch) * 0.9, File "/home/ahmedjawed/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate return self.generate_xmap(self.state, File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 615, in fun_mapped out_flat = xmap_p.bind( File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 818, in bind return core.call_bind(self, fun, *args, **params) # type: ignore File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 821, in process return trace.process_xmap(self, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 646, in xmap_impl xmap_callable = make_xmap_callable( File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 673, in make_xmap_callable _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes raise TypeError(f"One of xmap results has an out_axes specification of " TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (4 by maintainers)
Top GitHub Comments
@srulikbd please install JAX using this pip install “jax[tpu]==0.2.16” -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Yeah, I tried installing packages using sudo, but it gives same error. Also jax=0.2.16 is working fine for me without sudo. jax=0.2.12 is giving this error on TPU VM F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.