question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to use `named_call` in jitted functions

See original GitHub issue

I’m trying to profile my code and use named_call to annotate some functions, but the names will not show up in the captured trace if they’re inside jitted functions.

Here is my test script (test_named_call.py):

import jax
from jax import numpy as jnp

@jax.named_call
def foo(x, y):
    return (x + y) / 2.

@jax.jit
def bar(a):
    def foo2(x, y):
        return foo(x, y), None

    out, _ = jax.lax.scan(foo2, 0., a)
    return out

a = jnp.array([1., 2., 3., 4., 5.])

jax.profiler.start_trace('/tmp/tensorboard')
with jax.profiler.StepTraceAnnotation('step', step_num=0): # JIT warm-up
    out = bar(a)
with jax.profiler.StepTraceAnnotation('step', step_num=1):
    out = bar(a)
out.block_until_ready()
jax.profiler.stop_trace()

My environments: Ubuntu 20.04, Python 3.8.5, CUDA 11.3, jax 0.2.14, jaxlib 0.1.67+cuda111, tensorflow 2.5.0, tbp-nightly 2.5.0a20210511

scrshot

Above is an overview of the captured trace. step 1 takes a very short time after step 0. I can find the name foo in that bunch of functions in step 0, but not in step 1.

scrshot2

Above is a zoomed-in view of step 1. The operations have only general names like ‘fusion’ or ‘Memcpy’. Because there’re 5 repeated operations, I can guess it’s the scan loop. But in general it’s really hard to associate those operations with Python lines.

Also, @jekbradbury mentioned that in the bottom plane there should be some information like ‘source’. Is it available?

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:12 (4 by maintainers)

github_iconTop GitHub Comments

2reactions
enolancommented, Aug 27, 2021

Any news on this? It’s extremely difficult to debug performance when everything is named custom-call or fusion.

0reactions
cagrikymkcommented, Nov 19, 2022

@sharadmv Is there any profiler that works well with JAX?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Automatic module jitting with jit_module - Numba
A common usage pattern is to have an entire module containing user-defined functions that all need to be jitted. One option to accomplish...
Read more >
flax.linen.enable_named_call - Read the Docs
Enables named call wrapping for labelling profile traces. ... Note that jax.named_scope only works for compiled functions (e.g.: using jax.jit or jax.pmap).
Read more >
Named Arguments for JavaScript Functions, Yet Another ...
The wrapper function uses this object to form the array of arguments used ... a second formal argument of the function namedCalled is...
Read more >
Rohan Rao on Twitter: "Jaxton Set 4: Just-In-Time (JIT) Compilation ...
JIT -compile functions for speedup ✓ Shape, jaxpr, XLA computation of jitted functions ✓ Named call for jitted functions Getting into the midst...
Read more >
Introducing LuaJIT
As the name says LuaJIT is a Just-In-Time (JIT) compiler. ... originate in a JIT compiled function a better mechanism, tentatively named call...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found