create_token XLA error on GPU (and rarely CPU)
See original GitHub issueI ran example from readme on version 0.2.7 (installation for other versions fails because of this issue). Below is my code and error: Am I doing something wrong?
code:
import jax
import mpi4jax
comm = MPI.COMM_WORLD
a = jax.numpy.ones((5,4))
b = mpi4jax.Allreduce(a, op=MPI.SUM, comm=comm)
b_jit = jax.jit(lambda x: mpi4jax.Allreduce(x, op=MPI.SUM, comm=comm))(a)
error:
File "test.py", line 7, in <module>
b = mpi4jax.Allreduce(a, op=MPI.SUM, comm=comm)
File "/opt/conda/lib/python3.7/site-packages/mpi4jax/validation.py", line 90, in wrapped
return function(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/mpi4jax/collective_ops/allreduce.py", line 66, in Allreduce
return mpi_allreduce_p.bind(x, token, op=op, comm=comm, transpose=_transpose)
File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 628, in process_primitive
return primitive.impl(*tracers, **params)
File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 238, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/opt/conda/lib/python3.7/site-packages/jax/_src/util.py", line 198, in wrapper
return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/jax/_src/util.py", line 191, in cached
return f(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 288, in xla_primitive_callable
compiled = backend_compile(backend, built_c, options)
File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 352, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: Unexpected shape kind for %after-all.2 = token[] after-all(), metadata={op_type="allreduce_mpi" op_name="allreduce_mpi[ comm=<mpi4jax.utils.HashableMPIType object at 0x7fd1025dfa20>\n op=<mpi4jax.utils.HashableMPIType object at 0x7fd1025df9b0>\n transpose=False ]"} and shape index {}
Issue Analytics
- State:
- Created 3 years ago
- Comments:5
Top Results From Across the Web
CustomCall: correctly handling Token objects · Issue #5707
create_token XLA error on GPU (and rarely CPU) mpi4jax/mpi4jax#40 ... Also, the error is triggered only spuriously on the CPU backend, ...
Read more >Known Issues | XLA - TensorFlow
Error message: INVALID_ARGUMENT: Trying to access resource <Variable> (defined @ <Loc>) located in device CPU:0 from device GPU:0. XLA ...
Read more >TensorFlow: XLA not running with "Invalid argument: No ...
This works on CPU (and GPU as well): jit_scope = tf.contrib.compiler.jit.experimental_jit_scope with jit_scope(): .
Read more >Troubleshooting a PC that REFUSES to POST! Here's how I ...
You'll never guess what the issue was... JayzTwoCents ... How To Test A CPU To See If It Works - Pt 4 Troubleshoot...
Read more >DeepPicker error: Input number of GPUs must be less than or ...
Hi everyone, I tried using DeepPicker for particle picking (select some particles with template picker and use them & respective micrographs ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
More probably it’s jax starting to validate token data, which we do not properly propagate along our operations.
I opened Jax#5707 for that.
@fgvbrt I’ll try to quickly fix mpi4jax once jax people tell us how to treat this case. in the meantime, the only workaround is to use an older version of jax, libjax and mpi4jax.
Fixed in the new jaxlib release.