jnp.einsum does not catch shape mismatch when optimize=True
See original GitHub issueThere seems to be a bug in jnp.einsum. The function does not catch a shape mismatch when the optimize flag is set to True.
import jax.numpy as jnp
A = jnp.ones(2)[:,None]
x = jnp.ones(2)
try:
A@x
print('Passed?')
except:
print('Failed as expected')
try:
jnp.einsum('ij,j->i', A, x, optimize=False)
print('Passed?')
except:
print('Failed as expected')
try:
out = jnp.einsum('ij,j->i', A, x, optimize=True)
print('Passed?')
print(out)
except:
print('Failed as expected')
# Failed as expected
# Failed as expected
# Passed?
# [2. 2.]
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
numpy.einsum() - JAX documentation - Read the Docs
In implicit mode einsum computes these values. In explicit mode, einsum provides further flexibility to compute other array operations that might not be ......
Read more >Einsum optimize fails for basic operation - Stack Overflow
With the recent update to Numpy (1.14), I found that it breaks my entire codebase. This is based on changing the default numpy...
Read more >numpy.einsum — NumPy v1.24 Manual
This means that np.einsum('ij', a) doesn't affect a 2D array, ... If the output shape is not provided in this format einsum will...
Read more >torch.einsum — PyTorch 1.13 documentation
The ellipsis does not need to cover the same number of dimensions across the operands but the 'shape' of the ellipsis (the size...
Read more >tf.einsum | TensorFlow v2.11.0
einsum . *inputs, the inputs to contract (each one a Tensor ), whose shapes should be consistent with equation .
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for catching this @EddieCunningham , and wow nice digging @IgorWilbert ! So should we consider this a bug in
opt_einsum
?Maybe we should do our own shape checking.
I think both NumPy and JAX should be strict about requiring exactly matching sizes. The entire point of einsum is to avoid implicit broadcasting…