Cannot pass keyword parameters to build function through wrapper
See original GitHub issueHi @adriangb,
Thanks for all your work on this so far. I was frustrated to find that there seems to be no active support for Scikit-Learn wrappers and, without any convenient and obvious alternatives for fine-tuning hyperperameters, I’m relieved you’re keeping things going.
I’m having issues passing any parameters through the wrapper to the model building function. Since this is pretty much what SciKeras is for I’m almost convinced I’m missing something, but for the life of me I can’t find it.
Originally I tried to run a randomized search of hyperparameters before realising it wasn’t explicitly supported by SciKeras, at least in the documentation. I tried again with a grid search but, after changing how the learning rate was searched for, ended up with the same errors. This happens regardless of what parameters I try to pass through.
import numpy as np
from tensorflow import keras
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.stats import reciprocal
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
import scikeras.wrappers
# - california housing data, for illustration -
housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(
housing.data, housing.target)
X_train, X_valid, y_train, y_valid = train_test_split(
X_train_full, y_train_full)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_valid = scaler.fit_transform(X_valid)
def build_model(n_hidden=1, n_neurons=30, learning_rate=3e-3, input_shape=[8]):
model = keras.models.Sequential()
model.add(keras.layers.InputLayer(input_shape=input_shape))
for layer in range(n_hidden):
model.add(keras.layers.Dense(n_neurons, activation="relu"))
model.add(keras.layers.Dense(1))
optimizer = keras.optimizers.SGD(lr=learning_rate)
model.compile(loss="mse", optimizer=optimizer)
return model
keras_reg = scikeras.wrappers.KerasRegressor(build_model())
param_distribs = {
"n_hidden" : [0, 1, 2, 3],
"n_neurons" : np.arange(1, 100, 10),
# "learning_rate" : reciprocal(3e-4, 3e-2),
"learning_rate" : np.arange(3e-4, 3e-2, 0.3*(1e-2 - 1e-4))
}
# - Previously attempted randomized search -
# rnd_search_cv = RandomizedSearchCV(keras_reg, param_distribs, n_iter=10, cv=3, verbose=2)
# rnd_search_cv.fit(X_train, y_train, epochs=100,
# validation_data=(X_valid, y_valid),
# callbacks=[keras.callbacks.EarlyStopping(patience=10)])
grid_search_cv = GridSearchCV(keras_reg, param_distribs, cv=3, verbose=2)
grid_search_cv.fit(X_train, y_train, epochs=100,
validation_data=(X_valid, y_valid),
callbacks=[keras.callbacks.EarlyStopping(patience=10)])
Here’s my full traceback:
File "C:/Users/Michael/PycharmProjects/HOML/scikeras_error.py", line 52, in <module>
grid_search_cv.fit(X_train, y_train, epochs=100,
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\sklearn\utils\validation.py", line 72, in inner_f
return f(**kwargs)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\sklearn\model_selection\_search.py", line 736, in fit
self._run_search(evaluate_candidates)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\sklearn\model_selection\_search.py", line 1188, in _run_search
evaluate_candidates(ParameterGrid(self.param_grid))
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\sklearn\model_selection\_search.py", line 708, in evaluate_candidates
out = parallel(delayed(_fit_and_score)(clone(base_estimator),
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\parallel.py", line 1029, in __call__
if self.dispatch_one_batch(iterator):
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\parallel.py", line 847, in dispatch_one_batch
self._dispatch(tasks)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\parallel.py", line 765, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\_parallel_backends.py", line 208, in apply_async
result = ImmediateResult(func)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\_parallel_backends.py", line 572, in __init__
self.results = batch()
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\parallel.py", line 252, in __call__
return [func(*args, **kwargs)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\joblib\parallel.py", line 252, in <listcomp>
return [func(*args, **kwargs)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\sklearn\model_selection\_validation.py", line 520, in _fit_and_score
estimator = estimator.set_params(**cloned_parameters)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\scikeras\wrappers.py", line 819, in set_params
return super().set_params(**passthrough)
File "C:\Users\Michael\PycharmProjects\HOML\venv\lib\site-packages\sklearn\base.py", line 249, in set_params
raise ValueError('Invalid parameter %s for estimator %s. '
ValueError: Invalid parameter learning_rate for estimator KerasRegressor(
model=<tensorflow.python.keras.engine.sequential.Sequential object at 0x0000023D9E8AB7F0>
build_fn=None
warm_start=False
random_state=None
optimizer=rmsprop
loss=None
metrics=None
batch_size=None
verbose=1
callbacks=None
validation_split=0.0
shuffle=True
run_eagerly=False
epochs=1
). Check the list of available parameters with `estimator.get_params().keys()`.
Process finished with exit code 1
As I said, I’m almost certain this is user error, but any light you could shed on the issue would be really appreciated.
Thanks!
Edit: I’m aware of issue #70 but am unsure exactly how related my issue is. Thanks again!
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (3 by maintainers)

Top Related StackOverflow Question
@adriangb Thanks for the implementation notes! I tried them out and they worked a charm. They’ll definitely come in useful.
I already merged #145 which implements a friendlier error. Hopefully that helps the next person who runs into this!
As promised, here are a couple of other notes on the implementation:
Namely, I wanted to highlight: