SVD: Very Memory Inefficient
See original GitHub issueThe following operations would run easily in both Tensorflow and PyTorch (GPU), but would cause OOM in Jax. When it runs on PyTorch, GPU memory usage is only ~2GB.
Please:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.numpy as jnp
import jax.random as jrandom
A = jrandom.normal(jrandom.PRNGKey(0), [300000, 50])
print(A.shape)
jnp.linalg.svd(A, full_matrices=False)
- If applicable, include full error messages/tracebacks.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_57631/2574565684.py in <module>
3 A = jrandom.normal(jrandom.PRNGKey(0), [300000, 50])
4 print(A.shape)
----> 5 jnp.linalg.svd(A, full_matrices=False)
[... skipping hidden 6 frame]
/opt/conda/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, output_buffer_counts, result_handlers, kept_var_idx, *args)
442 input_bufs = util.flatten(
443 device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
--> 444 out_bufs = compiled.execute(input_bufs)
445 check_special(name, out_bufs)
446 if output_buffer_counts is None:
RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 360336162820 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 57.22MiB
constant allocation: 0B
maybe_live_out allocation: 57.23MiB
preallocated temp allocation: 335.59GiB
preallocated temp fragmentation: 112B (0.00%)
total allocation: 335.70GiB
total fragmentation: 10.0KiB (0.00%)
Peak buffers:
Buffer 1:
Size: 335.28GiB
XLA Label: custom-call
Shape: f32[300000,300000]
==========================
Buffer 2:
Size: 263.36MiB
XLA Label: custom-call
Shape: f32[69038144]
==========================
Buffer 3:
Size: 57.22MiB
XLA Label: custom-call
Shape: f32[300000,50]
==========================
Buffer 4:
Size: 57.22MiB
Entry Parameter Subshape: f32[300000,50]
==========================
Buffer 5:
Size: 57.22MiB
Operator: op_name="jit(svd)/jit(main)/svd[full_matrices=False compute_uv=True]" source_file="/tmp/ipykernel_57631/2574565684.py" source_line=5
XLA Label: fusion
Shape: f32[300000,50]
==========================
Buffer 6:
Size: 9.8KiB
XLA Label: custom-call
Shape: f32[50,50]
==========================
Buffer 7:
Size: 9.8KiB
Operator: op_name="jit(svd)/jit(main)/svd[full_matrices=False compute_uv=True]" source_file="/tmp/ipykernel_57631/2574565684.py" source_line=5
XLA Label: fusion
Shape: f32[50,50]
==========================
Buffer 8:
Size: 200B
Operator: op_name="jit(svd)/jit(main)/svd[full_matrices=False compute_uv=True]" source_file="/tmp/ipykernel_57631/2574565684.py" source_line=5
XLA Label: fusion
Shape: f32[50]
==========================
Buffer 9:
Size: 48B
XLA Label: custom-call
Shape: (f32[300000,50], f32[50], f32[300000,300000], f32[50,50], s32[], /*index=5*/f32[69038144])
==========================
Buffer 10:
Size: 24B
XLA Label: tuple
Shape: (f32[300000,50], f32[50], f32[50,50])
==========================
Buffer 11:
Size: 4B
XLA Label: custom-call
Shape: s32[]
==========================
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
Memory efficient implementations of partial Singular Value ...
If I just call the svd routine from the numpy.linalg module in Python for a random matrix of this size, I run into...
Read more >Applying SVD throws a Memory Error instantaneously?
Yes, the full_matrices parameter to scipy.linalg.svd is important: your input is highly rank-deficient (rank max 3,241), so you don't want ...
Read more >A Framework for Out of Memory SVD Algorithms - ICL UTK
When the matrix A is too large and does not fit in-memory, our goal is to design efficient algorithms to perform the computation...
Read more >FameSVD: Fast and Memory-efficient Singular Value ... - DeepAI
Due to the direct connection with PCA, Singular Value Decomposition (SVD) is one of the most well-known algorithms for low-rank ...
Read more >Distributed Out-of-Memory SVD on CPU/GPU Architectures
Various implementations of SVD have been proposed, but most only estimate the singular values as an estimation of the singular vectors which can ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
#9760 is now merged, which should have fixed this problem.
No need to close this, there’s certainly something for us to fix here…