Broadcasting error when using eqx.nn.Linear with use_bias=True
See original GitHub issueI’m pretty new to both jax and equinox, so there’s a chance I’m missing something, but I think that there’s a bug in the implementation of eqx.nn.Linear
when using a bias. Namely, the addition of the bias to the product weight*input fails whenever the input has more than one dimensions due to a broadcasting error. Example:
import jax
import jax.numpy as jnp
import equinox as eqx
key = jax.random.PRNGKey(0)
k1, k2, k3, k4 = jax.random.split(key, 4)
layer1 = eqx.nn.Linear(10, 8, use_bias=True, key=k1)
layer2 = eqx.nn.Linear(10, 8, use_bias=False, key=k2)
arr1 = jax.random.normal(k3, shape=(10, 100))
arr2 = jax.random.normal(k4, shape=(10,))
layer1(arr1) #fails with "Incompatible shapes for broadcasting: ((8, 100), (1, 8))"
layer2(arr1) #ok
layer1(arr2) #ok
layer2(arr2) #ok
Thanks for the cool library!
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Depends on the library author’s preferences! But yes, everything I write leaves the vectorisation to vmap 😃
Fwiw, the place where this behavior has bitten me the most is with LayerNorm as it happily broadcasts, so it ends up being a silent issue. Happy to open another issue if you think you’d like to do something about it, but I could also see leaving it as is.