question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

HMC fails to 'draw' from standard gaussian target

See original GitHub issue

blackjax currently fails to target the standard normal distribution with HMC. in my example, the empirical mean is close to 0 but the empirical variance is close to 0.5. similar values arise when employing nuts with and without stan warmup.

running

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state

inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2

hmc_kernel = hmc.kernel(
    potential,
    step_size=step_size,
    inverse_mass_matrix=inv_mass_matrix,
    num_integration_steps=num_integration_steps
)

hmc_kernel = jax.jit(hmc_kernel)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()

gives

[-0.00325776]
[0.47831735]

it seems the behaviour changed with 7be282232d87d93a15d5c5ff07648e2eb6f144c6. in particular, with the change in blackjax/inference/proposal.py. using p_accept = jnp.clip(jnp.exp(proposal.weight), a_max=1) gives with the above code

[-0.00980166]
[0.9992622]

similar, with the parent commit (i.e., c6f75e91217f3336c6b0cea33d53efa5c8c843b5)

import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt

potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state

inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
params = hmc.HMCParameters(
    step_size=step_size,
    inv_mass_matrix=inv_mass_matrix,
    num_integration_steps=num_integration_steps
)
hmc_kernel = hmc.kernel(potential, params)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)

samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()

results in

[-0.00980166]
[0.9992622]

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
wiepcommented, Jun 9, 2021

thank you very much for the fix ❤️

1reaction
hrieblcommented, Jun 8, 2021

Thank you for your work, @rlouf!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Callhome SRCs E3320001 and E3320002 may result ... - IBM
HMC generates E3320001 and/or E3320002 refcodes at what appears to be random times. This src usually indicates a connectivity problem with one ...
Read more >
HMC (jittered) vs. NUTS on 1000-dimensional standard normal
A common suggestion to mitigate this problem is to jitter either the step size or the number of steps (see, e.g., Neal's MCMC...
Read more >
Hamiltonian Dynamics Sampling Contents - probability.ca
Motivated by the failure of globally adaptive mass matrix HMC, as a global estimation of co-variance ma- trix of target distribution can be...
Read more >
HMC: Reducing the number of rejections by not using leapfrog ...
The initial θ(0) is drawn from the target π(θ), and we have monitored the acceptance rate, the mean, across the 5000 integration legs/Markov ......
Read more >
Randomized Hamiltonian Monte Carlo - arXiv
is also verified numerically in non-Gaussian target distributions. Fi- ... As is the case with other MCMC methods, HMC does not require knowing....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found