[FEATURE-REQUEST] Support storing JAX arrays
See original GitHub issueThank you for reaching out and helping us improve Vaex!
Before you submit a new Issue, please read through the documentation. Also, make sure you search through the Open and Closed Issues - your request may already be discussed or addressed.
Description If vaex could store a jax array (similar to numpy or arrow) it could make major processing on Vaex significantly faster.
Take a cosine similarity
import vaex
import numpy as np
from jax import grad, jit, vmap, pmap
import jax.numpy as jnp
df = vaex.from_arrays(emb=np.random.rand(500_000, 1_000))
data = df[:100_000].emb.values
cos_vec = np.random.rand(1, 1_000)
def cosine_np(a, b):
return np.divide(
np.sum(a * b, axis=1),
np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1),
)
def cosine_jax(a, b):
return jnp.divide(
jnp.sum(a * b, axis=1),
jnp.linalg.norm(a, axis=1) * jnp.linalg.norm(b, axis=1),
)
cos_jax_jit = jit(cosine_jax)

An enormous improvement. The issue is that converting to and from numpy arrays from JAX has overhead, and this causes slowdowns. If Vaex could store and return JAX arrays, this could be a huge performance boost.
jax.numpy
has nearly identical syntax to numpy. Maybe this could be basic support, storing the raw jax array? That would let users do vaex.register_function
and utilize jnp
instead of np
while maintaining the huge performance, keeping everything in jax land.
Additional context https://github.com/google/jax#installation https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022#Hessians
Issue Analytics
- State:
- Created a year ago
- Comments:5 (5 by maintainers)
yes - if it were done lazily, then i think much more time would be saved
Sorry i was missing important code in my example!
the major slowdown here is the
to_py
which, if vaex could take care of, we’d see much much better performanceHey, thanks for the example!
So if I understand it correctly, the vaex part is done with converting the numpy arrays to jax arrays on the fly?