Using remat with host_callback is broken
See original GitHub issueThis throws an error:
import jax
from jax.experimental import host_callback
def f1(x):
x = host_callback.id_tap(print, x)
return x
def f2(x, _):
f1_r = jax.remat(f1)
x = f1_r(x)
return x, 1
def forward(x):
x = jax.lax.scan(f2, x, None, length=2)
return x[0]
jax.jit(jax.grad(forward))(1.2)
...
# /usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py in _outside_call_partial_eval_rule(trace, *args, **params)
985 primals, tangents = util.split_list(args, [nr_primals])
986 c_primals_tapped, _ = util.split_list(consts, [nr_primals])
--> 987 assert all([c is not None for c in c_primals_tapped])
988
989 prims, _ = params["arg_treedef"].unflatten(args)
AssertionError:
It runs fine if the line x = host_callback.id_tap(print, x)
is removed, or if no lax.scan
is used.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (1 by maintainers)
Top Results From Across the Web
Source code for jax.experimental.host_callback
This module introduces the host callback functions :func:`call`, ... host callback functions will be executed for each device in the order in which...
Read more >Ashtanga Prana Yoga in Seoul | ์์ฌํ๊ฐ์๊ฐ ์ ๋ฌธ ์๋ จ์ ...
๋ง์ด์์์์ ์ ์ฌ ; ๊น์ ํ๋์ด ๋ ์ ๋ค ๊ฑด์ฌํ๊ธฐ๋ ํ๋คํ
๋ฐ ์ ์ฌ์์ฌ๋ฅผ ํด์ฃผ์์ต๋๋คใ
.
Read more >์ํ Q&A - ๋๋
ธ๋ผ์ดํธ์ฝ๋ฆฌ์
19544 blonde teen hotty with big orbs gets torn up pornworms por ... ์์ ์ญ์ 13825 kery in remat via sexart hot girlsย ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Actually, the
scan
doesnโt seem necessary to cause the failure, this also works:Sadly, I have progress to report, except that I have reproduced the issue and saw that the error arises in the partial eval rule for the outside_call primitive. I wrote that code but with low confidence, and this issue is proof that there is still something I do not understand.
I can try to dig more into it, but it may be more useful if @mattjj finds some time to take a look also.