core.call_bind aggressively raises args to top trace
See original GitHub issuedef f(a_bool, y):
if a_bool:
return y + 1
else:
return y
jax.jit(jax.remat(f), static_argnums=0)(True, 1)
Results in:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(bool[], weak_type=True):JaxprTrace(level=-1/2)>
I think this arises from the full_raise occurring here when processing remat_call_p
- which raise to the JaxprTrace when we told JIT we don’t want to!
https://github.com/google/jax/blob/77901e9fa71f5b23066c70132a983ae57f655b39/jax/core.py#L1001
This also applies to user-defined call primitives using core.call_bind
, resulting in unnecessary workarounds like this one in Haiku:
https://github.com/deepmind/dm-haiku/blob/49b21f7192dfdb3dc0a49cc097c8d3b0ccabb107/haiku/_src/named_call.py#L101-L109
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (8 by maintainers)
Top Results From Across the Web
Treasury Offset Program - Bureau of the Fiscal Service
The Treasury Offset Program (TOP) collects past-due (delinquent) debts (for example, child support payments) that people owe to state and federal agencies.
Read more >ceph/global.yaml.in at main · ceph/ceph - GitHub
Ceph is a distributed object, block, and file storage platform - ceph/global.yaml.in at main · ceph/ceph.
Read more >Ixia User Guide - AWS
Once you have created a Recreate test, you can modify any of the Recreate parameters if the General. Behavior parameter to Use User-specified...
Read more >Onload User Guide - Xilinx
NUMA node to avoid unnecessary increases in QPI traffic and to avoid dropped packets. Useful commands. •. To identify NUMA nodes, socket memory...
Read more >demosys-py Documentation - Read the Docs
The cube directory is a template for an effect: - The standard effect.py module containing a single Effect implementation - A local shaders...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
gotcha, remat was just a hasty choice on my part to pick a jax-native call primitive that didn’t seem to raise abstraction level; the main motivation is the linked Haiku
named_call_p
(and my other experimental call primitives).😃 Thanks for catching this! Your guess was quite good and that’s what let us solve this fast.