namedtuple support in arguments to transformed functions
See original GitHub issueIt would be great if xla.abstractify
would also accept namedtuples. Loop state’s can consist of quite a lot of values and organizing them in a namedtuple rather than a tuple would make things nicer.
Issue Analytics
- State:
- Created 5 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
Python namedtuple - ZetCode
Python namedtuple is an immutable container type, whose values can be accessed with indexes and named attributes. It has functionality like ...
Read more >Python Namedtuple — Working and Benefits of ... - Medium
A Python namedtuple lets us access elements in a tuple using names/labels. To define it, we import namedtuple from Python collections module and ......
Read more >Python namedtuple Syntax & Function with Example
A namedtuple in python is a subclass of tuples. The named tuple has the same functionalities as a normal tuple, but its values...
Read more >Write Pythonic and Clean Code With namedtuple - Real Python
Python's namedtuple() is a factory function available in collections . It allows you to create tuple subclasses with named fields. You can access...
Read more >Python Named Tuple: What, How and When to Use
Python named tuple is kind of struct between a tuple and a class. It can be converted to dictionary, ordered dict, used in...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
There’s actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for
grad
,jit
,vmap
, etc, all at once. Of course it’s not documented at all… 😃You can register a custom type as a “pytree” (tree-like Python container) like this:
So that’s an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.
(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we’re just returning None when mapping to an iterable and then ignoring it when reconstructing.)
However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?
+1 to having JAX work with all namedtuple classes