question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`jax.numpy.sum` and `jax.numpy.mean` inconsistency on python lists

See original GitHub issue

This works:

import jax.numpy as jnp

x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.sum(x)

and gives me DeviceArray(2., dtype=float32). But this doesn’t:

import jax.numpy as jnp

x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.mean(x)

and I got TypeError: data type not understood.

In numpy both would work:

import numpy as np

x = [np.ones((), dtype=np.float32), np.ones((), dtype=np.float32)]
np.mean(x)  # 1.0
np.sum(x)  # 2.0

Is this intended?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Dec 2, 2020

Both sum and mean now raise errors at head!

0reactions
jakevdpcommented, May 15, 2020

From discussions with team members, I think the preferred appoach here would be to not accept lists as inputs to JAX functions (in particular, converting lists to arrays may lead to strange interactions with how we handle pytrees internally).

So to make the API more consistent, we should ensure that sum() and other routines raise an appropriate TypeError when passed a Python list.

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.sum() - JAX documentation - Read the Docs
Sum of array elements over a given axis. LAX-backend implementation of numpy.sum() . Original docstring below. Parameters. a (array_like) – Elements to sum....
Read more >
How to use the jax.numpy.sum function in jax - Snyk
To help you get started, we've selected a few jax.numpy.sum examples, ... multivariate Gaussian. return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std))).
Read more >
[D] Should We Be Using JAX in 2022? : r/MachineLearning
I posted benchmarks with a comparison of JAX vs NumPy both on CPU, and then with JAX on TPU further down to control...
Read more >
numpy gradient
This means our output shape (before taking the mean of each "inner" 10x10 array) would be: >>>. gradient python,numpy,gradient,networks,neural,optimizer ...
Read more >
tfp.substrates.jax.distributions.JointDistributionCoroutine
This limitation applies only in TensorFlow; vectorized samplers in JAX ... they are inconsistent with the definition of batch shapes used elsewhere in...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found