Sparse covariance matrix fails
See original GitHub issueHi! I am reproducing the tutorial for counts clustering.
Requesting covariance matrix in full shape (sparse = False
)
mu, cov = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo, ell, probes, sparse=False)
returns error
TypeError: _transpose() got an unexpected keyword argument 'axes'
with this line from angular_cl.py
:
cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).(n_ell * n_cls, n_ell * n_cls))
I changed the source line code to
cov_mat = np.transpose(cov_mat, axes=(0, 2, 1, 3)).(n_ell * n_cls, n_ell * n_cls))
and the error disappeared. I don’t know, probably it is related to newer versions of jax.numpy
.
The returned non-sparse matrix equals to jc.sparse.to_dense(cov_sparse)
.
Package versions are:
JAX version: 0.2.18
jax-cosmo version: 0.1rc7
By the way: I running the tutorial and the Cls calculations by jax cosmo is ~40 sec in opposition to CCL’s 0.2 sec for 1x1, 1x2 and 2x2 cross correlations. What I am doing wrong?
Thanks.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (3 by maintainers)
Top GitHub Comments
Not a lame question 😃 That’s a great question Well, it has to do with how JAX works, when you jit a function, you “compile” it for a specific shape and type of input, and it also wants its inputs to be something JAX internally understands in terms of simple arrays. So, if your function takes as an input some sort of complicated object that doesn’t have a translation into a simple representation JAX understands, it will fail. But it is fine to use that function inside of a larger function, with simple input/outputs. It also takes time to jit a function, so for instance, if you want to apply your cl function to various arrays of \ell, of different sizes, JAX would rejit the code everytime, because everytime your input arrays are of different sizes, it looks to JAX as a different function.
Looong story short, I think the best practice is to leave it to the end user to decide where and what to jit, so there is minimal or no jitting at all in the library itself.
In practice, imagine you want to run an MCMC, you would only jit the entire function that computes the likelihood for a given set of cosmological parameters.
Hi @EiffL I also noted a typo in
likelihood
module. Instead ofinclude_logdet
you should useignore_logdet
as a name of a parameter in the functiongaussian_log_likelihood
.