question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Batching rule for 'gather' not implemented

See original GitHub issue

Here is a minimal example to reproduce the error.

import jax
import jax.numpy as np

a = np.ones((3, 3))
def f(a):
    inds = np.arange(a.shape[0])
    return a[inds]
jax.vmap(f)(a)

NotImplementedError: Batching rule for 'index_take' not implemented

So if you want to use vmap, the only form of indexing you can do is scalar indexing. Am I interpreting this correctly? For me that’s a huge limitation.

I would try to implement the batch rule but I think it will be complicated and I have a workaround: np.stack([a[ind] for ind in inds]) instead of a[inds]. With an ideal JIT compiler those two would be equivalent, but my guess is that the performance of this workaround is not optimal.

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:8 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Feb 3, 2019

Since #307 removed index_take as a primitive, running the OP’s code now results in

NotImplementedError: Batching rule for 'gather' not implemented

I changed the title of the issue to reflect the new topic: we want a batching rule for lax.gather!

0reactions
mattjjcommented, Feb 11, 2019

Whew! We finally got this one done in #351. The challenge was that XLA’s Gather HLO is a bit complex. In fact, it supports more fancy indexing than you can even express in NumPy. But we finally completed the batching rule for it, and as a result we can vmap through all of NumPy’s advanced integer array indexing now.

(The story isn’t totally over until we get #353 in too, but that should get in momentarily.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Torch.vmap: Batching rule not implemented for aten
Hi, I'm trying to use vmap for a simple function like this one: def f_test(x, label): coeffs = [1, 2, 3, 4, 5]...
Read more >
Operation Semantics | XLA - TensorFlow
This gather operation acts as a batch dynamic slice with G as the batch dimension. The gather indices may be multidimensional. For instance,...
Read more >
NVIDIA Deep Learning TensorRT Documentation
This NVIDIA TensorRT Developer Guide demonstrates how to use the C++ and Python APIs for implementing the most common deep learning layers.
Read more >
jax._src.numpy.lax_numpy — Objax 1.6.0 documentation
Since NumPy operations are not primitive and instead are implemented in ... do not need to define transformation rules such as gradient or...
Read more >
Azure Batch best practices - Microsoft Learn
Avoid designing a Batch solution that requires thousands of simultaneously active jobs. There's no quota for tasks, so executing many tasks ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found