question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jnp.einsum does not catch shape mismatch when optimize=True

See original GitHub issue

There seems to be a bug in jnp.einsum. The function does not catch a shape mismatch when the optimize flag is set to True.

import jax.numpy as jnp

A = jnp.ones(2)[:,None]
x = jnp.ones(2)
try:
    A@x
    print('Passed?')
except:
    print('Failed as expected')

try:
    jnp.einsum('ij,j->i', A, x, optimize=False)
    print('Passed?')
except:
    print('Failed as expected')

try:
    out = jnp.einsum('ij,j->i', A, x, optimize=True)
    print('Passed?')
    print(out)
except:
    print('Failed as expected')
    
# Failed as expected
# Failed as expected
# Passed?
# [2. 2.]

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

3reactions
mattjjcommented, Jul 2, 2020

Thanks for catching this @EddieCunningham , and wow nice digging @IgorWilbert ! So should we consider this a bug in opt_einsum?

Maybe we should do our own shape checking.

0reactions
shoyercommented, Apr 21, 2021

I think both NumPy and JAX should be strict about requiring exactly matching sizes. The entire point of einsum is to avoid implicit broadcasting…

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.einsum() - JAX documentation - Read the Docs
In implicit mode einsum computes these values. In explicit mode, einsum provides further flexibility to compute other array operations that might not be ......
Read more >
Einsum optimize fails for basic operation - Stack Overflow
With the recent update to Numpy (1.14), I found that it breaks my entire codebase. This is based on changing the default numpy...
Read more >
numpy.einsum — NumPy v1.24 Manual
This means that np.einsum('ij', a) doesn't affect a 2D array, ... If the output shape is not provided in this format einsum will...
Read more >
torch.einsum — PyTorch 1.13 documentation
The ellipsis does not need to cover the same number of dimensions across the operands but the 'shape' of the ellipsis (the size...
Read more >
tf.einsum | TensorFlow v2.11.0
einsum . *inputs, the inputs to contract (each one a Tensor ), whose shapes should be consistent with equation .
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found