question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

SVD: Very Memory Inefficient

See original GitHub issue

The following operations would run easily in both Tensorflow and PyTorch (GPU), but would cause OOM in Jax. When it runs on PyTorch, GPU memory usage is only ~2GB.

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.numpy as jnp
import jax.random as jrandom
A = jrandom.normal(jrandom.PRNGKey(0), [300000, 50])
print(A.shape)
jnp.linalg.svd(A, full_matrices=False)
  • If applicable, include full error messages/tracebacks.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_57631/2574565684.py in <module>
      3 A = jrandom.normal(jrandom.PRNGKey(0), [300000, 50])
      4 print(A.shape)
----> 5 jnp.linalg.svd(A, full_matrices=False)

    [... skipping hidden 6 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, output_buffer_counts, result_handlers, kept_var_idx, *args)
    442   input_bufs = util.flatten(
    443       device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
--> 444   out_bufs = compiled.execute(input_bufs)
    445   check_special(name, out_bufs)
    446   if output_buffer_counts is None:

RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 360336162820 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   57.22MiB
              constant allocation:         0B
        maybe_live_out allocation:   57.23MiB
     preallocated temp allocation:  335.59GiB
  preallocated temp fragmentation:       112B (0.00%)
                 total allocation:  335.70GiB
              total fragmentation:    10.0KiB (0.00%)
Peak buffers:
	Buffer 1:
		Size: 335.28GiB
		XLA Label: custom-call
		Shape: f32[300000,300000]
		==========================

	Buffer 2:
		Size: 263.36MiB
		XLA Label: custom-call
		Shape: f32[69038144]
		==========================

	Buffer 3:
		Size: 57.22MiB
		XLA Label: custom-call
		Shape: f32[300000,50]
		==========================

	Buffer 4:
		Size: 57.22MiB
		Entry Parameter Subshape: f32[300000,50]
		==========================

	Buffer 5:
		Size: 57.22MiB
		Operator: op_name="jit(svd)/jit(main)/svd[full_matrices=False compute_uv=True]" source_file="/tmp/ipykernel_57631/2574565684.py" source_line=5
		XLA Label: fusion
		Shape: f32[300000,50]
		==========================

	Buffer 6:
		Size: 9.8KiB
		XLA Label: custom-call
		Shape: f32[50,50]
		==========================

	Buffer 7:
		Size: 9.8KiB
		Operator: op_name="jit(svd)/jit(main)/svd[full_matrices=False compute_uv=True]" source_file="/tmp/ipykernel_57631/2574565684.py" source_line=5
		XLA Label: fusion
		Shape: f32[50,50]
		==========================

	Buffer 8:
		Size: 200B
		Operator: op_name="jit(svd)/jit(main)/svd[full_matrices=False compute_uv=True]" source_file="/tmp/ipykernel_57631/2574565684.py" source_line=5
		XLA Label: fusion
		Shape: f32[50]
		==========================

	Buffer 9:
		Size: 48B
		XLA Label: custom-call
		Shape: (f32[300000,50], f32[50], f32[300000,300000], f32[50,50], s32[], /*index=5*/f32[69038144])
		==========================

	Buffer 10:
		Size: 24B
		XLA Label: tuple
		Shape: (f32[300000,50], f32[50], f32[50,50])
		==========================

	Buffer 11:
		Size: 4B
		XLA Label: custom-call
		Shape: s32[]
		==========================

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Mar 4, 2022

#9760 is now merged, which should have fixed this problem.

1reaction
hawkinspcommented, Mar 3, 2022

No need to close this, there’s certainly something for us to fix here…

Read more comments on GitHub >

github_iconTop Results From Across the Web

Memory efficient implementations of partial Singular Value ...
If I just call the svd routine from the numpy.linalg module in Python for a random matrix of this size, I run into...
Read more >
Applying SVD throws a Memory Error instantaneously?
Yes, the full_matrices parameter to scipy.linalg.svd is important: your input is highly rank-deficient (rank max 3,241), so you don't want ...
Read more >
A Framework for Out of Memory SVD Algorithms - ICL UTK
When the matrix A is too large and does not fit in-memory, our goal is to design efficient algorithms to perform the computation...
Read more >
FameSVD: Fast and Memory-efficient Singular Value ... - DeepAI
Due to the direct connection with PCA, Singular Value Decomposition (SVD) is one of the most well-known algorithms for low-rank ...
Read more >
Distributed Out-of-Memory SVD on CPU/GPU Architectures
Various implementations of SVD have been proposed, but most only estimate the singular values as an estimation of the singular vectors which can ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found