question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

JAX Profiling in Colab

See original GitHub issue

Hi 👋 What do you think about having a JAX Profiling in Colab notebook as an add-on to the current JAX Profiling guide (https://jax.readthedocs.io/en/latest/profiling.html)?

Also, do you know how and where the data get captured in the temp Colab Compute Engine instance’s Linux folder hierarchy, so that we can point TB’s --logdir flag to it for reading? For example, with TF, you set such dir when instantiating a callback.


To give you an idea of what a Colab guide would look like:

  1. Upgrade TensorFlow and the TensorBoard plugin to the latest versions:
!pip install --upgrade tensorflow tensorboard_plugin_profile
  1. Launch TensorBoard:
%load_ext tensorboard
  1. Import JAX and supporting APIs, including jax.profiler:
import jax
import jax.profiler
import jax.numpy as jnp
import jax.random
  1. Launch a profiling server with a port 1234 that the TensorBoard instance can connect to:
server = jax.profiler.start_server(port=1234)

(In the non-Colab JAX profiling instructions, this step is similar to step 2: import jax.profiler and jax.profiler.start_server(9999))

  1. Run some JAX code. Your intent is to grab its trace.
# Your JAX code
...
  1. Start a TensorBoard server:
tensorboard --logdir=/tmp/{FOLDER}/

(In the non-Colab JAX profiling instructions, this step is step 1)

[Note: currently, it’s not possible to perform the next steps, as the log files cannot be found - the web UI says INACTIVE - see my question at the top of this “Issue”.]

  1. Load TensorBoard at localhost:1234:
  • In the web UI, select Profile from the drop down menu in the top right
  • Click on the Capture Profile button.
  • Enter localhost:1234 in the Profile Service URL field
  1. Capture:
  • Rerun the cell with the awesome JAX code
  • While the cell is running, press Capture and wait for the capture to complete
  • On the left-hand side, under Tools, click trace_viewer (Note: the overview doesn’t show anything meaningful at the moment for JAX)

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:3
  • Comments:7 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
sholtodouglascommented, Dec 22, 2021

@akshay-jaggi

# This will install it
!add-apt-repository ppa:longsleep/golang-backports -y
!apt update
!apt install golang-go
%env GOPATH=/root/go

!apt-get install graphviz gv
!go install github.com/google/pprof@latest

# Do stuff / profile as per guide 

# This will save the output to a png 
!go tool pprof -png memory.prof ```
1reaction
akshay-jaggicommented, May 10, 2021

On this note, does anyone have some simple code for getting the memory profiling working in colab? The go requirement seems to complicate things a bit.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Device Memory Profiling - JAX documentation - Read the Docs
The JAX Device Memory Profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it...
Read more >
pjit-colab - Colaboratory
devices = np.asarray(jax.devices()).reshape(*mesh_shape) ... appears 386 # separately in Python profiling results --> 387 return backend.compile(built_c, ...
Read more >
Google Colab Notebook using JAX / Flax + TPUs for ... - Reddit
Google Colab Notebook using JAX / Flax + TPUs for INCREDIBLY fast image generation for free! Link to the goods: https://colab.research.google.
Read more >
Quickstart: Run a calculation on a Cloud TPU VM using JAX
When you run JAX code in a Colab notebook, Colab automatically creates a legacy TPU node. TPU nodes have a different architecture. For...
Read more >
How to use TensorBoard in JAX & Flax
Using TensorBoard with Jupyter notebooks and Google Colab ... To profile JAX programs, send data to the TensorBoard profiler.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found