question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

xla in pmap fails (i.e. jit-of-pmap or lax.scan with collectives)

See original GitHub issue

The parallel xla interpreter currently doesn’t properly support nested jit compilation. A practical example of this issue is when trying to use psum from within scan:

pmap( partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.), axis_name="i" )(np.ones((8, 4)))

Scan need to compiles the body of the loop using xla which fails because psum is only defined in the context of the pxla interpreter.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Jun 9, 2019

Fixed the issue @fehiepsi raised in #832.

1reaction
mattjjcommented, Jun 8, 2019

Thanks for raising this.

I tweaked the issue title for a canonical “jit-of-pmap” name. This is actually a case of jit-of-pmap because lax.scan does something like jit in its implementation, which is why it doesn’t know about parallel collectives.

jit can be thought of as a special case of pmap (mapping over a non-existant singleton axis), and if we implemented it that way then this would all work automatically. But at the moment because the jit implementation predates pmap, it needs to learn about how to handle parallel primitives like psum (and including pmap itself).

Read more comments on GitHub >

github_iconTop Results From Across the Web

pmap inside scan triggers ragged assertion. · Issue #2018
Works fine with vmap . The text was updated successfully, but these errors were encountered: ...
Read more >
Issue with jax.lax.scan - python
I am unsure how to fix the jax.lax.scan. The error that keeps popping up is missing the required XS. When I put a...
Read more >
jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... constructs in an jit() function are unrolled, leading to large XLA...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found