`jax.numpy.sum` and `jax.numpy.mean` inconsistency on python lists
See original GitHub issueThis works:
import jax.numpy as jnp
x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.sum(x)
and gives me DeviceArray(2., dtype=float32)
. But this doesn’t:
import jax.numpy as jnp
x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.mean(x)
and I got TypeError: data type not understood
.
In numpy both would work:
import numpy as np
x = [np.ones((), dtype=np.float32), np.ones((), dtype=np.float32)]
np.mean(x) # 1.0
np.sum(x) # 2.0
Is this intended?
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (1 by maintainers)
Top Results From Across the Web
numpy.sum() - JAX documentation - Read the Docs
Sum of array elements over a given axis. LAX-backend implementation of numpy.sum() . Original docstring below. Parameters. a (array_like) – Elements to sum....
Read more >How to use the jax.numpy.sum function in jax - Snyk
To help you get started, we've selected a few jax.numpy.sum examples, ... multivariate Gaussian. return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std))).
Read more >[D] Should We Be Using JAX in 2022? : r/MachineLearning
I posted benchmarks with a comparison of JAX vs NumPy both on CPU, and then with JAX on TPU further down to control...
Read more >numpy gradient
This means our output shape (before taking the mean of each "inner" 10x10 array) would be: >>>. gradient python,numpy,gradient,networks,neural,optimizer ...
Read more >tfp.substrates.jax.distributions.JointDistributionCoroutine
This limitation applies only in TensorFlow; vectorized samplers in JAX ... they are inconsistent with the definition of batch shapes used elsewhere in...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Both
sum
andmean
now raise errors at head!From discussions with team members, I think the preferred appoach here would be to not accept lists as inputs to JAX functions (in particular, converting lists to arrays may lead to strange interactions with how we handle pytrees internally).
So to make the API more consistent, we should ensure that
sum()
and other routines raise an appropriateTypeError
when passed a Python list.