question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Make lightgbm work with HyperbandSearchCV

See original GitHub issue

These libraries don’t seem to work together. I think that supporting or claiming integration with any new ML library should include support for hyperparameter tuning, that’s definitely an MVP.

Here a code and error dump to back up my point:

import dask
import dask.dataframe as dd
from distributed import Client
from dask_ml.model_selection import HyperbandSearchCV
from dask_ml import datasets
import lightgbm as lgb

client = Client('10.118.232.173:8786')

X, y = datasets.make_classification(chunks=50)

model = lgb.DaskLGBMRegressor(client=client)


param_space = {
    'n_estimators': range(100, 200, 50),
    'max_depth': range(3, 6, 2),
    'booster': ('gbtree', 'dart'),
}

search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05)
search.fit(X, y)

And the error message

/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=81. Running 8 iterations. For exhaustive searches, use GridSearchCV. warnings.warn( /opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=34. Running 8 iterations. For exhaustive searches, use GridSearchCV. warnings.warn( /opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=15. Running 8 iterations. For exhaustive searches, use GridSearchCV. warnings.warn(

[CV, bracket=0] For training there are between 47 and 47 examples in each chunk [CV, bracket=1] For training there are between 47 and 47 examples in each chunk


AttributeError Traceback (most recent call last) <ipython-input-6-450c96a48290> in <module> 10 11 search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05) —> 12 search.fit(X, y)

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params) 715 client = default_client() 716 if not client.asynchronous: –> 717 return client.sync(self._fit, X, y, **fit_params) 718 return self._fit(X, y, **fit_params) 719

/opt/conda/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs) 849 return future 850 else: –> 851 return sync( 852 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs 853 )

/opt/conda/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs) 352 if error[0]: 353 typ, exc, tb = error[0] –> 354 raise exc.with_traceback(tb) 355 else: 356 return result[0]

/opt/conda/lib/python3.8/site-packages/distributed/utils.py in f() 335 if callback_timeout is not None: 336 future = asyncio.wait_for(future, callback_timeout) –> 337 result[0] = yield future 338 except Exception as exc: 339 error[0] = sys.exc_info()

/opt/conda/lib/python3.8/site-packages/tornado/gen.py in run(self) 760 761 try: –> 762 value = future.result() 763 except Exception: 764 exc_info = sys.exc_info()

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_hyperband.py in _fit(self, X, y, **fit_params) 399 _brackets_ids = list(reversed(sorted(SHAs))) 400 –> 401 _SHAs = await asyncio.gather( 402 *[SHAs[b]._fit(X, y, **fit_params) for b in _brackets_ids] 403 )

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params) 661 662 with context: –> 663 results = await fit( 664 self.estimator, 665 self._get_params(),

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix) 475 A history of all models scores over time 476 “”" –> 477 return await _fit( 478 model, 479 params,

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix) 266 # async for future, result in seq: 267 for _i in itertools.count(): –> 268 metas = await client.gather(new_scores) 269 270 if log_delay and _i % int(log_delay) == 0:

/opt/conda/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker) 1846 exc = CancelledError(key) 1847 else: -> 1848 raise exception.with_traceback(traceback) 1849 raise exc 1850 if errors == “skip”:

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _partial_fit() 101 if len(X): 102 model = deepcopy(model) –> 103 model.partial_fit(X, y, **(fit_params or {})) 104 105 meta = dict(meta)

AttributeError: ‘DaskLGBMRegressor’ object has no attribute ‘partial_fit’

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
stsievertcommented, May 30, 2021

When I moved from incremental hyperparameter optimization with HyperbandSearchCV to passive hyperparameter optimization with RandomizedSearchCV/GridSearchCV, your example worked for me:

from distributed import Client
from dask_ml.model_selection import RandomizedSearchCV
from dask_ml import datasets
import lightgbm as lgb

if __name__ == "__main__":
    X, y = datasets.make_classification(chunks=50)
    model = lgb.LGBMRegressor()
    param_space = {'n_estimators': range(100, 200, 50),
                   'max_depth': range(3, 6, 2)}

    client = Client()
    search = RandomizedSearchCV(model, param_space, n_iter=5)
    search.fit(X, y)
    print(search.best_score_)

{Randomized, Grid}SearchCV has the advantage of not requiring a partial_fit implementation. However, they do require that the entire training dataset fit in memory.

I think that supporting or claiming integration with any new ML library should include support for hyperparameter tuning, that’s definitely an MVP.

Where have you seen that claim show up? That should be fixed I think.

my understanding is that the hyperparameter tuning stuff in dask-ml, like GridsearchCV, expects to be given training data in Dask collections and a model object that would only perform local training on local chunks of data.

That’s my understanding too, even for the mentioned HyperbandSearchCV. In that case, at the end of the day model.partial_fit is called with two NumPy arrays (or the chunks of a Dask array): model.partial_fit(X_chunk, y_chunk).

0reactions
stsievertcommented, May 30, 2021

I presume you’re talking about https://ml.dask.org/hyper-parameter-search.html. Why did you have to read that documentation several times?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Hyper Parameter Search — dask-ml 2022.5.28 documentation
Tools to perform hyperparameter optimization of Scikit-Learn API-compatible models using Dask, and to scale hyperparameter optimization to larger data ...
Read more >
Advanced Topics — LightGBM 3.3.3.99 documentation
LightGBM enables the missing value handle by default. ... it often works best to treat the feature as numeric, either by simply ignoring...
Read more >
Understanding LightGBM Parameters (and How to Tune Them)
I figured I should do some research, understand more about ... of lightgbm, we know that tree learners cannot work well with one...
Read more >
LightGBM - Read the Docs
Users who want to perform benchmarking can make LightGBM output time costs for ... This is a conceptual overview of how LightGBM works[1]....
Read more >
LightGBM hyperparameter optimisation (LB: 0.761) | Kaggle
input/" print(os.listdir(PATH)) # Any results you write to the current directory are ... We will use LightGBM classifier - LightGBM allows to build...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found