question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Autodidax has some bugs involving reduce_sum

See original GitHub issue

This works in JAX:

import jax
from jax import grad, vmap
import jax.numpy as np

def simple(a):
  b = a
  t = a + b
  return t * b

def f(a):
  L0 = np.sum(simple(a))
  return L0

def g(a):
  dL0_da = vmap(grad(f), in_axes=0)(a)
  L1 = np.sum(dL0_da * dL0_da)
  return L1

key = jax.random.PRNGKey(0)
print(grad(g)(jax.random.normal(key,(2,4))))

However, the same code does not work in autodidax:

def simple(a):
  b = a
  t = a + b
  return t * b

def f(a):
  L0 = reduce_sum(simple(a))
  return L0

def g(a):
  dL0_da = vmap(grad(f), in_axes=0)(a)
  L1 = reduce_sum(dL0_da * dL0_da)
  return L1

print(grad(g)(np.random.rand(2,4)))

gives

[<ipython-input-27-03f392c346bd>](https://localhost:8080/#) in batched_f(*args)
     17     args_flat, in_tree = tree_flatten(args)
     18     in_axes_flat, in_tree2 = tree_flatten(in_axes)
---> 19     if in_tree != in_tree2: raise TypeError
     20     f_flat, out_tree = flatten_fun(f, in_tree)
     21     outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)

TypeError:

Actually, I was thinking that it would fail because there doesn’t appear to be any transpose rule for reduce_sum, but it doesn’t seem that Autodidax got that far. What I was actually trying to do was see how Autodidax formulated the transpose rule for sum, because if you support vmap + arbitrary dims to reduce on it gets pretty complicated.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
ezyangcommented, Mar 16, 2022

Oh this is great! I think BroadcastInDim answers the question; in PyTorch we don’t natively have this operator so we have to manually implement it with unsqueezes and expands.

0reactions
ezyangcommented, Mar 16, 2022

Thanks a lot!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Autodidax: JAX core from scratch
Type-checking a jaxpr involves checking that there are no unbound variables, that variables are only bound once, and that for each equation the...
Read more >
jax/autodidax.py at main · google/jax - GitHub
We can implement stacks of interpreters and even have them all discharge on ... Type-checking a jaxpr involves checking that there are no...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found