poor optimization of slicing operations on netCDF-backed xarray datasets
See original GitHub issueWe have encountered a performance issue in xarray (pydata/xarray#1396) which appears to be a dask optimization issue at its core. In correspondence on the xarray mailing list, @mrocklin recommended I open an issue here.
The scenario is the following: I have a collection of large netCDF files on a disk. Each file can be though of as representing geospatial data: the first dimension is time, and the higher dimensions are spatial dimensions (e.g. x and y). I want to open the data as an xarray dataset, concatenating along the time dimension, select a single point in space, and lazily load the data for this point into memory, creating a timeseries. (This is an extremely common workflow in climate science and related fields.) xarray handles this by opening the files, wrapping the data in dask arrays, and then concatenating those arrays together, under the hood calling dask.array.concatenate
.
The problem is that the entire array from each file is being read eagerly before the slicing operation. This incurs a huge performance and memory penalty for what should naively be a very cheap operation. For very large datasets, it leads to memory overruns and total breakdown of the workflow.
Below is a reproducible example which requires xarray. A more detailed version of this example can be found in this notebook
import numpy as np
import xarray as xr
### Create some test data
# This test data is typical of many climate datasets. Each file has 12 entires in the time
# dimension, representing e.g. each month of the year. There are 10 separate files,
# representing e.g. consecutive years. The code below writes about 9GB of data to disk.
# Change data_dir to a place where you can put 9GB of data
data_dir = '/data/scratch/rpa/xarray_test_data/'
nfiles = 10
nt = 12
for n in range(nfiles):
data = np.random.rand(nt,1000,10000)
time = (n*nt) + np.arange(nt)
da = xr.DataArray(data, dims=['time', 'y', 'x'],
coords={'time': time})
da.to_dataset(name='data').to_netcdf(data_dir + 'test_data.%03d.nc' % n)
### open the data via xarray
all_files = glob(data_dir + '*.nc')
ds = xr.open_mfdataset(all_files)
### Time loading the whole dataset into memory
# As a point of reference, let's see how long it takes to read all the data from disk and load it
# into memory. (We make a deep copy of the dataset first in order to preserve the original
# chunked dataset for later use.)
ds_copy = ds.copy(deep=True)
with ProgressBar():
# on a high-performance workstation, this takes 15-20s
%time _ = ds_copy.load()
### Time loading a single point
# This is the crux issue. We now want to just extract a timeseries of single point in space.
# If the loading is lazy, it should go much faster and use much less memory compared to
# the example above.
y, x = 200, 300
with ProgressBar():
%time ts = ds.data[:, y, x].load()
The extraction of a single point takes about half the time of loading the whole dataset! Furthermore, by monitoring memory usage via top, I observe that resident memory increases strongly as the command executes. If I just loop through all_files
manually and extract the point, it goes 500-1000 times faster and uses a tiny fraction of the memory. I explore this, and many other permutations (including bypassing xarray and doing dask.array.concatenate
directly on the underling dask arrays of each dataset) in my notebook. I believe this is a dask issue, not an xarray issue.
Here is the partial dot graph for ds.data[:, y, x].dask
Here is the partial dot graph for result._optimize(result.dask, result._keys())
where result = ds.data[:, y, x].dask
The root cause of the issue, in the words of @mrocklin, is that “it looks like there is a getarray followed by a getitem. We’ll have to dive into why these two weren’t fused into a single gatarray/getitem call.”
>>> print(xarray.__version__)
0.9.5-9-gd5c7e06
>>> print(dask.__version__)
0.14.3+19.gbcd0426
Issue Analytics
- State:
- Created 6 years ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top GitHub Comments
This has nothing to do with xarray, it’s an issue with the internal optimizations of dask.
See #2364.