plot_trace lines is unclear and it may yield unexpected results
See original GitHub issueDescribe the bug
The argument lines
for the function plot_trace
can give unexpected results. Moreover, the documentation is a bit nebulous.
To Reproduce A toy example is defined
import pymc3 as pm
import arviz as az
import numpy as np
# fake data
mu_real = 0
sigma_real = 1
n_samples = 150
Y = np.random.normal(loc=mu_real, scale=sigma_real, size=n_samples)
with pm.Model() as model:
mu = pm.Normal('mu', mu=0, sigma=10)
sigma = pm.HalfNormal('sigma', sigma=10)
likelihood = pm.Normal('likelihood', mu=mu, sigma=sigma, observed=Y)
trace = pm.sample()
As per documentation, the argument lines
accepts a tuple in the form (var_name, {‘coord’: selection}, [line, positions])
. So, the command
az.plot_trace(trace, lines=(('mu', {}, mu_real),))
yields correctly
I can also pass a list of tuples or a list of tuples and lists and it will work fine:
az.plot_trace(trace, lines=[('mu', {}, mu_real)]) # list of tuples
az.plot_trace(trace, lines=[['mu', {}, mu_real]]) # list of lists
az.plot_trace(trace, lines=[['mu', {}, mu_real], ('sigma', {}, sigma_real)]) # list of lists and tuples
however, I cannot pass a simple tuple because I will get a KeyError: 0
az.plot_trace(trace, lines=(['mu', {}, mu_real]))
az.plot_trace(trace, lines=(('mu', {}, mu_real)))
Also, I can pass a variable or coordinate name that do not exist and Arviz will not complain—but not lines will be plotted (here I would expect a warning)
az.plot_trace(trace, lines=[('hey', {}, mu_real)])
az.plot_trace(trace, lines=[('mu', {'hey'}, mu_real)])
The weird behavior happens when I pass a string:
az.plot_trace(trace, lines=[('mu', {}, 'hey')])
Expected behavior
The documentation could be improved and the function could check the inputs. In addition to what described above, the placeholder [line, positions]
in (var_name, {‘coord’: selection}, [line, positions])
should be something like [line_positions]
otherwise one may think (like myself 😃 ) that two values should be inserted (one for line
and one for positions
).
Additional context I am using Win10, fresh conda environment with PyMC3 and Arviz from master.
Possibly related https://github.com/pymc-devs/pymc3/issues/3495, https://github.com/pymc-devs/pymc3/issues/3497
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:5 (4 by maintainers)
Okay, I’d go with that. Also, I’d add the warning for the case of variable or coordinate name that do not exist (as pointed in issue description)
Sounds gook, one option for checking the list is to check the dtype of
line_values
after the atleast_1d to make sure the array contains numeric values. One possibility is to follow this SO answer.