question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

support consts in custom batching rules

See original GitHub issue

To save memory in a setup where jax.checkpoint was insufficient, I implemented a custom transposition via jax.custom_derivatives.linear_call. This worked just fine and yielded the desired savings. However, when vmaping, I hit a road blocker. Unfortunately, jax.custom_derivative.linear_call does not implement a batching rule and thus naively applying vmap does not work even if the arguments to linear_call do support batching. Furthermore, in my case it is not possible to implement the batching manually via jax.custom_batching.custom_vmap because custom_vmap does not work with constants in jaxpr.

I think the best approach would be to allow both linear_call to be batched and custom_vmap to work with constants. Either of the two would unblock me 😃

Below is a minimal reproducible example:

from functools import partial

import jax
from jax.custom_batching import sequential_vmap
from jax.custom_derivatives import linear_call
import numpy as np


def _mul(residual_args, a, *, c):
    b, = residual_args
    c = np.array(c)  # needs to be known at trace-time
    return a * b * c


def _mul_T(residual_args, out, *, c):
    b, = residual_args
    c = np.array(c)  # needs to be known at trace-time
    return out * b * c


def mul(a, b, c):
    print(a.shape)
    return linear_call(partial(_mul, c=c), partial(_mul_T, c=c), (b, ), a)


a, b, c = np.arange(12, dtype=float), 10, np.array([2., 4.]).reshape(2, 1)

m = partial(mul, b=b, c=c)
jax.vmap(sequential_vmap(m), in_axes=(0, ))(a)

Issue Analytics

  • State:closed
  • Created 10 months ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
Edenhofercommented, Nov 23, 2022

Awesome! Thank you very much for the resolving this issue so quickly!

0reactions
froystigcommented, Nov 21, 2022

I’ll repurpose this issue. We’ll certainly need to make sure consts are supported before we consider custom batching complete. It doesn’t hurt to have an open issue reminding us to do this, with your particular example registered. Thanks for filing!

Read more comments on GitHub >

github_iconTop Results From Across the Web

custom batching (vmap) · Issue #9073 · google/jax - GitHub
Support custom batching, i.e. the ability to register a custom "vmap rule" for any given function. Example usage would look something like:.
Read more >
Static batching - Unity - Manual
Static batching is a draw call batching method that combines meshes that don't move to reduce draw calls. It transforms the combined meshes...
Read more >
Pass parameters from triggers in EventBridge to AWS Batch jobs
How do I pass parameters from a scheduled trigger in EventBridge to an AWS Batch job? · 1. Open the EventBridge console. ·...
Read more >
The Power of Batching: Print and Mail Processes Can Be ...
This benefits the mailer with reduced costs by using labor more efficiently. Household mailings together into one envelope. Matching based on business rules...
Read more >
Custom Shaders - Catlike Coding
Write an HLSL shader. Define constant buffers. Use the Render Pipeline Core Library. Support dynamic batching and GPU instancing.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found