`vmap` doesn't handle non-JAX-types in non-vmap'd outputs
See original GitHub issueNot sure if this counts as a bug or a feature request.
import jax
def make1(x):
return [x, 0]
def make2(x):
return [x, object()]
jax.vmap(make1, out_axes=[0, None])(jax.numpy.zeros(2))
# [DeviceArray([0., 0.], dtype=float32), 0]
jax.vmap(make2, out_axes=[0, None])(jax.numpy.zeros(2))
# TypeError: Output from batched function <object object at 0x7f372dbf9860> with type
# <class 'object'> is not a valid JAX type
If we’re not vmap’ing the value anyway, it’d be nice if non-JAX-types can be valid outputs from vmap’d functions.
In contrast the in-axis equivalent does work:
jax.vmap(lambda x, y: x, in_axes=(0, None))(jax.numpy.zeros(2), object())
# DeviceArray([0., 0.], dtype=float32)
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (6 by maintainers)
Top Results From Across the Web
GN Reference - Google Git
These are output files that were generated from previous builds, but the current build graph no longer references them. This command requires a...
Read more >Soong Build System | Android Open Source Project
Appending a map produces the union of keys in both maps, appending the values of any keys that are present in both maps....
Read more >CMake Object Libraries Don't Map Reliably Into Xcode (#17500)
The object files generated seem to only be artifacts for the parent to incorporate, and not meant to be part of the final...
Read more >Documentation - A lightweight Java and Kotlin web framework
Javalin has three main handler types: before-handlers, endpoint-handlers, ... get an attribute on the request attributeMap() // map of all attributes on the ......
Read more >Atomos Ninja-2
Do not insert anything but Ninja-2 Master Disk caddies in the Master Disk ... information about how to handle and care for your...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think this is by design: functions passed to vmap are traced, and therefore their outputs are required to be valid JAX types, which can be a normal numeric scalar or array, or a custom type that you have registered.
The
in_args
version works because the tracing in vmap is only with respect to batched arguments; note however that if you wrap the function injit
(which traces with respect to all arguments) you’ll get the same error there as well.That makes a lot of sense. Thanks for the response!