question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

scan with gradient checkpointing

See original GitHub issue

It would be great to have a version of lax.scan used a recursive gradient checkpointing (e.g., “binomial checkpointing”) that allows for differentiating through long time series with logarithmic time/space costs.

In principle this could be built on top of the experimental remat decorator: https://github.com/google/jax/pull/1749

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:6
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

8reactions
patrick-kidgercommented, Feb 15, 2022

So Diffrax actually implements a bounded_while_loop that does exactly this – early exit by nesting scan-conds, and managing memory using recursive checkpointing. In Diffrax’s case it’s used to handle the stepping of a differential equation solver.

The implementation is here: https://github.com/patrick-kidger/diffrax/blob/2b4e4d863c15abc7143919bac7825090bbfe50be/diffrax/misc/bounded_while_loop.py

It’s worth noting that there are a lot of caveats that need to be worked around in order to make something like this feasible.


In practice most of these details are hidden from an end-user. (You just end up with a funny-looking extra argument to body_fun, and in many cases have to suffer subpar performance.) But I thought I’d record them here for anyone who ends up treading down the same path I did. Implementing a bounded_while_loop that exhibits reasonable performance was easily the single hardest part of implementing Diffrax, by a very large margin.

1reaction
shoyercommented, Jul 19, 2022

A few other reference points for anyone who find this issue:

  1. Flax has flax.linen.remat_scan for scanning over Flax modules.
  2. I wrote a simpler version of scanning with nested gradient checkpointing, based on some the same design principles as Diffrax’s bounded_while_loop:
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union

import jax
import jax.numpy as jnp


Carry = TypeVar('Carry')
Input = TypeVar('Input')
Output = TypeVar('Output')
Func = TypeVar('Func', bound=Callable)


def nested_checkpoint_scan(
    f: Callable[[Carry, Input], Tuple[Carry, Output]],
    init: Carry,
    xs: Input,
    length: Optional[int] = None,
    *,
    nested_lengths: Sequence[int],
    scan_fn: typing.ScanFn = jax.lax.scan,
    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,
) -> Tuple[Carry, Output]:
  """A version of lax.scan that supports recursive gradient checkpointing.

  The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for
  the required `nested_lengths` argument.

  The key feature of `nested_checkpoint_scan` is that gradient calculations
  require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested
  scans, which it achieves by re-evaluating the forward pass
  `len(nested_lengths) - 1` times.

  `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a
  single element.

  Args:
    f: function to scan over.
    init: initial value.
    xs: scanned over values.
    length: leading length of all dimensions
    nested_lengths: required list of lengths to scan over for each level of
      checkpointing. The product of nested_lengths must match length (if
      provided) and the size of the leading axis for all arrays in ``xs``.
    scan_fn: function matching the API of lax.scan
    checkpoint_fn: function matching the API of jax.checkpoint.

  Returns:
    Carry and output values.
  """
  if length is not None and length != math.prod(nested_lengths):
    raise ValueError(f'inconsistent {length=} and {nested_lengths=}')

  def nested_reshape(x):
    x = jnp.asarray(x)
    new_shape = tuple(nested_lengths) + x.shape[1:]
    return x.reshape(new_shape)

  sub_xs = jax.tree_map(nested_reshape, xs)
  return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn,
                            checkpoint_fn)


def _inner_nested_scan(f, init, xs, lengths, scan_fn, checkpoint_fn):
  """Recursively applied scan function."""
  if len(lengths) == 1:
    return scan_fn(f, init, xs, lengths[0])

  @checkpoint_fn
  def sub_scans(carry, xs):
    return _inner_nested_scan(f, carry, xs, lengths[1:], scan_fn, checkpoint_fn)

  carry, out = scan_fn(sub_scans, init, xs, lengths[0])
  stacked_out = jax.tree_map(jnp.concatenate, out)
  return carry, stacked_out
Read more comments on GitHub >

github_iconTop Results From Across the Web

feat: scan layers + gradient checkpointing (#161)
feat: scan layers + gradient checkpointing (#161). Browse files. Files changed (5) hide show. src/dalle_mini/model/configuration.py +9 -3 ...
Read more >
Gradient Checkpointing Explained - Papers With Code
Gradient Checkpointing is a method used for reducing the memory footprint when training deep neural networks, at the cost of having a small...
Read more >
jax.checkpoint - JAX documentation - Read the Docs
The jax.checkpoint() decorator, aliased to jax.remat() , provides a way to trade off ... But in some settings, like when used inside a...
Read more >
Memory-efficient Learning for Large-scale ... - OpenReview
real-world large-scale systems, computing gradients via backpropagation restricts learning due ... our reverse recalculation methods with checkpointing to.
Read more >
Memory-efficient Learning for Large-scale ... - IEEE Xplore
networks, computing gradients via backpropagation is infeasible ... recalculation, forward checkpointing, and reverse recalcula- ... Further, scan.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found