question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Feature request: disabling/detecting out-of-place `.at[].set()` updates.

See original GitHub issue

I have a program in which XLA is treating a .at[].set() update out-of-place rather than in-place. At the moment it’s actually not clear to me whether this is a bug in XLA*, or a fault in my program that is preventing the optimisation.

What I’d really like is either:

  • a way to disable out-of-place updates (and throw an error that can be debugged); or
  • a way of detecting out-of-place updates (once again, to debug).

AFAIK there’s no way to do this in JAX at the moment. In my head I’m imagining something like an environment variable similar to the JAX_DEBUG_NANS one used to catch NaNs.

(More generally I would remark that ways to properly introspect/debug/understand the compiled XLA would be really nice.)

* I know of at least one example of this being the case for XLA:CPU (#8192), but I don’t think I’m running into that particular bug here.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:5
  • Comments:14 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Sep 1, 2022

When I run your test on a Colab GPU runtime, I see the following:

%%writefile bench.py
import jax
import jax.numpy as jnp
import pytest
from functools import partial
import time


def benchmark_update_func(init_func, update_func, num_iters):
    # To pre-compile the update_func in case it uses jit
    update_func_input = init_func()
    update_func(update_func_input)

    elapsed_time = 0.0
    for _ in range(num_iters):
        update_func_input = init_func()
        jax.block_until_ready(update_func_input)

        start_time = time.time()
        output = update_func(update_func_input)
        jax.block_until_ready(output)
        elapsed_time += time.time() - start_time

    time_per_iter = elapsed_time / num_iters
    print(f'{time_per_iter * 1000.0:.2f}ms per iteration')


def make_item(n):
    return jnp.full((64,), n)


def make_item_batch(n, batch_size):
    return jnp.full((batch_size, 64,), n)

@pytest.mark.parametrize('buffer_size', [1_000, 10_000, 100_000, 1_000_000])
def test_modify_inplace_donate_performance(buffer_size):
    def init_func():
        return make_item_batch(0, buffer_size)

    @partial(jax.jit, donate_argnums=(0,))
    def update_func(state):
        item = make_item(0)
        return state.at[10].set(item)

    benchmark_update_func(init_func, update_func, num_iters=100)
!python -m pytest --durations=5 bench.py
============================= test session starts ==============================
platform linux -- Python 3.7.13, pytest-3.6.4, py-1.11.0, pluggy-0.7.1
rootdir: /content, inifile:
plugins: typeguard-2.7.1
collected 4 items                                                              

bench.py ....                                                            [100%]

=========================== slowest 5 test durations ===========================
1.13s call     bench.py::test_modify_inplace_donate_performance[1000]
0.31s call     bench.py::test_modify_inplace_donate_performance[1000000]
0.18s call     bench.py::test_modify_inplace_donate_performance[100000]
0.17s call     bench.py::test_modify_inplace_donate_performance[10000]
0.00s setup    bench.py::test_modify_inplace_donate_performance[1000]
=========================== 4 passed in 3.04 seconds ===========================

… which looks to me like it’s consistent with updates being in-place when buffer donation is available. What do you think?

1reaction
jakevdpcommented, Sep 1, 2022

Buffer donation is not implemented on CPU (see https://jax.readthedocs.io/en/latest/faq.html#buffer-donation); when you execute this on CPU you should see warnings that look like

UserWarning: Some donated buffers were not usable: ShapedArray(int32[10]).
Donation is not implemented for cpu.
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  _module_unique_id = itertools.count()

Two instantiated jax arrays (such as the input and output of your test function) cannot share the same memory unless their buffers are donated. So I don’t think this test tells us anything about whether updates are done in-place within JIT.

Read more comments on GitHub >

github_iconTop Results From Across the Web

KB5004442—Manage changes for Windows DCOM Server ...
KB5004442—Manage changes for Windows DCOM Server Security Feature Bypass (CVE-2021-26414) ... In that update, DCOM hardening was disabled by default.
Read more >
Disabling/Enabling Feature Requests - Instabug
Described here is how to disable or enable the feature requests. ... You can completely prevent any feature request related features from displaying...
Read more >
Permissions Policy - HTTP - MDN Web Docs - Mozilla
Chrome Edge Permissions‑Policy Full support. Chrome88. more. Toggle history Full supp... accelerometer. Experimental Full support. Chrome88. more. Toggle history Full supp... ambient‑light‑sensor. Experimental Full support. Chrome88....
Read more >
About iOS 16 Updates - Apple Support
This release also adds Apple Music Sing, Advanced Data Protection for iCloud, Lock Screen improvements, and other features and bug fixes for ...
Read more >
Merge request approval settings - GitLab Docs
If you don't want users to change approval rules on merge requests, you can disable this setting: On the left sidebar, select Settings...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found