question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

BCOO matrix constructor fails within a `jax.lax.fori_loop`

See original GitHub issue

I’m implementing a power iteration routine in jax, using sparse BCOO matrices.

I get the following error when using fori_loop:

ValueError: Invalid sparse representation: got indices.shape=(), data.shape=(), sparse_shape=(10,)

It seems like the error stems from

batched_M = jsparse.BCOO.fromdense(M.todense(), n_batch=1)

which I need to perform a sparse mat-vec multiply (as a batched vec-vec dot) for now (until #4710 is resolved).

In my routine, I fill the power iteration matrix incrementally, following the sparsity pattern of the vector. In this code, it is filled from a known dense matrix, for simplicity.

Minimal reproduction code: https://colab.research.google.com/drive/11hTyCoX30e054MI5zKV0lPaE9666YmUW?usp=sharing

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:35

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Sep 22, 2021

Yeah, that looks good. The larger question here… can we figure out a good way to provide the sum_with_nse type functionality more easily? So far I’ve left the dedupe functionality private and undocumented, but it’s obviously useful in some cases.

1reaction
jakevdpcommented, Sep 22, 2021

It’s not the sum per se, it comes from the fact that the function passed to fori_loop much have outputs that match the inputs. Sparsity aside, here’s an example of a success and a failure:

import jax.numpy as jnp
from jax import lax

x = jnp.arange(5)

def f_good(i, x):
  return x

lax.fori_loop(0, 5, f_good, x)  # succeeds

def f_bad(i, x):
  return jnp.append(x, 1)

lax.fori_loop(0, 5, f_bad, x)  # fails

Because sparse matrices are represented by dense arrays of shape nse, the nse of the inputs and outputs must match. If you use a function that changes the nse, such as addition of two matrices with non-shared sparsity patterns, your loop will fail. If you use a function that does not change nse (or other matrix attributes that affect the shapes & dtypes of the representation), then your loop will succeed. Your power function above succeeds for this reason.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.fori_loop - JAX documentation - Read the Docs
If the trip count is static (meaning known at tracing time, perhaps because lower and upper are Python integer literals) then the fori_loop...
Read more >
Getting a type error while using fori_loop with JAX
I didn't want the code to be unrolled so I used fori_loop, but I'm getting an error and can't figure out what I...
Read more >
Making a for-loop more efficient - numpyro
Hi Everyone, While sampling with NUTS, I'm solving a system of differential algebraic equations using jax's odeint().
Read more >
Untitled
Seeler industries careers, Zombie road killer game, Assistir ncis los angeles 5x21, Paragraph indent! Machno stud, Rodriguez-arango! Find os version in suse ...
Read more >
See raw diff - Hugging Face
... +##dition +##AR +switch +##uce +existing +sav +constructor +##Data +apply ... +private +fails +##ramework +flag +Int +receive +##ctionary +Array +##fox ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found