question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

np.arange weirdness with jit.

See original GitHub issue

I found an issue related to jit compilation with np.arange when it is indexed dynamically. In particular, the code below produces the produces the following error.

def fun(x):
  r = np.arange(x.shape[0])[x]
  return r

jit(fun)(np.array([0, 1, 2], dtype=np.int32))

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:6 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Mar 31, 2019

You can probably also call np.take here. So long as we’re not dispatching on a raw ndarray things are okay.

1reaction
sschoenholzcommented, Mar 31, 2019

Ah thanks Matt, makes sense! It turns out one can also do

import jax.numpy as np
from jax.api import device_put

def fun(x):
  r = device_put(np.arange(x.shape[0]))[x]
  return r

jit(fun)(np.array([0, 1, 2], dtype=np.int32))

which might be useful, if less efficient, in cases where the more advanced functionality of np.arange is required.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Creating `NumPy` arrays inside a function decorated with ...
Numba mitigates this by automatically trying to jit loops in nopython mode. This allows for array creation at the top of a function...
Read more >
The Sharp Bits — JAX documentation
# JAX re-runs the Python function when the type or shape of the argument changes print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array ...
Read more >
A guide to using @overload - Numba
Numba supports NumPy through the provision of @jit compatible ... import numpy as np a = np.arange(10) # function np.repeat(a, 10) # method...
Read more >
How do I parallelize this code? - Support - Numba Discussion
import numpy as np from numba import jit, ... x=np.array(np.arange(0,10,1.0)) par=np.array([[1.0,0.5],[2.0,3.0],[3.0,5.0]]) myFoo(x, par, 4).
Read more >
Miscellaneous — NumPy v1.24 Manual
nansum() nanmax() nanmin() nanargmax() nanargmin() >>> x = np.arange(10.) ... Turns pure python into efficient machine code through jit-like optimizations.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found