numpyro and blackjax samplers producing different results
See original GitHub issueBug Description
I followed the use-with-numpyro notebook to get a model that works with numpyro’s sampler running on blackjax. The model runs (quickly) but the values it produces are well off. It is suspected this is due to a poor choice of step size and mass matrix.
Steps/Code to Reproduce
The code requires external data, so it is best to clone the repo if the problem isn’t immediately solvable. The working numpyro code is here and the attempt using blackjax is here. The numpyro model and blackjax code is also below.
def model(home_id, away_id, score1_obs=None, score2_obs=None):
# priors
alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
sd_att = numpyro.sample(
"sd_att",
dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
)
sd_def = numpyro.sample(
"sd_def",
dist.FoldedDistribution(dist.StudentT(3.0, 0.0, 2.5)),
)
home = numpyro.sample("home", dist.Normal(0.0, 1.0)) # home advantage
nt = len(np.unique(home_id))
# team-specific model parameters
with numpyro.plate("plate_teams", nt):
attack = numpyro.sample("attack", dist.Normal(0, sd_att))
defend = numpyro.sample("defend", dist.Normal(0, sd_def))
# likelihood
theta1 = jnp.exp(alpha + home + attack[home_id] - defend[away_id])
theta2 = jnp.exp(alpha + attack[away_id] - defend[home_id])
with numpyro.plate("data", len(home_id)):
numpyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
numpyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)
rng_key = random.PRNGKey(0)
# translate the model into a log-probability function
init_params, potential_fn_gen, *_ = initialize_model(
rng_key,
model,
model_args=(
train["Home_id"].values,
train["Away_id"].values,
train["score1"].values,
train["score2"].values,
),
dynamic_args=True,
)
logprob = lambda position: -potential_fn_gen(
train["Home_id"].values,
train["Away_id"].values,
train["score1"].values,
train["score2"].values,
)(position)
initial_position = init_params.z
initial_state = nuts.new_state(initial_position, logprob)
# run the window adaptation (warmup)
kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
logprob, step_size, inverse_mass_matrix
)
last_state, (step_size, inverse_mass_matrix), _ = stan_warmup.run(
rng_key, kernel_factory, initial_state, 1000
)
@partial(jax.jit, static_argnums=(1, 3))
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
# Build the kernel using the step size and inverse mass matrix returned from the window adaptation
kernel = kernel_factory(step_size, inverse_mass_matrix)
# Sample from the posterior distribution
states, infos = inference_loop(rng_key, kernel, last_state, 100_000)
Expected Results
As an example, for the “home” parameter, should be around 0.2-0.3. The “sd_att” and “sd_def” parameters should be constrained by the model to be positive (using FoldedDistribution()
).
Actual Results
The parameter values are well off:
Versions
BlackJAX 0.2.1 numpyro 0.7.2 Python 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:19) [Clang 9.0.0 (tags/RELEASE_900/final)] Jax 0.2.17 Jaxlib 0.1.67
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (6 by maintainers)
Top GitHub Comments
Yay, great news! The values are negative because numpyro transforms the variables so that their values are between -infty and +infty before sampling (NUTS works much better with this kind of values). The logpdf that numpyro returns works with these “unconstrained” variables, so this is what blackjax returns. To get values of the untransformed variablles you thus need to apply the inverse of the transformation that numpyro used, which I guess here is the absolute value.
Does that make sense ?
Hello! I’m just wondering if this is still an issue, or has this been fixed by recent releases?