question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

namedtuple support in arguments to transformed functions

See original GitHub issue

It would be great if xla.abstractify would also accept namedtuples. Loop state’s can consist of quite a lot of values and organizing them in a namedtuple rather than a tuple would make things nicer.

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

18reactions
mattjjcommented, Feb 25, 2019

There’s actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for grad, jit, vmap, etc, all at once. Of course it’s not documented at all… 😃

You can register a custom type as a “pytree” (tree-like Python container) like this:

from collections import namedtuple
from jax.tree_util import register_pytree_node
from jax import grad, jit
import jax.numpy as np

Point = namedtuple("Point", ["x", "y"])

register_pytree_node(
    Point,
    lambda xs: (tuple(xs), None),  # tell JAX how to unpack to an iterable
    lambda _, xs: Point(*xs)       # tell JAX how to pack back into a Point
)


def f(pt):
  return np.sqrt(pt.x**2 + pt.y**2)

pt = Point(1., 2.)

print f(pt)        # 2.236068
print grad(f)(pt)  # Point(x=..., y=...)

g = jit(f)
print g(pt)  # 2.236068

So that’s an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.

(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we’re just returning None when mapping to an iterable and then ignoring it when reconstructing.)

However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?

4reactions
rsepassicommented, Mar 8, 2019

+1 to having JAX work with all namedtuple classes

Read more comments on GitHub >

github_iconTop Results From Across the Web

Python namedtuple - ZetCode
Python namedtuple is an immutable container type, whose values can be accessed with indexes and named attributes. It has functionality like ...
Read more >
Python Namedtuple — Working and Benefits of ... - Medium
A Python namedtuple lets us access elements in a tuple using names/labels. To define it, we import namedtuple from Python collections module and ......
Read more >
Python namedtuple Syntax & Function with Example
A namedtuple in python is a subclass of tuples. The named tuple has the same functionalities as a normal tuple, but its values...
Read more >
Write Pythonic and Clean Code With namedtuple - Real Python
Python's namedtuple() is a factory function available in collections . It allows you to create tuple subclasses with named fields. You can access...
Read more >
Python Named Tuple: What, How and When to Use
Python named tuple is kind of struct between a tuple and a class. It can be converted to dictionary, ordered dict, used in...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found