question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

numpyro and blackjax samplers producing different results

See original GitHub issue

Bug Description

I followed the use-with-numpyro notebook to get a model that works with numpyro’s sampler running on blackjax. The model runs (quickly) but the values it produces are well off. It is suspected this is due to a poor choice of step size and mass matrix.

Steps/Code to Reproduce

The code requires external data, so it is best to clone the repo if the problem isn’t immediately solvable. The working numpyro code is here and the attempt using blackjax is here. The numpyro model and blackjax code is also below.

def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
    sd_att = numpyro.sample(
        "sd_att",
        dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
    )
    sd_def = numpyro.sample(
        "sd_def",
        dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
    )

    home = numpyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage

    nt = len(np.unique(home_id))

    # team-specific model parameters
    with numpyro.plate("plate_teams", nt):
        attack = numpyro.sample("attack", dist.Normal(0, sd_att))
        defend = numpyro.sample("defend", dist.Normal(0, sd_def))

    # likelihood
    theta1 = jnp.exp(alpha + home + attack[home_id] - defend[away_id])
    theta2 = jnp.exp(alpha + attack[away_id] - defend[home_id])

    with numpyro.plate("data", len(home_id)):
        numpyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        numpyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)


rng_key = random.PRNGKey(0)

# translate the model into a log-probability function
init_params, potential_fn_gen, *_ = initialize_model(
    rng_key,
    model,
    model_args=(
        train["Home_id"].values,
        train["Away_id"].values,
        train["score1"].values,
        train["score2"].values,
    ),
    dynamic_args=True,
)

logprob = lambda position: -potential_fn_gen(
    train["Home_id"].values,
    train["Away_id"].values,
    train["score1"].values,
    train["score2"].values,
)(position)

initial_position = init_params.z
initial_state = nuts.new_state(initial_position, logprob)

# run the window adaptation (warmup)
kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
    logprob, step_size, inverse_mass_matrix
)

last_state, (step_size, inverse_mass_matrix), _ = stan_warmup.run(
    rng_key, kernel_factory, initial_state, 1000
)


@partial(jax.jit, static_argnums=(1, 3))
def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Build the kernel using the step size and inverse mass matrix returned from the window adaptation
kernel = kernel_factory(step_size, inverse_mass_matrix)

# Sample from the posterior distribution
states, infos = inference_loop(rng_key, kernel, last_state, 100_000)

Expected Results

As an example, for the “home” parameter, should be around 0.2-0.3. The “sd_att” and “sd_def” parameters should be constrained by the model to be positive (using FoldedDistribution()).

Actual Results

The parameter values are well off: Screenshot 2021-10-29 at 16 28 58

Versions

BlackJAX 0.2.1 numpyro 0.7.2 Python 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:19) [Clang 9.0.0 (tags/RELEASE_900/final)] Jax 0.2.17 Jaxlib 0.1.67

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
rloufcommented, Jun 10, 2022

Yay, great news! The values are negative because numpyro transforms the variables so that their values are between -infty and +infty before sampling (NUTS works much better with this kind of values). The logpdf that numpyro returns works with these “unconstrained” variables, so this is what blackjax returns. To get values of the untransformed variablles you thus need to apply the inverse of the transformation that numpyro used, which I guess here is the absolute value.

Does that make sense ?

1reaction
jeremiecoulloncommented, Jun 9, 2022

Hello! I’m just wondering if this is still an issue, or has this been fixed by recent releases?

Read more comments on GitHub >

github_iconTop Results From Across the Web

An astronomer's introduction to NumPyro
In practice, this means that you can (relatively) easily combine different JAX libraries to develop your preferred workflow. For example, you ...
Read more >
MCMC for big datasets -- faster sampling with JAX and the GPU
Currently, PyMC uses numpyro's NUTS sampler to do sampling with JAX. ... Still, I hope you'll agree that there are some interesting results....
Read more >
Pymc3 produces different results than Stan/NumPyro - v3
PyMC3 produces weird results for this multiple linear regression model. I tried the same model with Stan and Numpyro. The results of both...
Read more >
Getting Started with NumPyro
As a result, Pyro and PyTorch users can rely on the same API and batching ... Note that by doing so, NumPyro runs...
Read more >
BlackJAX test
Sampling using PyMC JAX Numpyro NUTS sampler¶. In [8]:.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found