Batching rule for 'gather' not implemented
See original GitHub issueHere is a minimal example to reproduce the error.
import jax
import jax.numpy as np
a = np.ones((3, 3))
def f(a):
inds = np.arange(a.shape[0])
return a[inds]
jax.vmap(f)(a)
NotImplementedError: Batching rule for 'index_take' not implemented
So if you want to use vmap
, the only form of indexing you can do is scalar indexing. Am I interpreting this correctly? For me that’s a huge limitation.
I would try to implement the batch rule but I think it will be complicated and I have a workaround:
np.stack([a[ind] for ind in inds])
instead of a[inds]
. With an ideal JIT compiler those two would be equivalent, but my guess is that the performance of this workaround is not optimal.
Issue Analytics
- State:
- Created 5 years ago
- Comments:8 (5 by maintainers)
Top Results From Across the Web
Torch.vmap: Batching rule not implemented for aten
Hi, I'm trying to use vmap for a simple function like this one: def f_test(x, label): coeffs = [1, 2, 3, 4, 5]...
Read more >Operation Semantics | XLA - TensorFlow
This gather operation acts as a batch dynamic slice with G as the batch dimension. The gather indices may be multidimensional. For instance,...
Read more >NVIDIA Deep Learning TensorRT Documentation
This NVIDIA TensorRT Developer Guide demonstrates how to use the C++ and Python APIs for implementing the most common deep learning layers.
Read more >jax._src.numpy.lax_numpy — Objax 1.6.0 documentation
Since NumPy operations are not primitive and instead are implemented in ... do not need to define transformation rules such as gradient or...
Read more >Azure Batch best practices - Microsoft Learn
Avoid designing a Batch solution that requires thousands of simultaneously active jobs. There's no quota for tasks, so executing many tasks ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Since #307 removed index_take as a primitive, running the OP’s code now results in
I changed the title of the issue to reflect the new topic: we want a batching rule for
lax.gather
!Whew! We finally got this one done in #351. The challenge was that XLA’s Gather HLO is a bit complex. In fact, it supports more fancy indexing than you can even express in NumPy. But we finally completed the batching rule for it, and as a result we can vmap through all of NumPy’s advanced integer array indexing now.
(The story isn’t totally over until we get #353 in too, but that should get in momentarily.)