question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Question] Efficiently sampling from the binomial distribution

See original GitHub issue

I currently have a naive binomial sampler which uses jax.random.bernoulli under the hood:

def binomial(key, p, n=1, shape=()):
    p, n = _promote_shapes(p, n)
    shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n))
    n_max = np.max(n)
    uniforms = random.uniform(key, shape + (n_max,))
    n = np.expand_dims(n, axis=-1)
    p = np.expand_dims(p, axis=-1)
    mask = (np.arange(n_max) > n).astype(uniforms.dtype)
    p, uniforms = promote_shapes(p, uniforms)
    return np.sum(mask * lax.lt(uniforms, p), axis=-1, keepdims=False)

This works, but the biggest drawback is that it is not jittable due to the dynamic size argument passed to random.uniform (see error trace below). It is also wasteful in that it draws n_max uniform random floats and does an element-wise multiply with mask. I was wondering if there is a more efficient way to sample from the binomial distribution, using existing primitives. I am new to JAX, so any insights on improving on this implementation will be really helpful.

error trace
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../numpyro/distributions/distribution.py:47: in rvs
    return _rvs(self, *args, **kwargs)
../numpyro/distributions/distribution.py:36: in _rvs
    vals = _instance._rvs(*args)
../numpyro/distributions/discrete.py:22: in _rvs
    return binomial(self._random_state, p, n, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
    jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
    ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
    compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
    ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
    ans = self.f(*args, **self.kwargs)
../numpyro/distributions/util.py:285: in binomial
    uniforms = random.uniform(key, shape + (n_max,))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
    jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:544: in call_bind
    ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:85: in process_call
    out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
    ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
    compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
    ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
    ans = self.f(*args, **self.kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:240: in uniform
    bits = _random_bits(key, nbits, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:199: in _random_bits
    if max_count >= onp.iinfo(onp.uint32).max:
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:256: in __bool__
    def __bool__(self): return self.aval._bool(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ShapedArray(bool[])
args = (Traced<ShapedArray(bool[]):JaxprTrace(level=-1/2)>,)

    def error(self, *args):
>     raise TypeError(concretization_err_msg(fun))
E     TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/abstract_arrays.py:38: TypeError

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Reactions:1
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

4reactions
carlosgmartincommented, Oct 9, 2022

@hawkinsp It would be nice to have the functions jax.random.binomial and jax.random.multinomial, like numpy.random.binomial and numpy.random.multinomial. Are there any current plans to add them?

2reactions
neerajpradcommented, Mar 9, 2019

The usage would be something like sampler = jit(binomial, static_argnums=(3,)).

Thanks for the suggestion. I don’t think that would work as is because the size argument to random.uniform itself isn’t static. It depends on n_max = np.max(n), but if we have both n and shape as static_args, it should work fine! This might not be ideal to use as a default decorator because in cases where n can take many possible values, e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this and not reclaiming any performance benefits in return. Another not so ideal strategy might be to set n_max to some high value or take in as static function argument so as not to force recompilation as often.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to sample a binomial random variable?
It is possible to sample a continuous random variable by finding the inverse CDF (F−1(x)), sampling from the uniform distribution u=U(0,1) and calculating...
Read more >
python - Efficient sampling from a 'partial' binomial distribution
One way of sampling from this distribution is to sample a uniformly distributed number and apply the inverse of the CDF(obtained using the...
Read more >
Binomial Distribution Questions and Answers - Study.com
Binomial Distribution Questions and Answers. Test your understanding with practice problems and step-by-step solutions. Browse through all study tools.
Read more >
Sampling From the Binomial Distribution on a Computer
This article describes a method of sampling from the binomial dis- tribution (B (n, p) that appears less costly than the beta method....
Read more >
Setting Up Binomial Probability Problems - YouTube
In this video I explain how to read through binomial probability problems, extract the important ... Answers to sample problems are given.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found