How do you nest implementations of VJP?
See original GitHub issueIf I have a function f
that has a VJP. Can I somehow build a VJP on top of it for some function g
? My use case is that the backwards pass of g
is quite complicated and makes multiple calls to f_vjp
. Here’s a simplified example:
from jax import vjp, custom_vjp
def f(x):
return x ** 2
@custom_vjp
def g(x):
return f(x)
def g_fwd(x):
return vjp(f, x)
def g_bwd(f_vjp, y_bar):
return f_vjp(y_bar)
g.defvjp(g_fwd, g_bwd)
y, g_vjp = vjp(g, 1.0)
This prints: TypeError: <class 'functools.partial'> is not a valid JAX type
because vjp
returns a functool.partial
instance instead of a pytree. It should be possible for vjp
to return a pytree-like callable instead since internally vjp(f, x)
produces residuals that must be pytree-like. It’s just unfortunate that when they’re wrapped up into a callable, then that callable is not pytree-like.
Issue Analytics
- State:
- Created 3 years ago
- Comments:11 (11 by maintainers)
Top Results From Across the Web
NestJS: Project Setup | Blog Project V-01 - YouTube
Build Nest.js Microservices With RabbitMQ, MongoDB & Docker | Tutorial ... How to implement a Nest JS backend with an Angular frontend.
Read more >Nest's GPS-enabled Home/Away Assist feature makes your ...
With the upcoming implementation of GPS support via the Nest app, however, it looks like your house is about to become decidedly more ......
Read more >Google tweaks Nest Hub personalization, notification settings
The Google Home app has a new “Recognition & Personalization” menu located in Device settings. It's available on both a Nest Hub (firmware...
Read more >Nest Hub Max Review: Google's Smart Display With Camera ...
You can't give the Nest Hub Max a voice command to turn on the security camera when you leave the house. Instead, you...
Read more >AD for an Array Language with Nested Parallelism
an efficient GPU implementation of reverse mode AD as ... AD, respectively.5 The types of jvp and vjp are ... Note the nesting...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This seems like a case where perhaps you are pushing the limits of what
custom_vjp
is designed to support should be considering other options 😃That said, one way to fix this immediate issue would be to replace
partial
insidejax.api._vjp
withtree_util.Partial
, which is serializable as a pytree. You could also do this wrapping in user code:In general, it’s fine to using
Partial
inside the forward pass ofcustom_vjp
functions as long as you are careful not to close over any tracers, like the valuex
in this case. The dtypes, constants, jaxprs and treedefs used in the closures should all be fine. (Otherwise you should get a nasty error message.)(@mattjj please correct me if I’m mis-stating anything here)
Thanks for your help. I made a pull request that addresses this issue. When you have time, do you mind taking a look?