Batching rules not implemented for host_callback.call
See original GitHub issueDescription
I’m trying to vmap a function that solves linear programs (related: #12827) using cvxopt.solvers.lp wrapped inside a host_callback.call. Consider the following example:
from cvxopt import matrix, solvers
from numpy import asarray
from jax import numpy as jnp, jit, ShapeDtypeStruct, vmap, lax
from jax.experimental.host_callback import call
solvers.options['glpk'] = {'msg_lev': 'GLP_MSG_OFF'}
def lp_helper(c, G, h, A=None, b=None):
result = solvers.lp(
matrix(asarray(c, dtype='double')),
matrix(asarray(G, dtype='double')),
matrix(asarray(h, dtype='double')),
None if A is None else matrix(asarray(A, dtype='double')),
None if b is None else matrix(asarray(b, dtype='double')),
solver='glpk',
)
if result['status'] == 'optimal':
return {
'x': jnp.array(result['x'], dtype=float)[:, 0],
'z': jnp.array(result['z'], dtype=float)[:, 0],
}
else:
breakpoint()
def lp(*args):
m, n = args[1].shape
return call(
lambda args: lp_helper(*args),
args,
result_shape={
'x': ShapeDtypeStruct([n], float),
'z': ShapeDtypeStruct([m], float),
},
)
def main():
lp_jit = jit(lp)
G = jnp.array([
[-1, 0],
[0, -1],
[2, 3],
])
h = jnp.array([0, 0, 1])
cs = jnp.array([
[-1, 0],
[0, -1],
[1, 1],
])
xs = jnp.stack(list(map(lambda c: lp_jit(c, G, h)['x'], cs)))
print(xs)
xs = lax.map(lambda c: lp_jit(c, G, h)['x'], cs)
print(xs)
try:
xs = vmap(lambda c: lp_jit(c, G, h)['x'])(cs)
print(xs)
except NotImplementedError as e:
print(e)
if __name__ == '__main__':
main()
It outputs
[[0.5 0. ]
[0. 0.33333334]
[0. 0. ]]
[[0.5 0. ]
[0. 0.33333334]
[0. 0. ]]
batching rules are implemented only for id_tap, not for call.
As you can see, vmap raises an error.
What jax/jaxlib version are you using?
jax 0.3.24, jaxlib 0.3.24
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.10.7, macOS 11.7
NVIDIA GPU info
No response
Issue Analytics
- State:
- Created 10 months ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
jax/host_callback.py at main · google/jax - GitHub
We can call the Numpy implementation from any JAX accelerator computation, ... **Note that after you have used the host callback functions, you...
Read more >Source code for jax.experimental.host_callback
We can call the Numpy implementation from any JAX accelerator computation, ... NotImplementedError("batching rules are implemented only for id_tap, not for ...
Read more >SemanticsNodeInteraction - Android Developers
Due to the batching of events, all events in a block are sent together and no recomposition will take place in between events....
Read more >No Batch Rule Applies to company code - SAP Community
Hello Experts, I am facing the problem during creation of Payment medium (FBPM1). The following error has been triggered during FBPM1 I have ......
Read more >inlet/react-pixi - UNPKG
toString.call(obj) === '[object Object]'\n\nexport const hasKey = collection ... + 'forcing frame rates higher than 125 fps is not supported');\n return;\n ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@sharadmv Thanks. Perhaps a note about this could be added to the doc page?
I agree w/ Patrick.
pure_callback
seems like the appropriate solution here.host_callback.call
doesn’t support batching because it isn’t generally safe to batch callbacks that might have arbitrary side-effects in them.