question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Error in custom JVP involving tokens

See original GitHub issue

In mpi4jax, we encountered some errors when working with gradients of primitives that use tokens.

Issue 1

The following throws an error:

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

comm = MPI.COMM_WORLD

def foo(x, token):
    token = jax.lax.create_token()
    x1, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
    x2, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
    return x1 + x2

x = 0.0
gfoo = jax.grad(foo)
gx = gfoo(x)

Traceback:

Traceback (most recent call last):
  File "bug.py", line 18, in <module>
    gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: AssertionError

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "bug.py", line 18, in <module>
    gx = gfoo(x)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 760, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 823, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 1896, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 103, in linearize
    assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError

This is an abridged version of our custom JVP:

def mpi_allreduce_value_and_jvp(in_args, tan_args, op, comm, transpose):
    x, token = in_args
    x_tan, _ = tan_args

    # sum x across all ranks
    val, token = mpi_allreduce_p.bind(x, token, op=op, comm=comm, transpose=transpose)

    # sum of x_tan across all ranks
    jvp, token = mpi_allreduce_p.bind(x_tan, token, op=op, comm=comm, transpose=transpose)

    return (val, token), (jvp, ad.Zero.from_value(token))

Issue 2

The error above doesn’t occur when I generate a new token before every call:

def foo(x):
    token = jax.lax.create_token()
    x1, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
    token = jax.lax.create_token()
    x2, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
    return x1 + x2

(of course this gives bogus results now, just for illustration)

But now, I get a different error:

Traceback (most recent call last):
  File "bug.py", line 19, in <module>
    gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: KeyError: <class 'jax.interpreters.xla.Token'>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "bug.py", line 19, in <module>
    gx = gfoo(x)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 760, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 829, in value_and_grad_f
    g = vjp_py(np.ones((), dtype=dtype))
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 1816, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 231, in backward_pass
    map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/util.py", line 41, in safe_map
    return list(map(f, *args))
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 176, in write_cotangent
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 482, in add_tangents
    return add_jaxvals(x, y)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/ad_util.py", line 37, in add_jaxvals
    return add_jaxvals_p.bind(x, y)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/core.py", line 284, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/core.py", line 622, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/ad_util.py", line 44, in add_impl
    return jaxval_adders[type(xs)](xs, ys)
KeyError: <class 'jax.interpreters.xla.Token'>

It looks like the tangent accumulator is trying to add the input tokens.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
froystigcommented, Apr 20, 2021

Another approach to consider: if you are only expecting reverse-mode usage, and if there is an implementable ordering under reverse-mode (as your most recent comment suggests), then you could consider disallowing forward-mode altogether and defining only a VJP with jax.custom_vjp.

0reactions
dionhaefnercommented, Apr 15, 2021

Aha, I see! Thanks for clearing this up. I think we’re on the same page now.

I had understood from your comment that this reordering would be problematic, such that you would like to enforce the original ordering strictly, even if that violates a technical requirement of JVP routines. Is that right?

Correct.

But in reverse-mode AD, the analogue of the original strict ordering is not possible even in principle. That’s what I attempted to highlight in my previous question.

Actually, the situation is a bit more complicated in my case - executing everything in mirrored order is actually fine, so I wouldn’t be worried about forward vs. reverse AD here.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax._src.custom_derivatives - JAX documentation
Args: jvp: a Python callable representing the custom JVP rule. ... call_jaxpr.jaxpr, ctx.tokens_in, consts, *args_) ctx.set_tokens_out(tokens) return out ...
Read more >
Custom token has invalid signature · Issue #556 - GitHub
Go to your Firebase account "Project Settings" -> "Service Account" and click "Generate new private key". Download it and put it locally ...
Read more >
JVP Domes and custom probe. (a) The set of 8 domes used in ...
JVP Domes and custom probe. (a) The set of 8 domes used in the study. (b) Grid spacing on one dome. (c) Custom...
Read more >
Informix Error Messages - Oninit:
This error occurs if the daemon name specified in the license-file feature line does not match the vendor-daemon name or if the attempt...
Read more >
trax-ml/community - Gitter
It throws out an error message: 'Normalize' is not a callable object. ... out_tokens = trax.supervised.decoding.autoregressive_sample( self.model, tokens, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found