question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Type demotion in `jax.jacobian`

See original GitHub issue

Example:

from jax import jacobian
from jax.config import config
import jax.numpy as jnp
config.update('jax_enable_x64', True)

def f(x, y):
  return x * y

x = jnp.ones((), jnp.float32)
y = jnp.ones((), jnp.float64)

f(x, y).dtype, jacobian(f)(x, y).dtype

Gives

(dtype('float64'), dtype('float32'))

However, df/dx = y should have type jnp.float64.

(this seems to cause a type mismatch in https://github.com/google/neural-tangents/issues/112)

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
froystigcommented, May 4, 2021

The observed behavior is correct, or at least consistent with our differentiation semantics. The cotangent of the first argument of f has type float32, since the first argument has type float32. The tangent of the output of f has type float16, since the output has type float16.

What’s perhaps unexpected is that jacobian is an alias for jacrev rather than jacfwd, as @jakevdp notes. I’d go further and offer that calling either function jacobian (via alias) is imprecise in this setting. Especially when a function has differing input/output types, it helps to take the more general view of its Jacobian (at a primal point) as a function.

Namely, for a fixed primal point x, the Jacobian J(x) is a (linear) function from input tangents to output tangents. Its transpose J'(x) is a function from output (co)tangents to input (co)tangents. In this view, jacfwd is the Jacobian map J(x) applied to the standard basis vectors of its domain, and jacrev is the Jacobian map’s transpose J'(x) applied to the standard basis vectors of its domain.

So the outputs of jacfwd and jacrev are elements of different vector spaces, typed accordingly. Neither object exactly captures the full notion of “the Jacobian”—that’s a function.

Perhaps we can improve documentation around jacobian, jacfwd, and jacrev to clarify their meaning.

1reaction
jakevdpcommented, May 4, 2021

Here’s a repro that’s slightly more to the point:

from jax import jacobian
import jax.numpy as jnp

def f(x):
  return x.astype('float16')

x = jnp.array(1.0)
print(f(x).dtype, jacobian(f)(x).dtype)
# float16 float32
Read more comments on GitHub >

github_iconTop Results From Across the Web

The Autodiff Cookbook - JAX documentation
Jacobian -Vector products (JVPs, aka forward-mode autodiff)#. JAX includes efficient and general implementations of both forward- and reverse-mode automatic ...
Read more >
Full Jacobian using reverse-mode AD in JAX - YouTube
We can use the primitive of reverse-mode automatic differentiation, the pullback (=vector- Jacobian product, vJp) for obtaining full (and ...
Read more >
Using JAX Jacobians for Adjoint Sensitivities over ... - YouTube
Let's use the Automatic Differentiation (AD) functionality of the JAX deep learning framework to obtain the additionally necessary ...
Read more >
Gradients and Jacobians in Jax | Kaggle
Jax introduces gradient operator in math as a transformer which takes a python ... In that case, our python function will take two...
Read more >
Jacobian determinant of vector-valued function with Python ...
TypeError: Can't differentiate w.r.t. type <class 'int'> ... jac = jax.jacobian(f) vmap_jac = jax.vmap(jac) result ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found