question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Batching rules not implemented for host_callback.call

See original GitHub issue

Description

I’m trying to vmap a function that solves linear programs (related: #12827) using cvxopt.solvers.lp wrapped inside a host_callback.call. Consider the following example:

from cvxopt import matrix, solvers
from numpy import asarray
from jax import numpy as jnp, jit, ShapeDtypeStruct, vmap, lax
from jax.experimental.host_callback import call

solvers.options['glpk'] = {'msg_lev': 'GLP_MSG_OFF'}

def lp_helper(c, G, h, A=None, b=None):
    result = solvers.lp(
        matrix(asarray(c, dtype='double')),
        matrix(asarray(G, dtype='double')),
        matrix(asarray(h, dtype='double')),
        None if A is None else matrix(asarray(A, dtype='double')),
        None if b is None else matrix(asarray(b, dtype='double')),
        solver='glpk',
    )
    if result['status'] == 'optimal':
        return {
            'x': jnp.array(result['x'], dtype=float)[:, 0],
            'z': jnp.array(result['z'], dtype=float)[:, 0],
        }
    else:
        breakpoint()

def lp(*args):
    m, n = args[1].shape
    return call(
        lambda args: lp_helper(*args),
        args,
        result_shape={
            'x': ShapeDtypeStruct([n], float),
            'z': ShapeDtypeStruct([m], float),
        },
    )

def main():
    lp_jit = jit(lp)

    G = jnp.array([
        [-1, 0],
        [0, -1],
        [2, 3],
    ])
    h = jnp.array([0, 0, 1])
    cs = jnp.array([
        [-1, 0],
        [0, -1],
        [1, 1],
    ])

    xs = jnp.stack(list(map(lambda c: lp_jit(c, G, h)['x'], cs)))
    print(xs)

    xs = lax.map(lambda c: lp_jit(c, G, h)['x'], cs)
    print(xs)

    try:
        xs = vmap(lambda c: lp_jit(c, G, h)['x'])(cs)
        print(xs)
    except NotImplementedError as e:
        print(e)

if __name__ == '__main__':
    main()

It outputs

[[0.5        0.        ]
 [0.         0.33333334]
 [0.         0.        ]]
[[0.5        0.        ]
 [0.         0.33333334]
 [0.         0.        ]]
batching rules are implemented only for id_tap, not for call.

As you can see, vmap raises an error.

What jax/jaxlib version are you using?

jax 0.3.24, jaxlib 0.3.24

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.10.7, macOS 11.7

NVIDIA GPU info

No response

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
carlosgmartincommented, Nov 16, 2022

@sharadmv Thanks. Perhaps a note about this could be added to the doc page?

0reactions
sharadmvcommented, Nov 15, 2022

I agree w/ Patrick. pure_callback seems like the appropriate solution here. host_callback.call doesn’t support batching because it isn’t generally safe to batch callbacks that might have arbitrary side-effects in them.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax/host_callback.py at main · google/jax - GitHub
We can call the Numpy implementation from any JAX accelerator computation, ... **Note that after you have used the host callback functions, you...
Read more >
Source code for jax.experimental.host_callback
We can call the Numpy implementation from any JAX accelerator computation, ... NotImplementedError("batching rules are implemented only for id_tap, not for ...
Read more >
SemanticsNodeInteraction - Android Developers
Due to the batching of events, all events in a block are sent together and no recomposition will take place in between events....
Read more >
No Batch Rule Applies to company code - SAP Community
Hello Experts, I am facing the problem during creation of Payment medium (FBPM1). The following error has been triggered during FBPM1 I have ......
Read more >
inlet/react-pixi - UNPKG
toString.call(obj) === '[object Object]'\n\nexport const hasKey = collection ... + 'forcing frame rates higher than 125 fps is not supported');\n return;\n ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found