question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Moving data between CPU/GPU

See original GitHub issue

Apologies if this is a silly question, but I couldn’t find anything in the documentation about this.

So I have a pipeline where the first N steps need to be executed on the GPU as they are efficiently batched, and theres fast cudnn code for them.

The last step is something like an outer product, where the resulting matrix won’t fit in gpu memory but it will fit in my CPU memory.

Currently I am casting the output of my jax pipeline that runs on the gpu to a regular numpy array and using a standard numpy dot product, which works but seems some what kludgey. More generally anytime I work with this big object I need to make sure I use onp instead of np or I get a gpu out of memory error.

My previous iteration of my code written in pytorch would run the gpu code and do .numpy(), and proceed with the regular cpu portion of the pipeline.

For what its worth jax.devices() outputs only 1 device (the gpu).

Is this the best way to go about this right now?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:3
  • Comments:9

github_iconTop GitHub Comments

8reactions
cgarciaecommented, Nov 16, 2020

I think this piece of code should appear on the FAQ to clarify how to actually find all device of interest:

cpus = jax.devices("cpu")
gpus = jax.devices("gpu")

Before reading this I was under the impression that jax.devices() should return all available devices, I didn’t even realize it could accept an argument.

8reactions
gneculacommented, Apr 27, 2020

I just tried it, I had to add jax.device_put:

cpus = jax.devices("cpu")
gpus = jax.devices("gpu")

x = jax.jit(lambda x: x * 2., device=gpus[0])(1.)
print(x, x.device_buffer.device())
z = jax.device_put(x, cpus[0])
print(z, z.device_buffer.device())
y = jax.jit(lambda x: x * 11., device=cpus[0])(z)
print(y, y.device_buffer.device())

prints

2.0 gpu:0
2.0 cpu:0
22.0 gpu:0
Read more comments on GitHub >

github_iconTop Results From Across the Web

PCI-E bottleneck when transferring data between CPU and GPU
I've read that the transfer overhead between CPU and GPU is a big bottleneck in achieving high performance in GPU/CPU applications.
Read more >
Data transfer between CPU and GPU
First question: I need to transfer data from GPU to CPU and CPU to GPU. To compute the transfer rate I'm timing the...
Read more >
Transferring Data Between Connected GPUs - Apple Developer
To copy data between members of a peer group, make a remote view on the second GPU that's connected to the resource you...
Read more >
Why is it faster to transfer data from CPU to GPU rather than ...
Considering GPUs are primarily designed to produce graphics on screen, it makes sense to make upload speed a priority over download speed. – ......
Read more >
Nvidia wants to speed up data transfer by connecting data ...
The API promises faster load times and more detailed graphics by letting game developers make apps that load graphical data from the SSD ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found