question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.lax.custom_linear_solve / jax.lax.custom_root doesn't respect jax.disable_jit

See original GitHub issue

After #9938 The document says:

For debugging it is useful to have a mechanism that disables jit() everywhere in a dynamic context. Note that this not only disables explicit uses of jit by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of body and cond functions passed to higher-level primitives like scan() and while_loop(), JIT used in implementations of jax.numpy functions, and any other case where jit is used within an API’s implementation.

However, when I worked for https://github.com/google/jax/issues/9714#issuecomment-1077319705, I found that there are jaxprs generated inside with jax.disable_jit():. I checked the source of jax.lax.custom_linear_solve, and found it doesn’t check config.jax_disable_jit. jax.lax.custom_root has the same problem too.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

2reactions
shoyercommented, Mar 27, 2022

@nicholasjng Contributions here would be very welcome! I’m not 100% sure it will be possible to remove closure conversion custom_linear_solve because we currently need implement it with a JAX primitive (until https://github.com/google/jax/issues/9129 implementing custom_transpose is finished, CC @froystig). It might be worth holding off any API changes until we can switch both custom_linear_solve and custom_root at once, but in any case starting to investigate how removing closure conversion would work would certainly be valuable.

0reactions
nicholasjngcommented, Mar 26, 2022

Just wanted to chime in that I’m happy to take a look at these, I’ve had some fun with adding auxiliary arguments to custom_{root,linear_solve} in the past and would be happy to continue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.custom_root - JAX documentation
Differentiably solve for a roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root()...
Read more >
How to use the jax.lax function in jax - Snyk
Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues...
Read more >
Accumulation in JAX - python - Stack Overflow
My impression was that jax.lax.scan is the only option for accumulation with complicated nonlinear (custom) operators. Is there a built-in ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found