question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Add arviz plots for MCMC

See original GitHub issue

Hello Team, I’m working with Dr. @murphyk, Prof. @nipunbatra. and lab colleague @patel-zeel.

We noticed that to visualize trace plots using states, we need to write manual code like below,

burnin = 300
fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.plot(states.position[:, i])
    axi.set_title(f"$w_{i}$")
    axi.axvline(x=burnin, c="tab:red")
plt.show()

We’re proposing to use arviz plots for the same and add a helper function that can convert states to trace (arviz can take as an input). I have used create_trace_from_states() in this notebook and also made some arviz trace plots and rank plots.

Please find minimal implementation of create_trace_from_states below

def create_trace_from_states(states, info, burn_in=0):
    """
    input: states & info which is returned by blackjax kernel
    output: make trace which can be directly passed to arviz
    """
    samples = {}
    for param in states.position.keys():
        ndims = len(states.position[param].shape)
        if ndims == 1:
            """
            states have only single chain
            """
            divergence = info.is_divergent[burn_in:]
            samples[param] = states.position[param][burn_in:]

        elif ndims > 1:
            """
            states have states.position[param].shape[1] chains
            blackjax states format: n_samples × n_chains × dims0 × dim1 × ..
            arviz.convert_to_inference_data() requires: n_chains × n_samples × dims0 × dim1 × ..
            """
            # so we swap n_samples and n_chains
            samples[param] = jnp.swapaxes(states.position[param][burn_in:], 0, 1)

            # get divergences
            divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)

    trace_posterior = az.convert_to_inference_data(samples)
    trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats")
    trace = az.concat(trace_posterior, trace_sample_stats)
    return trace, samples

If this looks good, I am also happy to add arviz plots using the above function in any of the existing tutorials (for example in Bayesian Logistic Regression?)

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:1
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
rloufcommented, Sep 21, 2022

Sounds great! I hope I could, but these days I’m occupied in my school studies and not getting the time for open source 😦

No worries, school takes at least as much time as 1.5 full-time jobs, totally understand. Good luck!

1reaction
junpenglaocommented, Jun 15, 2022

Something along that line but with some additional filtering on states so we only returns the most relevant sampling diagnostic to the end user. I think we should aim to have a API similar to numpyro with options to:

  • display progression bar
  • vmap or pmap
Read more comments on GitHub >

github_iconTop Results From Across the Web

Source code for arviz.plots.traceplot
Source code for arviz.plots.traceplot. """Plot kde or histograms and values from MCMC samples.""" import warnings from typing import Any, Callable, List, ...
Read more >
Change pair_plot behavior for MCMC methods that do not ...
Divergences are plotted as expected. Show a warning and continue plotting. Additional context. Versions of arviz and other libraries used, ...
Read more >
ArviZ in depth: plot_trace - Oriol unraveled
plot_trace is one of the most common plots to assess the convergence of MCMC runs, therefore, it is also one of the most...
Read more >
Bayesian Modeling with PYMC3 - seekinginference
When looking at MCMC chains, we look for convergence. If the trace plot shows no trends but instead shows random variability in sampling ......
Read more >
TypeError when using arviz.from_cmdstanpy(fit) - General
Latest cmdstanpy version added many improvements on how metadata for variables and sample stats are handled, so we updated the converter in ArviZ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found