"logit" parameter of jax.random.categorical is misnamed, actually a log probability
See original GitHub issuesummary
jax.random.categorical
(introduced #1855) takes a parameter, logit
, to specify the categorical distribution to sample from.
I believe this parameter is misnamed. It behaves like a log probability, not a logit.
potential resolutions
- rename the parameter to
logprob
or something similar - change implementation and tests to behave like a logit
- change to something less ambiguous, like a probability
p
background
Maybe people use “logit” in different ways, but according to Wikipedia, logit(p) = ln( p / (1 - p) ). This is also the behavior implemented in scipy.special.logit
(and jax.scipy.special.logit
).
example / demo
import jax
import jax.numpy as np
key = jax.random.PRNGKey(seed=1)
p = np.array([0.4, 0.6])
n = 10000
arg = np.log(p / (1 - p))
# or:
# arg = jax.scipy.special.logit(p)
samples = jax.random.categorical(key, logits=arg, shape=(n,))
print(np.unique(samples, return_counts=True))
# (DeviceArray([0, 1], dtype=int32), DeviceArray([3059, 6941], dtype=int32))
arg = np.log(p)
samples = jax.random.categorical(key, logits=arg, shape=(n,))
print(np.unique(samples, return_counts=True))
# (DeviceArray([0, 1], dtype=int32), DeviceArray([3998, 6002], dtype=int32))
as you can see the counts for log(p)
follow the given distribution, but the counts using logit(p)
do not.
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
jax.random.categorical - JAX documentation - Read the Docs
– Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
Read more >Does an unbalanced sample matter when doing logistic ...
My problem is the following: whatever set of predictor variables I use, the classifications never get better than a specificity of 100% and...
Read more >can anyone give a tiny example to explain the params of tf ...
Dense(output_size) , then use tf.random.categorical(logits, 1) to choose one or more. Is there a log or softmax operator inner tf.
Read more >Untitled
A vector $\delta$ in the canonical parameter space of a regular full exponential family is a DOR if the log likelihood function $l$...
Read more >An astronomer's introduction to NumPyro
At this point, NumPyro is probably the most mature JAX-based ... priors on our parameters using the numpyro.sample function and probability ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
If you like the conventions of TensorFlow probability, note that uou can already use TFP’s
Categorical
with JAX. We probably don’t advertise this well enough!Alternatively, you can convert from probabilities to unnormalized log-probabilities just by using
log()
, e.g.,jax.random.categorical(key, logits=jnp.log(p), shape=(n,))
.I would prefer to document either or both of these approaches in
jax.random.categorical
rather than to add a redundant argument.thanks! yup, i was going to close.
thanks @shoyer about the fyi with tfp/jax. very useful!