Improved error messages and documentation regarding what can/cannot be `jit`ed
See original GitHub issuetl;dr
Better error messages and documentation for what can and cannot be jit
ed would be great. Current behavior is “black box.” See https://github.com/google/jax/issues/953#issuecomment-507119726.
See also:
Original issue
I have a particular function that I believe should be jit
-able but I’m getting errors hitting it with jit
. I have a snipped of code that looks like this:
from jax.experimental import optimizers
OptState = TypeVar("OptState")
class Optimizer(NamedTuple):
init: Callable[[Any], OptState]
update: Callable[[int, Any, OptState], OptState]
get: Callable[[OptState], Any]
def ddpg_episode(
optimizer: Optimizer,
...
) -> LoopState:
...
optimizer = Optimizer(*optimizers.adam(step_size=1e-3))
jit(ddpg_episode)(
optimizer,
...
)
but I get an error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in run_module(mod_name, init_globals, run_name, alter_sys)
203 run_name = mod_name
204 if alter_sys:
--> 205 return _run_module_code(code, init_globals, run_name, mod_spec)
206 else:
207 # Leave the sys module alone
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_module_code(code, init_globals, mod_name, mod_spec, pkg_name, script_name)
94 mod_globals = temp_module.module.__dict__
95 _run_code(code, mod_globals, init_globals,
---> 96 mod_name, mod_spec, pkg_name, script_name)
97 # Copy the globals of the temporary module, as they
98 # may be cleared when the temporary module goes away
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_code(code, run_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
83 __package__ = pkg_name,
84 __spec__ = mod_spec)
---> 85 exec(code, run_globals)
86 return run_globals
87
~/dev/research/research/estop/ddpg_pendulum.py in <module>()
73 critic,
74 episode_length,
---> 75 noise,
76 )
77 print(f"Episode {epsiode}, reward = {reward}")
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
119 f, dyn_args = _argnums_partial(f, dyn_argnums, args)
120 args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 121 _check_args(args_flat)
122 flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
123 out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in _check_args(args)
944 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
945 raise TypeError("Argument '{}' of type {} is not a valid JAX type"
--> 946 .format(arg, type(arg)))
947
948 def _valid_jaxtype(arg):
TypeError: Argument '<function adam.<locals>.init at 0x11b176c80>' of type <class 'function'> is not a valid JAX type
It’s not clear to me what this error message is trying to communicate. What exactly is a valid JAX type? And why is this particular function rejecting while there are plenty of jit
examples that include functions as arguments?
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (6 by maintainers)
Top Results From Across the Web
Error Prone Improves Java Code by Detecting Common ...
Error Prone, a Java compiler plugin open sourced by Google, performs static analysis during compilation to detect bugs or possible ...
Read more >1. Building a JIT: Starting out with KaleidoscopeJIT - LLVM
Chapter #4: Improve the laziness of our JIT by replacing the Compile-On-Demand layer with a custom layer that uses the ORC Compile Callbacks...
Read more >Troubleshooting and tips — Numba 0.50.1 documentation
There can be various reasons why Numba cannot compile your code, and raises an error instead. One common reason is that your code...
Read more >What's New In Python 3.12 — Python 3.12.0a3 documentation
Improve the SyntaxError error message when the user types import x from y instead of from ... Add Python support for the Linux...
Read more >Error Messages - VMware Docs
Administrators or end users may see errors related to Just-in-Time ... If JIT User provisioning is enabled, at least one directory must be ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
+1 the suggestion to clearly define “valid JAX type” using that phrase to make it emerge in search results. I currently cannot find such a definition when Googling “valid JAX type”.
I did find this documentation, which appears helpful with errors related to using invalid JAX types as input.
Thanks for raising this and the detailed notes.
Regarding the latter error message, have you already read the How it works and What’s supported sections of the readme, and the Gotchas notebook, especially the Control Flow section? If so, it’d be useful to think through together how they could be improved (i.e. what they’re missing), and if not it’d be useful to figure out how to make them more discoverable!
I remember also we wrote a bit more in a comment on #196, which I thought we also linked from the readme but apparently we don’t.
Regarding the former error, actually functions can’t be arguments to
jit
-ed functions. As it says in thejit
docstring, “Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.” Other than adding more information to the “valid jaxtype” error message, do you have suggestions for how to improve the documentation?(Valid jaxtypes are in practice numpy.ndarray and our effective subclasses. There’s also a tuple type but user code doesn’t use it. The functions in api.py can accept pytrees of jaxtypes, which in practice effectively means tuple/list/dict-trees of arrays.)