question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Improved error messages and documentation regarding what can/cannot be `jit`ed

See original GitHub issue

tl;dr

Better error messages and documentation for what can and cannot be jited would be great. Current behavior is “black box.” See https://github.com/google/jax/issues/953#issuecomment-507119726.

See also:

Original issue

I have a particular function that I believe should be jit-able but I’m getting errors hitting it with jit. I have a snipped of code that looks like this:

from jax.experimental import optimizers

OptState = TypeVar("OptState")

class Optimizer(NamedTuple):
  init: Callable[[Any], OptState]
  update: Callable[[int, Any, OptState], OptState]
  get: Callable[[OptState], Any]

def ddpg_episode(
    optimizer: Optimizer,
    ...
) -> LoopState:
  ...

optimizer = Optimizer(*optimizers.adam(step_size=1e-3))
jit(ddpg_episode)(
    optimizer,
    ...
)

but I get an error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in run_module(mod_name, init_globals, run_name, alter_sys)
    203         run_name = mod_name
    204     if alter_sys:
--> 205         return _run_module_code(code, init_globals, run_name, mod_spec)
    206     else:
    207         # Leave the sys module alone

/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_module_code(code, init_globals, mod_name, mod_spec, pkg_name, script_name)
     94         mod_globals = temp_module.module.__dict__
     95         _run_code(code, mod_globals, init_globals,
---> 96                   mod_name, mod_spec, pkg_name, script_name)
     97     # Copy the globals of the temporary module, as they
     98     # may be cleared when the temporary module goes away

/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_code(code, run_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
     83                        __package__ = pkg_name,
     84                        __spec__ = mod_spec)
---> 85     exec(code, run_globals)
     86     return run_globals
     87 

~/dev/research/research/estop/ddpg_pendulum.py in <module>()
     73       critic,
     74       episode_length,
---> 75       noise,
     76   )
     77   print(f"Episode {epsiode}, reward = {reward}")

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    119     f, dyn_args = _argnums_partial(f, dyn_argnums, args)
    120     args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
    123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in _check_args(args)
    944     if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
    945       raise TypeError("Argument '{}' of type {} is not a valid JAX type"
--> 946                       .format(arg, type(arg)))
    947 
    948 def _valid_jaxtype(arg):

TypeError: Argument '<function adam.<locals>.init at 0x11b176c80>' of type <class 'function'> is not a valid JAX type

It’s not clear to me what this error message is trying to communicate. What exactly is a valid JAX type? And why is this particular function rejecting while there are plenty of jit examples that include functions as arguments?

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:7 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
wbradknoxcommented, Aug 17, 2022

+1 the suggestion to clearly define “valid JAX type” using that phrase to make it emerge in search results. I currently cannot find such a definition when Googling “valid JAX type”.

I did find this documentation, which appears helpful with errors related to using invalid JAX types as input.

1reaction
mattjjcommented, Jul 4, 2019

Thanks for raising this and the detailed notes.

Regarding the latter error message, have you already read the How it works and What’s supported sections of the readme, and the Gotchas notebook, especially the Control Flow section? If so, it’d be useful to think through together how they could be improved (i.e. what they’re missing), and if not it’d be useful to figure out how to make them more discoverable!

I remember also we wrote a bit more in a comment on #196, which I thought we also linked from the readme but apparently we don’t.

Regarding the former error, actually functions can’t be arguments to jit-ed functions. As it says in the jit docstring, “Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.” Other than adding more information to the “valid jaxtype” error message, do you have suggestions for how to improve the documentation?

(Valid jaxtypes are in practice numpy.ndarray and our effective subclasses. There’s also a tuple type but user code doesn’t use it. The functions in api.py can accept pytrees of jaxtypes, which in practice effectively means tuple/list/dict-trees of arrays.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Error Prone Improves Java Code by Detecting Common ...
Error Prone, a Java compiler plugin open sourced by Google, performs static analysis during compilation to detect bugs or possible ...
Read more >
1. Building a JIT: Starting out with KaleidoscopeJIT - LLVM
Chapter #4: Improve the laziness of our JIT by replacing the Compile-On-Demand layer with a custom layer that uses the ORC Compile Callbacks...
Read more >
Troubleshooting and tips — Numba 0.50.1 documentation
There can be various reasons why Numba cannot compile your code, and raises an error instead. One common reason is that your code...
Read more >
What's New In Python 3.12 — Python 3.12.0a3 documentation
Improve the SyntaxError error message when the user types import x from y instead of from ... Add Python support for the Linux...
Read more >
Error Messages - VMware Docs
Administrators or end users may see errors related to Just-in-Time ... If JIT User provisioning is enabled, at least one directory must be ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found