what's a good way to save jax arrays?
See original GitHub issuehi all,
thanks for making such a cool library! I was wondering how I should go about saving jax arrays to file, since np.save isn’t yet implemented. Apologies if I missed something!
Rowan
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
jax.numpy.save - JAX documentation - Read the Docs
Save an array to a binary file in NumPy .npy format. Parameters. file (file, str, or pathlib.Path) – File or filename to which...
Read more >TF_JAX_tutorials - Part 4 (JAX and DeviceArray) | Kaggle
JAX arrays are immutable, just like TensorFlow tensors. Meaning, JAX arrays don't support item assignment as you do in ndarray . Let's take...
Read more >Is there a way to speed up indexing a vector with JAX?
The short answer: to speed things up in JAX, use jit . The long answer: You should generally expect single operations using JAX...
Read more >TUTORIAL / Eric Ma / Magical NumPy with JAX - YouTube
The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to...
Read more >How to convert NumPy array to list ? - GeeksforGeeks
We can convert the Numpy array to the list by 2 different methods, we can have a list of data elements that is...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Thanks for the question!
JAX’s device-backed ndarray class,
DeviceArray, is effectively a subclass ofnumpy.ndarray. You can turn it into a regular ndarray usingnumpy.asarray, like this:DeviceArrays try to turn themselves into ndarrays automatically when appropriate, so you can also use Numpy’s
saveon them directly:We could also override
__getstate__so that they work with pickle. (We did that in an older version of JAX, but it looks like we lost that code at some point.)I think we should probably bring
numpy.saveintojax.numpy, so that you don’t have to import raw NumPy (asonphere) to get the same saving behavior… Let’s leave this issue open until we either addjax.numpy.saveor else figure out there’s some reason not to.In the meantime, does
import numpy as onpandonp.savework for you?Thanks for the response! Sorry, I’m still learning about the library and didn’t know that you could convert
DeviceArrays into regular numpy arrays like that – that helps a lot and solves my problem (and will hopefully help others if they search for it) 😃