Sampling issue with `jax.random.multivariate_normal`
See original GitHub issueSampling from a GP is giving unexpected results. For example, consider following code:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import tinygp
N = 1024
fig, ax = plt.subplots(3,3,figsize=(6.4*1.5, 4.8*1.5))
for seed in range(9):
key = jax.random.PRNGKey(seed)
x = 5 * jax.random.normal(key, shape=(N,1))
kernel = 4 * tinygp.kernels.ExpSquared(scale=0.5)
y = tinygp.GaussianProcess(kernel, x, diag=1.0).sample(key)
ax.ravel()[seed].scatter(x, y, s=5);
This code generates the following output:

If I replace tinygp.GaussianProcess(kernel, x, diag=1.0).sample(key) with jax.random.multivariate_normal(key, jnp.zeros(N), kernel(x, x)+jnp.eye(N)), I get the same result. But, if I change the sampling method as jax.random.multivariate_normal(key, jnp.zeros(N), kernel(x, x)+jnp.eye(N), method='svd'), I get drastically different results:

Issue Analytics
- State:
- Created a year ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
Sampling from multivariate normal distribution in JAX gives ...
In the jax.random module, most shapes must explicitly be tuples. So instead of shape 5000 , use (5000,) :
Read more >Multivariate Normal · Issue #1384 · google/jax - GitHub
I noticed that sampling from predefined distributions works a bit different in jax than it does in numpy. Is it possible to use...
Read more >jax.random.multivariate_normal - JAX documentation
Sample multivariate normal random values with given mean and covariance. Parameters. key ( Union [ Array , PRNGKeyArray ]) – a PRNG key...
Read more >Sampling Univariate Gausssian With Specific Mean And ...
Sampling Univariate Gausssian With Specific Mean And Standard Deviation Using Jax.Random.Normal. It represents the distribution of a multivariate random ...
Read more >Sampling from a Multivariate Normal Distribution
Warning: The sum of two normally distributed random variables does not need to be normally distributed (see below). The multivariate normal ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Sure, thank you!
I should say: feel free to open an issue on the JAX repo if you don’t like that, but this isn’t a
tinygpquestion!