question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

deepcopy of jax.numpy array generates regular numpy array

See original GitHub issue

While updating some existing code that used copy.deepcopy I got what was a rather unexpected result namely that deepcopy returns a regular numpy array when I passed in a jax.numpy array. Is this expected behaviour? If so perhaps a note in the docs to indicated this would be good.

The following is a snippet that replicates the problem.

for jax.version.__version__ == '0.1.62''

import jax.numpy as np

import copy

array_in = np.array([1,2,3], dtype= np.int32) array_out = copy.deepcopy(array_in)

print(“type in:” + str(type(array_in))+ “type out:” +str(type(array_out)))

give type in:<class ‘jax.interpreters.xla.DeviceArray’>type out:<class ‘numpy.ndarray’>

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:13 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
shoyercommented, Jan 6, 2022

There are at least two surprising behaviors here that should be fixed:

  • Pickling/unpickling a JAX DeviceArray objects should return another DeviceArray. To do so, we’ll need to define our own methods for pickling instead of using NumPy’s __reduce__ directly.
  • Deep-copying JAX arrays should possibly copy device buffers, but certainly should not do so via serializing to the host with pickle. Here we also need to define __deepcopy__ instead of relying on Python’s default behavior. (There is some ambiguousness about whether we need to copy device buffers with deepcopy, given that JAX arrays themselves are immutable, but I think doing so would be within the spirit of deep copy.)

@tawe141 I’m not sure what’s going on with example. I believe deepcopy() on JAX tracers works just fine (preserving gradients) – the surprising implementation of copy only arises for arrays that are actually loaded into memory (DeviceArray).

0reactions
jakevdpcommented, May 23, 2022

Fixed in #10659

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.copy() - JAX documentation
Return an array copy of the given object. LAX-backend implementation of numpy.copy() . This function will create arrays on JAX's default device.
Read more >
jax.numpy package - JAX documentation - Read the Docs
Create a NumPy array from an object implementing the __dlpack__. full (shape, fill_value[, dtype]). Return a new array of given shape and type,...
Read more >
jax.numpy.array - JAX documentation
LAX-backend implementation of numpy. array() . This function will create arrays on JAX's default device. For control of the device placement of data,...
Read more >
jax.numpy.rst.txt
Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place ... Generate the list below as follows: >>> import jax.numpy, numpy...
Read more >
numpy.asarray() - JAX documentation
jax.numpy.asarray# · a (array_like) – Input data, in any form that can be converted to an array. · dtype (data-type, optional) – By...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found