Cannot print a doubly-reshaped ShardedDeviceArray
See original GitHub issueThe issue can be replicated using the following repro code
import jax
x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
print(x.reshape((-1, 2)).reshape(-1))
The error message is
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-129-21f9347dfcf4> in <module>
1 import jax
2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 print(x.reshape((-1, 2)).reshape(-1))
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _forward_method(attrname, self, fun, *args)
969
970 def _forward_method(attrname, self, fun, *args):
--> 971 return fun(getattr(self, attrname), *args)
972 _forward_to_value = partial(_forward_method, "_value")
973
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in _value(self)
587 def _value(self):
588 if self._npy_value is None:
--> 589 self.copy_to_host_async()
590 npy_value = onp.empty(self.aval.shape, self.aval.dtype)
591 for i in self.one_replica_buffer_indices:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in copy_to_host_async(self)
566 def copy_to_host_async(self):
567 for buffer_index in self.one_replica_buffer_indices:
--> 568 self.device_buffers[buffer_index].copy_to_host_async()
569
570 def delete(self):
IndexError: list index out of range
Actually, we can’t do anything with doubly-reshaped ShardedDeviceArray. Something likes x.reshape((-1, 2)).reshape(-1) * 1
displays a better error message
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-136-b3cc850a373d> in <module>
1 import jax
2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 y = x.reshape((-1, 2)).reshape(-1) * 1
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in deferring_binary_op(self, other)
4257 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
4258 return NotImplemented
-> 4259 return binary_op(self, other)
4260 return deferring_binary_op
4261
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in fn(x1, x2)
338 def fn(x1, x2):
339 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 340 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
341 return _wraps(numpy_fn)(fn)
342
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/lax/lax.py in mul(x, y)
306 def mul(x: Array, y: Array) -> Array:
307 r"""Elementwise multiplication: :math:`x \times y`."""
--> 308 return mul_p.bind(x, y)
309
310 def div(x: Array, y: Array) -> Array:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **kwargs)
271 top_trace = find_top_trace(args)
272 if top_trace is None:
--> 273 return self.impl(*args, **kwargs)
274
275 tracers = map(top_trace.full_raise, args)
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
227 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
228 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 229 return compiled_fun(*args)
230
231 @cache()
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
327 device, = compiled.local_devices()
--> 328 input_bufs = [device_put(x, device) for x in args if x is not token]
329 out_bufs = compiled.execute(input_bufs)
330 if FLAGS.jax_debug_nans:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
327 device, = compiled.local_devices()
--> 328 input_bufs = [device_put(x, device) for x in args if x is not token]
329 out_bufs = compiled.execute(input_bufs)
330 if FLAGS.jax_debug_nans:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in device_put(x, device)
115 x = canonicalize_dtype(x)
116 try:
--> 117 return device_put_handlers[type(x)](x, device)
118 except KeyError as err:
119 raise TypeError(f"No device_put handler for type: {type(x)}") from err
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _device_put_array(x, device)
121 def _device_put_array(x, device: Optional[Device]):
122 backend = xb.get_device_backend(device)
--> 123 return backend.buffer_from_pyval(x, device)
124
125 def _device_put_scalar(x, device):
RuntimeError: Invalid argument: from_python argument must be an array.
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@fehiepsi , you’re too kind. Thanks for the heads up and the offer.
For the last ~month, the only updates to #3370 have been rebasing it on master as I struggle to land it in Google’s monorepo. The trouble is that with a monorepo I need to make sure all the existing code works with any updates to JAX, and #3370 breaks a lot of code by making some common performance bugs into hard errors. New code is being added fast enough that I’m not sure I’ve even made any progress.
I’ve been really stuck on this, so I hesitate to give any new projected date on when #3370 will land. I have one more idea to try this week. If it doesn’t look like it’s going to work by Wednesday evening, I’ll try to fix this issue directly on master rather than blocking it on #3370. Does that work for you?
I am, of course, behind schedule 😃 Maybe it’s worth just fixing this issue directly, and not blocking on #3370…