cumtom_jvp fails when the function is real-valued and its gradient is complex-valued
See original GitHub issueFirst of all, thank you for supporting complex numbers! Other similar tools have spent years on this this problem and still don’t have it working. I was wondering if there’s a good way to solve the problem below.
Background: The exponential families are important families of probability distributions. All exponential families are identified by a real-valued scalar function called a log-normalizer (among other things). The gradient of the log-normalizer is also very important. In particular, it shows up when optimizing a predictive distribution. One example of a gradient log-normalizer is the logistic sigmoid function that shows up when optimizing a Bernoulli model.
The problem is that some exponential families like the complex normal distribution have a complex-valued gradient log-normalizer. This causes a problem with custom_jvp
:
from jax import grad, custom_jvp
import jax.numpy as jnp
def f(q):
return jnp.real(jnp.sum(q))
def nat_to_exp(q):
return q
if True: # When false, no errors are produced.
f = custom_jvp(f)
@f.defjvp
def f_jvp(primals, tangents):
q, = primals
q_prime, = tangents
a = nat_to_exp(q)
return f(q), a * q_prime
some_q = jnp.array([2+1j, 1+1j], dtype=jnp.complex64)
print(grad(f)(some_q))
gives
TypeError: mul requires arguments to have the same dtypes, got float32, complex64.
What is the best way to fix this? One way would be to make the log-normalizer complex-valued, but this is unfortunate from both a usability standpoint (it shows up in many calculations that are real-valued like the cross entropy, and the density), and from an efficiency standpoint (complex numbers need twice the memory, etc.)
I’m wondering if it’s possible to relax the check in JVP evaluation to make the multiplication support widening from real to complex?
Issue Analytics
- State:
- Created 3 years ago
- Comments:12 (12 by maintainers)
If there’s one thing I love, it’s exponential families. It’s no joke to say an original motivation for JAX was to do research involving exponential families (particularly this work and this work). So you’ve got allies here! Though TBH I’ve never worked with exponential families involving complex numbers, so I think I’m going to learn some things here.
EDIT: and this talk!
Heh, thanks for the trust. Sorry if I’m not expressing things clearly… I realized I’m using nonstandard notation and ideas that we haven’t written down from start to finish anywhere. (I prepared some slides on AD for a Stanford course recently, and those provide some clues about the notation I’m using, but they don’t actually cover complex numbers at all…)
I’ll take this as a +1 to having a careful exposition on complex numbers in any future AD docs we write (like the “Autodiff Cookbook Part 2” we’ve never gotten around to writing…).