scan with gradient checkpointing
See original GitHub issueIt would be great to have a version of lax.scan used a recursive gradient checkpointing (e.g., “binomial checkpointing”) that allows for differentiating through long time series with logarithmic time/space costs.
In principle this could be built on top of the experimental remat
decorator: https://github.com/google/jax/pull/1749
Issue Analytics
- State:
- Created 4 years ago
- Reactions:6
- Comments:5 (5 by maintainers)
Top Results From Across the Web
feat: scan layers + gradient checkpointing (#161)
feat: scan layers + gradient checkpointing (#161). Browse files. Files changed (5) hide show. src/dalle_mini/model/configuration.py +9 -3 ...
Read more >Gradient Checkpointing Explained - Papers With Code
Gradient Checkpointing is a method used for reducing the memory footprint when training deep neural networks, at the cost of having a small...
Read more >jax.checkpoint - JAX documentation - Read the Docs
The jax.checkpoint() decorator, aliased to jax.remat() , provides a way to trade off ... But in some settings, like when used inside a...
Read more >Memory-efficient Learning for Large-scale ... - OpenReview
real-world large-scale systems, computing gradients via backpropagation restricts learning due ... our reverse recalculation methods with checkpointing to.
Read more >Memory-efficient Learning for Large-scale ... - IEEE Xplore
networks, computing gradients via backpropagation is infeasible ... recalculation, forward checkpointing, and reverse recalcula- ... Further, scan.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
So Diffrax actually implements a
bounded_while_loop
that does exactly this – early exit by nesting scan-conds, and managing memory using recursive checkpointing. In Diffrax’s case it’s used to handle the stepping of a differential equation solver.The implementation is here: https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/misc/bounded_while_loop.py
It’s worth noting that there are a lot of caveats that need to be worked around in order to make something like this feasible.
lax.cond
s used to handle whether to evaluatebody_fun
or simply perform an identity operation will get turned intolax.select
s, and the entire point of the efficiency gains are lost. This has to be worked around with some custom vmap behaviour. (In particular anunvmap
operation.)body_fun
with a custom way of handling in-place updates: https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/misc/bounded_while_loop.py#L25 Moreover XLA clearly has some bugs because you can improve performance of nested in-place updates (in nestedbounded_while_loop
s) by adding dead code that doesn’t actually evaluate to anything (!!!) https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/integrate.py#L242 https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/integrate.py#L256 https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/integrate.py#L274 Which is something I file under “voodoo magic”.O(1)
compile times but the backpropagation time scales logarithmically with the bound on the maximum number of steps.)cond
s produces trace times that are exponential in the depth due to JAX issues #8184 and #8193. In practicebounded_while_loop
works around these by monkey-patching the JAX tracing mechanisms here. Hopefully the JAX tracing mechanisms can be updated at some point to make this unnecessary.In practice most of these details are hidden from an end-user. (You just end up with a funny-looking extra argument to
body_fun
, and in many cases have to suffer subpar performance.) But I thought I’d record them here for anyone who ends up treading down the same path I did. Implementing abounded_while_loop
that exhibits reasonable performance was easily the single hardest part of implementing Diffrax, by a very large margin.A few other reference points for anyone who find this issue:
flax.linen.remat_scan
for scanning over Flax modules.bounded_while_loop
: