question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.random.split uses extra memory before preallocated memory is used up

See original GitHub issue

I monitored the gpu memory usage via nvidia-smi. I find when run the command

jax.random.split()

Jax will always use more memory even if the preallocated memory is not used at all. This issue keeps raising OOM errors since Jax has already preallocated 90% memory by default.

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:2
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
onurdanacicommented, Aug 31, 2022

I have been experiencing exactly the same issue on Ubuntu 22.04, graphics card driver CUDA 11.7, nvcc CUDA version 11.2, cuDNN 8.2.

Doing nothing but just assigning a random key on JAX costs 3900 MB on GPU. And, anytime you split that key, it costs 3900^(n_splits) amount of GPU memory: exponential memory blow up!

image

0reactions
SamTovcommented, Dec 2, 2022

I have the same issue. It does not matter what I set the memory fraction to, this will not run.

Read more comments on GitHub >

github_iconTop Results From Across the Web

GPU memory allocation - JAX documentation - Read the Docs
JAX will preallocate 90% of the total GPU memory when the first JAX operation is ... GPU memory as needed, potentially decreasing the...
Read more >
تويتر \ التغريدات مع الردود بواسطة (Onur) (Danaci) ≥ i〈〚Onur ...
3900 MB of GPU memory just for a single random key lol. ... jax.random.split uses extra memory before preallocated memory is used up...
Read more >
Help Needed: Hierarchical Model with crossed structure
I suspect I have messed something up with specifying the random ... I can do to work around this apart from getting a...
Read more >
Using tf.data.Datasets for loading - The Aquila consortium
PRNGKey(0) rng, data_key = jax.random.split(rng) data_keys ... However, if the data is too large to fit in memory then we can use the...
Read more >
Populating an array in do loop, allocating memory or not
The approach a={} and then setting a[[i]]=number is not possible. One could use Append but it should be avoided for the same reasons...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found