[Question] Efficiently sampling from the binomial distribution
See original GitHub issueI currently have a naive binomial sampler which uses jax.random.bernoulli
under the hood:
def binomial(key, p, n=1, shape=()):
p, n = _promote_shapes(p, n)
shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n))
n_max = np.max(n)
uniforms = random.uniform(key, shape + (n_max,))
n = np.expand_dims(n, axis=-1)
p = np.expand_dims(p, axis=-1)
mask = (np.arange(n_max) > n).astype(uniforms.dtype)
p, uniforms = promote_shapes(p, uniforms)
return np.sum(mask * lax.lt(uniforms, p), axis=-1, keepdims=False)
This works, but the biggest drawback is that it is not jittable due to the dynamic size argument passed to random.uniform
(see error trace below). It is also wasteful in that it draws n_max
uniform random floats and does an element-wise multiply with mask. I was wondering if there is a more efficient way to sample from the binomial distribution, using existing primitives. I am new to JAX, so any insights on improving on this implementation will be really helpful.
error trace
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../numpyro/distributions/distribution.py:47: in rvs
return _rvs(self, *args, **kwargs)
../numpyro/distributions/distribution.py:36: in _rvs
vals = _instance._rvs(*args)
../numpyro/distributions/discrete.py:22: in _rvs
return binomial(self._random_state, p, n, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
ans = self.f(*args, **self.kwargs)
../numpyro/distributions/util.py:285: in binomial
uniforms = random.uniform(key, shape + (n_max,))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:544: in call_bind
ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:85: in process_call
out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
ans = self.f(*args, **self.kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:240: in uniform
bits = _random_bits(key, nbits, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:199: in _random_bits
if max_count >= onp.iinfo(onp.uint32).max:
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:256: in __bool__
def __bool__(self): return self.aval._bool(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = ShapedArray(bool[])
args = (Traced<ShapedArray(bool[]):JaxprTrace(level=-1/2)>,)
def error(self, *args):
> raise TypeError(concretization_err_msg(fun))
E TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/abstract_arrays.py:38: TypeError
Issue Analytics
- State:
- Created 5 years ago
- Reactions:1
- Comments:8 (8 by maintainers)
Top Results From Across the Web
How to sample a binomial random variable?
It is possible to sample a continuous random variable by finding the inverse CDF (F−1(x)), sampling from the uniform distribution u=U(0,1) and calculating...
Read more >python - Efficient sampling from a 'partial' binomial distribution
One way of sampling from this distribution is to sample a uniformly distributed number and apply the inverse of the CDF(obtained using the...
Read more >Binomial Distribution Questions and Answers - Study.com
Binomial Distribution Questions and Answers. Test your understanding with practice problems and step-by-step solutions. Browse through all study tools.
Read more >Sampling From the Binomial Distribution on a Computer
This article describes a method of sampling from the binomial dis- tribution (B (n, p) that appears less costly than the beta method....
Read more >Setting Up Binomial Probability Problems - YouTube
In this video I explain how to read through binomial probability problems, extract the important ... Answers to sample problems are given.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@hawkinsp It would be nice to have the functions
jax.random.binomial
andjax.random.multinomial
, likenumpy.random.binomial
andnumpy.random.multinomial
. Are there any current plans to add them?Thanks for the suggestion. I don’t think that would work as is because the size argument to
random.uniform
itself isn’t static. It depends onn_max = np.max(n)
, but if we have bothn
andshape
asstatic_args
, it should work fine! This might not be ideal to use as a default decorator because in cases wheren
can take many possible values, e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this and not reclaiming any performance benefits in return. Another not so ideal strategy might be to setn_max
to some high value or take in as static function argument so as not to force recompilation as often.