question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Refactor BaseEstimator.get_params for easier use in subclasses

See original GitHub issue

Below is a proposal for a minor refactoring of BaseEstimator.get_params to make it easier to use in subclasses. Would like to hear your opinions 😃 If it sounds like a sensible idea, I would be glad to throw a PR.

Issue

Currently, invoking get_params from custom subclasses can produce an incomplete list of parameters, depending on how you define the subclass __init__.

For example, if you have, say, the following class:

from sklearn.base import BaseEstimator


class SomeSKLearnEstimator(BaseEstimator):
    def __init__(self, param1=1, param2='foo'):
        self.param1 = param1
        self.param2 = param2

and want to make some subclass from it. You might want to do something like:

class MySubclass(SomeSKLearnEstimator):
    def __init__(self, myparam=42, **kwargs):
        super().__init__(**kwargs)
        self.myparam = myparam

This however produces an incomplete parameter dict when calling get_params(), which can be surprising:

x = MySubclass(myparam=42, param1=2, param2='bar')
print(x.get_params())
# {'myparam': 42}

You could work around it by copy-and-pasting the keyword arguments from the parent class and passing them via super():

class MySubclass(SomeSKLearnEstimator):
    def __init__(self, myparam=42, param1=1, param2='foo'):
        # These params can be a lot!
        # Duplicated code.
        super().__init__(param1=param1, param2=param2)
        self.myparam = myparam


x = MySubclass(myparam=42, param1=2, param2='bar')
print(x.get_params())
# {'param1': 2, 'param2': 'bar', 'myparam': 42}

However, the parameters can be many, and this leads to code duplication.

Proposal

Refactor BaseEstimator.get_params so that client code has access to the parent class’ __init__ and has more control over how to use it together with its own subclass parameters. Specifically:

  1. Make _get_param_names a static method that accepts a class.
  2. Extract the core logic of get_params into an auxiliary _get_params_from

This would look like this:

class BaseEstimator:
    """Base class for all estimators in scikit-learn
    Notes
    -----
    All estimators should specify all the parameters that can be set
    at the class level in their ``__init__`` as explicit keyword
    arguments (no ``*args`` or ``**kwargs``).
    """

    @staticmethod
    def _get_param_names(klass):
        """Get parameter names for the estimator"""
        # fetch the constructor or the original constructor before
        # deprecation wrapping if any
        init = getattr(klass.__init__, 'deprecated_original', klass.__init__)
        if init is object.__init__:
            # No explicit constructor to introspect
            return []

        # introspect the constructor arguments to find the model parameters
        # to represent
        init_signature = inspect.signature(init)
        # Consider the constructor parameters excluding 'self'
        parameters = [p for p in init_signature.parameters.values()
                      if p.name != 'self' and p.kind != p.VAR_KEYWORD]
        for p in parameters:
            if p.kind == p.VAR_POSITIONAL:
                raise RuntimeError("scikit-learn estimators should always "
                                   "specify their parameters in the signature"
                                   " of their __init__ (no varargs)."
                                   " %s with constructor %s doesn't "
                                   " follow this convention."
                                   % (klass, init_signature))
        # Extract and sort argument names excluding 'self'
        return sorted([p.name for p in parameters])
    
    def _get_params_from(self, names, deep=True):
        """Get parameters for this estimator from the given names"""
        out = dict()
        for key in names:
            value = getattr(self, key, None)
            if deep and hasattr(value, 'get_params'):
                deep_items = value.get_params().items()
                out.update((key + '__' + k, val) for k, val in deep_items)
            out[key] = value
        return out

    def get_params(self, deep=True):
        """Get parameters for this estimator.
        Parameters
        ----------
        deep : boolean, optional
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.
        Returns
        -------
        params : mapping of string to any
            Parameter names mapped to their values.
        """
        # The source code now also makes the get_params recipe obvious
        names = self._get_param_names(self.__class__)
        return self._get_params_from(names, deep=deep)
    
    # ... rest of the class is abbreviated ...

This way, client code could do something like this:

class MySubclass(SomeSKLearnEstimator):
    def __init__(self, myparam=42, **kwargs):
        # Just pass **kwargs.
        # No duplicated code.
        super().__init__(**kwargs)
        self.myparam = myparam
        
    def get_params(self, deep=True):
        # Gotta override get_params. But it shouldn't be hard.
        
        # params from parent
        parent_names = self._get_param_names(super())
        params = self._get_params_from(parent_names, deep=deep)
        
        # Add my params
        names = self._get_param_names(self.__class__)
        params.update(self._get_params_from(names, deep=deep))
        # or just simply
        # params['myparam'] = self.myparam
        
        return params


x = MySubclass(myparam=42, param1=2, param2='bar')
print(x.get_params())
# {'param1': 2, 'param2': 'bar', 'myparam': 42}

Motivation

  • Personal motivation:
    • In the current project I’m working on, I would like to extend some sklearn classes via cooperative multiple inheritance where I want to override the __init__ method, but still want to get_params get the params from the parent class.
  • Also, other people seem to have encountered the same problem:
    1. here
    2. and here

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:17
  • Comments:17 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
alegonzcommented, Jun 24, 2019

I forgot to answer your question above.

But isn’t that what super().some_classmethod would do?

Sadly, super().some_classmethod won’t do that. @classmethod is essentially some_classmethod(type(obj), ...) so it will be bound to the object class. You can workaround it by first unbinding the method, but it’s a bit hacky:

from inspect import signature


class Foo:
    def __init__(self, x=1, y=2):
        pass
    
    @classmethod
    def some_classmethod(cls):
        print("__init__ is", signature(cls.__init__))
        

class Bar(Foo):
    def __init__(self, a=3, b=4):
        pass
    
    def some_method(self):
        super().some_classmethod()
        
    def hacky_method(self):
        super().some_classmethod.__func__(super())
        

b = Bar()
b.some_method()  # prints: __init__ is (self, a=3, b=4)
b.hacky_method() # prints: __init__ is (x=1, y=2)

Not that it matters now though, we now have to move _get_param_names out of the class anyway to reuse it in gaussian_process.kernels.Kernel .

I’m working on the PR now.

1reaction
jnothmancommented, Jun 23, 2019

Open a PR. I’d be in support of some abstraction here, particularly if it can remove some duplication between BaseEstimator and gaussian_process.kernels.Kernel… 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Sklearn Pipeline - How to inherit get_params in custom ...
Is there any downside to having a non-estimator inherit from BaseEstimator? Or is that the recommended way to get get_params for any transformer...
Read more >
sklearn.base.BaseEstimator — scikit-learn 1.2.0 documentation
Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as Pipeline ). The...
Read more >
Replace Type Code with Subclasses - Refactoring.Guru
Solution: Create subclasses for each value of the coded type. Then extract the relevant behaviors from the original class to these subclasses.
Read more >
Pull members up, push members down - IntelliJ IDEA - JetBrains
The Push Members Down refactoring lets you clean up the class hierarchy by moving class members to a subclass or a subinterface. The...
Read more >
MLRun - Release UNKNOWN Iguazio
MLRun allows you to easily build ML pipelines that take data from various ... scalable production pipelines without refactoring code, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found