JAX Profiling in Colab
See original GitHub issueHi 👋 What do you think about having a JAX Profiling in Colab notebook as an add-on to the current JAX Profiling guide (https://jax.readthedocs.io/en/latest/profiling.html)?
Also, do you know how and where the data get captured in the temp Colab Compute Engine instance’s Linux folder hierarchy, so that we can point TB’s --logdir
flag to it for reading? For example, with TF, you set such dir when instantiating a callback.
To give you an idea of what a Colab guide would look like:
- Upgrade TensorFlow and the TensorBoard plugin to the latest versions:
!pip install --upgrade tensorflow tensorboard_plugin_profile
- Launch TensorBoard:
%load_ext tensorboard
- Import JAX and supporting APIs, including
jax.profiler
:
import jax
import jax.profiler
import jax.numpy as jnp
import jax.random
- Launch a profiling server with a port
1234
that the TensorBoard instance can connect to:
server = jax.profiler.start_server(port=1234)
(In the non-Colab JAX profiling instructions, this step is similar to step 2: import jax.profiler
and jax.profiler.start_server(9999)
)
- Run some JAX code. Your intent is to grab its trace.
# Your JAX code
...
- Start a TensorBoard server:
tensorboard --logdir=/tmp/{FOLDER}/
(In the non-Colab JAX profiling instructions, this step is step 1)
[Note: currently, it’s not possible to perform the next steps, as the log files cannot be found - the web UI says INACTIVE - see my question at the top of this “Issue”.]
- Load TensorBoard at
localhost:1234
:
- In the web UI, select Profile from the drop down menu in the top right
- Click on the Capture Profile button.
- Enter
localhost:1234
in the Profile Service URL field
- Capture:
- Rerun the cell with the awesome JAX code
- While the cell is running, press Capture and wait for the capture to complete
- On the left-hand side, under Tools, click trace_viewer (Note: the overview doesn’t show anything meaningful at the moment for JAX)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:7 (2 by maintainers)
@akshay-jaggi
On this note, does anyone have some simple code for getting the memory profiling working in colab? The go requirement seems to complicate things a bit.