question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Trying to understand the Haiku approach

See original GitHub issue

There is no discussion tab for Haiku so I am raising this question here. My apologies in advance if this isn’t appropriate. Suggestions of where I might ask would be much appreciated.

The mnist example (dm-haiku/examples/mnist.py) has a style of coding that seems to be “the way to do things” in Haiku, but to someone with an object-oriented background appears strange. I have seen the same thing in the training code for Deepmind’s PerceiverIO (recently open-sourced on github), so I assume it’s accepted practice in the Haiku/JAX community.

In the mnist example, this is the function that creates the network:

def net_fn(batch: Batch) -> jnp.ndarray:
  """Standard LeNet-300-100 MLP network."""
  x = batch["image"].astype(jnp.float32) / 255.
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  return mlp(x)

This gets transformed and the init() and apply() pure functions are subsequently used, as you would expect.

What I, as an oo programmer, find very difficult to understand is this. Every time you call apply() to get new outputs, it appears it executes the commands to instantiate the Sequential, Flatten, and Linear classes, creating new objects. This seems like completely unnecessary duplication, with creation of new objects on the heap and subsequent garbage collection. This could happen millions of times with a long training run. With Tensorflow/Keras, you would create these objects once, and then call them with each new set of inputs. In this simple example, that is probably not a concern, but with a complex network (for example, oh I don’t know, say AlphaFold), would you want to do that every time?

I realise I am not a functional programmer, but I know that functional programming includes immutability so you would expect to create new copies every time values change. However, I thought the whole point of the Haiku/JAX approach was to remove the state from the network objects with transform, so these objects never actually do change. The code recreates the objects on each call to apply() and they are identical to the last time they were created.

Have I misunderstood what’s happening here? Am I worrying unnecessarily? Is this repetitive instantiation a compromise between functional JAX and non-functional Python?

Thank you very much for helping an object-oriented guy improve his functional understanding.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:3
  • Comments:5

github_iconTop GitHub Comments

5reactions
tomhennigancommented, Sep 15, 2021

Hi @juliangall , as @avital points out getting good performance in JAX requires users to JIT compile blocks of code.

When you JIT a program in JAX it keeps a log of all the operations (e.g. matrix multiply, add etc) that are performed on your inputs. The next time you call the jitted function it can replay (an optimized version) of those operations without needing to call into your Python code. This only works because JAX requires JIT compiled functions to be pure (e.g. a function of just their inputs with no side effects). You can find more about JIT compilation in this guide: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

The goal in Haiku is to make it easy to define a neural network and to get pure functions that (1) initialise the params/state and (2) apply that network at some params/state/input. You can easily JIT compile these functions.

The reason why we force objects to be created inside of transformed functions, is because we give all top level modules unique names. If we allowed modules to be created outside of a transformed funciton, this would require global state to keep track of previously used module names (to avoid reusing them in your program). Our experience with TensorFlow 1 was that this global state (for names) was hard to reason about.

There is a bit more detail in this guide: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html

There are some other OO NN libraries in the JAX space, @avital is a lead on flax which provides an alternative OO design that you may want to take a look at.

Have I misunderstood what’s happening here? Am I worrying unnecessarily? Is this repetitive instantiation a compromise between functional JAX and non-functional Python?

You haven’t misunderstood, these objects will be continuously recreated every time you call the transformed function. However, inside a @jax.jit this doesn’t matter because your Python code is not actually re-ran. This is what all JAX users (OO or pure functional) rely on for good performance with JAX.

In this simple example, that is probably not a concern, but with a complex network (for example, oh I don’t know, say AlphaFold), would you want to do that every time?

AlphaFold takes precisely this approach and achieves good performance: https://github.com/deepmind/alphafold/blob/1e216f93f06aa04aa699562f504db1d02c3b704c/alphafold/model/model.py#L57-L66

2reactions
juliangallcommented, Sep 15, 2021

@avital and @tomhennigan thanks so much for taking the time to answer my question. You have explained very well how JIT compiles functions so that they can be replayed without further python overhead, and I understand that now. However, it leaves me with one more question.

When I think about what JIT does, I understand that it keeps a log of all the operations. I can see how this could work for a Python function. What does jitted code do, though, with an object instantiation? When I think about this in a typical OO environment, instantiation of an object involves allocating space on the heap, keeping track of references, garbage collection etc. That happens even if the object has no instance variables. It happens whether the language is interpreted or compiled.

If you are telling me JIT compilations are different from compiled languages and all object instantiations are “remembered” in the compiled code (rather than on the heap), I guess that answers my question. However, it blows my mind and I’m very impressed!

I am wondering if my problem here is with the word “compiled”. In the JIT sense, it seems to mean that the code is executed and remembered, so that it can be replayed with different data. If that’s the case, everything may be becoming clear.

Thanks again.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Haiku Poem: Definition, Format, History, and Examples - 2022
Haiku is a form of Japanese poetry made of short, unrhymed lines that evoke natural imagery. Haiku can come in various formats of...
Read more >
The Real “Rules” of Haiku - The Poetry Place
Your haiku should aim to convey a moment of insight, probably from something you've observed, either in the world or in your mind...
Read more >
Understand Haiku - The Center for Global Studies
Think of a haiku as a meditation of sorts that conveys an objective image or feeling without employing subjective judgment and analysis. When...
Read more >
How to Write a Haiku (With Haiku Examples) - Scribophile
Learn all about haiku format, structure, syllable count, rules, and ultimately ... Try these techniques and approaches in your own haiku.
Read more >
Rules for Writing Haiku - Grammar | YourDictionary
Want to learn the rules for writing haiku? In modern haiku there are no specific rules; however, the structure of traditional haiku is...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found