Add arviz plots for MCMC
See original GitHub issueHello Team, I’m working with Dr. @murphyk, Prof. @nipunbatra. and lab colleague @patel-zeel.
We noticed that to visualize trace plots using states
, we need to write manual code like below,
burnin = 300
fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
axi.plot(states.position[:, i])
axi.set_title(f"$w_{i}$")
axi.axvline(x=burnin, c="tab:red")
plt.show()
We’re proposing to use arviz plots for the same and add a helper function that can convert states
to trace
(arviz can take as an input). I have used create_trace_from_states()
in this notebook and also made some arviz trace plots and rank plots.
Please find minimal implementation of create_trace_from_states
below
def create_trace_from_states(states, info, burn_in=0):
"""
input: states & info which is returned by blackjax kernel
output: make trace which can be directly passed to arviz
"""
samples = {}
for param in states.position.keys():
ndims = len(states.position[param].shape)
if ndims == 1:
"""
states have only single chain
"""
divergence = info.is_divergent[burn_in:]
samples[param] = states.position[param][burn_in:]
elif ndims > 1:
"""
states have states.position[param].shape[1] chains
blackjax states format: n_samples × n_chains × dims0 × dim1 × ..
arviz.convert_to_inference_data() requires: n_chains × n_samples × dims0 × dim1 × ..
"""
# so we swap n_samples and n_chains
samples[param] = jnp.swapaxes(states.position[param][burn_in:], 0, 1)
# get divergences
divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)
trace_posterior = az.convert_to_inference_data(samples)
trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats")
trace = az.concat(trace_posterior, trace_sample_stats)
return trace, samples
If this looks good, I am also happy to add arviz plots using the above function in any of the existing tutorials (for example in Bayesian Logistic Regression?)
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:10 (10 by maintainers)
Top Results From Across the Web
Source code for arviz.plots.traceplot
Source code for arviz.plots.traceplot. """Plot kde or histograms and values from MCMC samples.""" import warnings from typing import Any, Callable, List, ...
Read more >Change pair_plot behavior for MCMC methods that do not ...
Divergences are plotted as expected. Show a warning and continue plotting. Additional context. Versions of arviz and other libraries used, ...
Read more >ArviZ in depth: plot_trace - Oriol unraveled
plot_trace is one of the most common plots to assess the convergence of MCMC runs, therefore, it is also one of the most...
Read more >Bayesian Modeling with PYMC3 - seekinginference
When looking at MCMC chains, we look for convergence. If the trace plot shows no trends but instead shows random variability in sampling ......
Read more >TypeError when using arviz.from_cmdstanpy(fit) - General
Latest cmdstanpy version added many improvements on how metadata for variables and sample stats are handled, so we updated the converter in ArviZ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
No worries, school takes at least as much time as 1.5 full-time jobs, totally understand. Good luck!
Something along that line but with some additional filtering on
states
so we only returns the most relevant sampling diagnostic to the end user. I think we should aim to have a API similar to numpyro with options to: