question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Incorrect results when sampling from the prior

See original GitHub issue

While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.

What’s more, I played with some other synthetic examples and they also give unintuitive results, see further down.

Examples

Example from the rethinking

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def model():
    μ <~ dist.Normal(178, 20)
    σ <~ dist.Uniform(0, 50)
    h <~ dist.Normal(μ, σ)
    
    return h

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=model, 
    model_args=(), 
    num_samples=10_000
)

fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)

sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])

plt.tight_layout()

Result

image

Expected

image

Synthetic example 1

In this example I sample an offset from Uniform(0, 1). Then I sample from Uniform(12 - offset, 12 + offset) So I expect my samples to be distributed in range [11, 13] But I get samples in range [-15, 15]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_1():
    
    center = 12
    offset <~ dist.Uniform(0, 1)
    
    low = (center - offset)
    high = (center + offset)
    
    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_1, 
    model_args=(), 
    num_samples=10_000
)


ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Synthetic example 2

This is the same example as above, but center variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_2(center):
    
    offset <~ dist.Uniform(0, 1)
    
    low = (center - offset)
    high = (center + offset)
    
    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_2, 
    model_args=(12, ), 
    num_samples=10_000
)


ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Expectation

For the examples 1 and 2, here’s what I’d expect to get:

image

Environment

Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:13 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
elanmartcommented, May 5, 2021

Thanks for the update! Looking forward to the NUTS sampler as well.

I’ve decided to first go through the theory, and then make a second pass implementing the examples.

I’ve just finished the book, so I’m going back to the code, which hopefully should go faster now.

There were a few places in the book where some advanced STAN featuers were used. I’m a bit worried about those, but we’ll see how it goes.

0reactions
rloufcommented, May 6, 2021

Great! If you remember which ones don’t hesitate to open issues now.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Sampling Error - Voxco
Sampling error can be defined as a statistical error that occurs when a researcher fails to select a sample that is representative of...
Read more >
Survey Error and Response Bias Problems - EdTech Books
Sampling error typically occurs when a specific subgroup within the population is under- or overrepresented in the sample. When this happens, the results ......
Read more >
Nested Sampling Errors — dynesty 2.0.3 documentation
Nested Sampling has two main sources of error. The first is the statistical errors associated with uncertainties on the prior volume at a...
Read more >
Get different results with different sampling order in Gibbs ...
The results are different due to Monte Carlo variability but you should eventually see the same outcome, e.g., the same marginals for θ1...
Read more >
Chapter 7. Sampling Techniques
Before a sample is taken, we must first define the population to which we want to generalize our results. The population of interest...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found