xla in pmap fails (i.e. jit-of-pmap or lax.scan with collectives)
See original GitHub issueThe parallel xla interpreter currently doesn’t properly support nested jit compilation. A practical example of this issue is when trying to use psum from within scan:
pmap( partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.), axis_name="i" )(np.ones((8, 4)))
Scan need to compiles the body of the loop using xla which fails because psum is only defined in the context of the pxla interpreter.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top Results From Across the Web
pmap inside scan triggers ragged assertion. · Issue #2018
Works fine with vmap . The text was updated successfully, but these errors were encountered: ...
Read more >Issue with jax.lax.scan - python
I am unsure how to fix the jax.lax.scan. The error that keeps popping up is missing the required XS. When I put a...
Read more >jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... constructs in an jit() function are unrolled, leading to large XLA...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Fixed the issue @fehiepsi raised in #832.
Thanks for raising this.
I tweaked the issue title for a canonical “jit-of-pmap” name. This is actually a case of jit-of-pmap because
lax.scan
does something likejit
in its implementation, which is why it doesn’t know about parallel collectives.jit
can be thought of as a special case ofpmap
(mapping over a non-existant singleton axis), and if we implemented it that way then this would all work automatically. But at the moment because thejit
implementation predatespmap
, it needs to learn about how to handle parallel primitives likepsum
(and includingpmap
itself).