question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How do you nest implementations of VJP?

See original GitHub issue

If I have a function f that has a VJP. Can I somehow build a VJP on top of it for some function g? My use case is that the backwards pass of g is quite complicated and makes multiple calls to f_vjp. Here’s a simplified example:

from jax import vjp, custom_vjp

def f(x):
    return x ** 2

@custom_vjp
def g(x):
    return f(x)

def g_fwd(x):
    return vjp(f, x)

def g_bwd(f_vjp, y_bar):
    return f_vjp(y_bar)

g.defvjp(g_fwd, g_bwd)
y, g_vjp = vjp(g, 1.0)

This prints: TypeError: <class 'functools.partial'> is not a valid JAX type because vjp returns a functool.partial instance instead of a pytree. It should be possible for vjp to return a pytree-like callable instead since internally vjp(f, x) produces residuals that must be pytree-like. It’s just unfortunate that when they’re wrapped up into a callable, then that callable is not pytree-like.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:11 (11 by maintainers)

github_iconTop GitHub Comments

1reaction
shoyercommented, Jul 6, 2020

This seems like a case where perhaps you are pushing the limits of what custom_vjp is designed to support should be considering other options 😃

That said, one way to fix this immediate issue would be to replace partial inside jax.api._vjp with tree_util.Partial, which is serializable as a pytree. You could also do this wrapping in user code:

from jax import vjp, custom_vjp, tree_util

def f(x):
    return x ** 2

@custom_vjp
def g(x):
    return f(x)

def g_fwd(x):
    y, f_vjp = vjp(f, x)
    return y, tree_util.Partial(f_vjp)

def g_bwd(f_vjp, y_bar):
    return f_vjp(y_bar)

g.defvjp(g_fwd, g_bwd)
y, g_vjp = vjp(g, 1.0)
print(y, g_vjp(1.0))
# 1.0 (DeviceArray(2., dtype=float32),)

In general, it’s fine to using Partial inside the forward pass of custom_vjp functions as long as you are careful not to close over any tracers, like the value x in this case. The dtypes, constants, jaxprs and treedefs used in the closures should all be fine. (Otherwise you should get a nasty error message.)

(@mattjj please correct me if I’m mis-stating anything here)

0reactions
NeilGirdharcommented, Jul 9, 2020

Thanks for your help. I made a pull request that addresses this issue. When you have time, do you mind taking a look?

Read more comments on GitHub >

github_iconTop Results From Across the Web

NestJS: Project Setup | Blog Project V-01 - YouTube
Build Nest.js Microservices With RabbitMQ, MongoDB & Docker | Tutorial ... How to implement a Nest JS backend with an Angular frontend.
Read more >
Nest's GPS-enabled Home/Away Assist feature makes your ...
With the upcoming implementation of GPS support via the Nest app, however, it looks like your house is about to become decidedly more ......
Read more >
Google tweaks Nest Hub personalization, notification settings
The Google Home app has a new “Recognition & Personalization” menu located in Device settings. It's available on both a Nest Hub (firmware...
Read more >
Nest Hub Max Review: Google's Smart Display With Camera ...
You can't give the Nest Hub Max a voice command to turn on the security camera when you leave the house. Instead, you...
Read more >
AD for an Array Language with Nested Parallelism
an efficient GPU implementation of reverse mode AD as ... AD, respectively.5 The types of jvp and vjp are ... Note the nesting...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found