question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

It would be nice if jax.device_get returned numpy scalars

See original GitHub issue

The following behavior is not quite idiomatic numpy:

In [5]: jax.device_get(jnp.abs(7))
Out[5]: array(7, dtype=int32)

Ideally, this would instead print 7. This is what would happen if rank 0 jax arrays converted to numpy scalars in jax.device_get, as is the default behavior in both numpy and tensorflow.

Is it possible to fix jax.device_get to make numpy scalars?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
shoyercommented, Feb 5, 2020

NumPy does return scalars almost all the time instead of 0d arrays, but personally I find that behavior from NumPy pretty confusing/inconsistent. I would rather that functions have consistent return types, e.g., always return arrays.

If I were redesigning NumPy, I would even try to get rid of NumPy scalars entirely. So I consider this a bit of an improvement in JAX’s design over the way NumPy works.

That’s just my two cents, though, I’d love to hear from anyone else who has a strong opinion 😃.

0reactions
mattjjcommented, Feb 6, 2020

Well actually, to be more precise, they repr as array(7, dtype=int32) but they print just like Python integers:

In [1]: from jax import jit

In [2]: jit(lambda x: x)(7)
Out[2]: DeviceArray(7, dtype=int32)

In [3]: print(jit(lambda x: x)(7))
7

We want the repr method to act that way because, as per Python conventions, we want to indicate that the type of the object is different (i.e. repr conveys more information, and isn’t made for printing ergonomics).

If you want to convert back to a Python integer or float, why not just use the builtin int or float?

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.numpy package - JAX documentation - Read the Docs
Concatenate slices, scalars and array-like objects along the last axis. can_cast (from_, to[, casting]). Returns True if cast between data types can occur ......
Read more >
Differentiating Awkward Arrays Using JAX - CERN Indico
Awkward Scalars are Python numbers, while JAX scalars are 0-dimensional arrays. There has to be a notion of a scalar in the Awkward...
Read more >
NumPy User Guide
If you need to get, or even set, properties of an array ... If N = 1 then the returned object is an...
Read more >
symjax - arXiv
print(w.value) # returns 10 ... consistent up to NumPy rank/shape broadcasting and dtype promotion ... This is a scalar if x is a...
Read more >
From PyTorch to JAX: towards neural net frameworks that ...
I would describe JAX as numpy, but on GPU, and then move on to the one ... i.e., how to get gradients of...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found