Autodidax has some bugs involving reduce_sum
See original GitHub issueThis works in JAX:
import jax
from jax import grad, vmap
import jax.numpy as np
def simple(a):
b = a
t = a + b
return t * b
def f(a):
L0 = np.sum(simple(a))
return L0
def g(a):
dL0_da = vmap(grad(f), in_axes=0)(a)
L1 = np.sum(dL0_da * dL0_da)
return L1
key = jax.random.PRNGKey(0)
print(grad(g)(jax.random.normal(key,(2,4))))
However, the same code does not work in autodidax:
def simple(a):
b = a
t = a + b
return t * b
def f(a):
L0 = reduce_sum(simple(a))
return L0
def g(a):
dL0_da = vmap(grad(f), in_axes=0)(a)
L1 = reduce_sum(dL0_da * dL0_da)
return L1
print(grad(g)(np.random.rand(2,4)))
gives
[<ipython-input-27-03f392c346bd>](https://localhost:8080/#) in batched_f(*args)
17 args_flat, in_tree = tree_flatten(args)
18 in_axes_flat, in_tree2 = tree_flatten(in_axes)
---> 19 if in_tree != in_tree2: raise TypeError
20 f_flat, out_tree = flatten_fun(f, in_tree)
21 outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
TypeError:
Actually, I was thinking that it would fail because there doesn’t appear to be any transpose rule for reduce_sum, but it doesn’t seem that Autodidax got that far. What I was actually trying to do was see how Autodidax formulated the transpose rule for sum, because if you support vmap + arbitrary dims to reduce on it gets pretty complicated.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (5 by maintainers)
Top Results From Across the Web
Autodidax: JAX core from scratch
Type-checking a jaxpr involves checking that there are no unbound variables, that variables are only bound once, and that for each equation the...
Read more >jax/autodidax.py at main · google/jax - GitHub
We can implement stacks of interpreters and even have them all discharge on ... Type-checking a jaxpr involves checking that there are no...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Oh this is great! I think BroadcastInDim answers the question; in PyTorch we don’t natively have this operator so we have to manually implement it with unsqueezes and expands.
Thanks a lot!