How to `vmap` over a batch of trees or dictionaries?
See original GitHub issueIssue
I want to vmap
a function over a batch of trees or dictionaries.
However, I am not sure what the appropriate way to pass such a batch is (see the example below).
Related
#2367 dealt with passing dictionaries to vmap
ed functions that do not vmap
over dictionary entries.
#3161 dealt with vmap
ing over dictionary entries.
This issue is different because it deals with vmap
ing over a batch of dictionaries.
Background
I want to vmap
a haiku
NN execution over a batch of different parameters.
Below is an example that deals with the same issue.
Minimal example Case 0 and 1 work as expected, 2 and 3 do not.
from jax import vmap
import jax.numpy as np
def f(x, params):
return params['a']*x + params['b']
print('\nCase 0:')
print(f(2, dict(a=3, b=4)))
print('\nCase 1:')
print(vmap(f, in_axes=(0, None))(
np.array([2, 2]),
dict(a=3, b=4)
))
print('\nCase 2:')
try:
print(vmap(f, in_axes=(0, 0))(
np.array([2, 2]),
[dict(a=3, b=4), dict(a=3, b=4)]
))
except Exception as e:
print(e)
print('\nCase 3:')
try:
print(vmap(f, in_axes=(0, 0))(
np.array([2, 2]),
np.array([dict(a=3, b=4), dict(a=3, b=4)])
))
except Exception as e:
print(e)
Output
Case 0:
10
Case 1:
[10 10]
Case 2:
vmap got arg 1 of rank 0 but axis to be mapped 0. The tree of ranks is:
((1, [{'a': 0, 'b': 0}, {'a': 0, 'b': 0}]), {})
Case 3:
JAX only supports number and bool dtypes, got dtype object in array
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (2 by maintainers)
Top Results From Across the Web
vmap over a list in jax - python - Stack Overflow
jax.vmap will only be mapped over jax array inputs, not inputs that are lists of arrays or tuples. In addition, vmapped functions cannot ......
Read more >jax.vmap - JAX documentation - Read the Docs
Vectorizing map. Creates a function which maps fun over argument axes. ... An integer, None, or (nested) standard Python container (tuple/list/dict) thereof ...
Read more >CSE373: Data Structures & Algorithms Lecture 4: Dictionaries
There are many good data structures for (large) dictionaries. 1. AVL trees (Friday's class). – Binary search trees with guaranteed balancing. 2. B-Trees....
Read more >CS 225 | lab_dict
In C++ there are two “built-in” dictionary types: std::map and ... program so that we can reuse it across multiple calls to the...
Read more >Iterate through list of dictionaries in Python - GeeksforGeeks
After using indexing to particular dictionaries, now we can treat each item of the list as a dictionary,. Example: Extracting values from a ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi @cisprague , utilities like
jax.tree_map(..)
often appear in user code, so I would personally (although I am not a JAX core team member) consider this part of the public API. I suspect the docs discussing how pytrees work (e.g. the treedef etc) are what are intended to be “implementation details” which you should only consider if developing extensions (new tree types).I’d be happy for this function to exist in a JAX library though, to avoid copy pasting, but I’m not sure Haiku is the right place (this function is not unique to Haiku and would benefit users of other JAX libraries). I’m not aware of any libraries with JAX recipes/snippets, but this would be at home there.
@tomhennigan thanks for the example and Colab notebook! This seems like a reasonable solution. Maybe this could be a new end-user feature in Haiku, since pytrees is intended for JAX internal users.