question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`jnp.linalg.svd` etc. do not respect `__jax_array__`

See original GitHub issue

A couple things going on here. First of all, the following is an example of jnp.linalg.svd failing to respect __jax_array__.

import jax
import jax.numpy as jnp

class MyArray:
    def __jax_array__(self):
        return jnp.array([[1.]])

with jax.disable_jit():
    jnp.linalg.svd(MyArray())

# TypeError: Value '<__main__.MyArray object at 0x7f5ec0f584c0>' with dtype object is not a valid
# JAX array type. Only arrays of numeric types are supported by JAX.

Remove the disable_jit and this works.

The reason it works without disable_jit is that jnp.linalg.svd and friends all have jax.jit wrappers, which is what spots the __jax_array__ and handles things appropriately… unless the JAX arraylike is also a PyTree, in which case they don’t. So this also fails (with a different error message this time):

import jax
import jax.numpy as jnp
from typing import NamedTuple

class MyArray(NamedTuple):
    def __jax_array__(self):
        return jnp.array([[1.]])

jnp.linalg.svd(MyArray())
# ValueError: Argument to singular value decomposition must have ndims >= 2

So whilst it takes either a disable_jit or a PyTree to actually trigger it, I think the fundamental issue here is that jnp.linalg.svd and friends do not check for JAX arraylikes.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:17 (13 by maintainers)

github_iconTop GitHub Comments

2reactions
PhilipVinccommented, Mar 30, 2022

@mattjj In favour of not removing __jax_array__: it allows to write code that is backend-agnostic using NEP47.

0reactions
YouJiachengcommented, Apr 1, 2022

maybe API can follow some priorities, e.g. method-specific-type(e.g. callable for cg) > pytree > __jax_array__. BUT. What about a pytree of (pytree and jax_array)?

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.linalg.svd() - JAX documentation
When a is higher-dimensional, SVD is applied in stacked mode as explained below. Parameters. a ((..., M, N) array_like) – A real or...
Read more >
Working With Python Scipy Linalg Svd
We will learn about “Python Scipy Linalg Svd” to compute the singular value decomposition of the data and how to implement it.
Read more >
CHANGELOG.md · Gitee 极速下载/JAX - Gitee.com
We introduce jax.Array which is a unified array type that subsumes DeviceArray ... The gradients of svd and jax.numpy.linalg.pinv are now computed more ......
Read more >
numpy.linalg.svd — NumPy v1.24 Manual
Singular Value Decomposition. When a is a 2D array, and full_matrices=False , then it is factorized as u @ ...
Read more >
Singular Value Decomposition for Data Visualization
linalg.svd actually returns a Σ that is not a diagonal matrix, but a list of the entries on the diagonal ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found