Consider raising when a jitted function's output is not PyTree-like
See original GitHub issueJust as https://github.com/google/jax/issues/2813 argues that inputs to jit
should always be hashable since that prevents unnecessary recompilation, another common error with jit
is when its output is not PyTree-like (either a PyTree or registered using register_pytree_node
). This has caught me a few times with tracers leaking out of functions. It makes me apprehensive about decorating functions with jit
. Is there any use case for non-PyTree-like objects being returned? If not, I propose checking that the return value is a PyTree and raising if not.
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
How can I redefine a subfunction of a JAX-jitted function?
The reason this is not working as expected is because the function is not pure: that is, its output depends not just on...
Read more >Compiling Python code with @jit - Numba
Of course, the compiled function gives the expected results: ... To prevent Numba from falling back, and instead raise an error, pass nopython=True...
Read more >5. Functions — Beginning Python Programming for Aspiring ...
The idea behind this diagram is that a function is like a machine that takes an input, x , and transforms it into...
Read more >torch.jit.script — PyTorch 1.13 documentation
Scripting a function or nn.Module will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a ScriptModule...
Read more >Jit recompile with new arguments - Numba Discussion
There is a recompile() function available on jitted functions. It would be great if it allowed the user to change the original parameters ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This is interesting. If I’m following correctly, we already do the same check for function arguments (we throw a “not a JaxType” error on unknown PyTree leaves). It seems like it would be sensible to also error on unknown PyTree leaves in return values, but there also might be valid uses for this (basically returning “compile-time” Python data).
I think I understand. The following code
will return an instance of
MyClass
with the tracer inside. No code is jitted, because JAX does not see the returned Tracers. This is a confusing.I think that James is thinking about an example like:
Maybe for such examples we can use a
has_aux
parameter to the jit to declare the the function being jitted as an intentional auxiliary result (a la grad)