Can't pickle trained model with callback to TensorBoard
See original GitHub issueDescription of the problem
I was excited about scikeras
because it can interface with sklearn
and the models can supposedly be pickled. Unfortunately scikeras.KerasClassifier
can’t be pickled when both of the following conditions are fulfilled:
- the
KerasClassifier
includes a callback toTensorBoard
. - it has been trained
The equivalent neural network from Keras can be pickled without issue.
Minimum, Complete, Verifiable Example
from joblib import dump
# from pickle import dump # causes the same problem
from numpy import random
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier
# %% shared data
X = random.random((10, 6))
y = random.randint(2, size=10)
def build_fn():
"""Build sequential neural network."""
model = Sequential()
model.add(Dense(30, activation="relu", input_shape=(6, )))
model.add(Dense(20, activation="relu"))
model.add(Dense(1, activation="sigmoid"))
model.compile(
optimizer="rmsprop",
loss="binary_crossentropy",
)
return model
X = random.random((10, 6))
y = random.randint(2, size=10)
# %% scikeras classifier [breaks]
clf = KerasClassifier(
model=build_fn,
epochs=5,
validation_split=0.1,
callbacks=[TensorBoard("testlogs")], # won't break without this line
)
clf = clf.fit(X, y) # won't break without this line
dump(clf, open("test_scikeras.pkl", "wb")) # raises InvalidArgumentError
# %% same classifier in pure tf.keras [works]
model = build_fn()
model.fit(
X,
y,
epochs=5,
validation_split=0.1,
callbacks=[TensorBoard("testlogs")]
)
dump(model, open("test_keras.pkl", "wb")) # works
Stack Trace
The last line of the # %% scikeras classifier [break]
block raises:
Traceback (most recent call last):
File "/home/lukas/Desktop/tensorboard_temp.py", line 52, in <module>
dump(clf, open("test_scikeras.pkl", "wb")) # raises InvalidArgumentError
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 482, in dump
NumpyPickler(filename, protocol=protocol).dump(value)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 487, in dump
self.save(obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 603, in save
self.save_reduce(obj=obj, *rv)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 717, in save_reduce
save(state)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
f(self, obj) # Call unbound method with explicit self
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
self._batch_setitems(obj.items())
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
save(v)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
f(self, obj) # Call unbound method with explicit self
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 931, in save_list
self._batch_appends(obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 958, in _batch_appends
save(tmp[0])
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 603, in save
self.save_reduce(obj=obj, *rv)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 717, in save_reduce
save(state)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
f(self, obj) # Call unbound method with explicit self
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
self._batch_setitems(obj.items())
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
save(v)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
f(self, obj) # Call unbound method with explicit self
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
self._batch_setitems(obj.items())
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
save(v)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 603, in save
self.save_reduce(obj=obj, *rv)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 717, in save_reduce
save(state)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
f(self, obj) # Call unbound method with explicit self
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
self._batch_setitems(obj.items())
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
save(v)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
return Pickler.save(self, obj)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 578, in save
rv = reduce(self.proto)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1000, in __reduce__
return convert_to_tensor, (self._numpy(),)
File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1039, in _numpy
six.raise_from(core._status_to_exception(e.code, e.message), None) # pylint: disable=protected-access
File "<string>", line 3, in raise_from
InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.
Versions
- SciKeras
0.3.3
- TensorFlow
2.4.1
- Python
3.9.5
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (7 by maintainers)
Top Results From Across the Web
TypeError: can't pickle _thread.lock objects - Stack Overflow
In this case, it contains two tf.Tensor s. It seems that Keras does not support serializing tf.Tensor in the model config right now....
Read more >can't pickle _thread.RLock objects while saving the keras ...
I want to save the network architecture of this model but getting can't pickle error. I tried some solution but not able to...
Read more >Save and load models | TensorFlow Core
You can use a trained model without having to retrain it, or pick-up training where you left off in case the training process...
Read more >Deep Learning basics with Python, TensorFlow and Keras p.6
Hello and welcome to part 6 of the deep learning basics with Python, TensorFlow and Keras. In this part, we're going to cover...
Read more >How to Save and Load Your Keras Deep Learning Model
You learned how to save your trained models to files, later load them up, ... You cannot use pickle for Keras models as...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yeah, this is good motivation to keep that behavior (passing keyword arguments through to Keras.fit).
I think that behavior should be strongly discouraged.
Awesome, I’m glad we found you a solution, even if its not ideal.
Like I said above, we will probably disable those warnings so that this API will be more straightforward to use going forward.
Your feedback has been very valuable, so thank you for the issue and bearing with me during troubleshooting.