Summing NamedTuple as if they were arrays with named axes
See original GitHub issueI heavily use NamedTuple
s (maybe too heavily) as I find it quite convenient to treat them as arrays with named axes.
The only problem is that some basic primitives do not work for them.
Addition actually works with the default operator +
, but it has a different meaning - concatenation.
Would it be possible to allow numpy operations on NamedTuples?
from typing import NamedTuple
import jax
import jax.numpy as jnp
class NamedArray(NamedTuple):
a: jnp.ndarray
b: jnp.ndarray
x = jnp.ones((2,), float)
a = NamedArray(x, x)
def add_named_array(l, r):
return jnp.add(l, r)
print(add_named_array(a, a))
Trace:
TypeError Traceback (most recent call last)
<ipython-input-8-999792ade930> in <module>()
14
15
---> 16 print(add_named_array(a, a))
<ipython-input-8-999792ade930> in add_named_array(l, r)
11
12 def add_named_array(l, r):
---> 13 return jnp.add(l, r)
14
15
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in fn(x1, x2)
383 def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
384 def fn(x1, x2):
--> 385 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
386 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
387 return _wraps(numpy_fn)(fn)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _promote_args(fun_name, *args)
320 def _promote_args(fun_name, *args):
321 """Convenience function to apply Numpy argument shape and dtype promotion."""
--> 322 _check_arraylike(fun_name, *args)
323 _check_no_float0s(fun_name, *args)
324 return _promote_shapes(fun_name, *_promote_dtypes(*args))
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
304 if not _arraylike(arg))
305 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 306 raise TypeError(msg.format(fun_name, type(arg), pos))
307
308 def _check_no_float0s(fun_name, *args):
TypeError: add requires ndarray or scalar arguments, got <class '__main__.NamedArray'> at position 0.
Issue Analytics
- State:
- Created 3 years ago
- Comments:5
Top Results From Across the Web
How do I use sum()/average() for namedtuple in python?
E.g. I know there's np.average(points, axis=0) if points.shape is (N, 2) , but I'd rather keep ...
Read more >pmap on pytree (namedtuple) with various dimensions #3102
I have a named tuple that works fine with jitted functions. ... vectorizes over array axes; without an axis of a contiguous array,...
Read more >Write Pythonic and Clean Code With namedtuple - Real Python
With namedtuple() , you can create immutable sequence types that allow you to access their values using descriptive field names and the dot...
Read more >++namedtuple - Python-style Named Tuples in C++20
https://cppcon.org/https://github.com/CppCon/CppCon2021---In this workshop, we will take a look into how Python-style named tuples can be ...
Read more >Redesigning Python's named tuples | Hacker News
In one project, I use a Point2D class as a namedtuple. Because it is just a tuple, I can easy convert it to...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I’m going to close for now. Let us know if other questions come up!
I’d say it’s not in the
jax.numpy
roadmap because such operations are not supported by NumPy.