question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

"logit" parameter of jax.random.categorical is misnamed, actually a log probability

See original GitHub issue

summary

jax.random.categorical (introduced #1855) takes a parameter, logit, to specify the categorical distribution to sample from.

I believe this parameter is misnamed. It behaves like a log probability, not a logit.

potential resolutions

  1. rename the parameter to logprob or something similar
  2. change implementation and tests to behave like a logit
  3. change to something less ambiguous, like a probability p

background

Maybe people use “logit” in different ways, but according to Wikipedia, logit(p) = ln( p / (1 - p) ). This is also the behavior implemented in scipy.special.logit (and jax.scipy.special.logit).

example / demo

import jax
import jax.numpy as np

key = jax.random.PRNGKey(seed=1)

p = np.array([0.4, 0.6])
n = 10000

arg = np.log(p / (1 - p))
# or:
# arg = jax.scipy.special.logit(p)
samples = jax.random.categorical(key, logits=arg, shape=(n,))

print(np.unique(samples, return_counts=True))
# (DeviceArray([0, 1], dtype=int32), DeviceArray([3059, 6941], dtype=int32))

arg = np.log(p)
samples = jax.random.categorical(key, logits=arg, shape=(n,))

print(np.unique(samples, return_counts=True))
# (DeviceArray([0, 1], dtype=int32), DeviceArray([3998, 6002], dtype=int32))

as you can see the counts for log(p) follow the given distribution, but the counts using logit(p) do not.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

2reactions
shoyercommented, Apr 29, 2020

If you like the conventions of TensorFlow probability, note that uou can already use TFP’s Categorical with JAX. We probably don’t advertise this well enough!

Alternatively, you can convert from probabilities to unnormalized log-probabilities just by using log(), e.g., jax.random.categorical(key, logits=jnp.log(p), shape=(n,)).

I would prefer to document either or both of these approaches in jax.random.categorical rather than to add a redundant argument.

1reaction
grisaitiscommented, Apr 30, 2020

thanks! yup, i was going to close.

thanks @shoyer about the fyi with tfp/jax. very useful!

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.random.categorical - JAX documentation - Read the Docs
– Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
Read more >
Does an unbalanced sample matter when doing logistic ...
My problem is the following: whatever set of predictor variables I use, the classifications never get better than a specificity of 100% and...
Read more >
can anyone give a tiny example to explain the params of tf ...
Dense(output_size) , then use tf.random.categorical(logits, 1) to choose one or more. Is there a log or softmax operator inner tf.
Read more >
Untitled
A vector $\delta$ in the canonical parameter space of a regular full exponential family is a DOR if the log likelihood function $l$...
Read more >
An astronomer's introduction to NumPyro
At this point, NumPyro is probably the most mature JAX-based ... priors on our parameters using the numpyro.sample function and probability ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found