Refactor BaseEstimator.get_params for easier use in subclasses
See original GitHub issueBelow is a proposal for a minor refactoring of BaseEstimator.get_params
to make it easier to use in subclasses. Would like to hear your opinions 😃
If it sounds like a sensible idea, I would be glad to throw a PR.
Issue
Currently, invoking get_params
from custom subclasses can produce an incomplete list of parameters, depending on how you define the subclass __init__
.
For example, if you have, say, the following class:
from sklearn.base import BaseEstimator
class SomeSKLearnEstimator(BaseEstimator):
def __init__(self, param1=1, param2='foo'):
self.param1 = param1
self.param2 = param2
and want to make some subclass from it. You might want to do something like:
class MySubclass(SomeSKLearnEstimator):
def __init__(self, myparam=42, **kwargs):
super().__init__(**kwargs)
self.myparam = myparam
This however produces an incomplete parameter dict when calling get_params()
, which can be surprising:
x = MySubclass(myparam=42, param1=2, param2='bar')
print(x.get_params())
# {'myparam': 42}
You could work around it by copy-and-pasting the keyword arguments from the parent class and passing them via super()
:
class MySubclass(SomeSKLearnEstimator):
def __init__(self, myparam=42, param1=1, param2='foo'):
# These params can be a lot!
# Duplicated code.
super().__init__(param1=param1, param2=param2)
self.myparam = myparam
x = MySubclass(myparam=42, param1=2, param2='bar')
print(x.get_params())
# {'param1': 2, 'param2': 'bar', 'myparam': 42}
However, the parameters can be many, and this leads to code duplication.
Proposal
Refactor BaseEstimator.get_params
so that client code has access to the parent class’ __init__
and has more control over how to use it together with its own subclass parameters. Specifically:
- Make
_get_param_names
a static method that accepts a class. - Extract the core logic of
get_params
into an auxiliary_get_params_from
This would look like this:
class BaseEstimator:
"""Base class for all estimators in scikit-learn
Notes
-----
All estimators should specify all the parameters that can be set
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
"""
@staticmethod
def _get_param_names(klass):
"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
init = getattr(klass.__init__, 'deprecated_original', klass.__init__)
if init is object.__init__:
# No explicit constructor to introspect
return []
# introspect the constructor arguments to find the model parameters
# to represent
init_signature = inspect.signature(init)
# Consider the constructor parameters excluding 'self'
parameters = [p for p in init_signature.parameters.values()
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError("scikit-learn estimators should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
" follow this convention."
% (klass, init_signature))
# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])
def _get_params_from(self, names, deep=True):
"""Get parameters for this estimator from the given names"""
out = dict()
for key in names:
value = getattr(self, key, None)
if deep and hasattr(value, 'get_params'):
deep_items = value.get_params().items()
out.update((key + '__' + k, val) for k, val in deep_items)
out[key] = value
return out
def get_params(self, deep=True):
"""Get parameters for this estimator.
Parameters
----------
deep : boolean, optional
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : mapping of string to any
Parameter names mapped to their values.
"""
# The source code now also makes the get_params recipe obvious
names = self._get_param_names(self.__class__)
return self._get_params_from(names, deep=deep)
# ... rest of the class is abbreviated ...
This way, client code could do something like this:
class MySubclass(SomeSKLearnEstimator):
def __init__(self, myparam=42, **kwargs):
# Just pass **kwargs.
# No duplicated code.
super().__init__(**kwargs)
self.myparam = myparam
def get_params(self, deep=True):
# Gotta override get_params. But it shouldn't be hard.
# params from parent
parent_names = self._get_param_names(super())
params = self._get_params_from(parent_names, deep=deep)
# Add my params
names = self._get_param_names(self.__class__)
params.update(self._get_params_from(names, deep=deep))
# or just simply
# params['myparam'] = self.myparam
return params
x = MySubclass(myparam=42, param1=2, param2='bar')
print(x.get_params())
# {'param1': 2, 'param2': 'bar', 'myparam': 42}
Motivation
- Personal motivation:
- In the current project I’m working on, I would like to extend some sklearn classes via cooperative multiple inheritance where I want to override the
__init__
method, but still want toget_params
get the params from the parent class.
- In the current project I’m working on, I would like to extend some sklearn classes via cooperative multiple inheritance where I want to override the
- Also, other people seem to have encountered the same problem:
Issue Analytics
- State:
- Created 4 years ago
- Reactions:17
- Comments:17 (6 by maintainers)
Top GitHub Comments
I forgot to answer your question above.
Sadly,
super().some_classmethod
won’t do that.@classmethod
is essentiallysome_classmethod(type(obj), ...)
so it will be bound to the object class. You can workaround it by first unbinding the method, but it’s a bit hacky:Not that it matters now though, we now have to move
_get_param_names
out of the class anyway to reuse it ingaussian_process.kernels.Kernel
.I’m working on the PR now.
Open a PR. I’d be in support of some abstraction here, particularly if it can remove some duplication between BaseEstimator and
gaussian_process.kernels.Kernel
… 😃