How to use `named_call` in jitted functions
See original GitHub issueI’m trying to profile my code and use named_call
to annotate some functions, but the names will not show up in the captured trace if they’re inside jitted functions.
Here is my test script (test_named_call.py
):
import jax
from jax import numpy as jnp
@jax.named_call
def foo(x, y):
return (x + y) / 2.
@jax.jit
def bar(a):
def foo2(x, y):
return foo(x, y), None
out, _ = jax.lax.scan(foo2, 0., a)
return out
a = jnp.array([1., 2., 3., 4., 5.])
jax.profiler.start_trace('/tmp/tensorboard')
with jax.profiler.StepTraceAnnotation('step', step_num=0): # JIT warm-up
out = bar(a)
with jax.profiler.StepTraceAnnotation('step', step_num=1):
out = bar(a)
out.block_until_ready()
jax.profiler.stop_trace()
My environments: Ubuntu 20.04, Python 3.8.5, CUDA 11.3, jax 0.2.14, jaxlib 0.1.67+cuda111, tensorflow 2.5.0, tbp-nightly 2.5.0a20210511
Above is an overview of the captured trace. step 1
takes a very short time after step 0
. I can find the name foo
in that bunch of functions in step 0
, but not in step 1
.
Above is a zoomed-in view of step 1
. The operations have only general names like ‘fusion’ or ‘Memcpy’. Because there’re 5 repeated operations, I can guess it’s the scan loop. But in general it’s really hard to associate those operations with Python lines.
Also, @jekbradbury mentioned that in the bottom plane there should be some information like ‘source’. Is it available?
Issue Analytics
- State:
- Created 2 years ago
- Comments:12 (4 by maintainers)
Top GitHub Comments
Any news on this? It’s extremely difficult to debug performance when everything is named
custom-call
orfusion
.@sharadmv Is there any profiler that works well with JAX?