jax.random.split uses extra memory before preallocated memory is used up
See original GitHub issueI monitored the gpu memory usage via nvidia-smi
.
I find when run the command
jax.random.split()
Jax will always use more memory even if the preallocated memory is not used at all. This issue keeps raising OOM errors since Jax has already preallocated 90% memory by default.
Issue Analytics
- State:
- Created a year ago
- Reactions:2
- Comments:6 (1 by maintainers)
Top Results From Across the Web
GPU memory allocation - JAX documentation - Read the Docs
JAX will preallocate 90% of the total GPU memory when the first JAX operation is ... GPU memory as needed, potentially decreasing the...
Read more >تويتر \ التغريدات مع الردود بواسطة (Onur) (Danaci) ≥ i〈〚Onur ...
3900 MB of GPU memory just for a single random key lol. ... jax.random.split uses extra memory before preallocated memory is used up...
Read more >Help Needed: Hierarchical Model with crossed structure
I suspect I have messed something up with specifying the random ... I can do to work around this apart from getting a...
Read more >Using tf.data.Datasets for loading - The Aquila consortium
PRNGKey(0) rng, data_key = jax.random.split(rng) data_keys ... However, if the data is too large to fit in memory then we can use the...
Read more >Populating an array in do loop, allocating memory or not
The approach a={} and then setting a[[i]]=number is not possible. One could use Append but it should be avoided for the same reasons...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I have been experiencing exactly the same issue on Ubuntu 22.04, graphics card driver CUDA 11.7, nvcc CUDA version 11.2, cuDNN 8.2.
Doing nothing but just assigning a random key on JAX costs 3900 MB on GPU. And, anytime you split that key, it costs 3900^(n_splits) amount of GPU memory: exponential memory blow up!
I have the same issue. It does not matter what I set the memory fraction to, this will not run.