`_bias_correction` forces recompilation of `integer_pow`
See original GitHub issueThis might be related to #197, but I don’t have enough knowledge to confirm it.
In short, while using optax
with flax.linen
, optimizer.update
forces jit-recompilation of the integer_pow
jax primitive.
I get these results on a GPU backend.
from functools import partial
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from jax.config import config
import os
config.update("jax_log_compiles", 1)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
def biascorrection_recompilation_test():
class M(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(1)(x)
return x
m = M()
o = optax.adam(0.025, b1=int(1))
input_array = jnp.ones((8,), dtype=float)
params = m.init(jax.random.PRNGKey(0), input_array)
opt_state = o.init(params)
m_forward = jax.vmap(m.apply, in_axes=(None, 0))
@jax.jit
@partial(jax.value_and_grad, has_aux=True)
def loss_fn(params, x, y):
params_m = params
y_m = m_forward(params_m, x)
loss = jnp.mean(y_m - y ** 2)
return loss, ()
for i in range(1000):
print(i)
inputs = jax.random.normal(jax.random.PRNGKey(i), (32, 8))
outputs = jax.random.normal(jax.random.PRNGKey(-i), (32, 8))
(loss, aux), grads = loss_fn(params, inputs, outputs)
updates, opt_state = o.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if __name__ == "__main__":
biascorrection_recompilation_test()
At each i
:
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.00016188621520996094 sec
WARNING:absl:Compiling prim_fun (140402616089792 for args (ShapedArray(float32[], weak_type=True),).
WARNING:absl:Finished XLA compilation of integer_pow in 0.02527928352355957 sec
Where prim_fun
refers the integer_pow
op at this line:
https://github.com/deepmind/optax/blob/b4aa6657bbf79985279dea76eaf6d53b25d7e8d9/optax/_src/transform.py#L105
#329 does remove the issue by replacing the call tointeger_pow
with pow
, but I am not sure this is the direction you would like to take.
Issue Analytics
- State:
- Created a year ago
- Comments:14 (7 by maintainers)
Top Results From Across the Web
MIPLIB 2017: data-driven compilation of the 6th mixed-integer ...
The overall goal was to avoid a biased selection of instances, i.e., to avoid that the absolute and relative performance of a solver...
Read more >Accurate estimators of power spectra in N-body simulations
We study the biases on the rough estimator of the power spectrum, which can be easily corrected for, as well as the unknown...
Read more >NVIDIA Deep Learning TensorRT Documentation
This NVIDIA TensorRT Developer Guide demonstrates how to use the C++ and Python APIs for implementing the most common deep learning layers.
Read more >Intelligence at scale through AI model efficiency | Qualcomm
Created an automated method that addresses bias and imbalance in weight ranges: No training. Data free. How can we make quantization as simple....
Read more >TopSpin User Guide
About this document. The user manual describes the main aspects of Bruker's integrated software package. TopSpin. This manual enables all users who work ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Indeed, the docs state that the exponent in
integer_pow
is a fixed integer.Hi @hbq1, so sorry about this – the new avenue makes much more sense. Thanks very much for pinging me!