question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`_bias_correction` forces recompilation of `integer_pow`

See original GitHub issue

This might be related to #197, but I don’t have enough knowledge to confirm it.

In short, while using optax with flax.linen, optimizer.update forces jit-recompilation of the integer_pow jax primitive. I get these results on a GPU backend.

from functools import partial
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from jax.config import config
import os


config.update("jax_log_compiles", 1)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"


def biascorrection_recompilation_test():
    class M(nn.Module):
        @nn.compact
        def __call__(self, x):
            x = nn.Dense(1)(x)
            return x

    m = M()
    o = optax.adam(0.025, b1=int(1))

    input_array = jnp.ones((8,), dtype=float)
    params = m.init(jax.random.PRNGKey(0), input_array)
    opt_state = o.init(params)

    m_forward = jax.vmap(m.apply, in_axes=(None, 0))

    @jax.jit
    @partial(jax.value_and_grad, has_aux=True)
    def loss_fn(params, x, y):
        params_m = params
        y_m = m_forward(params_m, x)
        loss = jnp.mean(y_m - y ** 2)
        return loss, ()



    for i in range(1000):
        print(i)
        inputs = jax.random.normal(jax.random.PRNGKey(i), (32, 8))
        outputs = jax.random.normal(jax.random.PRNGKey(-i), (32, 8))
        (loss, aux), grads = loss_fn(params, inputs, outputs)
        updates, opt_state = o.update(grads, opt_state)
        params = optax.apply_updates(params, updates)


if __name__ == "__main__":
    biascorrection_recompilation_test()

At each i:

WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.00016188621520996094 sec
WARNING:absl:Compiling prim_fun (140402616089792 for args (ShapedArray(float32[], weak_type=True),).
WARNING:absl:Finished XLA compilation of integer_pow in 0.02527928352355957 sec

Where prim_fun refers the integer_pow op at this line: https://github.com/deepmind/optax/blob/b4aa6657bbf79985279dea76eaf6d53b25d7e8d9/optax/_src/transform.py#L105

#329 does remove the issue by replacing the call tointeger_pow with pow, but I am not sure this is the direction you would like to take.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:14 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
hbq1commented, Apr 19, 2022

Indeed, the docs state that the exponent in integer_pow is a fixed integer.

0reactions
epignatellicommented, Nov 8, 2022

Hi @hbq1, so sorry about this – the new avenue makes much more sense. Thanks very much for pinging me!

Read more comments on GitHub >

github_iconTop Results From Across the Web

MIPLIB 2017: data-driven compilation of the 6th mixed-integer ...
The overall goal was to avoid a biased selection of instances, i.e., to avoid that the absolute and relative performance of a solver...
Read more >
Accurate estimators of power spectra in N-body simulations
We study the biases on the rough estimator of the power spectrum, which can be easily corrected for, as well as the unknown...
Read more >
NVIDIA Deep Learning TensorRT Documentation
This NVIDIA TensorRT Developer Guide demonstrates how to use the C++ and Python APIs for implementing the most common deep learning layers.
Read more >
Intelligence at scale through AI model efficiency | Qualcomm
Created an automated method that addresses bias and imbalance in weight ranges: No training. Data free. How can we make quantization as simple....
Read more >
TopSpin User Guide
About this document. The user manual describes the main aspects of Bruker's integrated software package. TopSpin. This manual enables all users who work ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found