question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Consider raising when a jitted function's output is not PyTree-like

See original GitHub issue

Just as https://github.com/google/jax/issues/2813 argues that inputs to jit should always be hashable since that prevents unnecessary recompilation, another common error with jit is when its output is not PyTree-like (either a PyTree or registered using register_pytree_node). This has caught me a few times with tracers leaking out of functions. It makes me apprehensive about decorating functions with jit. Is there any use case for non-PyTree-like objects being returned? If not, I propose checking that the return value is a PyTree and raising if not.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
jekbradburycommented, Apr 28, 2020

This is interesting. If I’m following correctly, we already do the same check for function arguments (we throw a “not a JaxType” error on unknown PyTree leaves). It seems like it would be sensible to also error on unknown PyTree leaves in return values, but there also might be valid uses for this (basically returning “compile-time” Python data).

0reactions
gneculacommented, Apr 28, 2020

I think I understand. The following code

class MyClass:
  def __init__(self, x):
     self.x = x

jax.jit(lambda x: MyClass(x + 2))(0)

will return an instance of MyClass with the tracer inside. No code is jitted, because JAX does not see the returned Tracers. This is a confusing.

I think that James is thinking about an example like:

jax.jit(lambda x: x + 5, "hello")(0)

Maybe for such examples we can use a has_aux parameter to the jit to declare the the function being jitted as an intentional auxiliary result (a la grad)

Read more comments on GitHub >

github_iconTop Results From Across the Web

How can I redefine a subfunction of a JAX-jitted function?
The reason this is not working as expected is because the function is not pure: that is, its output depends not just on...
Read more >
Compiling Python code with @jit - Numba
Of course, the compiled function gives the expected results: ... To prevent Numba from falling back, and instead raise an error, pass nopython=True...
Read more >
5. Functions — Beginning Python Programming for Aspiring ...
The idea behind this diagram is that a function is like a machine that takes an input, x , and transforms it into...
Read more >
torch.jit.script — PyTorch 1.13 documentation
Scripting a function or nn.Module will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a ScriptModule...
Read more >
Jit recompile with new arguments - Numba Discussion
There is a recompile() function available on jitted functions. It would be great if it allowed the user to change the original parameters ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found