Error in differentiating a function with multiple components
See original GitHub issueI encountered a runtime error when I ran the following code with one MPI process. The bottom line is that jax fails to differentiate a function which adds two values allreduced by mpi4jax. Is this intended behavior?
- mpi4jax-0.2.16 and 0.2.16+5.ga28c335
- jax 0.2.11
- jaxlib 0.1.64
- mpi4py 3.0.3
- Python 3.8.5
Code
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
def foo(x):
token = None
x1, token = mpi4jax.allreduce(x**2, op=MPI.SUM, comm=comm, token=token)
x2, token = mpi4jax.allreduce(jnp.abs(x), op=MPI.SUM, comm=comm, token=token)
return x1 + x2
x = 0.0
fx = foo(x)
gfoo = jax.grad(foo)
gx = gfoo(x)
Error messages
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "mpi.py", line 18, in <module>
gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: KeyError: <class 'jax.interpreters.xla.Token'>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mpi.py", line 18, in <module>
gx = gfoo(x)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 752, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 821, in value_and_grad_f
g = vjp_py(np.ones((), dtype=dtype))
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 1808, in _vjp_pullback_wrapper
ans = fun(*args)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 231, in backward_pass
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map
return list(map(f, *args))
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 176, in write_cotangent
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 482, in add_tangents
return add_jaxvals(x, y)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 37, in add_jaxvals
return add_jaxvals_p.bind(x, y)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 259, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 597, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 44, in add_impl
return jaxval_adders[type(xs)](xs, ys)
KeyError: <class 'jax.interpreters.xla.Token'>
Traceback (most recent call last):
File "mpi.py", line 18, in <module>
gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: KeyError: <class 'jax.interpreters.xla.Token'>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mpi.py", line 18, in <module>
gx = gfoo(x)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 752, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 821, in value_and_grad_f
g = vjp_py(np.ones((), dtype=dtype))
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 1808, in _vjp_pullback_wrapper
ans = fun(*args)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 231, in backward_pass
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map
return list(map(f, *args))
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 176, in write_cotangent
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 482, in add_tangents
return add_jaxvals(x, y)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 37, in add_jaxvals
return add_jaxvals_p.bind(x, y)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 259, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 597, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 44, in add_impl
return jaxval_adders[type(xs)](xs, ys)
KeyError: <class 'jax.interpreters.xla.Token'>
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (1 by maintainers)
Top Results From Across the Web
Strategy in differentiating functions (article) - Khan Academy
Many calculus students know their derivative rules pretty well yet struggle to apply the ... Common mistake: Confusing function notation with multiplication.
Read more >Basic derivative rules: find the error (video) | Khan Academy
more. For a linear equation in the form y=mx+b, the coefficient of the x term is the slope of the line. Since the...
Read more >derivative of error function - Mathematics Stack Exchange
The error function erf(x) is just 2√π∫x0e−t2 dt, so its derivative is just 2√πe−x2. All you have to do for your examples is...
Read more >Common Calculus Errors - Pauls Online Math Notes
Calculus Errors. Many of the errors listed here are not really calculus errors, but errors that commonly occur in a calculus class and ......
Read more >React.js does not differentiate between two components of ...
I created a mixin that adds a function called runRoute which takes a JSX component and uses Page 's setState method to update...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
https://github.com/google/jax/issues/6285
Hope that fixes it for now, let us know if you run into further problems.