question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Trace leak with `odeint` and `vmap`

See original GitHub issue

I am trying to rewrite PointFlow in jax, which involves backpropagation through and inside odeint. I am running into an issue when I get my tracers leaked when trying to vmap a function using odeint internally. I am not 100% vmap over a while_loop is supported, but I am not getting any explicit errors that it is wrong. A standalone repro is below:

import jax
import jax.numpy as jnp
import haiku as hk
from jax.experimental.ode import odeint

def ODEnet():
    def ODEnet_inner(y, context):
        return hk.Linear(y.shape[-1])(y)

    return ODEnet_inner

class ODEfunc(hk.Module):
    def __init__(self, diffeq):
        super(ODEfunc, self).__init__()
        self.diffeq = diffeq

    def __call__(self, states, t):
        assert t.shape == ()
        y, _logpx = states

        dy, div = self.dy_and_div(t[None], y)
        return dy, -div
    
    def dy_and_div(self, ctx, y):
        e = jax.random.normal(hk.next_rng_key(), y.shape)
        dy, dy_vjp_fn = jax.vjp(lambda y: self.diffeq(y, ctx), y)
        vjp, = dy_vjp_fn(e)
        div = jnp.dot(vjp, e)

        return dy, div

class CNF(hk.Module):
    def __init__(self, odefunc):
        super(CNF, self).__init__()

        self.odefunc = odefunc
        self.T = jnp.array(1.)

    def __call__(self, x, logpx):
        integration_times = jnp.array([0., self.T])

        v_call = jax.vmap(
            lambda x, lpx, t: self._call_inner(x, lpx, t),
            in_axes=(0, 0, None),
            out_axes=0,
        )

        z_t, logpz_t = v_call(x, logpx, integration_times)

        if logpx is None:
            return z_t
        else:
            return z_t, logpz_t

    def _call_inner(self, x, logpx, integration_times):
        assert logpx.shape == (*x.shape[:-1], 1)

        states = (x, logpx)

        state_t = odeint(
            self.odefunc, 
            states,
            integration_times,
        )

        z_t, logpz_t = state_t[:2]
        return z_t, logpz_t

import unittest

class Tests(unittest.TestCase):
    def test_cnf_unconditional(self):
        import numpy as np

        N, zdim, context_dim = 128, 32, 48

        def create(*args, **kwargs):
            return CNF(
                ODEfunc(ODEnet()),
            )(*args, **kwargs)

        cnf_fn = hk.transform_with_state(create)
        rng = jax.random.PRNGKey(42)

        x = np.random.randn(N, 3)
        logpx = np.zeros((N, 1))

        data = dict(x=x, logpx=logpx)
        params, state = cnf_fn.init(rng, **data)
        (z_t, logpz_t), state = cnf_fn.apply(params, state, rng, **data)

        self.assertEqual(z_t.shape, x.shape)
        self.assertEqual(logpz_t.shape, x.shape[:-1])

if __name__ == '__main__':
    unittest.main()

which fails with

(base) jatentaki@drozd:~/PhD/data2/home/tyszkiew/PhD/flow/jflow/jflow/repro$ JAX_CHECK_TRACER_LEAKS=1 python flow.py 
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
E
======================================================================
ERROR: test_cnf_unconditional (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "flow.py", line 97, in <module>
    unittest.main()
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/runner.py", line 176, in run
    test(result)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 122, in run
    test(result)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 122, in run
    test(result)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/case.py", line 736, in __call__
    return self.run(*args, **kwds)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/case.py", line 676, in run
    self._callTestMethod(testMethod)
  File "/home/jatentaki/miniconda3/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
    method()
  File "flow.py", line 90, in test_cnf_unconditional
    params, state = cnf_fn.init(rng, **data)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/transform.py", line 297, in init_fn
    f(*args, **kwargs)
  File "flow.py", line 78, in create
    return CNF(
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "flow.py", line 48, in __call__
    z_t, logpz_t = v_call(x, logpx, integration_times)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/api.py", line 1294, in batched_fun
    out_flat = batching.batch(
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "flow.py", line 43, in <lambda>
    lambda x, lpx, t: self._call_inner(x, lpx, t),
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "flow.py", line 60, in _call_inner
    state_t = odeint(
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/experimental/ode.py", line 172, in odeint
    converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 825, in closure_convert
    return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 832, in _closure_convert_for_avals
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1253, in trace_to_jaxpr_dynamic
    del main, fun
  File "/home/jatentaki/miniconda3/lib/python3.8/contextlib.py", line 120, in __exit__
    next(self.gen)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/core.py", line 765, in new_main
    raise Exception(f'Leaked trace {t()}')
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(2,DynamicJaxprTrace)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "flow.py", line 90, in test_cnf_unconditional
    params, state = cnf_fn.init(rng, **data)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/transform.py", line 297, in init_fn
    f(*args, **kwargs)
  File "flow.py", line 78, in create
    return CNF(
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "flow.py", line 48, in __call__
    z_t, logpz_t = v_call(x, logpx, integration_times)
  File "flow.py", line 43, in <lambda>
    lambda x, lpx, t: self._call_inner(x, lpx, t),
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "flow.py", line 60, in _call_inner
    state_t = odeint(
  File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/experimental/ode.py", line 172, in odeint
    converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
  File "/home/jatentaki/miniconda3/lib/python3.8/contextlib.py", line 120, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(2,DynamicJaxprTrace)

----------------------------------------------------------------------
Ran 1 test in 0.167s

FAILED (errors=1)

If that helps, before reducing the example I was getting Exception: Leaked trace MainTrace(4,JVPTrace) instead.

Library versions:

jatentaki@drozd:~$ python -c "import jax, jaxlib, haiku; print(jax.__version__, jaxlib.__version__, haiku.__version__)"
0.2.16 0.1.68 0.0.4

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
LenaMartenscommented, Jun 29, 2021

Indeed, this is a Haiku interaction issue: you’re calling odeint on a Haiku module before hk.transforming that module, meaning this Haiku module is still side-effecting. One of the side-effects is creating the module parameters (for the Linear layer in this case) which are the tracers that cause this UnexpectedTracerError.

Some more context here: https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html (and this is also the reason why we need things like hk.vmap and hk.scan). We don’t have hk.odeint, but there are ways to get around this error.

One option is to not call odeint when initializing your module. There is hk.running_init() which tells you if you’re running your initialization. I replaced the odeint call in your module with the below, and it seems to work:

if hk.running_init():
  # not running odeint when initializing parameters
  state_t = self.odefunc(states, integration_times)
else:
  state_t = odeint(
      self.odefunc, 
      states,
      integration_times,
  )

I don’t think we document all this super clearly, so I opened this issue to document this better. Happy to follow-up there if any of this is unclear (or if the above doesn’t work).

0reactions
mattjjcommented, Jun 30, 2021

Thanks for the follow-ups, @LenaMartens and @jatentaki !

By the way, the leak checker can be overly sensitive and give false positives.

I’ll close this issue for now if we’re planning to follow up on the Haiku issue tracker, but please reopen if the situation changes!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Add more documentation on calling JAX transforms inside a ...
In my code I'm using vjp (inside traced code), is it somehow a different case than the former and doesn't require special treatment,...
Read more >
The repressilator enables self-sustaining oscillations - Caltech
It is useful to make a linear stability diagram, which is a map of parameter space highlighting stable and unstable regions. We know...
Read more >
Relating ion channel expression, bifurcation structure ... - NCBI
The model includes voltage-gated Na + and K + currents and a leak current ... in amplitude and broaden as stimulation continues (black...
Read more >
Increasing the accuracy and temporal resolution of two-filter ...
Ahnert, K. and Mulansky, M.: Odeint – solving ordinary differential equations in C++, in: AIP Conference Proceedings, AIP Publishing, vol.
Read more >
Polynomial, piecewise-Linear, Step (PLS): A Simple, Scalable ...
(A) Voltage traces for the original model (black line) overlapped with ... that odeint performs due to the adaptive time-step algorithm.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found