pmap complains about shapes even when given bad types
See original GitHub issueHere’s an example of an error message which seems to obscure the most important and actionable issue:
import jax
import jax.numpy as jnp
class A:
@jax.pmap
def f(self, x):
return x * 2
a = A()
y = a.f(jnp.array([1])) # prints 'tracing!'
ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
It’d be better here to mention that A
is not a mappable type!
Issue Analytics
- State:
- Created 2 years ago
- Reactions:5
- Comments:9 (1 by maintainers)
Top Results From Across the Web
How to Handle the Incompatible Types Error in Java - Rollbar
The Java incompatible types error happens when a value assigned to a variable or returned by a method is incompatible with the one...
Read more >Tracing the Central Idea in "A Quilt of a Country" - Quizlet
Pride seems excessive, given the American willingness to endlessly complain about them, them being whoever is new, different, unknown or currently under ...
Read more >Map projections and distortion - Hunter Geography
Even on a conformal map, shapes are a bit distorted for very large areas, like continents. A conformal map distorts area—most features are...
Read more >Experts identify the worst examples of gerrymandering
Lamar: Federal courts held that the long, snake-like 12th District was an unconstitutional racial gerrymander because it packed too many black ...
Read more >Chapter 3 - Things Fall Apart - Cliffs Notes
According to the first story from Okonkwo's past, his father, Unoka, consulted the Oracle of the Hills and Caves, asking why he had...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I just want to call out that regardless of what @dwromero 's actual issue is, the fact that the error message provides no useful hints about what the issue is is exactly why this issue was filed and why it should be fixed.
For those who that might help, my problem was that I forgot to replicate the state across devices. You can do that with
state = flax.jax_utils.replicate(state)
.