segment_sum primitives / advanced indexing in jax.ops.index_add
See original GitHub issueWould it be useful to others (aside from me) to have better support of sorted and unsorted segment_sums?
https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
In numpy one way to do unsorted segment sums is to sort then call np.add.reduceat, but this doesn’t seem to be in jax or autograd:
>>> import jax.numpy as np
np.add
>>> np.add
<function _one_to_one_binop.<locals>.<lambda> at 0x1430cdbf8>
>>> np.add.reduceat
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'function' object has no attribute 'reduceat'
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:5 (5 by maintainers)
Top Results From Across the Web
jax.ops.segment_sum - JAX documentation
GatherScatterMode value describing how out-of-bounds indices should be handled. By default, values outside of the range [0, num_segments) are dropped and do not ......
Read more >jax segment_sum along array dimension - Stack Overflow
I am fairly new to jax and have the ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yup, that’s the idea!
Ah, I guess it does not yet because it does not support advanced indexing. We should fix that.