`jnp.linalg.svd` etc. do not respect `__jax_array__`
See original GitHub issueA couple things going on here. First of all, the following is an example of jnp.linalg.svd
failing to respect __jax_array__
.
import jax
import jax.numpy as jnp
class MyArray:
def __jax_array__(self):
return jnp.array([[1.]])
with jax.disable_jit():
jnp.linalg.svd(MyArray())
# TypeError: Value '<__main__.MyArray object at 0x7f5ec0f584c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.
Remove the disable_jit
and this works.
The reason it works without disable_jit
is that jnp.linalg.svd
and friends all have jax.jit
wrappers, which is what spots the __jax_array__
and handles things appropriately… unless the JAX arraylike is also a PyTree, in which case they don’t. So this also fails (with a different error message this time):
import jax
import jax.numpy as jnp
from typing import NamedTuple
class MyArray(NamedTuple):
def __jax_array__(self):
return jnp.array([[1.]])
jnp.linalg.svd(MyArray())
# ValueError: Argument to singular value decomposition must have ndims >= 2
So whilst it takes either a disable_jit
or a PyTree to actually trigger it, I think the fundamental issue here is that jnp.linalg.svd
and friends do not check for JAX arraylikes.
Issue Analytics
- State:
- Created a year ago
- Comments:17 (13 by maintainers)
Top Results From Across the Web
numpy.linalg.svd() - JAX documentation
When a is higher-dimensional, SVD is applied in stacked mode as explained below. Parameters. a ((..., M, N) array_like) – A real or...
Read more >Working With Python Scipy Linalg Svd
We will learn about “Python Scipy Linalg Svd” to compute the singular value decomposition of the data and how to implement it.
Read more >CHANGELOG.md · Gitee 极速下载/JAX - Gitee.com
We introduce jax.Array which is a unified array type that subsumes DeviceArray ... The gradients of svd and jax.numpy.linalg.pinv are now computed more ......
Read more >numpy.linalg.svd — NumPy v1.24 Manual
Singular Value Decomposition. When a is a 2D array, and full_matrices=False , then it is factorized as u @ ...
Read more >Singular Value Decomposition for Data Visualization
linalg.svd actually returns a Σ that is not a diagonal matrix, but a list of the entries on the diagonal ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@mattjj In favour of not removing
__jax_array__
: it allows to write code that is backend-agnostic using NEP47.maybe API can follow some priorities, e.g. method-specific-type(e.g.
callable
forcg
) >pytree
>__jax_array__
. BUT. What about a pytree of (pytree and jax_array)?