question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to pass callbacks to scikit_learn wrappers (e.g. KerasClassifier)

See original GitHub issue

I want to use EarlyStopping and TensorBoard callbacks with the KerasClassifier scikit_learn wrapper. Normally, when not using scikit_learn wrappers, I pass the callbacks to the fit function as outlined in the documentation. However, when using scikit_learn wrappers, this function is a method of KerasClassifier. The documentation mentions that sk_params can contain arguments to the the fit method (among others) but I am unable to figure out how to use sk_params to pass callbacks to the fit function inside the KerasClassifier class.

My code looks like this (excluding code for loading data into x and encoded_y for brevity):

from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.wrappers.scikit_learn import KerasClassifier
from keras.layers import BatchNormalization
from keras.callbacks import EarlyStopping, TensorBoard

from sklearn.model_selection import cross_val_score, StratifiedKFold

# model architecture
def DNN():
    model = Sequential()
    model.add(Dense(512, input_dim=x.shape[1], init='normal', activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(32, init='normal', activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(1, init='normal', activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adagrad', metrics=['accuracy'])
    return model

# fix random seed for reproducibility
seed = 8

classifier = KerasClassifier(build_fn=DNN, nb_epoch=32, batch_size=8, verbose=1)
kfold = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed)
results = cross_val_score(classifier, x, encoded_y, cv=kfold, verbose=1)
print("Result: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))

Does anyone know how I should use EarlyStopping and TensorBoard with this setup?

Issue Analytics

  • State:closed
  • Created 7 years ago
  • Reactions:11
  • Comments:26 (1 by maintainers)

github_iconTop GitHub Comments

6reactions
jfr311commented, Nov 7, 2016

I’m simply using the first listed method in the documentation:

Values passed to the dictionary arguments of fit, predict, predict_proba, and score methods

GridSearchCV’s fit_params argument is used to pass a dictionary to the fit method of the base estimator, the KerasClassifier in this case.

Upon further investigation it looks like when the callback is passed to sk_params and then the estimator is cloned by GridSearchCV, two different instances of the callback are created. Thus the sanity check in clone fails when comparing that the original callback instance and the cloned callback instance (line 117 in sklearn\base.py).

In an attempt to allow callbacks to be passed to sk_params. I added callbacks to the constructor of the BaseWrapper as so:

def __init__(self, build_fn=None, callbacks=[], **sk_params): self.build_fn = build_fn self.callbacks = callbacks self.sk_params = sk_params self.sk_params['callbacks'] = callbacks self.check_params(sk_params)`

Then I added self.callbacks to the get_params method of the same class:

def get_params(self, deep=True): res = copy.deepcopy(self.sk_params) res.update({'build_fn': self.build_fn, 'callbacks': self.callbacks}) return res

After testing with my model using callbacks and not using callbacks, no more errors arise and my callbacks function as intended. I don’t think this deserves a PR simply because it is already possible to pass in callbacks and the unexpected behavior of ModelCheckpoint is due to that callback’s implementation.

@Hudler, I believe this change allows different instances of ModelCheckpoint to be passed to each fit call, but since each instance would have the same file name for saving the checkpoints, it may end up overwriting the file at the begin of the following fit call.

5reactions
littlewinecommented, Jul 5, 2018

I have sort of the same issue (caused by the same reasons). I have been trying to plot the learning_curves of a CNN (which never accepted the fit_params argument in the first place), with a varying number of epochs - depending on the # of training examples (which is important, as training the same # epochs with 1/5 of the data will greatly overfit).

I am trying to use EarlyStopping as a callback, but I get the following error:

model = KerasClassifier(build_fn=create_model , batch_size=batch_size, epochs=epochs, verbose=1, callbacks = [EarlyStopping(patience=0)] )
learning_curve(model, ...)
>>> Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x7f129fae8390>, as the constructor does not seem to set parameter callbacks

What confuses me is that from the documentation of the wrapper, I can see that the params should be passed into the .fit method, so I can’t see why it’s not working.

@anselal I am not sure if I understood your solution correctly, but how can you pass directly arguments in the fit method without using ‘fit_params’ or passing them in the KerasClassifier?

Thanks!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Can I send callbacks to a KerasClassifier? - Stack Overflow
Reading from here, which is the source code of KerasClassifier, you can pass it the arguments of fit and they should be used....
Read more >
Use Keras Deep Learning Models with Scikit-Learn in Python
You pass this function name to the KerasClassifier class by the model ... In this example, you will use the scikit-learn StratifiedKFold to ......
Read more >
Advanced Usage of SciKeras Wrappers
This is to make sure that the arguments you pass are not touched afterwards, which makes it possible to clone the wrapper instance,...
Read more >
nn_keras_hyperparameter_tuning
Keras provides a wrapper class KerasClassifier that allows us to use our deep learning models with scikit-learn, this is especially useful when you...
Read more >
Scikit-learn API - Keras Documentation
There are two wrappers available: keras.wrappers.sklearn.KerasClassifier(build_fn=None, **sk_params) , which implements the sklearn classifier interface,. keras ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found