support consts in custom batching rules
See original GitHub issueTo save memory in a setup where jax.checkpoint
was insufficient, I implemented a custom transposition via jax.custom_derivatives.linear_call
. This worked just fine and yielded the desired savings. However, when vmap
ing, I hit a road blocker. Unfortunately, jax.custom_derivative.linear_call
does not implement a batching rule and thus naively applying vmap
does not work even if the arguments to linear_call
do support batching. Furthermore, in my case it is not possible to implement the batching manually via jax.custom_batching.custom_vmap
because custom_vmap
does not work with constants in jaxpr.
I think the best approach would be to allow both linear_call
to be batched and custom_vmap
to work with constants. Either of the two would unblock me 😃
Below is a minimal reproducible example:
from functools import partial
import jax
from jax.custom_batching import sequential_vmap
from jax.custom_derivatives import linear_call
import numpy as np
def _mul(residual_args, a, *, c):
b, = residual_args
c = np.array(c) # needs to be known at trace-time
return a * b * c
def _mul_T(residual_args, out, *, c):
b, = residual_args
c = np.array(c) # needs to be known at trace-time
return out * b * c
def mul(a, b, c):
print(a.shape)
return linear_call(partial(_mul, c=c), partial(_mul_T, c=c), (b, ), a)
a, b, c = np.arange(12, dtype=float), 10, np.array([2., 4.]).reshape(2, 1)
m = partial(mul, b=b, c=c)
jax.vmap(sequential_vmap(m), in_axes=(0, ))(a)
Issue Analytics
- State:
- Created 10 months ago
- Comments:5 (5 by maintainers)
Awesome! Thank you very much for the resolving this issue so quickly!
I’ll repurpose this issue. We’ll certainly need to make sure consts are supported before we consider custom batching complete. It doesn’t hurt to have an open issue reminding us to do this, with your particular example registered. Thanks for filing!