Error in custom JVP involving tokens
See original GitHub issueIn mpi4jax
, we encountered some errors when working with gradients of primitives that use tokens.
Issue 1
The following throws an error:
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
def foo(x, token):
token = jax.lax.create_token()
x1, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
x2, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
return x1 + x2
x = 0.0
gfoo = jax.grad(foo)
gx = gfoo(x)
Traceback:
Traceback (most recent call last):
File "bug.py", line 18, in <module>
gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: AssertionError
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "bug.py", line 18, in <module>
gx = gfoo(x)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 760, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 823, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 1896, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 103, in linearize
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError
This is an abridged version of our custom JVP:
def mpi_allreduce_value_and_jvp(in_args, tan_args, op, comm, transpose):
x, token = in_args
x_tan, _ = tan_args
# sum x across all ranks
val, token = mpi_allreduce_p.bind(x, token, op=op, comm=comm, transpose=transpose)
# sum of x_tan across all ranks
jvp, token = mpi_allreduce_p.bind(x_tan, token, op=op, comm=comm, transpose=transpose)
return (val, token), (jvp, ad.Zero.from_value(token))
Issue 2
The error above doesn’t occur when I generate a new token before every call:
def foo(x):
token = jax.lax.create_token()
x1, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
token = jax.lax.create_token()
x2, token = mpi4jax.allreduce(x, op=MPI.SUM, comm=comm, token=token)
return x1 + x2
(of course this gives bogus results now, just for illustration)
But now, I get a different error:
Traceback (most recent call last):
File "bug.py", line 19, in <module>
gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: KeyError: <class 'jax.interpreters.xla.Token'>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "bug.py", line 19, in <module>
gx = gfoo(x)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 760, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 829, in value_and_grad_f
g = vjp_py(np.ones((), dtype=dtype))
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/api.py", line 1816, in _vjp_pullback_wrapper
ans = fun(*args)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 231, in backward_pass
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/_src/util.py", line 41, in safe_map
return list(map(f, *args))
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 176, in write_cotangent
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/interpreters/ad.py", line 482, in add_tangents
return add_jaxvals(x, y)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/ad_util.py", line 37, in add_jaxvals
return add_jaxvals_p.bind(x, y)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/core.py", line 284, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/core.py", line 622, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/dion/.virtualenvs/mpi4jax/lib/python3.8/site-packages/jax/ad_util.py", line 44, in add_impl
return jaxval_adders[type(xs)](xs, ys)
KeyError: <class 'jax.interpreters.xla.Token'>
It looks like the tangent accumulator is trying to add the input tokens.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
jax._src.custom_derivatives - JAX documentation
Args: jvp: a Python callable representing the custom JVP rule. ... call_jaxpr.jaxpr, ctx.tokens_in, consts, *args_) ctx.set_tokens_out(tokens) return out ...
Read more >Custom token has invalid signature · Issue #556 - GitHub
Go to your Firebase account "Project Settings" -> "Service Account" and click "Generate new private key". Download it and put it locally ...
Read more >JVP Domes and custom probe. (a) The set of 8 domes used in ...
JVP Domes and custom probe. (a) The set of 8 domes used in the study. (b) Grid spacing on one dome. (c) Custom...
Read more >Informix Error Messages - Oninit:
This error occurs if the daemon name specified in the license-file feature line does not match the vendor-daemon name or if the attempt...
Read more >trax-ml/community - Gitter
It throws out an error message: 'Normalize' is not a callable object. ... out_tokens = trax.supervised.decoding.autoregressive_sample( self.model, tokens, ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Another approach to consider: if you are only expecting reverse-mode usage, and if there is an implementable ordering under reverse-mode (as your most recent comment suggests), then you could consider disallowing forward-mode altogether and defining only a VJP with
jax.custom_vjp
.Aha, I see! Thanks for clearing this up. I think we’re on the same page now.
Correct.
Actually, the situation is a bit more complicated in my case - executing everything in mirrored order is actually fine, so I wouldn’t be worried about forward vs. reverse AD here.