`jax.nn.sigmoid` raises `Exception: Leaked trace` errors
See original GitHub issueI got some luck to find that jax.nn.sigmoid
was causing a very vague error for the last two days. The following code may reproduce this problem:
from functools import partial
import traceback
import jax
import jax.numpy as jnp
import haiku as hk
jax.config.update('jax_platform_name', 'gpu') # same error on 'cpu or 'gpu'
jax.config.update('jax_check_tracer_leaks', True)
jax.config.update('jax_log_compiles', True)
jax.config.update('jax_enable_checks', True)
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
class Adjust(hk.Module):
def __init__(self, size):
super().__init__(name=None)
self.__f = hk.Linear(size)
def __call__(self, x):
h = self.__f(x)
return jax.nn.sigmoid(h) # <- Changing this to sigmoid(h) solves the problem!
def wrap_module(module, *module_args, **module_kwargs):
def wrap(*args, **kwargs):
model = module(*module_args, **module_kwargs)
return model(*args, **kwargs)
return wrap
prng_key = jax.random.PRNGKey(42)
x = jnp.zeros((50,))
init, adjust = hk.without_apply_rng(hk.transform(wrap_module(Adjust, size=50)))
params = init(prng_key, x)
@jax.jit
def loss(params, x):
h = adjust(params, x)
return jnp.sum((h - x)**2)
try:
grads = jax.grad(loss)(params, x)
except Exception as e:
tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
print(tb_str)
raise e
And I get the following raised exception: Exception: Leaked sublevel 1.
, only when the loss function is jitted.
The logged traceback doesn’t point to the sigmoid line at all (attached output: jax-error.txt)
The version of the libraries:
dm-haiku==0.0.5.dev0
jax==0.2.18
jaxlib==0.1.69+cuda101
However, using the following older versions have shown more helpful error message pointing to the sigmoid line:
dm-haiku==0.0.4
jax==0.2.12
jaxlib==0.1.64+cuda101
Issue Analytics
- State:
- Created 2 years ago
- Comments:14 (10 by maintainers)
Top Results From Across the Web
JAX Errors - JAX documentation - Read the Docs
JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an UnexpectedTracerError ....
Read more >trax-ml/community - Gitter
Hi there! I was wondering whether it is possible to utilize Tensorflow Dataset API with Trax Trainer class object? The tutorial you provide...
Read more >Neurotic Networking - The Cloistered Monkey
Note: the Celeba class will sometimes raise an exception: Traceback (most recent call last): File "/home/neurotic/download_celeba.py", ...
Read more >sitemap-questions-175.xml - Stack Overflow
... /395432/script-generate-migration-throws-error-about-i18n-in-rails-2-2-2 ... /1263128/most-common-checked-and-unchecked-java-exceptions 2018-09-20 ...
Read more >Rockpool - Zenodo
For example, Jax is required to use the Jax-backed modules (e.g. RateJax); ... from rockpool.nn.modules import Linear, Rate, Module.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
A bisection suggests that 21907346370320209a26fbf7a4f619bb62978352 fixed this bug (thanks, @LenaMartens!), and fd7b286ec9e8c89488494506a3591e6698f8fa05 introduced it or at least exposed it (curse you, @mattjj!), or at least was the first commit where the above jax-only repro started failing. (According to my process, it seems that jax==v0.2.12 also had an error, though maybe a different one, and I had to go back to jax==v0.2.10 to find a good pypi release.)
Solved now! Thanks a lot
Even after first-time compilation, when branches consists of complicated functions like
odeint
, thelax.cond
executes all branches and then select the results from the two branches based on the condition value. I will post a new request after closing this issue.