deepcopy of jax.numpy array generates regular numpy array
See original GitHub issueWhile updating some existing code that used copy.deepcopy I got what was a rather unexpected result namely that deepcopy returns a regular numpy array when I passed in a jax.numpy array. Is this expected behaviour? If so perhaps a note in the docs to indicated this would be good.
The following is a snippet that replicates the problem.
for jax.version.__version__ == '0.1.62''
import jax.numpy as np
import copy
array_in = np.array([1,2,3], dtype= np.int32) array_out = copy.deepcopy(array_in)
print(“type in:” + str(type(array_in))+ “type out:” +str(type(array_out)))
give type in:<class ‘jax.interpreters.xla.DeviceArray’>type out:<class ‘numpy.ndarray’>
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (6 by maintainers)
Top Results From Across the Web
numpy.copy() - JAX documentation
Return an array copy of the given object. LAX-backend implementation of numpy.copy() . This function will create arrays on JAX's default device.
Read more >jax.numpy package - JAX documentation - Read the Docs
Create a NumPy array from an object implementing the __dlpack__. full (shape, fill_value[, dtype]). Return a new array of given shape and type,...
Read more >jax.numpy.array - JAX documentation
LAX-backend implementation of numpy. array() . This function will create arrays on JAX's default device. For control of the device placement of data,...
Read more >jax.numpy.rst.txt
Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place ... Generate the list below as follows: >>> import jax.numpy, numpy...
Read more >numpy.asarray() - JAX documentation
jax.numpy.asarray# · a (array_like) – Input data, in any form that can be converted to an array. · dtype (data-type, optional) – By...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
There are at least two surprising behaviors here that should be fixed:
DeviceArray
objects should return anotherDeviceArray
. To do so, we’ll need to define our own methods for pickling instead of using NumPy’s__reduce__
directly.__deepcopy__
instead of relying on Python’s default behavior. (There is some ambiguousness about whether we need to copy device buffers with deepcopy, given that JAX arrays themselves are immutable, but I think doing so would be within the spirit of deep copy.)@tawe141 I’m not sure what’s going on with example. I believe
deepcopy()
on JAX tracers works just fine (preserving gradients) – the surprising implementation of copy only arises for arrays that are actually loaded into memory (DeviceArray).Fixed in #10659