Support saving as netcdf `InferenceData` that has `MultiIndex` coordinates
See original GitHub issueTell us about it
There are many situations in which it is very convenient to use pandas.MultiIndex
as coordinates of an xarray.DataArray
. The problem is that, at the moment, xarray
doesn’t provide a builtin way to save these indexes in netcdf format. Take for example:
from arviz.tests.helpers import create_model
idata = create_model()
idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
idata.to_netcdf("test.nc")
This raises a NotImplementedError
with the following traceback
NotImplementedError Traceback (most recent call last)
<ipython-input-15-43e455b97609> in <module>
3 idata = create_model()
4 idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
----> 5 idata.to_netcdf("test.nc")
~/anaconda3/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
442 if _compressible_dtype(values.dtype)
443 }
--> 444 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
445 data.close()
446 mode = "a"
~/anaconda3/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
1898 from ..backends.api import to_netcdf
1899
-> 1900 return to_netcdf(
1901 self,
1902 path,
~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
1070 # TODO: allow this work (setting up the file for writing array data)
1071 # to be parallelized with dask
-> 1072 dump_to_store(
1073 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
1074 )
~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
1117 variables, attrs = encoder(variables, attrs)
1118
-> 1119 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
1120
1121
~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
259 writer = ArrayWriter()
260
--> 261 variables, attributes = self.encode(variables, attributes)
262
263 self.set_attributes(attributes)
~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
348 # All NetCDF files get CF encoded by default, without this attempting
349 # to write times, for example, would fail.
--> 350 variables, attributes = cf_encoder(variables, attributes)
351 variables = {k: self.encode_variable(v) for k, v in variables.items()}
352 attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
857 _update_bounds_encoding(variables)
858
--> 859 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
860
861 # Remove attrs from bounds variables (issue #2921)
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
857 _update_bounds_encoding(variables)
858
--> 859 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
860
861 # Remove attrs from bounds variables (issue #2921)
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
262 A variable which has been encoded as described above.
263 """
--> 264 ensure_not_multiindex(var, name=name)
265
266 for coder in [
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in ensure_not_multiindex(var, name)
177 def ensure_not_multiindex(var, name=None):
178 if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex):
--> 179 raise NotImplementedError(
180 "variable {!r} is a MultiIndex, which cannot yet be "
181 "serialized to netCDF files "
NotImplementedError: variable 'sample' is a MultiIndex, which cannot yet be serialized to netCDF files (https://github.com/pydata/xarray/issues/1077). Use reset_index() to convert MultiIndex levels into coordinate variables instead.
Thoughts on implementation
I had a look at the mentioned xarray issue, and the approach suggested by @dcherian works (at least in the scenario that I had to work with a month ago). I think that it would be good to incorporate something like that into arviz.from_netcdf
and InferenceData.to_netcdf
. The basic idea is to convert the MultiIndex
into a simple array of integers, that are the codes of the MultiIndex
, and also add an attribute that states that the dimension/coordinates were originally a MultiIndex
. This attribute is also used to keep track of the level values and names of the original MultiIndex
. The modified datastructure can be serialized to netcdf
without any problems. The only thing to be aware of is that when the netcdf
is loaded, some work has to happen to rebuild the MultiIndex
from the original coordinates. I think that this small overhead is worth the benefit of bringing MultiIndex
support to arviz.
If you all agree that this would be valuable, I can write a PR.
Issue Analytics
- State:
- Created 10 months ago
- Reactions:1
- Comments:8 (4 by maintainers)
Top GitHub Comments
I’m not sure if we want to deviate from netcdf4 spec.
We could have functionality to transform from and to multiindex with suitable info in attrs. But it wouldn’t be then part of official spec.
I think this would be a great PR to xarray!