HMC fails to 'draw' from standard gaussian target
See original GitHub issueblackjax currently fails to target the standard normal distribution with HMC. in my example, the empirical mean is close to 0 but the empirical variance is close to 0.5. similar values arise when employing nuts with and without stan warmup.
running
import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt
potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state
inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
hmc_kernel = hmc.kernel(
potential,
step_size=step_size,
inverse_mass_matrix=inv_mass_matrix,
num_integration_steps=num_integration_steps
)
hmc_kernel = jax.jit(hmc_kernel)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)
samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
gives
[-0.00325776]
[0.47831735]
it seems the behaviour changed with 7be282232d87d93a15d5c5ff07648e2eb6f144c6. in particular, with the change in blackjax/inference/proposal.py
.
using p_accept = jnp.clip(jnp.exp(proposal.weight), a_max=1)
gives with the above code
[-0.00980166]
[0.9992622]
similar, with the parent commit (i.e., c6f75e91217f3336c6b0cea33d53efa5c8c843b5)
import numpy as np
import jax
import jax.numpy as jnp
import blackjax.hmc as hmc
import matplotlib.pyplot as plt
potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = hmc.new_state(initial_position, potential)
initial_state
inv_mass_matrix = 0.1 * jnp.ones_like(initial_position)
num_integration_steps=100
step_size=1e-2
params = hmc.HMCParameters(
step_size=step_size,
inv_mass_matrix=inv_mass_matrix,
num_integration_steps=num_integration_steps
)
hmc_kernel = hmc.kernel(potential, params)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 50_000)
samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
results in
[-0.00980166]
[0.9992622]
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (10 by maintainers)
Top Results From Across the Web
Callhome SRCs E3320001 and E3320002 may result ... - IBM
HMC generates E3320001 and/or E3320002 refcodes at what appears to be random times. This src usually indicates a connectivity problem with one ...
Read more >HMC (jittered) vs. NUTS on 1000-dimensional standard normal
A common suggestion to mitigate this problem is to jitter either the step size or the number of steps (see, e.g., Neal's MCMC...
Read more >Hamiltonian Dynamics Sampling Contents - probability.ca
Motivated by the failure of globally adaptive mass matrix HMC, as a global estimation of co-variance ma- trix of target distribution can be...
Read more >HMC: Reducing the number of rejections by not using leapfrog ...
The initial θ(0) is drawn from the target π(θ), and we have monitored the acceptance rate, the mean, across the 5000 integration legs/Markov ......
Read more >Randomized Hamiltonian Monte Carlo - arXiv
is also verified numerically in non-Gaussian target distributions. Fi- ... As is the case with other MCMC methods, HMC does not require knowing....
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
thank you very much for the fix ❤️
Thank you for your work, @rlouf!