map_blocks output inference problems
See original GitHub issueI am excited about using map_blocks
to overcome a long-standing challenge related to calculating climatologies / anomalies with dask arrays. However, I hit what feels like a bug. I don’t love how the new map_blocks
function does this:
The function will be first run on mocked-up data, that looks like ‘obj’ but has sizes 0, to determine properties of the returned object such as dtype, variable names, new dimensions and new indexes (if any).
The problem is that many functions will simply error on size 0 data. As in the example below
MCVE Code Sample
import xarray as xr
ds = xr.tutorial.load_dataset('rasm').chunk({'y': 20})
def calculate_anomaly(ds):
# needed to workaround xarray's check with zero dimensions
#if len(ds['time']) == 0:
# return ds
gb = ds.groupby("time.month")
clim = gb.mean(dim='T')
return gb - clim
xr.map_blocks(calculate_anomaly, ds)
Raises
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in _construct_dataarray(self, name)
1145 try:
-> 1146 variable = self._variables[name]
1147 except KeyError:
KeyError: 'time.month'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in infer_template(func, obj, *args, **kwargs)
77 try:
---> 78 template = func(*meta_args, **kwargs)
79 except Exception as e:
<ipython-input-40-d7b2b2978c29> in calculate_anomaly(ds)
5 # return ds
----> 6 gb = ds.groupby("time.month")
7 clim = gb.mean(dim='T')
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/common.py in groupby(self, group, squeeze, restore_coord_dims)
656 return self._groupby_cls(
--> 657 self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims
658 )
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/groupby.py in __init__(self, obj, group, squeeze, grouper, bins, restore_coord_dims, cut_kwargs)
298 )
--> 299 group = obj[group]
300 if len(group) == 0:
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in __getitem__(self, key)
1235 if hashable(key):
-> 1236 return self._construct_dataarray(key)
1237 else:
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in _construct_dataarray(self, name)
1148 _, name, variable = _get_virtual_variable(
-> 1149 self._variables, name, self._level_coords, self.dims
1150 )
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in _get_virtual_variable(variables, key, level_vars, dim_sizes)
157 else:
--> 158 data = getattr(ref_var, var_name).data
159 virtual_var = Variable(ref_var.dims, data)
AttributeError: 'IndexVariable' object has no attribute 'month'
The above exception was the direct cause of the following exception:
Exception Traceback (most recent call last)
<ipython-input-40-d7b2b2978c29> in <module>
8 return gb - clim
9
---> 10 xr.map_blocks(calculate_anomaly, ds)
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in map_blocks(func, obj, args, kwargs)
203 input_chunks = dataset.chunks
204
--> 205 template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs)
206 if isinstance(template, DataArray):
207 result_is_array = True
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in infer_template(func, obj, *args, **kwargs)
80 raise Exception(
81 "Cannot infer object returned from running user provided function."
---> 82 ) from e
83
84 if not isinstance(template, (Dataset, DataArray)):
Exception: Cannot infer object returned from running user provided function.
Problem Description
We should try to imitate what dask does in map_blocks
: https://docs.dask.org/en/latest/array-api.html#dask.array.map_blocks
Specifically:
- We should allow the user to override the checks by explicitly specifying output dtype and shape
- Maybe the check should be on small, rather than zero size, test data
Output of xr.show_versions()
INSTALLED VERSIONS
commit: None python: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 4.14.138+ machine: x86_64 processor: x86_64 byteorder: little LC_ALL: en_US.UTF-8 LANG: en_US.UTF-8 LOCALE: en_US.UTF-8 libhdf5: 1.10.5 libnetcdf: 4.6.2
xarray: 0.14.0 pandas: 0.25.3 numpy: 1.17.3 scipy: 1.3.2 netCDF4: 1.5.1.2 pydap: installed h5netcdf: 0.7.4 h5py: 2.10.0 Nio: None zarr: 2.3.2 cftime: 1.0.4.2 nc_time_axis: 1.2.0 PseudoNetCDF: None rasterio: 1.0.25 cfgrib: None iris: 2.2.0 bottleneck: 1.3.0 dask: 2.7.0 distributed: 2.7.0 matplotlib: 3.1.2 cartopy: 0.17.0 seaborn: 0.9.0 numbagg: None setuptools: 41.6.0.post20191101 pip: 19.3.1 conda: None pytest: 5.3.1 IPython: 7.9.0 sphinx: None
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:6 (6 by maintainers)
Top GitHub Comments
With #3816, this becomes
@rabernat How does this look to you?
This is why I didn’t do it for the first pass