question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TypeError when using None as custom_vjp cotangent for custom pytree and asarray

See original GitHub issue
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class C:
    def __init__(self, a=1.):
        self.a = jnp.asarray(a, dtype=float)  # TypeError: float() argument must be a string or a real number, not 'object'
        # self.a = jnp.asarray(a)  # TypeError: Value '<object object at 0x7fff5220fc20>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

    def tree_flatten(self):
        return (self.a,), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

@custom_vjp
def f(x, y):
    return 2. * x  # doesn't depend on y in this simple example


def f_fwd(x, y):
    z = f(x, y)
    res = None
    return z, res

def f_bwd(res, z_cot):
    x_cot = 2. * z_cot
    y_cot = None
    return x_cot, y_cot

f.defvjp(f_fwd, f_bwd)

c = C()
vjp(f, 3., c)[1](3.)  # TypeError with different messages (see above) depending on whether c.a is weakly typed or not
vjp(f, 3., [1., 2])[1](3.)  # this works

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:15 (11 by maintainers)

github_iconTop GitHub Comments

2reactions
jakevdpcommented, Apr 12, 2022

Hi @eelregit - this is a common issue when using custom PyTrees in JAX transforms. Various transforms will pass None values or object() placeholder values to the PyTree constructor, and this will cause issues if your pytree does too strict an input validation at initialization. Here’s an example of how we deal with this in the PyTree used to represent sparse matrices: https://github.com/google/jax/blob/3136004c623be4cc7b25f8477ffdce0b3a110a2e/jax/experimental/sparse/bcoo.py#L1716-L1720

https://github.com/google/jax/blob/3136004c623be4cc7b25f8477ffdce0b3a110a2e/jax/experimental/sparse/util.py#L44-L60

You’ll have to do some kind of similar check if you have your own pytrees that you want to use with JAX transformations.

1reaction
YouJiachengcommented, Apr 12, 2022

@jakevdp Could we only check placeholder in tree_unflatten? Thus we can use the strictest validation at normal initialization. Moreover, do you think we need a special class to represent placeholder? Using None and object() may surprise user and give a vague error message like this issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Working with Pytrees - JAX documentation - Read the Docs
If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a...
Read more >
jax.tree_util.tree_unflatten Example - Program Talk
Learn how to use python api jax.tree_util.tree_unflatten. ... raise ValueError(msg) from None # TODO(mattjj): consider supporting pytree inputs for i, (x, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found