How to pass callbacks to scikit_learn wrappers (e.g. KerasClassifier)
See original GitHub issueI want to use EarlyStopping
and TensorBoard
callbacks with the KerasClassifier
scikit_learn wrapper. Normally, when not using scikit_learn wrappers, I pass the callbacks to the fit function as outlined in the documentation. However, when using scikit_learn wrappers, this function is a method of KerasClassifier
. The documentation mentions that sk_params
can contain arguments to the the fit method (among others) but I am unable to figure out how to use sk_params
to pass callbacks to the fit function inside the KerasClassifier
class.
My code looks like this (excluding code for loading data into x
and encoded_y
for brevity):
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.wrappers.scikit_learn import KerasClassifier
from keras.layers import BatchNormalization
from keras.callbacks import EarlyStopping, TensorBoard
from sklearn.model_selection import cross_val_score, StratifiedKFold
# model architecture
def DNN():
model = Sequential()
model.add(Dense(512, input_dim=x.shape[1], init='normal', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(32, init='normal', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(1, init='normal', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adagrad', metrics=['accuracy'])
return model
# fix random seed for reproducibility
seed = 8
classifier = KerasClassifier(build_fn=DNN, nb_epoch=32, batch_size=8, verbose=1)
kfold = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed)
results = cross_val_score(classifier, x, encoded_y, cv=kfold, verbose=1)
print("Result: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))
Does anyone know how I should use EarlyStopping
and TensorBoard
with this setup?
Issue Analytics
- State:
- Created 7 years ago
- Reactions:11
- Comments:26 (1 by maintainers)
Top Results From Across the Web
Can I send callbacks to a KerasClassifier? - Stack Overflow
Reading from here, which is the source code of KerasClassifier, you can pass it the arguments of fit and they should be used....
Read more >Use Keras Deep Learning Models with Scikit-Learn in Python
You pass this function name to the KerasClassifier class by the model ... In this example, you will use the scikit-learn StratifiedKFold to ......
Read more >Advanced Usage of SciKeras Wrappers
This is to make sure that the arguments you pass are not touched afterwards, which makes it possible to clone the wrapper instance,...
Read more >nn_keras_hyperparameter_tuning
Keras provides a wrapper class KerasClassifier that allows us to use our deep learning models with scikit-learn, this is especially useful when you...
Read more >Scikit-learn API - Keras Documentation
There are two wrappers available: keras.wrappers.sklearn.KerasClassifier(build_fn=None, **sk_params) , which implements the sklearn classifier interface,. keras ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I’m simply using the first listed method in the documentation:
GridSearchCV’s fit_params argument is used to pass a dictionary to the fit method of the base estimator, the KerasClassifier in this case.
Upon further investigation it looks like when the callback is passed to sk_params and then the estimator is cloned by GridSearchCV, two different instances of the callback are created. Thus the sanity check in clone fails when comparing that the original callback instance and the cloned callback instance (line 117 in sklearn\base.py).
In an attempt to allow callbacks to be passed to sk_params. I added callbacks to the constructor of the BaseWrapper as so:
def __init__(self, build_fn=None, callbacks=[], **sk_params):
self.build_fn = build_fn
self.callbacks = callbacks
self.sk_params = sk_params
self.sk_params['callbacks'] = callbacks
self.check_params(sk_params)`Then I added self.callbacks to the get_params method of the same class:
def get_params(self, deep=True):
res = copy.deepcopy(self.sk_params)
res.update({'build_fn': self.build_fn, 'callbacks': self.callbacks})
return res
After testing with my model using callbacks and not using callbacks, no more errors arise and my callbacks function as intended. I don’t think this deserves a PR simply because it is already possible to pass in callbacks and the unexpected behavior of ModelCheckpoint is due to that callback’s implementation.
@Hudler, I believe this change allows different instances of ModelCheckpoint to be passed to each fit call, but since each instance would have the same file name for saving the checkpoints, it may end up overwriting the file at the begin of the following fit call.
I have sort of the same issue (caused by the same reasons). I have been trying to plot the learning_curves of a CNN (which never accepted the fit_params argument in the first place), with a varying number of epochs - depending on the # of training examples (which is important, as training the same # epochs with 1/5 of the data will greatly overfit).
I am trying to use EarlyStopping as a callback, but I get the following error:
What confuses me is that from the documentation of the wrapper, I can see that the params should be passed into the .fit method, so I can’t see why it’s not working.
@anselal I am not sure if I understood your solution correctly, but how can you pass directly arguments in the fit method without using ‘fit_params’ or passing them in the
KerasClassifier
?Thanks!