question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

what's a good way to save jax arrays?

See original GitHub issue

hi all,

thanks for making such a cool library! I was wondering how I should go about saving jax arrays to file, since np.save isn’t yet implemented. Apologies if I missed something!

Rowan

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

6reactions
mattjjcommented, May 14, 2019

Thanks for the question!

JAX’s device-backed ndarray class, DeviceArray, is effectively a subclass of numpy.ndarray. You can turn it into a regular ndarray using numpy.asarray, like this:

import numpy as onp
numpy_array = onp.asarray(device_array)

DeviceArrays try to turn themselves into ndarrays automatically when appropriate, so you can also use Numpy’s save on them directly:

onp.save('array.npy', device_array)

We could also override __getstate__ so that they work with pickle. (We did that in an older version of JAX, but it looks like we lost that code at some point.)

I think we should probably bring numpy.save into jax.numpy, so that you don’t have to import raw NumPy (as onp here) to get the same saving behavior… Let’s leave this issue open until we either add jax.numpy.save or else figure out there’s some reason not to.

In the meantime, does import numpy as onp and onp.save work for you?

2reactions
rowanzcommented, May 14, 2019

Thanks for the response! Sorry, I’m still learning about the library and didn’t know that you could convert DeviceArrays into regular numpy arrays like that – that helps a lot and solves my problem (and will hopefully help others if they search for it) 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.numpy.save - JAX documentation - Read the Docs
Save an array to a binary file in NumPy .npy format. Parameters. file (file, str, or pathlib.Path) – File or filename to which...
Read more >
TF_JAX_tutorials - Part 4 (JAX and DeviceArray) | Kaggle
JAX arrays are immutable, just like TensorFlow tensors. Meaning, JAX arrays don't support item assignment as you do in ndarray . Let's take...
Read more >
Is there a way to speed up indexing a vector with JAX?
The short answer: to speed things up in JAX, use jit . The long answer: You should generally expect single operations using JAX...
Read more >
TUTORIAL / Eric Ma / Magical NumPy with JAX - YouTube
The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to...
Read more >
How to convert NumPy array to list ? - GeeksforGeeks
We can convert the Numpy array to the list by 2 different methods, we can have a list of data elements that is...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found