question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

4x slowdown in evaluation of RBM (Flax.linen vs jax.experimental.stax)

See original GitHub issue

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I compare a simple implementation of an RBM with flax against a similar implementation with Jax.experimental.stax. See this gist notebook .

The two produce the same jaxpr code when traced, so I would expect comparable performance (minus dispatch cost and time taken to flatten/unflattne the inputs and outputs), but that is not the case, and flax has a 4x disadvantage.

Essentially, the two implementations are

stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)

and

class FlaxRBM(nn.Module):
    dtype: Any = np.float32
    alpha: int = 1
    use_bias: bool = True

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            name="Dense",
            features=self.alpha * x.shape[-1],
            dtype=self.dtype,
            use_bias=self.use_bias,
        )(x)
        x = nn.activation.sigmoid(x)
        return jnp.sum(x, axis=-1)

What I observe is peculiar:

# alpha=1
# Input shape (1,1) 

# jax
63.3 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# flax
252 µs ± 7.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

which would suggest that flax has a 4x times the dispatch cost of jax (weird… but ok).

Still, if I increase the size:

# alpha=3
# Input shape (32,20) 

# jax
69.5 µs ± 4.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# flax
280 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

We finally overcome the dispatch cost, but flax runtime increases too?

# alpha=6
# Input shape (128,30) 

# jax
116 µs ± 8.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# flax
407 µs ± 83.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

???

I also checked that the two produced the same jaxpr, which is indeed the case

jax.make_jaxpr(j_ma.apply)(j_w, x)
{ lambda  ; a b c.
  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None ] c a
      e = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(1, 180) ] b
      f = add d e
      g = sign f
      h = mul f g
      i = mul h -2.0
      j = exp i
      k = add j 1.0
      l = log k
      m = add h l
      n = log 2.0
      o = sub m n
      p = reduce_sum[ axes=(1,) ] o
  in (p,) }
jax.make_jaxpr(f_ma.apply)(f_w, x)
{ lambda  ; a b c.
  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None ] c b
      e = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(1, 180) ] a
      f = add d e
      g = sign f
      h = mul f g
      i = mul h -2.0
      j = exp i
      k = add j 1.0
      l = log k
      m = add h l
      n = log 2.0
      o = sub m n
      p = reduce_sum[ axes=(1,) ] o
  in (p,) }

I have jax==v0.2.8 and flax==v0.3.0

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:17 (14 by maintainers)

github_iconTop GitHub Comments

2reactions
levskayacommented, Sep 7, 2021

Tempore omnia vulnera sanabuntur

I wanted to re-check the results here given how much work has gone on in JAX this year concerning dispatch overheads.

Across the board JAX dispatch times are down -a lot-. The overhead of using flax vs stax is now extremely slight in dispatch-sensitive tests (lots of tiny launches) and non-existent on any reasonably large computation.

on my macbook, small setting:

  • 6.81 µs ± 198 ns STAX vs 7.81 µs ± 195 ns FLAX (old numbers were 65µs vs 255µs!)
  • ~1.14x now vs ~4x, and a 10x absolute improvement in dispatch!

on my macbook, larger setting:

  • 304 µs ± 23 µs STAX vs 308 µs ± 42.8 µs FLAX (old numbers were 410µs vs 603µs!)
  • ~1.x now vs ~1.5x

I reran all the benchmarks on colab cpu and my macbook cpu, results here: https://colab.research.google.com/drive/1mPa51aFSK_NOSOi3USrLu-GnqOlh-LDX?usp=sharing

So I believe for all practical purposes this issue is resolved, and will mark this issue closed. Please feel free to re-open if you have further concerns here though!

0reactions
marcvanzeecommented, Feb 2, 2021

Blocked by google/jax#5485

Read more comments on GitHub >

github_iconTop Results From Across the Web

(PDF) NetKet 3: Machine Learning Toolbox for Many-Body ...
NetKet is built around neural quantum states and provides efficient algorithms for their evaluation and optimization.
Read more >
glove_vocab.250k.txt - Bar Ilan NLP Lab
the . and to of a in " is for : i ) that ( you it on - with 's this by...
Read more >
Rheem classic 90 plus furnace reset button - Weebly
In subscribing to our newsletter by entering your email address above you confirm you are over the age of 18 (or have obtained...
Read more >
UNK the , . of and in - Stanford University
UNK the , . of and in " a to was is ) ( for as on by he with 's that at...
Read more >
biennials Feijo BSDs Villalon woodi woods spiders Nampo
... TYO Fascinating Gholamhossein Coinciding Siderno invertebrates TYR reproducing Baripada Mircea Pruritic Kisah Dancel Shadowcat yohimbe VS-300 screentest ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found