question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Error in differentiating a function with multiple components

See original GitHub issue

I encountered a runtime error when I ran the following code with one MPI process. The bottom line is that jax fails to differentiate a function which adds two values allreduced by mpi4jax. Is this intended behavior?

  • mpi4jax-0.2.16 and 0.2.16+5.ga28c335
  • jax 0.2.11
  • jaxlib 0.1.64
  • mpi4py 3.0.3
  • Python 3.8.5

Code

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

def foo(x):
    token = None
    x1, token = mpi4jax.allreduce(x**2, op=MPI.SUM, comm=comm, token=token)
    x2, token = mpi4jax.allreduce(jnp.abs(x), op=MPI.SUM, comm=comm, token=token)
    return x1 + x2

x = 0.0
fx = foo(x)
gfoo = jax.grad(foo)
gx = gfoo(x)

Error messages

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "mpi.py", line 18, in <module>
    gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: KeyError: <class 'jax.interpreters.xla.Token'>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mpi.py", line 18, in <module>
    gx = gfoo(x)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 752, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 821, in value_and_grad_f
    g = vjp_py(np.ones((), dtype=dtype))
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 1808, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 231, in backward_pass
    map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map
    return list(map(f, *args))
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 176, in write_cotangent
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 482, in add_tangents
    return add_jaxvals(x, y)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 37, in add_jaxvals
    return add_jaxvals_p.bind(x, y)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 259, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 597, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 44, in add_impl
    return jaxval_adders[type(xs)](xs, ys)
KeyError: <class 'jax.interpreters.xla.Token'>
Traceback (most recent call last):
  File "mpi.py", line 18, in <module>
    gx = gfoo(x)
jax._src.traceback_util.FilteredStackTrace: KeyError: <class 'jax.interpreters.xla.Token'>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mpi.py", line 18, in <module>
    gx = gfoo(x)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 752, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 821, in value_and_grad_f
    g = vjp_py(np.ones((), dtype=dtype))
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/api.py", line 1808, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 121, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 231, in backward_pass
    map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/_src/util.py", line 40, in safe_map
    return list(map(f, *args))
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 176, in write_cotangent
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/interpreters/ad.py", line 482, in add_tangents
    return add_jaxvals(x, y)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 37, in add_jaxvals
    return add_jaxvals_p.bind(x, y)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 259, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/core.py", line 597, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/shinaoka/.local/lib/python3.8/site-packages/jax/ad_util.py", line 44, in add_impl
    return jaxval_adders[type(xs)](xs, ys)
KeyError: <class 'jax.interpreters.xla.Token'>

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:9 (1 by maintainers)

github_iconTop GitHub Comments

0reactions
dionhaefnercommented, Apr 5, 2021

Hope that fixes it for now, let us know if you run into further problems.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Strategy in differentiating functions (article) - Khan Academy
Many calculus students know their derivative rules pretty well yet struggle to apply the ... Common mistake: Confusing function notation with multiplication.
Read more >
Basic derivative rules: find the error (video) | Khan Academy
more. For a linear equation in the form y=mx+b, the coefficient of the x term is the slope of the line. Since the...
Read more >
derivative of error function - Mathematics Stack Exchange
The error function erf(x) is just 2√π∫x0e−t2 dt, so its derivative is just 2√πe−x2. All you have to do for your examples is...
Read more >
Common Calculus Errors - Pauls Online Math Notes
Calculus Errors. Many of the errors listed here are not really calculus errors, but errors that commonly occur in a calculus class and ......
Read more >
React.js does not differentiate between two components of ...
I created a mixin that adds a function called runRoute which takes a JSX component and uses Page 's setState method to update...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found