what's a good way to save jax arrays?
See original GitHub issuehi all,
thanks for making such a cool library! I was wondering how I should go about saving jax arrays to file, since np.save
isn’t yet implemented. Apologies if I missed something!
Rowan
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
jax.numpy.save - JAX documentation - Read the Docs
Save an array to a binary file in NumPy .npy format. Parameters. file (file, str, or pathlib.Path) – File or filename to which...
Read more >TF_JAX_tutorials - Part 4 (JAX and DeviceArray) | Kaggle
JAX arrays are immutable, just like TensorFlow tensors. Meaning, JAX arrays don't support item assignment as you do in ndarray . Let's take...
Read more >Is there a way to speed up indexing a vector with JAX?
The short answer: to speed things up in JAX, use jit . The long answer: You should generally expect single operations using JAX...
Read more >TUTORIAL / Eric Ma / Magical NumPy with JAX - YouTube
The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to...
Read more >How to convert NumPy array to list ? - GeeksforGeeks
We can convert the Numpy array to the list by 2 different methods, we can have a list of data elements that is...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for the question!
JAX’s device-backed ndarray class,
DeviceArray
, is effectively a subclass ofnumpy.ndarray
. You can turn it into a regular ndarray usingnumpy.asarray
, like this:DeviceArrays try to turn themselves into ndarrays automatically when appropriate, so you can also use Numpy’s
save
on them directly:We could also override
__getstate__
so that they work with pickle. (We did that in an older version of JAX, but it looks like we lost that code at some point.)I think we should probably bring
numpy.save
intojax.numpy
, so that you don’t have to import raw NumPy (asonp
here) to get the same saving behavior… Let’s leave this issue open until we either addjax.numpy.save
or else figure out there’s some reason not to.In the meantime, does
import numpy as onp
andonp.save
work for you?Thanks for the response! Sorry, I’m still learning about the library and didn’t know that you could convert
DeviceArray
s into regular numpy arrays like that – that helps a lot and solves my problem (and will hopefully help others if they search for it) 😃