Moving data between CPU/GPU
See original GitHub issueApologies if this is a silly question, but I couldn’t find anything in the documentation about this.
So I have a pipeline where the first N steps need to be executed on the GPU as they are efficiently batched, and theres fast cudnn code for them.
The last step is something like an outer product, where the resulting matrix won’t fit in gpu memory but it will fit in my CPU memory.
Currently I am casting the output of my jax pipeline that runs on the gpu to a regular numpy array and using a standard numpy dot product, which works but seems some what kludgey. More generally anytime I work with this big object I need to make sure I use onp
instead of np
or I get a gpu out of memory error.
My previous iteration of my code written in pytorch would run the gpu code and do .numpy(), and proceed with the regular cpu portion of the pipeline.
For what its worth jax.devices() outputs only 1 device (the gpu).
Is this the best way to go about this right now?
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:9
I think this piece of code should appear on the FAQ to clarify how to actually find all device of interest:
Before reading this I was under the impression that
jax.devices()
should return all available devices, I didn’t even realize it could accept an argument.I just tried it, I had to add
jax.device_put
:prints