question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`vmap` doesn't handle non-JAX-types in non-vmap'd outputs

See original GitHub issue

Not sure if this counts as a bug or a feature request.

import jax

def make1(x):
    return [x, 0]

def make2(x):
    return [x, object()]

jax.vmap(make1, out_axes=[0, None])(jax.numpy.zeros(2))
# [DeviceArray([0., 0.], dtype=float32), 0]
jax.vmap(make2, out_axes=[0, None])(jax.numpy.zeros(2))
# TypeError: Output from batched function <object object at 0x7f372dbf9860> with type
# <class 'object'> is not a valid JAX type

If we’re not vmap’ing the value anyway, it’d be nice if non-JAX-types can be valid outputs from vmap’d functions.

In contrast the in-axis equivalent does work:

jax.vmap(lambda x, y: x, in_axes=(0, None))(jax.numpy.zeros(2), object())
# DeviceArray([0., 0.], dtype=float32)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:9 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Aug 12, 2021

I think this is by design: functions passed to vmap are traced, and therefore their outputs are required to be valid JAX types, which can be a normal numeric scalar or array, or a custom type that you have registered.

The in_args version works because the tracing in vmap is only with respect to batched arguments; note however that if you wrap the function in jit (which traces with respect to all arguments) you’ll get the same error there as well.

0reactions
patrick-kidgercommented, Aug 13, 2021

That makes a lot of sense. Thanks for the response!

Read more comments on GitHub >

github_iconTop Results From Across the Web

GN Reference - Google Git
These are output files that were generated from previous builds, but the current build graph no longer references them. This command requires a...
Read more >
Soong Build System | Android Open Source Project
Appending a map produces the union of keys in both maps, appending the values of any keys that are present in both maps....
Read more >
CMake Object Libraries Don't Map Reliably Into Xcode (#17500)
The object files generated seem to only be artifacts for the parent to incorporate, and not meant to be part of the final...
Read more >
Documentation - A lightweight Java and Kotlin web framework
Javalin has three main handler types: before-handlers, endpoint-handlers, ... get an attribute on the request attributeMap() // map of all attributes on the ......
Read more >
Atomos Ninja-2
Do not insert anything but Ninja-2 Master Disk caddies in the Master Disk ... information about how to handle and care for your...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found