question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Cannot print a doubly-reshaped ShardedDeviceArray

See original GitHub issue

The issue can be replicated using the following repro code

import jax
x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
print(x.reshape((-1, 2)).reshape(-1))

The error message is

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-129-21f9347dfcf4> in <module>
      1 import jax
      2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 print(x.reshape((-1, 2)).reshape(-1))

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _forward_method(attrname, self, fun, *args)
    969 
    970 def _forward_method(attrname, self, fun, *args):
--> 971   return fun(getattr(self, attrname), *args)
    972 _forward_to_value = partial(_forward_method, "_value")
    973 

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in _value(self)
    587   def _value(self):
    588     if self._npy_value is None:
--> 589       self.copy_to_host_async()
    590       npy_value = onp.empty(self.aval.shape, self.aval.dtype)
    591       for i in self.one_replica_buffer_indices:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in copy_to_host_async(self)
    566   def copy_to_host_async(self):
    567     for buffer_index in self.one_replica_buffer_indices:
--> 568       self.device_buffers[buffer_index].copy_to_host_async()
    569 
    570   def delete(self):

IndexError: list index out of range

Actually, we can’t do anything with doubly-reshaped ShardedDeviceArray. Something likes x.reshape((-1, 2)).reshape(-1) * 1 displays a better error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-136-b3cc850a373d> in <module>
      1 import jax
      2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 y = x.reshape((-1, 2)).reshape(-1) * 1

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in deferring_binary_op(self, other)
   4257     if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
   4258       return NotImplemented
-> 4259     return binary_op(self, other)
   4260   return deferring_binary_op
   4261 

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in fn(x1, x2)
    338   def fn(x1, x2):
    339     x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 340     return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
    341   return _wraps(numpy_fn)(fn)
    342 

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/lax/lax.py in mul(x, y)
    306 def mul(x: Array, y: Array) -> Array:
    307   r"""Elementwise multiplication: :math:`x \times y`."""
--> 308   return mul_p.bind(x, y)
    309 
    310 def div(x: Array, y: Array) -> Array:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **kwargs)
    271     top_trace = find_top_trace(args)
    272     if top_trace is None:
--> 273       return self.impl(*args, **kwargs)
    274 
    275     tracers = map(top_trace.full_raise, args)

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    227   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    228   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 229   return compiled_fun(*args)
    230 
    231 @cache()

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
    326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
    327   device, = compiled.local_devices()
--> 328   input_bufs = [device_put(x, device) for x in args if x is not token]
    329   out_bufs = compiled.execute(input_bufs)
    330   if FLAGS.jax_debug_nans:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
    326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
    327   device, = compiled.local_devices()
--> 328   input_bufs = [device_put(x, device) for x in args if x is not token]
    329   out_bufs = compiled.execute(input_bufs)
    330   if FLAGS.jax_debug_nans:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in device_put(x, device)
    115   x = canonicalize_dtype(x)
    116   try:
--> 117     return device_put_handlers[type(x)](x, device)
    118   except KeyError as err:
    119     raise TypeError(f"No device_put handler for type: {type(x)}") from err

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _device_put_array(x, device)
    121 def _device_put_array(x, device: Optional[Device]):
    122   backend = xb.get_device_backend(device)
--> 123   return backend.buffer_from_pyval(x, device)
    124 
    125 def _device_put_scalar(x, device):

RuntimeError: Invalid argument: from_python argument must be an array.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Jul 21, 2020

@fehiepsi , you’re too kind. Thanks for the heads up and the offer.

For the last ~month, the only updates to #3370 have been rebasing it on master as I struggle to land it in Google’s monorepo. The trouble is that with a monorepo I need to make sure all the existing code works with any updates to JAX, and #3370 breaks a lot of code by making some common performance bugs into hard errors. New code is being added fast enough that I’m not sure I’ve even made any progress.

I’ve been really stuck on this, so I hesitate to give any new projected date on when #3370 will land. I have one more idea to try this week. If it doesn’t look like it’s going to work by Wednesday evening, I’ll try to fix this issue directly on master rather than blocking it on #3370. Does that work for you?

1reaction
mattjjcommented, Jul 23, 2020

I am, of course, behind schedule 😃 Maybe it’s worth just fixing this issue directly, and not blocking on #3370

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found