Incorrect results when sampling from the prior
See original GitHub issueWhile going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.
What’s more, I played with some other synthetic examples and they also give unintuitive results, see further down.
Examples
Example from the rethinking
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def model():
μ <~ dist.Normal(178, 20)
σ <~ dist.Uniform(0, 50)
h <~ dist.Normal(μ, σ)
return h
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=model,
model_args=(),
num_samples=10_000
)
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)
sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])
plt.tight_layout()
Result
Expected
Synthetic example 1
In this example I sample an offset
from Uniform(0, 1)
.
Then I sample from Uniform(12 - offset, 12 + offset)
So I expect my samples to be distributed in range [11, 13]
But I get samples in range [-15, 15]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_1():
center = 12
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_1,
model_args=(),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result
Synthetic example 2
This is the same example as above, but center
variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_2(center):
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_2,
model_args=(12, ),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result
Expectation
For the examples 1
and 2
, here’s what I’d expect to get:
Environment
Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (7 by maintainers)
Top Results From Across the Web
Sampling Error - Voxco
Sampling error can be defined as a statistical error that occurs when a researcher fails to select a sample that is representative of...
Read more >Survey Error and Response Bias Problems - EdTech Books
Sampling error typically occurs when a specific subgroup within the population is under- or overrepresented in the sample. When this happens, the results ......
Read more >Nested Sampling Errors — dynesty 2.0.3 documentation
Nested Sampling has two main sources of error. The first is the statistical errors associated with uncertainties on the prior volume at a...
Read more >Get different results with different sampling order in Gibbs ...
The results are different due to Monte Carlo variability but you should eventually see the same outcome, e.g., the same marginals for θ1...
Read more >Chapter 7. Sampling Techniques
Before a sample is taken, we must first define the population to which we want to generalize our results. The population of interest...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thanks for the update! Looking forward to the NUTS sampler as well.
I’ve decided to first go through the theory, and then make a second pass implementing the examples.
I’ve just finished the book, so I’m going back to the code, which hopefully should go faster now.
There were a few places in the book where some advanced STAN featuers were used. I’m a bit worried about those, but we’ll see how it goes.
Great! If you remember which ones don’t hesitate to open issues now.