question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Estimator tag overwriting and update in _get_tags

See original GitHub issue

I am updating some code in imbalanced-learn to use the estimator tag. In some way, I was able to add a new entry in the _DEFAULT_TAGS and use the implementation of _safe_tags.

I have the following use case:

_DEFAULT_TAGS = {'sample_indices': False}

class BaseClass:
    ...
    def _more_tags(self):
        return {'sample_indices': True}

class SpecialClass(BaseClass):
    ...
    def _more_tags(self):
        tags = super()._more_tags()
        tags['sample_indices'] = False
        return tags

For some reason, all estimators inheriting from the BaseClass would have a sample_indices tag to True. All but not for one class where I would like to overwrite the tag. Here, I made use of the super class to give the trick that I think could solve the following issue.

Because we are overwriting the sample_indices, _get_tag is failing due to what is currently considered as an inconsistent update of the tags dictionary:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-31-5558b1aebfa8> in <module>
----> 1 _safe_tags(xx, 'sample_indices')

~/Documents/packages/scikit-learn/sklearn/utils/estimator_checks.py in _safe_tags(estimator, key)
     68     if hasattr(estimator, "_get_tags"):
     69         if key is not None:
---> 70             return estimator._get_tags().get(key, _DEFAULT_TAGS[key])
     71         tags = estimator._get_tags()
     72         return {key: tags.get(key, _DEFAULT_TAGS[key])

~/Documents/packages/scikit-learn/sklearn/base.py in _get_tags(self)
    320         if hasattr(self, '_more_tags'):
    321             more_tags = self._more_tags()
--> 322             collected_tags = _update_if_consistent(collected_tags, more_tags)
    323         tags = _DEFAULT_TAGS.copy()
    324         tags.update(collected_tags)

~/Documents/packages/scikit-learn/sklearn/base.py in _update_if_consistent(dict1, dict2)
    132         if dict1[key] != dict2[key]:
    133             raise TypeError("Inconsistent values for tag {}: {} != {}".format(
--> 134                 key, dict1[key], dict2[key]
    135             ))
    136     dict1.update(dict2)

TypeError: Inconsistent values for tag sample_indices: True != False

Without the call of super(), I find this rule quite meaningful and the error raising a good thing. However, I think that we could lift this rule in case where we call super() in self._more_tag. By calling super(), one should be aware of overwriting the base-class default tag.

Thus, by introspecting self._more_tag and check that super() is called (e.g. using 'super' in inspect.getclosurevars(self._more_tags).builtins), we could still allow to update the tags.

@amueller @rth @jnothman Do you think that it is use case and a solution which make sense. I still have the option to add a _more_tag to each classes but this is a lot of duplicated code then.

So the changes would be something like the following:

-def _update_if_consistent(dict1, dict2):
+def _update_if_consistent(dict1, dict2, force=False):
     common_keys = set(dict1.keys()).intersection(dict2.keys())
-    for key in common_keys:
-        if dict1[key] != dict2[key]:
-            raise TypeError("Inconsistent values for tag {}: {} != {}".format(
-                key, dict1[key], dict2[key]
-            ))
+    if not force:
+        for key in common_keys:
+            if dict1[key] != dict2[key]:
+                raise TypeError("Inconsistent values for tag {}: {} != {}"
+                                .format(key, dict1[key], dict2[key]))
     dict1.update(dict2)
     return dict1
 
@@ -319,7 +319,10 @@ class BaseEstimator:
                                                        more_tags)
         if hasattr(self, '_more_tags'):
             more_tags = self._more_tags()
-            collected_tags = _update_if_consistent(collected_tags, more_tags)
+            force = 'super' in inspect.getclosurevars(self._more_tags).builtins
+            collected_tags = _update_if_consistent(
+                collected_tags, more_tags, force=force
+            )
         tags = _DEFAULT_TAGS.copy()
         tags.update(collected_tags)
         return tags

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:31 (31 by maintainers)

github_iconTop GitHub Comments

1reaction
glemaitrecommented, Aug 12, 2019

we have mixins instead of using inheritance

You should blame @agramfort apparently 😃

https://github.com/scikit-learn/scikit-learn/commit/495fb3acad0306c66830d19c762678ad36e4a753

1reaction
amuellercommented, Jun 11, 2019

@rth yeah took me a bit to recover all my reasoning, but it seems to have been surprisingly sound 😉 Maybe we should add comments to the code or to the dev docs on why this is so strange and restrictive?

Is there a good reason for us not to change the order of mixins and make them be actually pythonic? Arguably what we’re doing now isn’t right…

Read more comments on GitHub >

github_iconTop Results From Across the Web

Setting up tags for an API stage in API Gateway
Learn how to set up and manage tags in Amazon API Gateway. ... You can also use the previous request to update an...
Read more >
SAX — sktime documentation
Get tags from estimator class and dynamic tag overrides. get_test_params (). Return testing parameter settings for the estimator. set_params (**params).
Read more >
Estimators — sagemaker 2.124.0 documentation
Update training jobs to enable profiling. This method updates the profiler_config parameter and initiates Debugger built-in rules for profiling. Parameters.
Read more >
Deer Tags | Louisiana Department of Wildlife and Fisheries
They must also have the tags in their possession while hunting deer. ... You may be asked to update or verify your account...
Read more >
Developing scikit-learn estimators
All built-in estimators also have a set_params method, which sets data-independent parameters (overriding previous parameter values passed to __init__ ).
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found