Iterate over dynamic slices
See original GitHub issueHi there,
Excellent work here!
Our team is trying to construct a Jax implementation of a matrix factorization which requires us to iterate over slices whose dimensions depend upon the iteration counter. For example, suppose we need to operate upon each column of some matrix’s lower-diagonal, as in the Householder QR algorithm, along the lines of:
@jax.jit
def operate_on_lower_triangle(A):
for j in A.size[1]:
A_slice = A[j:, j] # *
v = vector_function(A_slice)
A = jax.lax.index_update(A, index[j:, j], apply_transform(A, v))
This actually works, but since the for loop gets statically unrolled, compilation can get really slow, especially in the more complex example we are working with. The cure for this is to use one of the Lax control flow primitives. But this has the consequence that ‘j’ becomes an abstract tracer, so that the line marked *
becomes problematic due to the dynamic shape of the slice.
Is there a workaround for this? It seems like there must be, since QR has a TPU implementation.
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (6 by maintainers)
Our current solution for this kind of compilation time problem is to at least partially roll up either inner or outer loops.
The QR implementation JAX uses is in the XLA client library, in C++: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/qr.h
You can do the same sort of thing in Python too, e.g., the LU decomposition here: https://github.com/google/jax/blob/master/jax/lax_linalg.py#L452
You can use
lax.dynamic_update_slice
orlax.scatter
if you need more control over updates; ultimately that’s what the indexed update lowers to.I think it’s an open question how well these implementations can perform without more manual tuning, but they certainly suffice for many things.
Hope that helps!
Yes, you still do the computation on the masked elements. But depending on the platform (particularly for accelerators like TPUs) this may even be faster than trying to do dynamic slicing, regardless of compilation times.