question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to `vmap` over a batch of trees or dictionaries?

See original GitHub issue

Issue I want to vmap a function over a batch of trees or dictionaries. However, I am not sure what the appropriate way to pass such a batch is (see the example below).

Related #2367 dealt with passing dictionaries to vmaped functions that do not vmap over dictionary entries. #3161 dealt with vmaping over dictionary entries. This issue is different because it deals with vmaping over a batch of dictionaries.

Background I want to vmap a haiku NN execution over a batch of different parameters. Below is an example that deals with the same issue.

Minimal example Case 0 and 1 work as expected, 2 and 3 do not.

from jax import vmap
import jax.numpy as np

def f(x, params):
  return params['a']*x + params['b']

print('\nCase 0:')
print(f(2, dict(a=3, b=4)))

print('\nCase 1:')
print(vmap(f, in_axes=(0, None))(
    np.array([2, 2]), 
    dict(a=3, b=4)
))

print('\nCase 2:')
try:
  print(vmap(f, in_axes=(0, 0))(
      np.array([2, 2]),
      [dict(a=3, b=4), dict(a=3, b=4)]
  ))
except Exception as e:
  print(e)

print('\nCase 3:')
try:
  print(vmap(f, in_axes=(0, 0))(
      np.array([2, 2]),
      np.array([dict(a=3, b=4), dict(a=3, b=4)])
  ))
except Exception as e:
  print(e)

Output

Case 0:
10

Case 1:
[10 10]

Case 2:
vmap got arg 1 of rank 0 but axis to be mapped 0. The tree of ranks is:
((1, [{'a': 0, 'b': 0}, {'a': 0, 'b': 0}]), {})

Case 3:
JAX only supports number and bool dtypes, got dtype object in array

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
tomhennigancommented, Aug 1, 2021

Hi @cisprague , utilities like jax.tree_map(..) often appear in user code, so I would personally (although I am not a JAX core team member) consider this part of the public API. I suspect the docs discussing how pytrees work (e.g. the treedef etc) are what are intended to be “implementation details” which you should only consider if developing extensions (new tree types).

I’d be happy for this function to exist in a JAX library though, to avoid copy pasting, but I’m not sure Haiku is the right place (this function is not unique to Haiku and would benefit users of other JAX libraries). I’m not aware of any libraries with JAX recipes/snippets, but this would be at home there.

0reactions
cispraguecommented, Aug 1, 2021

@tomhennigan thanks for the example and Colab notebook! This seems like a reasonable solution. Maybe this could be a new end-user feature in Haiku, since pytrees is intended for JAX internal users.

Read more comments on GitHub >

github_iconTop Results From Across the Web

vmap over a list in jax - python - Stack Overflow
jax.vmap will only be mapped over jax array inputs, not inputs that are lists of arrays or tuples. In addition, vmapped functions cannot ......
Read more >
jax.vmap - JAX documentation - Read the Docs
Vectorizing map. Creates a function which maps fun over argument axes. ... An integer, None, or (nested) standard Python container (tuple/list/dict) thereof ...
Read more >
CSE373: Data Structures & Algorithms Lecture 4: Dictionaries
There are many good data structures for (large) dictionaries. 1. AVL trees (Friday's class). – Binary search trees with guaranteed balancing. 2. B-Trees....
Read more >
CS 225 | lab_dict
In C++ there are two “built-in” dictionary types: std::map and ... program so that we can reuse it across multiple calls to the...
Read more >
Iterate through list of dictionaries in Python - GeeksforGeeks
After using indexing to particular dictionaries, now we can treat each item of the list as a dictionary,. Example: Extracting values from a ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found