question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Predator-prey model example

See original GitHub issue

Hi I would like to send a PR to blackjax with this example from numpyro without using numpyro’s initialize_model model function. However, it seems that with some seeds the window adaption outputs kernels with close to zero step_size and inverse_mass_matrix values. I don’t know exactly what I am doing wrong. The original example doesn’t seem to do any reparametrization either. Any input would be appreciated 😃

import jax
jax.config.update('jax_platform_name', 'cpu')
from collections import namedtuple
import jax.numpy as jnp
from jax.experimental.ode import odeint
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

params = namedtuple('params', ['uinit', 'vinit', 'alpha', 'betta',
                               'gamma', 'delta', 'u_sigma', 'v_sigma'])

uinit_prior = tfd.LogNormal(jnp.log(10.0), 1.0)
vinit_prior = tfd.LogNormal(jnp.log(10.0), 1.0)
alpha_prior = tfd.TruncatedNormal(1.00, 0.50, 0.0, jnp.inf)
betta_prior = tfd.TruncatedNormal(0.05, 0.05, 0.0, jnp.inf)
gamma_prior = tfd.TruncatedNormal(1.00, 0.50, 0.0, jnp.inf)
delta_prior = tfd.TruncatedNormal(0.05, 0.05, 0.0, jnp.inf)
u_sigma_prior = tfd.LogNormal(-1, 1)
v_sigma_prior = tfd.LogNormal(-1, 1)


def lotka_volterra_ODE(z, t, alpha, betta, gamma, delta):
    u, v = z[0], z[1]
    du_dt = (+alpha - betta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])


def Model(ODE, rtol=1e-6, atol=1e-5, max_steps=1000):
    def apply(z_init, year_args, alpha, betta, gamma, delta):
        z = odeint(ODE, z_init, year_args, alpha, betta, gamma, delta,
                   rtol=rtol, atol=atol, mxstep=max_steps)
        return z
    return apply


def target_log_prob(model, data):
    year_args = jnp.arange(len(data), dtype=jnp.float32)

    def apply(params):
        uinit_log_prob = uinit_prior.log_prob(params.uinit)
        vinit_log_prob = vinit_prior.log_prob(params.vinit)
        alpha_log_prob = alpha_prior.log_prob(params.alpha)
        betta_log_prob = betta_prior.log_prob(params.betta)
        gamma_log_prob = gamma_prior.log_prob(params.gamma)
        delta_log_prob = delta_prior.log_prob(params.delta)
        u_sigma_log_prob = u_sigma_prior.log_prob(params.u_sigma)
        v_sigma_log_prob = v_sigma_prior.log_prob(params.v_sigma)
        zinit = jnp.array([params.uinit, params.vinit])
        args = (params.alpha, params.betta, params.gamma, params.delta)
        z = model(zinit, year_args, *args)
        sigmas = jnp.array([params.u_sigma, params.v_sigma])
        log_likelihood = tfd.LogNormal(jnp.log(z), sigmas).log_prob(data)
        return (uinit_log_prob + vinit_log_prob + alpha_log_prob +
                betta_log_prob + gamma_log_prob + delta_log_prob +
                u_sigma_log_prob + v_sigma_log_prob + log_likelihood.sum())
    return apply


def sample(key):
    keys = jax.random.split(key, 8)
    uinit = uinit_prior.sample(seed=keys[0])
    vinit = vinit_prior.sample(seed=keys[1])
    alpha = alpha_prior.sample(seed=keys[2])
    betta = betta_prior.sample(seed=keys[3])
    gamma = gamma_prior.sample(seed=keys[4])
    delta = delta_prior.sample(seed=keys[5])
    u_sigma = u_sigma_prior.sample(seed=keys[6])
    v_sigma = v_sigma_prior.sample(seed=keys[7])
    return params(uinit, vinit, alpha, betta, gamma, delta, u_sigma, v_sigma)


def inference_loop(rng_key, kernel, initial_state, num_samples):

    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
    return states, infos


if __name__ == "__main__":
    from numpyro.examples.datasets import LYNXHARE, load_dataset
    import blackjax

    fetch = load_dataset(LYNXHARE, shuffle=False)[1]
    year, data = fetch()

    key = jax.random.PRNGKey(1)
    key_init, key_warm, key_loop, key_pred = jax.random.split(key, 4)
    warmup_steps = 1000
    num_samples = 1000

    model = Model(lotka_volterra_ODE)
    log_prob = target_log_prob(model, data)
    position = sample(key_init)
    adapt = blackjax.window_adaptation(
        algorithm=blackjax.nuts,
        logprob_fn=log_prob,
        is_mass_matrix_diagonal=False,
        initial_step_size=1.0,
        progress_bar=True)

    state, kernel, kernel_params = adapt.run(key_warm, position, warmup_steps)
    print('Kernel params', kernel_params)
    states, infos = inference_loop(key_loop, kernel, state, num_samples)

Issue Analytics

  • State:closed
  • Created 10 months ago
  • Reactions:1
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
rloufcommented, Nov 20, 2022

Np. Looking forward to the PR 😃

0reactions
oarriagacommented, Nov 20, 2022

It was this line v_sigma = v_sigma_prior.sample(seed=keys[7]). I used keys[6] repeating it for u_sigma. But I fixed it while posting this. Yes the posterior predictive, I will update this issue in case I encounter another problem! Thank you.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Predator-Prey Models - Duke Mathematics Department
Predator-Prey Models. Part 1: Background: Canadian Lynx and Snowshoe Hares · lions and gazelles, · birds and insects, · pandas and eucalyptus trees,...
Read more >
Predator-prey model - Scholarpedia
Predator -prey models are arguably the building blocks of the bio- and ecosystems as biomasses are grown out of their resource masses.
Read more >
Predator-Prey Models - CS UNM
A computational artifact that captures essential components and interactions (I.e. a computer program). • Encodes a theory about relevant mechanisms:.
Read more >
Predator-Prey Model - YouTube
... https://learncheme.com/A simple model of the interaction between predator and prey that is set up very similarly to a kinetics model o.
Read more >
Predator-Prey Model (Lotka-Volterra) Overview and Steady ...
Hi everyone! This video gives a brief overview of the Lotka-Volterra Predator - Prey model, and how to solve for the steady states....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found