sp.linalg.solve SEGFAULT when matrix total size is > int32.max
See original GitHub issueThe below code works for a matrix of size n=46340
:
import time
import jax.numpy as np
import jax.random as random
import jax.scipy.linalg as linalg
def get_data(n, c):
A = random.normal(random.PRNGKey(1), (n, n)).astype(np.float32)
A = A @ A.T + np.identity(n, np.float32)
b = random.normal(random.PRNGKey(1), (n, c)).astype(np.float32)
return A, b
def solve(A, b):
x = linalg.solve(A,
b,
sym_pos=True,
overwrite_a=True,
overwrite_b=True,
check_finite=False)
return x
A, b = get_data(n=46340, c=10)
start = time.time()
x = solve(A, b)
x.block_until_ready()
print('shapes = %s, %s,' % (str(A.shape), str(b.shape)),
'time = %s,' % (time.time() - start),
'max error = %s.' % np.max(np.abs(A @ x - b)))
But not for n=46341
:
/Users/romann/.conda/envs/conda/bin/python /Users/romann/PycharmProjects/untitled2/jax_main.py
/Users/romann/.conda/envs/conda/lib/python3.7/site-packages/jax/lib/xla_bridge.py:114: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)
Numpy version works fine for n=50000
. Note that the square root of 2147483647
is 46340.95
.
It would be very useful to have it work for larger sizes since the training set size for many classic datasets is >50,000, and kernel methods require linalg.solve
called on these matrices.
Thanks!
Issue Analytics
- State:
- Created 4 years ago
- Reactions:3
- Comments:5 (5 by maintainers)
Top Results From Across the Web
Segmentation fault when two matrices size are over 800*800
I'm trying to write a super simple C program of the vector multiply-add "axpy" algorithm for integer data types. The program output the ......
Read more >SciPy 0.19.0 Release Notes — SciPy v1.9.3 Manual
The function scipy.linalg.solve obtained two more keywords assume_a and transposed . The underlying LAPACK routines are replaced with “expert” versions and ...
Read more >Source code for pyunicorn.core.network
Hamming distance is only defined for networks with an equal number of nodes. :rtype: float between 0 and 1 """ # Get own...
Read more >SciPy 0.15.0 Release Notes — SciPy v1.10.0.dev0+2088 ...
The function scipy.signal.max_len_seq was added, which computes a Maximum Length Sequence (MLS) signal. scipy.integrate improvements#.
Read more >Python Scientific lecture notes
function quicksort(array) var list less, greater if length(array) < 2 ... directly provides matrices full of indices for cases where we can't (or...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@romanngg I wrote a full description of this potential optimization in https://github.com/google/jax/issues/1747 (with code that should work in simpler cases)
The segfault that was the subject of this issue is long fixed at this point.