question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`jax.nn.sigmoid` raises `Exception: Leaked trace` errors

See original GitHub issue

I got some luck to find that jax.nn.sigmoid was causing a very vague error for the last two days. The following code may reproduce this problem:

from functools import partial
import traceback

import jax
import jax.numpy as jnp
import haiku as hk

jax.config.update('jax_platform_name', 'gpu') # same error on 'cpu or 'gpu'
jax.config.update('jax_check_tracer_leaks', True)
jax.config.update('jax_log_compiles', True)
jax.config.update('jax_enable_checks', True)

def sigmoid(x):
    return 1. / (1. + jnp.exp(-x))

class Adjust(hk.Module):
    def __init__(self, size):
        super().__init__(name=None)
        self.__f = hk.Linear(size)
        
    def __call__(self, x):
        h = self.__f(x)
        return jax.nn.sigmoid(h) # <- Changing this to sigmoid(h) solves the problem!
    
def wrap_module(module, *module_args, **module_kwargs):
    def wrap(*args, **kwargs):
        model = module(*module_args, **module_kwargs)
        return model(*args, **kwargs)
    return wrap

prng_key = jax.random.PRNGKey(42)
x = jnp.zeros((50,))
init, adjust = hk.without_apply_rng(hk.transform(wrap_module(Adjust, size=50)))
params = init(prng_key, x)

@jax.jit
def loss(params, x):
    h = adjust(params, x)
    return jnp.sum((h - x)**2)

try:
    grads = jax.grad(loss)(params, x)
except Exception as e:
    tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
    print(tb_str)
    raise e

And I get the following raised exception: Exception: Leaked sublevel 1., only when the loss function is jitted. The logged traceback doesn’t point to the sigmoid line at all (attached output: jax-error.txt)

The version of the libraries:

dm-haiku==0.0.5.dev0
jax==0.2.18
jaxlib==0.1.69+cuda101

However, using the following older versions have shown more helpful error message pointing to the sigmoid line:

dm-haiku==0.0.4
jax==0.2.12
jaxlib==0.1.64+cuda101

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:14 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Aug 13, 2021

A bisection suggests that 21907346370320209a26fbf7a4f619bb62978352 fixed this bug (thanks, @LenaMartens!), and fd7b286ec9e8c89488494506a3591e6698f8fa05 introduced it or at least exposed it (curse you, @mattjj!), or at least was the first commit where the above jax-only repro started failing. (According to my process, it seems that jax==v0.2.12 also had an error, though maybe a different one, and I had to go back to jax==v0.2.10 to find a good pypi release.)

1reaction
A-Alaacommented, Aug 13, 2021

Just pushed jax==0.2.19 to pypi! Can you confirm the bug no longer reproduces against that version?

Solved now! Thanks a lot

They stage out all branches (and in so doing execute and trace the Python callables representing each branch), as they must because their purpose is to stage out control flow which can’t be executed in Python (and hence either branch could be taken later, so both branches must be traced). Is that what you mean? (It might help to link the other issues, if you have them handy and if I’m missing the point you’re making.)

Even after first-time compilation, when branches consists of complicated functions like odeint, the lax.cond executes all branches and then select the results from the two branches based on the condition value. I will post a new request after closing this issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Errors - JAX documentation - Read the Docs
JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an UnexpectedTracerError ....
Read more >
trax-ml/community - Gitter
Hi there! I was wondering whether it is possible to utilize Tensorflow Dataset API with Trax Trainer class object? The tutorial you provide...
Read more >
Neurotic Networking - The Cloistered Monkey
Note: the Celeba class will sometimes raise an exception: Traceback (most recent call last): File "/home/neurotic/download_celeba.py", ...
Read more >
sitemap-questions-175.xml - Stack Overflow
... /395432/script-generate-migration-throws-error-about-i18n-in-rails-2-2-2 ... /1263128/most-common-checked-and-unchecked-java-exceptions 2018-09-20 ...
Read more >
Rockpool - Zenodo
For example, Jax is required to use the Jax-backed modules (e.g. RateJax); ... from rockpool.nn.modules import Linear, Rate, Module.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found