question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[FEATURE-REQUEST] Support storing JAX arrays

See original GitHub issue

Thank you for reaching out and helping us improve Vaex!

Before you submit a new Issue, please read through the documentation. Also, make sure you search through the Open and Closed Issues - your request may already be discussed or addressed.

Description If vaex could store a jax array (similar to numpy or arrow) it could make major processing on Vaex significantly faster.

Take a cosine similarity

import vaex
import numpy as np
from jax import grad, jit, vmap, pmap
import jax.numpy as jnp

df = vaex.from_arrays(emb=np.random.rand(500_000, 1_000))
data = df[:100_000].emb.values
cos_vec = np.random.rand(1, 1_000)

def cosine_np(a, b):
    return np.divide(
        np.sum(a * b, axis=1),
        np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1),
    )

def cosine_jax(a, b):
    return jnp.divide(
        jnp.sum(a * b, axis=1),
        jnp.linalg.norm(a, axis=1) * jnp.linalg.norm(b, axis=1),
    )

cos_jax_jit = jit(cosine_jax)
image

An enormous improvement. The issue is that converting to and from numpy arrays from JAX has overhead, and this causes slowdowns. If Vaex could store and return JAX arrays, this could be a huge performance boost.

jax.numpy has nearly identical syntax to numpy. Maybe this could be basic support, storing the raw jax array? That would let users do vaex.register_function and utilize jnp instead of np while maintaining the huge performance, keeping everything in jax land.

Additional context https://github.com/google/jax#installation https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022#Hessians

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
Ben-Epsteincommented, Jul 1, 2022

yes - if it were done lazily, then i think much more time would be saved

Sorry i was missing important code in my example!

import os
import jax.numpy as jnp
from jax import grad, jit, vmap
import vaex
import numpy as np



df = vaex.from_arrays(emb=np.random.rand(1_000_000, 100))

@vaex.register_function()
def orig_cosine_sim(a: np.ndarray, cos_vec: np.ndarray) -> np.ndarray:
    return np.divide(
        np.sum(a * cos_vec, axis=1),
        np.linalg.norm(a, axis=1) * np.linalg.norm(cos_vec, axis=1),
    )


def cosine_sim(a: np.ndarray, cos_vec: np.ndarray) -> np.ndarray:
    return jnp.divide(
        jnp.sum(a * cos_vec, axis=1),
        jnp.linalg.norm(a, axis=1) * jnp.linalg.norm(cos_vec, axis=1),
    )


@vaex.register_function()
def vaex_jax_func(a, b):
    return cosine_sim(jnp.array(a),jnp.array(b)).to_py()


jax_cosine_sim = jit(cosine_sim)

@vaex.register_function()
def vaex_jax_jit_func(a, b):
    return jax_cosine_sim(jnp.array(a),jnp.array(b)).to_py()

the major slowdown here is the to_py which, if vaex could take care of, we’d see much much better performance

0reactions
JovanVeljanoskicommented, Jul 1, 2022

Hey, thanks for the example!

So if I understand it correctly, the vaex part is done with converting the numpy arrays to jax arrays on the fly?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Source code for jax._src.numpy.lax_numpy - JAX documentation
_src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax. ... JAX does not support JIT-compilation of the single-argument form of ...
Read more >
Source code for numpyro.distributions.continuous
**References** [1] `Generating random correlation matrices based on vines and ... Please make a feature", " request if you need to support jax...
Read more >
MathJax basic tutorial and quick reference - Mathematics Meta
Matrices. Use $$\begin{matrix}…\end{matrix}$$ In between the \begin and \end , put the matrix elements. End each matrix row with \\ , and ...
Read more >
User's Guide - Apache Axis
On line 15 we define the operation (method) name of the Web Service. And on line 17 we actually invoke the desired service,...
Read more >
MathJax basic tutorial and quick reference - University of Idaho
To check that a command is supported, note that MathJax.org maintains a list of currently supported $\LaTeX$ commands, and one can also check...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found