question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to save Scikit-Learn-Keras Model into a Persistence File (pickle/hd5/json/yaml)

See original GitHub issue

I have the following code, using Keras Scikit-Learn Wrapper:

from keras.models import Sequential
from sklearn import datasets
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn import preprocessing
import pickle
import numpy as np
import json

def classifier(X, y):
    """
    Description of classifier
    """
    NOF_ROW, NOF_COL =  X.shape

    def create_model():
        # create model
        model = Sequential()
        model.add(Dense(12, input_dim=NOF_COL, init='uniform', activation='relu'))
        model.add(Dense(6, init='uniform', activation='relu'))
        model.add(Dense(1, init='uniform', activation='sigmoid'))
        # Compile model
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model

    # evaluate using 10-fold cross validation
    seed = 7
    np.random.seed(seed)
    model = KerasClassifier(build_fn=create_model, nb_epoch=150, batch_size=10, verbose=0)
    return model
    

def main():
    """
    Description of main
    """

    iris = datasets.load_iris()
    X, y = iris.data, iris.target
    X = preprocessing.scale(X)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
    model_tt = classifier(X_train, y_train)
    model_tt.fit(X_train,y_train)

    #--------------------------------------------------
    # This fail
    #-------------------------------------------------- 
    filename = 'finalized_model.sav'
    pickle.dump(model_tt, open(filename, 'wb'))
    # load the model from disk
    loaded_model = pickle.load(open(filename, 'rb'))
    result = loaded_model.score(X_test, Y_test)
    print(result)
    
    #--------------------------------------------------
    # This also fail
    #--------------------------------------------------
    # from keras.models import load_model       
    # model_tt.save('test_model.h5')
    

    #--------------------------------------------------
    # This works OK 
    #-------------------------------------------------- 
    # print model_tt.score(X_test, y_test)
    # print model_tt.predict_proba(X_test)
    # print model_tt.predict(X_test)


if __name__ == '__main__':
    main()

As stated in the code there it fails at this line:

pickle.dump(model_tt, open(filename, 'wb'))

With this error: pickle.PicklingError: Can't pickle <function create_model at 0x101c09320>: it's not found as __main__.create_model How can I get around it?

Issue Analytics

  • State:closed
  • Created 7 years ago
  • Reactions:3
  • Comments:17

github_iconTop GitHub Comments

22reactions
crbrintoncommented, Apr 13, 2017

When trying to persist a KerasClassifier (or KerasRegressor) object, the KerasClassifier itself does not have a save method. It is the keras model that is wrapped by the KerasClassifier that can be saved using the save method. However, if you want to end up with a KerasClassifier after re-loading the persisted model, the re-loaded model must be wrapped anew in the KerasClassifier. This can be done by creating a new KerasClassifier object with a build_fn that actually calls the load_model method, such as:

def build_by_loading(self):
    model = load_model('nn_model.h5')
    return model 

So the KerasClassifier to be re-instantiated from a persisted file would be created as follows (for example):

    nn_model = KerasClassifier(build_fn=self.build_by_loading, nb_epoch=10, batch_size=5, verbose=1)

Unfortunately, the KerasClassifier code does not call the build_fn until the ‘fit’ method of the KerasClassifier is called. This would defeat the purpose of persisting the model.

I created a ‘build_only’ method in KerasClassifier that only calls the build_fn, but does not fit the model. This worked for me. I recommend that some means of instantiating a KerasClassifier from a persisted keras model similar to this be included in the next release.

15reactions
krishnateja614commented, Nov 4, 2016

Can you try running model_tt.model.save(“test_model.h5”)? . I think we can’t directly use save function on scikit learn wrapper but the above line should hopefully do what you want to do. Let me know

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to save Scikit-Learn-Keras Model into a Persistence File ...
Edit 1 : Original answer about saving model. With HDF5 : # saving model json_model = model_tt.model.to_json() open('model_architecture.json' ...
Read more >
How to save Scikit-Learn-Keras Model into a Persistence File ...
Answer a question I have the following code, using Keras Scikit-Learn Wrapper: from keras.models import Sequential from sklearn import ...
Read more >
9. Model persistence — scikit-learn 1.2.0 documentation
It is possible to save a model in scikit-learn by using Python's built-in persistence model, namely pickle: >>> >>> from sklearn import svm...
Read more >
Keras OwnerID - Stack Exchange Data Explorer
'How to save Scikit-Learn-Keras Model into a Persistence File (pickle/hd5/json/yaml)', 'How to convert a dense layer to an equivalent ...
Read more >
User Gaarv - Ask Ubuntu
Developer of data applications / pipelines in Python and Scala. ... How to save Scikit-Learn-Keras Model into a Persistence File (pickle/hd5/json/yaml).
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found