Predator-prey model example
See original GitHub issueHi I would like to send a PR to blackjax with this example from numpyro without using numpyro’s initialize_model
model function. However, it seems that with some seeds the window adaption outputs kernels with close to zero step_size
and inverse_mass_matrix
values. I don’t know exactly what I am doing wrong. The original example doesn’t seem to do any reparametrization either. Any input would be appreciated 😃
import jax
jax.config.update('jax_platform_name', 'cpu')
from collections import namedtuple
import jax.numpy as jnp
from jax.experimental.ode import odeint
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
params = namedtuple('params', ['uinit', 'vinit', 'alpha', 'betta',
'gamma', 'delta', 'u_sigma', 'v_sigma'])
uinit_prior = tfd.LogNormal(jnp.log(10.0), 1.0)
vinit_prior = tfd.LogNormal(jnp.log(10.0), 1.0)
alpha_prior = tfd.TruncatedNormal(1.00, 0.50, 0.0, jnp.inf)
betta_prior = tfd.TruncatedNormal(0.05, 0.05, 0.0, jnp.inf)
gamma_prior = tfd.TruncatedNormal(1.00, 0.50, 0.0, jnp.inf)
delta_prior = tfd.TruncatedNormal(0.05, 0.05, 0.0, jnp.inf)
u_sigma_prior = tfd.LogNormal(-1, 1)
v_sigma_prior = tfd.LogNormal(-1, 1)
def lotka_volterra_ODE(z, t, alpha, betta, gamma, delta):
u, v = z[0], z[1]
du_dt = (+alpha - betta * v) * u
dv_dt = (-gamma + delta * u) * v
return jnp.stack([du_dt, dv_dt])
def Model(ODE, rtol=1e-6, atol=1e-5, max_steps=1000):
def apply(z_init, year_args, alpha, betta, gamma, delta):
z = odeint(ODE, z_init, year_args, alpha, betta, gamma, delta,
rtol=rtol, atol=atol, mxstep=max_steps)
return z
return apply
def target_log_prob(model, data):
year_args = jnp.arange(len(data), dtype=jnp.float32)
def apply(params):
uinit_log_prob = uinit_prior.log_prob(params.uinit)
vinit_log_prob = vinit_prior.log_prob(params.vinit)
alpha_log_prob = alpha_prior.log_prob(params.alpha)
betta_log_prob = betta_prior.log_prob(params.betta)
gamma_log_prob = gamma_prior.log_prob(params.gamma)
delta_log_prob = delta_prior.log_prob(params.delta)
u_sigma_log_prob = u_sigma_prior.log_prob(params.u_sigma)
v_sigma_log_prob = v_sigma_prior.log_prob(params.v_sigma)
zinit = jnp.array([params.uinit, params.vinit])
args = (params.alpha, params.betta, params.gamma, params.delta)
z = model(zinit, year_args, *args)
sigmas = jnp.array([params.u_sigma, params.v_sigma])
log_likelihood = tfd.LogNormal(jnp.log(z), sigmas).log_prob(data)
return (uinit_log_prob + vinit_log_prob + alpha_log_prob +
betta_log_prob + gamma_log_prob + delta_log_prob +
u_sigma_log_prob + v_sigma_log_prob + log_likelihood.sum())
return apply
def sample(key):
keys = jax.random.split(key, 8)
uinit = uinit_prior.sample(seed=keys[0])
vinit = vinit_prior.sample(seed=keys[1])
alpha = alpha_prior.sample(seed=keys[2])
betta = betta_prior.sample(seed=keys[3])
gamma = gamma_prior.sample(seed=keys[4])
delta = delta_prior.sample(seed=keys[5])
u_sigma = u_sigma_prior.sample(seed=keys[6])
v_sigma = v_sigma_prior.sample(seed=keys[7])
return params(uinit, vinit, alpha, betta, gamma, delta, u_sigma, v_sigma)
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
if __name__ == "__main__":
from numpyro.examples.datasets import LYNXHARE, load_dataset
import blackjax
fetch = load_dataset(LYNXHARE, shuffle=False)[1]
year, data = fetch()
key = jax.random.PRNGKey(1)
key_init, key_warm, key_loop, key_pred = jax.random.split(key, 4)
warmup_steps = 1000
num_samples = 1000
model = Model(lotka_volterra_ODE)
log_prob = target_log_prob(model, data)
position = sample(key_init)
adapt = blackjax.window_adaptation(
algorithm=blackjax.nuts,
logprob_fn=log_prob,
is_mass_matrix_diagonal=False,
initial_step_size=1.0,
progress_bar=True)
state, kernel, kernel_params = adapt.run(key_warm, position, warmup_steps)
print('Kernel params', kernel_params)
states, infos = inference_loop(key_loop, kernel, state, num_samples)
Issue Analytics
- State:
- Created 10 months ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top Results From Across the Web
Predator-Prey Models - Duke Mathematics Department
Predator-Prey Models. Part 1: Background: Canadian Lynx and Snowshoe Hares · lions and gazelles, · birds and insects, · pandas and eucalyptus trees,...
Read more >Predator-prey model - Scholarpedia
Predator -prey models are arguably the building blocks of the bio- and ecosystems as biomasses are grown out of their resource masses.
Read more >Predator-Prey Models - CS UNM
A computational artifact that captures essential components and interactions (I.e. a computer program). • Encodes a theory about relevant mechanisms:.
Read more >Predator-Prey Model - YouTube
... https://learncheme.com/A simple model of the interaction between predator and prey that is set up very similarly to a kinetics model o.
Read more >Predator-Prey Model (Lotka-Volterra) Overview and Steady ...
Hi everyone! This video gives a brief overview of the Lotka-Volterra Predator - Prey model, and how to solve for the steady states....
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Np. Looking forward to the PR 😃
It was this line
v_sigma = v_sigma_prior.sample(seed=keys[7])
. I usedkeys[6]
repeating it foru_sigma
. But I fixed it while posting this. Yes the posterior predictive, I will update this issue in case I encounter another problem! Thank you.