question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Sampling issue with `jax.random.multivariate_normal`

See original GitHub issue

Sampling from a GP is giving unexpected results. For example, consider following code:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import tinygp

N = 1024

fig, ax = plt.subplots(3,3,figsize=(6.4*1.5, 4.8*1.5))
for seed in range(9):
  key = jax.random.PRNGKey(seed)
  x = 5 * jax.random.normal(key, shape=(N,1))
  kernel = 4 * tinygp.kernels.ExpSquared(scale=0.5)
  y = tinygp.GaussianProcess(kernel, x, diag=1.0).sample(key)
  ax.ravel()[seed].scatter(x, y, s=5);

This code generates the following output: image

If I replace tinygp.GaussianProcess(kernel, x, diag=1.0).sample(key) with jax.random.multivariate_normal(key, jnp.zeros(N), kernel(x, x)+jnp.eye(N)), I get the same result. But, if I change the sampling method as jax.random.multivariate_normal(key, jnp.zeros(N), kernel(x, x)+jnp.eye(N), method='svd'), I get drastically different results:

image

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
patel-zeelcommented, Mar 30, 2022

Sure, thank you!

1reaction
dfmcommented, Mar 29, 2022

I should say: feel free to open an issue on the JAX repo if you don’t like that, but this isn’t a tinygp question!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Sampling from multivariate normal distribution in JAX gives ...
In the jax.random module, most shapes must explicitly be tuples. So instead of shape 5000 , use (5000,) :
Read more >
Multivariate Normal · Issue #1384 · google/jax - GitHub
I noticed that sampling from predefined distributions works a bit different in jax than it does in numpy. Is it possible to use...
Read more >
jax.random.multivariate_normal - JAX documentation
Sample multivariate normal random values with given mean and covariance. Parameters. key ( Union [ Array , PRNGKeyArray ]) – a PRNG key...
Read more >
Sampling Univariate Gausssian With Specific Mean And ...
Sampling Univariate Gausssian With Specific Mean And Standard Deviation Using Jax.Random.Normal. It represents the distribution of a multivariate random ...
Read more >
Sampling from a Multivariate Normal Distribution
Warning: The sum of two normally distributed random variables does not need to be normally distributed (see below). The multivariate normal ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found