question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

cumtom_jvp fails when the function is real-valued and its gradient is complex-valued

See original GitHub issue

First of all, thank you for supporting complex numbers! Other similar tools have spent years on this this problem and still don’t have it working. I was wondering if there’s a good way to solve the problem below.

Background: The exponential families are important families of probability distributions. All exponential families are identified by a real-valued scalar function called a log-normalizer (among other things). The gradient of the log-normalizer is also very important. In particular, it shows up when optimizing a predictive distribution. One example of a gradient log-normalizer is the logistic sigmoid function that shows up when optimizing a Bernoulli model.

The problem is that some exponential families like the complex normal distribution have a complex-valued gradient log-normalizer. This causes a problem with custom_jvp:

from jax import grad, custom_jvp
import jax.numpy as jnp

def f(q):
    return jnp.real(jnp.sum(q))

def nat_to_exp(q):
    return q

if True:  # When false, no errors are produced.
    f = custom_jvp(f)

    @f.defjvp
    def f_jvp(primals, tangents):
        q, = primals
        q_prime, = tangents
        a = nat_to_exp(q)
        return f(q), a * q_prime


some_q = jnp.array([2+1j, 1+1j], dtype=jnp.complex64)
print(grad(f)(some_q))

gives

TypeError: mul requires arguments to have the same dtypes, got float32, complex64.

What is the best way to fix this? One way would be to make the log-normalizer complex-valued, but this is unfortunate from both a usability standpoint (it shows up in many calculations that are real-valued like the cross entropy, and the density), and from an efficiency standpoint (complex numbers need twice the memory, etc.)

I’m wondering if it’s possible to relax the check in JVP evaluation to make the multiplication support widening from real to complex?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:12 (12 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Mar 24, 2020

If there’s one thing I love, it’s exponential families. It’s no joke to say an original motivation for JAX was to do research involving exponential families (particularly this work and this work). So you’ve got allies here! Though TBH I’ve never worked with exponential families involving complex numbers, so I think I’m going to learn some things here.

EDIT: and this talk!

0reactions
mattjjcommented, Nov 26, 2020

Heh, thanks for the trust. Sorry if I’m not expressing things clearly… I realized I’m using nonstandard notation and ideas that we haven’t written down from start to finish anywhere. (I prepared some slides on AD for a Stanford course recently, and those provide some clues about the notation I’m using, but they don’t actually cover complex numbers at all…)

I’ll take this as a +1 to having a careful exposition on complex numbers in any future AD docs we write (like the “Autodiff Cookbook Part 2” we’ve never gotten around to writing…).

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found