jax.lax.custom_linear_solve / jax.lax.custom_root doesn't respect jax.disable_jit
See original GitHub issueAfter #9938 The document says:
For debugging it is useful to have a mechanism that disables jit() everywhere in a dynamic context. Note that this not only disables explicit uses of jit by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of body and cond functions passed to higher-level primitives like scan() and while_loop(), JIT used in implementations of jax.numpy functions, and any other case where jit is used within an API’s implementation.
However, when I worked for https://github.com/google/jax/issues/9714#issuecomment-1077319705, I found that there are jaxprs generated inside with jax.disable_jit():
.
I checked the source of jax.lax.custom_linear_solve
, and found it doesn’t check config.jax_disable_jit
.
jax.lax.custom_root
has the same problem too.
Issue Analytics
- State:
- Created a year ago
- Comments:9 (9 by maintainers)
@nicholasjng Contributions here would be very welcome! I’m not 100% sure it will be possible to remove closure conversion
custom_linear_solve
because we currently need implement it with a JAX primitive (until https://github.com/google/jax/issues/9129 implementingcustom_transpose
is finished, CC @froystig). It might be worth holding off any API changes until we can switch bothcustom_linear_solve
andcustom_root
at once, but in any case starting to investigate how removing closure conversion would work would certainly be valuable.Just wanted to chime in that I’m happy to take a look at these, I’ve had some fun with adding auxiliary arguments to
custom_{root,linear_solve}
in the past and would be happy to continue.