General curve fitting method
See original GitHub issueXarray should have a general curve-fitting function as part of its main API.
Motivation
Yesterday I wanted to fit a simple decaying exponential function to the data in a DataArray and realised there currently isn’t an immediate way to do this in xarray. You have to either pull out the .values
(losing the power of dask), or use apply_ufunc
(complicated).
This is an incredibly common, domain-agnostic task, so although I don’t think we should support various kinds of unusual optimisation procedures (which could always go in an extension package instead), I think a basic fitting method is within scope for the main library. There are SO questions asking how to achieve this.
We already have .polyfit
and polyval
anyway, which are more specific. (@AndrewWilliams3142 and @aulemahal I expect you will have thoughts on how implement this generally.)
Proposed syntax
I want something like this to work:
def exponential_decay(xdata, A=10, L=5):
return A*np.exp(-xdata/L)
# returns a dataset containing the optimised values of each parameter
fitted_params = da.fit(exponential_decay)
fitted_line = exponential_decay(da.x, A=fitted_params['A'], L=fitted_params['L'])
# Compare
da.plot(ax)
fitted_line.plot(ax)
It would also be nice to be able to fit in multiple dimensions. That means both for example fitting a 2D function to 2D data:
def hat(xdata, ydata, h=2, r0=1):
r = xdata**2 + ydata**2
return h*np.exp(-r/r0)
fitted_params = da.fit(hat)
fitted_hat = hat(da.x, da.y, h=fitted_params['h'], r0=fitted_params['r0'])
but also repeatedly fitting a 1D function to 2D data:
# da now has a y dimension too
fitted_params = da.fit(exponential_decay, fit_along=['x'])
# As fitted_params now has y-dependence, broadcasting means fitted_lines does too
fitted_lines = exponential_decay(da.x, A=fitted_params.A, L=fitted_params.L)
The latter would be useful for fitting the same curve to multiple model runs, but means we need some kind of fit_along
or dim
argument, which would default to all dims.
So the method docstring would end up like
def fit(self, f, fit_along=None, skipna=None, full=False, cov=False):
"""
Fits the function f to the DataArray.
Expects the function f to have a signature like
`result = f(*coords, **params)`
for example
`result_da = f(da.xcoord, da.ycoord, da.zcoord, A=5, B=None)`
The names of the `**params` kwargs will be used to name the output variables.
Returns
-------
fit_results - A single dataset which contains the variables (for each parameter in the fitting function):
`param1`
The optimised fit coefficients for parameter one.
`param1_residuals`
The residuals of the fit for parameter one.
...
"""
Questions
-
Should it wrap
scipy.optimise.curve_fit
, or reimplement it?Wrapping it is simpler, but as it just calls
least_squares
under the hood then reimplementing it would mean we could use the dask-powered version ofleast_squares
(likeda.polyfit does
). -
What form should we expect the curve-defining function to come in?
scipy.optimize.curve_fit
expects the curve to act asydata = f(xdata, *params) + eps
, but in xarray thenxdata
could be one or multiple coords or dims, not necessarily a single array. Might it work to require a signature likeresult_da = f(da.xcoord, da.ycoord, da.zcoord, ..., **params)
? Then the.fit
method would be work out how many coords to pass tof
based on the dimension of theda
and thefit_along
argument. But then the order of coord arguments in the signature off
would matter, which doesn’t seem very xarray-like. -
Is it okay to inspect parameters of the curve-defining function?
If we tell the user the curve-defining function has to have a signature like
da = func(*coords, **params)
, then we could read the names of the parameters by inspecting the function kwargs. Is that a good idea or might it end up being unreliable? Is theinspect
standard library module the right thing to use for that? This could also be used to provide default guesses for the fitting parameters.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:4
- Comments:9 (9 by maintainers)
I needed this functionality for a project, and piggy-backing off the last couple of comments decided the
curve_fit
wrapped byapply_ufunc
approach works quite well. I put together a PR in #4849. Any feedback welcome!+1 for just wrapping the existing functionality in SciPy for now. If we want a version of curve_fit
that supports dask, I would suggest implementing
curve_fit` with dask first, and then using that from xarray.I am OK with using
inspect
from the standard library for determining default parameter names.inspect.signature
is reasonably robust. But there should definitely be an optional argument for setting parameter names explicitly.