question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. Itย collects links to all the places you might be looking at while hunting down a tough bug.

And, if youโ€™re still stuck at the end, weโ€™re happy to hop on a call to see how we can help out.

Using remat with host_callback is broken

See original GitHub issue

This throws an error:

import jax
from jax.experimental import host_callback

def f1(x):
  x = host_callback.id_tap(print, x)
  return x

def f2(x, _):
  f1_r = jax.remat(f1)
  x = f1_r(x)
  return x, 1

def forward(x):
  x = jax.lax.scan(f2, x, None, length=2)
  return x[0]

jax.jit(jax.grad(forward))(1.2)
...
# /usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py in _outside_call_partial_eval_rule(trace, *args, **params)
    985   primals, tangents = util.split_list(args, [nr_primals])
    986   c_primals_tapped, _ = util.split_list(consts, [nr_primals])
--> 987   assert all([c is not None for c in c_primals_tapped])
    988 
    989   prims, _ = params["arg_treedef"].unflatten(args)

AssertionError: 

It runs fine if the line x = host_callback.id_tap(print, x) is removed, or if no lax.scan is used.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
sbodensteincommented, Mar 1, 2021

Actually, the scan doesnโ€™t seem necessary to cause the failure, this also works:

@jax.remat
def f(x):
  host_callback.id_tap(print, x)
  return x

jax.grad(f)(2.)
0reactions
gneculacommented, Sep 13, 2021

Sadly, I have progress to report, except that I have reproduced the issue and saw that the error arises in the partial eval rule for the outside_call primitive. I wrote that code but with low confidence, and this issue is proof that there is still something I do not understand.

I can try to dig more into it, but it may be more useful if @mattjj finds some time to take a look also.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Source code for jax.experimental.host_callback
This module introduces the host callback functions :func:`call`, ... host callback functions will be executed for each device in the order in which...
Read more >
Ashtanga Prana Yoga in Seoul | ์•„์‰ฌํƒ•๊ฐ€์š”๊ฐ€ ์ „๋ฌธ ์ˆ˜๋ จ์› ...
๋งˆ์ด์†”์—์„œ์˜ ์ ์‹ฌ ; ๊น€์ •ํ˜„๋‹˜์ด ๋‘ ์• ๋“ค ๊ฑด์‚ฌํ•˜๊ธฐ๋„ ํž˜๋“คํ…๋ฐ ์ ์‹ฌ์‹์‚ฌ๋ฅผ ํ•ด์ฃผ์—ˆ์Šต๋‹ˆ๋‹คใ…Ž.
Read more >
์ƒํ’ˆ Q&A - ๋””๋…ธ๋ผ์ดํŠธ์ฝ”๋ฆฌ์•„
19544 blonde teen hotty with big orbs gets torn up pornworms por ... ์ˆ˜์ • ์‚ญ์ œ 13825 kery in remat via sexart hot girlsย ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found