Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

hstack and vstack produce very inefficient jaxpr and jit slowly; possible fix with reshape?

See original GitHub issue

hstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.


{ lambda  ; a b c.
  let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] a
      e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      f = concatenate[ dimension=0 ] d e
      g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] c
      i = concatenate[ dimension=0 ] g h
      j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] f
      k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] i
      l = concatenate[ dimension=0 ] j k
      m = slice[ limit_indices=(1, 2, 2, 2, 3, 3)
                 start_indices=(0, 0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1, 1) ] l
      n = squeeze[ dimensions=(0,) ] m
      o = slice[ limit_indices=(2, 2, 2, 2, 3, 3)
                 start_indices=(1, 0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1, 1) ] l
      p = squeeze[ dimensions=(0,) ] o
      q = concatenate[ dimension=1 ] n p
      r = slice[ limit_indices=(1, 4, 2, 3, 3)
                 start_indices=(0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1) ] q
      s = squeeze[ dimensions=(0,) ] r
      t = slice[ limit_indices=(2, 4, 2, 3, 3)
                 start_indices=(1, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1) ] q
      u = squeeze[ dimensions=(0,) ] t
      v = concatenate[ dimension=1 ] s u
      w = slice[ limit_indices=(1, 4, 3, 3)
                 start_indices=(0, 0, 0, 0)
                 strides=(1, 1, 1, 1) ] v
      x = squeeze[ dimensions=(0,) ] w
      y = slice[ limit_indices=(2, 4, 3, 3)
                 start_indices=(1, 0, 0, 0)
                 strides=(1, 1, 1, 1) ] v
      z = squeeze[ dimensions=(0,) ] y
      ba = slice[ limit_indices=(3, 4, 3, 3)
                  start_indices=(2, 0, 0, 0)
                  strides=(1, 1, 1, 1) ] v
      bb = squeeze[ dimensions=(0,) ] ba
      bc = slice[ limit_indices=(4, 4, 3, 3)
                  start_indices=(3, 0, 0, 0)
                  strides=(1, 1, 1, 1) ] v
      bd = squeeze[ dimensions=(0,) ] bc
      be = concatenate[ dimension=1 ] x z bb bd
      bf = slice[ limit_indices=(1, 12, 3)
                  start_indices=(0, 0, 0)
                  strides=(1, 1, 1) ] be
      bg = squeeze[ dimensions=(0,) ] bf
      bh = slice[ limit_indices=(2, 12, 3)
                  start_indices=(1, 0, 0)
                  strides=(1, 1, 1) ] be
      bi = squeeze[ dimensions=(0,) ] bh
      bj = slice[ limit_indices=(3, 12, 3)
                  start_indices=(2, 0, 0)
                  strides=(1, 1, 1) ] be
      bk = squeeze[ dimensions=(0,) ] bj
      bl = slice[ limit_indices=(4, 12, 3)
                  start_indices=(3, 0, 0)
                  strides=(1, 1, 1) ] be
      bm = squeeze[ dimensions=(0,) ] bl
      bn = concatenate[ dimension=1 ] bg bi bk bm
  in (bn,) }

to a better, equivalent code that can be achieved using jnp.reshape

{ lambda  ; a b c.
  let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] a
      e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      f = concatenate[ dimension=0 ] d e
      g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] c
      i = concatenate[ dimension=0 ] g h
      j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] f
      k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] i
      l = concatenate[ dimension=0 ] j k
      m = reshape[ dimensions=(0, 2, 4, 1, 3, 5)
                   new_sizes=(12, 12) ] l
  in (m,) }

Probably hstack can be re-expressed in terms of reshape in general. I’m new to jax so maybe there are some negative side effects to such approach?

Code to reproduce issue:

import jax
import jax.numpy as jnp

n = 2

mAA = 1.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mBB = 10.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mAB = 2.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))

def stack_hard(AA,AB,BB):
    return jnp.hstack(

def stack_easy(AA,AB,BB):
    return  jax.lax.reshape(
                dimensions = (0,2,4,1,3,5)

# JIT is very slow in case of larger n
# fast_stack = jax.jit(stack_hard)
# fast_stack(mAA,mBB,mAB)




Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:17 (2 by maintainers)

github_iconTop GitHub Comments

RadostWcommented, Jun 1, 2021

Oh, I understand better why this happened. Perhaps we can improve just the case where only one jnp array is passed as argument.

I’m pretty sure in all cases jnp.hstack can be expressed with jax.lax.reshape (note: not jnp.reshape) due to it’s cool feature of optional arg dimensions.

In case of your example it would be:

>>> import jax numpy as jnp
>>> import jax
>>> x = jnp.arange(12).reshape(3, 2, 2)
>>> jax.lax.reshape(x,(2,6),dimensions=(1,0,2)) - jnp.hstack(x)
DeviceArray([[0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]], dtype=int32)
jakevdpcommented, Jun 1, 2021

Hi - the issue here is that the call signature of hstack is that it accepts a single argument, which is a tuple of arrays.

A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into N array objects that are then passed back to XLA.

I’m not sure what we could do to “fix” this – maybe we could raise an error in the case that a single array is passed to hstack, to prevent this sort of silent conversion back to a numpy tuple, and require users to pass tuple(arr) explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function’s signature.

What do you think?

Read more comments on GitHub >

github_iconTop Results From Across the Web

The reshape in Jax is much slower than Numpy #11013 - GitHub
The results show that with the large D, the Jax-version reshape is slower than Numpy and becomes slower and slower in the following...
Read more >
Why is this function slower in JAX vs numpy? - Stack Overflow
Result: jax is ever so slightly faster. The reason why jax doesn't outperform numpy here is that it's run on a CPU (just...
Read more >
Source code for jax._src.numpy.lax_numpy - JAX documentation
Rather than introducing a new type of JAX # scalar object with JAX promotion behaviors, instead we make the JAX scalar # types...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Post

No results found

github_iconTop Related Hashnode Post

No results found