hstack and vstack produce very inefficient jaxpr and jit slowly; possible fix with reshape?
See original GitHub issuehstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.
Compare:
{ lambda ; a b c.
let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] a
e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
f = concatenate[ dimension=0 ] d e
g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] c
i = concatenate[ dimension=0 ] g h
j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] f
k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] i
l = concatenate[ dimension=0 ] j k
m = slice[ limit_indices=(1, 2, 2, 2, 3, 3)
start_indices=(0, 0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1, 1) ] l
n = squeeze[ dimensions=(0,) ] m
o = slice[ limit_indices=(2, 2, 2, 2, 3, 3)
start_indices=(1, 0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1, 1) ] l
p = squeeze[ dimensions=(0,) ] o
q = concatenate[ dimension=1 ] n p
r = slice[ limit_indices=(1, 4, 2, 3, 3)
start_indices=(0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1) ] q
s = squeeze[ dimensions=(0,) ] r
t = slice[ limit_indices=(2, 4, 2, 3, 3)
start_indices=(1, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1) ] q
u = squeeze[ dimensions=(0,) ] t
v = concatenate[ dimension=1 ] s u
w = slice[ limit_indices=(1, 4, 3, 3)
start_indices=(0, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
x = squeeze[ dimensions=(0,) ] w
y = slice[ limit_indices=(2, 4, 3, 3)
start_indices=(1, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
z = squeeze[ dimensions=(0,) ] y
ba = slice[ limit_indices=(3, 4, 3, 3)
start_indices=(2, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
bb = squeeze[ dimensions=(0,) ] ba
bc = slice[ limit_indices=(4, 4, 3, 3)
start_indices=(3, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
bd = squeeze[ dimensions=(0,) ] bc
be = concatenate[ dimension=1 ] x z bb bd
bf = slice[ limit_indices=(1, 12, 3)
start_indices=(0, 0, 0)
strides=(1, 1, 1) ] be
bg = squeeze[ dimensions=(0,) ] bf
bh = slice[ limit_indices=(2, 12, 3)
start_indices=(1, 0, 0)
strides=(1, 1, 1) ] be
bi = squeeze[ dimensions=(0,) ] bh
bj = slice[ limit_indices=(3, 12, 3)
start_indices=(2, 0, 0)
strides=(1, 1, 1) ] be
bk = squeeze[ dimensions=(0,) ] bj
bl = slice[ limit_indices=(4, 12, 3)
start_indices=(3, 0, 0)
strides=(1, 1, 1) ] be
bm = squeeze[ dimensions=(0,) ] bl
bn = concatenate[ dimension=1 ] bg bi bk bm
in (bn,) }
to a better, equivalent code that can be achieved using jnp.reshape
{ lambda ; a b c.
let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] a
e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
f = concatenate[ dimension=0 ] d e
g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] c
i = concatenate[ dimension=0 ] g h
j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] f
k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] i
l = concatenate[ dimension=0 ] j k
m = reshape[ dimensions=(0, 2, 4, 1, 3, 5)
new_sizes=(12, 12) ] l
in (m,) }
Probably hstack
can be re-expressed in terms of reshape in general. I’m new to jax
so maybe there are some negative side effects to such approach?
Code to reproduce issue:
import jax
import jax.numpy as jnp
n = 2
mAA = 1.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mBB = 10.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mAB = 2.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
def stack_hard(AA,AB,BB):
return jnp.hstack(
jnp.hstack(
jnp.hstack(
jnp.hstack(
jnp.array(
[[AA,AB],[AB,BB]]
)
)
)
)
)
def stack_easy(AA,AB,BB):
return jax.lax.reshape(
jnp.array([[AA,AB],[AB,BB]]),
(6*n,6*n),
dimensions = (0,2,4,1,3,5)
)
# JIT is very slow in case of larger n
# fast_stack = jax.jit(stack_hard)
# fast_stack(mAA,mBB,mAB)
print('===========================')
print(
jax.make_jaxpr(stack_hard)(mAA,mAB,mBB)
)
print('===========================')
print(
jax.make_jaxpr(stack_easy)(mAA,mAB,mBB)
)
print(stack_easy(mAA,mAB,mBB))
print(stack_hard(mAA,mAB,mBB))
Issue Analytics
- State:
- Created 2 years ago
- Comments:17 (2 by maintainers)
Top Results From Across the Web
The reshape in Jax is much slower than Numpy #11013 - GitHub
The results show that with the large D, the Jax-version reshape is slower than Numpy and becomes slower and slower in the following...
Read more >Why is this function slower in JAX vs numpy? - Stack Overflow
Result: jax is ever so slightly faster. The reason why jax doesn't outperform numpy here is that it's run on a CPU (just...
Read more >Source code for jax._src.numpy.lax_numpy - JAX documentation
Rather than introducing a new type of JAX # scalar object with JAX promotion behaviors, instead we make the JAX scalar # types...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Oh, I understand better why this happened. Perhaps we can improve just the case where only one
jnp
array is passed as argument.I’m pretty sure in all cases
jnp.hstack
can be expressed withjax.lax.reshape
(note: notjnp.reshape
) due to it’s cool feature of optional argdimensions
.In case of your example it would be:
Hi - the issue here is that the call signature of
hstack
is that it accepts a single argument, which is a tuple of arrays.A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into
N
array objects that are then passed back to XLA.I’m not sure what we could do to “fix” this – maybe we could raise an error in the case that a single array is passed to
hstack
, to prevent this sort of silent conversion back to a numpy tuple, and require users to passtuple(arr)
explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function’s signature.What do you think?