TypeError when using None as custom_vjp cotangent for custom pytree and asarray
See original GitHub issueimport jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class C:
def __init__(self, a=1.):
self.a = jnp.asarray(a, dtype=float) # TypeError: float() argument must be a string or a real number, not 'object'
# self.a = jnp.asarray(a) # TypeError: Value '<object object at 0x7fff5220fc20>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
def tree_flatten(self):
return (self.a,), None
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
@custom_vjp
def f(x, y):
return 2. * x # doesn't depend on y in this simple example
def f_fwd(x, y):
z = f(x, y)
res = None
return z, res
def f_bwd(res, z_cot):
x_cot = 2. * z_cot
y_cot = None
return x_cot, y_cot
f.defvjp(f_fwd, f_bwd)
c = C()
vjp(f, 3., c)[1](3.) # TypeError with different messages (see above) depending on whether c.a is weakly typed or not
vjp(f, 3., [1., 2])[1](3.) # this works
Issue Analytics
- State:
- Created a year ago
- Comments:15 (11 by maintainers)
Top Results From Across the Web
Working with Pytrees - JAX documentation - Read the Docs
If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a...
Read more >jax.tree_util.tree_unflatten Example - Program Talk
Learn how to use python api jax.tree_util.tree_unflatten. ... raise ValueError(msg) from None # TODO(mattjj): consider supporting pytree inputs for i, (x, ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hi @eelregit - this is a common issue when using custom PyTrees in JAX transforms. Various transforms will pass
None
values orobject()
placeholder values to the PyTree constructor, and this will cause issues if your pytree does too strict an input validation at initialization. Here’s an example of how we deal with this in the PyTree used to represent sparse matrices: https://github.com/google/jax/blob/3136004c623be4cc7b25f8477ffdce0b3a110a2e/jax/experimental/sparse/bcoo.py#L1716-L1720https://github.com/google/jax/blob/3136004c623be4cc7b25f8477ffdce0b3a110a2e/jax/experimental/sparse/util.py#L44-L60
You’ll have to do some kind of similar check if you have your own pytrees that you want to use with JAX transformations.
@jakevdp Could we only check placeholder in
tree_unflatten
? Thus we can use the strictest validation at normal initialization. Moreover, do you think we need a special class to represent placeholder? UsingNone
andobject()
may surprise user and give a vague error message like this issue.