Type demotion in `jax.jacobian`
See original GitHub issueExample:
from jax import jacobian
from jax.config import config
import jax.numpy as jnp
config.update('jax_enable_x64', True)
def f(x, y):
return x * y
x = jnp.ones((), jnp.float32)
y = jnp.ones((), jnp.float64)
f(x, y).dtype, jacobian(f)(x, y).dtype
Gives
(dtype('float64'), dtype('float32'))
However, df/dx = y
should have type jnp.float64
.
(this seems to cause a type mismatch in https://github.com/google/neural-tangents/issues/112)
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
The Autodiff Cookbook - JAX documentation
Jacobian -Vector products (JVPs, aka forward-mode autodiff)#. JAX includes efficient and general implementations of both forward- and reverse-mode automatic ...
Read more >Full Jacobian using reverse-mode AD in JAX - YouTube
We can use the primitive of reverse-mode automatic differentiation, the pullback (=vector- Jacobian product, vJp) for obtaining full (and ...
Read more >Using JAX Jacobians for Adjoint Sensitivities over ... - YouTube
Let's use the Automatic Differentiation (AD) functionality of the JAX deep learning framework to obtain the additionally necessary ...
Read more >Gradients and Jacobians in Jax | Kaggle
Jax introduces gradient operator in math as a transformer which takes a python ... In that case, our python function will take two...
Read more >Jacobian determinant of vector-valued function with Python ...
TypeError: Can't differentiate w.r.t. type <class 'int'> ... jac = jax.jacobian(f) vmap_jac = jax.vmap(jac) result ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
The observed behavior is correct, or at least consistent with our differentiation semantics. The cotangent of the first argument of
f
has typefloat32
, since the first argument has typefloat32
. The tangent of the output off
has typefloat16
, since the output has typefloat16
.What’s perhaps unexpected is that
jacobian
is an alias forjacrev
rather thanjacfwd
, as @jakevdp notes. I’d go further and offer that calling either functionjacobian
(via alias) is imprecise in this setting. Especially when a function has differing input/output types, it helps to take the more general view of its Jacobian (at a primal point) as a function.Namely, for a fixed primal point
x
, the JacobianJ(x)
is a (linear) function from input tangents to output tangents. Its transposeJ'(x)
is a function from output (co)tangents to input (co)tangents. In this view,jacfwd
is the Jacobian mapJ(x)
applied to the standard basis vectors of its domain, andjacrev
is the Jacobian map’s transposeJ'(x)
applied to the standard basis vectors of its domain.So the outputs of
jacfwd
andjacrev
are elements of different vector spaces, typed accordingly. Neither object exactly captures the full notion of “the Jacobian”—that’s a function.Perhaps we can improve documentation around
jacobian
,jacfwd
, andjacrev
to clarify their meaning.Here’s a repro that’s slightly more to the point: