4x slowdown in evaluation of RBM (Flax.linen vs jax.experimental.stax)
See original GitHub issueProvide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
I compare a simple implementation of an RBM with flax against a similar implementation with Jax.experimental.stax
. See this gist notebook .
The two produce the same jaxpr
code when traced, so I would expect comparable performance (minus dispatch cost and time taken to flatten/unflattne the inputs and outputs), but that is not the case, and flax has a 4x disadvantage.
Essentially, the two implementations are
stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
and
class FlaxRBM(nn.Module):
dtype: Any = np.float32
alpha: int = 1
use_bias: bool = True
@nn.compact
def __call__(self, x):
x = nn.Dense(
name="Dense",
features=self.alpha * x.shape[-1],
dtype=self.dtype,
use_bias=self.use_bias,
)(x)
x = nn.activation.sigmoid(x)
return jnp.sum(x, axis=-1)
What I observe is peculiar:
# alpha=1
# Input shape (1,1)
# jax
63.3 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# flax
252 µs ± 7.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
which would suggest that flax has a 4x times the dispatch cost of jax (weird… but ok).
Still, if I increase the size:
# alpha=3
# Input shape (32,20)
# jax
69.5 µs ± 4.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# flax
280 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
We finally overcome the dispatch cost, but flax runtime increases too?
# alpha=6
# Input shape (128,30)
# jax
116 µs ± 8.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# flax
407 µs ± 83.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
???
I also checked that the two produced the same jaxpr, which is indeed the case
jax.make_jaxpr(j_ma.apply)(j_w, x)
{ lambda ; a b c.
let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
precision=None ] c a
e = broadcast_in_dim[ broadcast_dimensions=(1,)
shape=(1, 180) ] b
f = add d e
g = sign f
h = mul f g
i = mul h -2.0
j = exp i
k = add j 1.0
l = log k
m = add h l
n = log 2.0
o = sub m n
p = reduce_sum[ axes=(1,) ] o
in (p,) }
jax.make_jaxpr(f_ma.apply)(f_w, x)
{ lambda ; a b c.
let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
precision=None ] c b
e = broadcast_in_dim[ broadcast_dimensions=(1,)
shape=(1, 180) ] a
f = add d e
g = sign f
h = mul f g
i = mul h -2.0
j = exp i
k = add j 1.0
l = log k
m = add h l
n = log 2.0
o = sub m n
p = reduce_sum[ axes=(1,) ] o
in (p,) }
I have jax==v0.2.8
and flax==v0.3.0
Issue Analytics
- State:
- Created 3 years ago
- Comments:17 (14 by maintainers)
Top GitHub Comments
Tempore omnia vulnera sanabuntur
I wanted to re-check the results here given how much work has gone on in JAX this year concerning dispatch overheads.
Across the board JAX dispatch times are down -a lot-. The overhead of using flax vs stax is now extremely slight in dispatch-sensitive tests (lots of tiny launches) and non-existent on any reasonably large computation.
on my macbook, small setting:
6.81 µs ± 198 ns
STAX vs7.81 µs ± 195 ns
FLAX (old numbers were65µs
vs255µs
!)1.14x
now vs ~4x
, and a 10x absolute improvement in dispatch!on my macbook, larger setting:
304 µs ± 23 µs
STAX vs308 µs ± 42.8 µs
FLAX (old numbers were410µs
vs603µs
!)1.x
now vs ~1.5x
I reran all the benchmarks on colab cpu and my macbook cpu, results here: https://colab.research.google.com/drive/1mPa51aFSK_NOSOi3USrLu-GnqOlh-LDX?usp=sharing
So I believe for all practical purposes this issue is resolved, and will mark this issue closed. Please feel free to re-open if you have further concerns here though!
Blocked by google/jax#5485