Trace leak with `odeint` and `vmap`
See original GitHub issueI am trying to rewrite PointFlow in jax, which involves backpropagation through and inside odeint
. I am running into an issue when I get my tracers leaked when trying to vmap
a function using odeint
internally. I am not 100% vmap
over a while_loop
is supported, but I am not getting any explicit errors that it is wrong. A standalone repro is below:
import jax
import jax.numpy as jnp
import haiku as hk
from jax.experimental.ode import odeint
def ODEnet():
def ODEnet_inner(y, context):
return hk.Linear(y.shape[-1])(y)
return ODEnet_inner
class ODEfunc(hk.Module):
def __init__(self, diffeq):
super(ODEfunc, self).__init__()
self.diffeq = diffeq
def __call__(self, states, t):
assert t.shape == ()
y, _logpx = states
dy, div = self.dy_and_div(t[None], y)
return dy, -div
def dy_and_div(self, ctx, y):
e = jax.random.normal(hk.next_rng_key(), y.shape)
dy, dy_vjp_fn = jax.vjp(lambda y: self.diffeq(y, ctx), y)
vjp, = dy_vjp_fn(e)
div = jnp.dot(vjp, e)
return dy, div
class CNF(hk.Module):
def __init__(self, odefunc):
super(CNF, self).__init__()
self.odefunc = odefunc
self.T = jnp.array(1.)
def __call__(self, x, logpx):
integration_times = jnp.array([0., self.T])
v_call = jax.vmap(
lambda x, lpx, t: self._call_inner(x, lpx, t),
in_axes=(0, 0, None),
out_axes=0,
)
z_t, logpz_t = v_call(x, logpx, integration_times)
if logpx is None:
return z_t
else:
return z_t, logpz_t
def _call_inner(self, x, logpx, integration_times):
assert logpx.shape == (*x.shape[:-1], 1)
states = (x, logpx)
state_t = odeint(
self.odefunc,
states,
integration_times,
)
z_t, logpz_t = state_t[:2]
return z_t, logpz_t
import unittest
class Tests(unittest.TestCase):
def test_cnf_unconditional(self):
import numpy as np
N, zdim, context_dim = 128, 32, 48
def create(*args, **kwargs):
return CNF(
ODEfunc(ODEnet()),
)(*args, **kwargs)
cnf_fn = hk.transform_with_state(create)
rng = jax.random.PRNGKey(42)
x = np.random.randn(N, 3)
logpx = np.zeros((N, 1))
data = dict(x=x, logpx=logpx)
params, state = cnf_fn.init(rng, **data)
(z_t, logpz_t), state = cnf_fn.apply(params, state, rng, **data)
self.assertEqual(z_t.shape, x.shape)
self.assertEqual(logpz_t.shape, x.shape[:-1])
if __name__ == '__main__':
unittest.main()
which fails with
(base) jatentaki@drozd:~/PhD/data2/home/tyszkiew/PhD/flow/jflow/jflow/repro$ JAX_CHECK_TRACER_LEAKS=1 python flow.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
E
======================================================================
ERROR: test_cnf_unconditional (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "flow.py", line 97, in <module>
unittest.main()
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/main.py", line 101, in __init__
self.runTests()
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/runner.py", line 176, in run
test(result)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 122, in run
test(result)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/suite.py", line 122, in run
test(result)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/case.py", line 736, in __call__
return self.run(*args, **kwds)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/case.py", line 676, in run
self._callTestMethod(testMethod)
File "/home/jatentaki/miniconda3/lib/python3.8/unittest/case.py", line 633, in _callTestMethod
method()
File "flow.py", line 90, in test_cnf_unconditional
params, state = cnf_fn.init(rng, **data)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/transform.py", line 297, in init_fn
f(*args, **kwargs)
File "flow.py", line 78, in create
return CNF(
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "flow.py", line 48, in __call__
z_t, logpz_t = v_call(x, logpx, integration_times)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/api.py", line 1294, in batched_fun
out_flat = batching.batch(
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "flow.py", line 43, in <lambda>
lambda x, lpx, t: self._call_inner(x, lpx, t),
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "flow.py", line 60, in _call_inner
state_t = odeint(
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/experimental/ode.py", line 172, in odeint
converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 825, in closure_convert
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 832, in _closure_convert_for_avals
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1253, in trace_to_jaxpr_dynamic
del main, fun
File "/home/jatentaki/miniconda3/lib/python3.8/contextlib.py", line 120, in __exit__
next(self.gen)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/core.py", line 765, in new_main
raise Exception(f'Leaked trace {t()}')
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(2,DynamicJaxprTrace)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "flow.py", line 90, in test_cnf_unconditional
params, state = cnf_fn.init(rng, **data)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/transform.py", line 297, in init_fn
f(*args, **kwargs)
File "flow.py", line 78, in create
return CNF(
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "flow.py", line 48, in __call__
z_t, logpz_t = v_call(x, logpx, integration_times)
File "flow.py", line 43, in <lambda>
lambda x, lpx, t: self._call_inner(x, lpx, t),
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "flow.py", line 60, in _call_inner
state_t = odeint(
File "/home/jatentaki/miniconda3/lib/python3.8/site-packages/jax/experimental/ode.py", line 172, in odeint
converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
File "/home/jatentaki/miniconda3/lib/python3.8/contextlib.py", line 120, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(2,DynamicJaxprTrace)
----------------------------------------------------------------------
Ran 1 test in 0.167s
FAILED (errors=1)
If that helps, before reducing the example I was getting Exception: Leaked trace MainTrace(4,JVPTrace)
instead.
Library versions:
jatentaki@drozd:~$ python -c "import jax, jaxlib, haiku; print(jax.__version__, jaxlib.__version__, haiku.__version__)"
0.2.16 0.1.68 0.0.4
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (3 by maintainers)
Top Results From Across the Web
Add more documentation on calling JAX transforms inside a ...
In my code I'm using vjp (inside traced code), is it somehow a different case than the former and doesn't require special treatment,...
Read more >The repressilator enables self-sustaining oscillations - Caltech
It is useful to make a linear stability diagram, which is a map of parameter space highlighting stable and unstable regions. We know...
Read more >Relating ion channel expression, bifurcation structure ... - NCBI
The model includes voltage-gated Na + and K + currents and a leak current ... in amplitude and broaden as stimulation continues (black...
Read more >Increasing the accuracy and temporal resolution of two-filter ...
Ahnert, K. and Mulansky, M.: Odeint – solving ordinary differential equations in C++, in: AIP Conference Proceedings, AIP Publishing, vol.
Read more >Polynomial, piecewise-Linear, Step (PLS): A Simple, Scalable ...
(A) Voltage traces for the original model (black line) overlapped with ... that odeint performs due to the adaptive time-step algorithm.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Indeed, this is a Haiku interaction issue: you’re calling
odeint
on a Haiku module before hk.transforming that module, meaning this Haiku module is still side-effecting. One of the side-effects is creating the module parameters (for theLinear
layer in this case) which are the tracers that cause this UnexpectedTracerError.Some more context here: https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html (and this is also the reason why we need things like
hk.vmap
andhk.scan
). We don’t havehk.odeint
, but there are ways to get around this error.One option is to not call
odeint
when initializing your module. There ishk.running_init()
which tells you if you’re running your initialization. I replaced the odeint call in your module with the below, and it seems to work:I don’t think we document all this super clearly, so I opened this issue to document this better. Happy to follow-up there if any of this is unclear (or if the above doesn’t work).
Thanks for the follow-ups, @LenaMartens and @jatentaki !
By the way, the leak checker can be overly sensitive and give false positives.
I’ll close this issue for now if we’re planning to follow up on the Haiku issue tracker, but please reopen if the situation changes!