from_tfp with multiple chains?
See original GitHub issueHow should from_tfp
be used with multiple chains? It looks like it is only compatible with one chain, and it treats the different chains as different variables.
Example
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_probability.python.edward2 as ed
import numpy as np
import arviz as az
dtype = np.float32
def unnormalized_log_prob(x):
return -x**2.
samples_gauss, _ = tfp.mcmc.sample_chain(
num_results=1000,
current_state=np.ones(20, dtype=dtype),
kernel=tfp.mcmc.HamiltonianMonteCarlo(
unnormalized_log_prob,
step_size=1.0,
num_leapfrog_steps=3),
num_burnin_steps=500,
parallel_iterations=4)
with tf.Session() as sess:
[samples_gauss_, ] = sess.run([samples_gauss, ])
az.from_tfp([samples_gauss_]).posterior
# It is expected to be a list with a variable (array) at each position
Output
<xarray.Dataset>
Dimensions: (chain: 1, draw: 1000, var_0_dim_0: 20)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* var_0_dim_0 (var_0_dim_0) int64 0 1 2 3 4 5 6 7 ... 12 13 14 15 16 17 18 19
Data variables:
var_0 (chain, draw, var_0_dim_0) float32 0.38988703 ... -0.04516393
Attributes:
...
Workaround
My array has dimensions (ndraws, nchains), therefore simply transposing the array allows to call:
from_array = az.convert_to_inference_data(samples_gauss_.T)
az.plot_trace(from_array);
Output
Which as it can be seen, works properly.
I don’t know if I am missing how to use the coords or dims parameter to somehow sort this out or if it is an implementation detail.
Issue Analytics
- State:
- Created 5 years ago
- Comments:11 (11 by maintainers)
Top Results From Across the Web
Transformers Prime Beast Hunters - 3/6 - Chain Of Command ...
Transformers Prime Beast Hunters - 3/6 - Chain Of Command (FULL Episode in HD) ... Look also: TOP 10 The strongest transformers from...
Read more >A Simple Hamiltonian Monte Carlo Example with TensorFlow ...
We plot the samples for each chain and indicate their mean and plus/minus 2 standard deviations from the mean. fig, ax = plt.subplots(2,...
Read more >Long-Chain Hydroxyacyl-CoA Dehydrogenase Deficiency ...
Distinguishing LCHAD deficiency from TFP deficiency requires identification of isolated long-chain 3-hydroxyacyl-CoA dehydrogenase ...
Read more >Long Chain Acyl Coenzyme A Dehydrogenase - an overview
Enzyme and mutation testing will differentiate isolated LCHAD deficiency from TFP deficiency. TFP deficiency is more severe, with a high risk for early ......
Read more >tfp.mcmc: Modern Markov Chain Monte Carlo Tools ... - arXiv
Multi -chain MCMC is intrinsically embarrassingly parallel—each chain ... MCMC toolbox does not require that TLP functions be built from TFP.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I am intend to say not supporting, I haven’t really seen any valid use case.
Should we support multiple chain dims? We could stack them before conversion to dataset