[FEATURE REQUEST] add support for custom layers in `best_model()`
See original GitHub issueOverview
I built a model in Keras using the functional API. I also use the keras_contrib
and keras_radam
libraries to add new activations (Swish) and optimizers (RAdam) not yet implemented in keras
. Talos initializes and trains all iterations of the model without issue, but if I want to recall the best model or deploy the model, it fails with an error from keras.utils.generic_utils.deserialize_keras_object()
.
The error in question is ValueError: Unknown layer: Swish
.
Prerequisites
- My Python version is 3.5 or higher
- I have searched through the issues Issues for a duplicate
- I’ve tested that my Keras model works as a stand-alone
>>> talos.__version__
'0.6.0'
Expected behavior
scan_object.best_model(metric="acc")
should result in a new instance of the best performing model.
Actual behavior
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-29-2e96fabca5ec> in <module>
1 #ta.Deploy(t,"U-nets", metric="acc")
----> 2 scan_object.best_model(metric="acc")
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/talos/scan/scan_addon.py in func_best_model(scan_object, metric, asc)
12 from ..utils.best_model import best_model, activate_model
13 model_no = best_model(scan_object, metric, asc)
---> 14 out = activate_model(scan_object, model_no)
15
16 return out
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/talos/utils/best_model.py in activate_model(self, model_id)
18 '''Loads the model from the json that is stored in the Scan object'''
19
---> 20 model = model_from_json(self.saved_models[model_id])
21 model.set_weights(self.saved_weights[model_id])
22
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/engine/saving.py in model_from_json(json_string, custom_objects)
490 config = json.loads(json_string)
491 from ..layers import deserialize
--> 492 return deserialize(config, custom_objects=custom_objects)
493
494
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
53 module_objects=globs,
54 custom_objects=custom_objects,
---> 55 printable_module_name='layer')
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
143 config['config'],
144 custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 145 list(custom_objects.items())))
146 with CustomObjectScope(custom_objects):
147 return cls.from_config(config['config'])
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/engine/network.py in from_config(cls, config, custom_objects)
1020 # First, we create all layers and enqueue nodes to be processed
1021 for layer_data in config['layers']:
-> 1022 process_layer(layer_data)
1023 # Then we process nodes in order of layer depth.
1024 # Nodes that cannot yet be processed (if the inbound node
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/engine/network.py in process_layer(layer_data)
1006
1007 layer = deserialize_layer(layer_data,
-> 1008 custom_objects=custom_objects)
1009 created_layers[layer_name] = layer
1010
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
53 module_objects=globs,
54 custom_objects=custom_objects,
---> 55 printable_module_name='layer')
~/.pyenv/versions/miniconda3-4.3.30/envs/tf_gpu/lib/python3.6/site-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
136 if cls is None:
137 raise ValueError('Unknown ' + printable_module_name +
--> 138 ': ' + class_name)
139 if hasattr(cls, 'from_config'):
140 custom_objects = custom_objects or {}
ValueError: Unknown layer: Swish
Model details
MWE
import keras
from keras.models import Model
from keras.layers import Input, Conv2D, BatchNormalization
from keras.layers.advanced_activations import ReLU
import talos as ta
def u_net(shape, nb_filters=64, conv_size=3, init="glorot_uniform",
activation=ReLU, output_channels=5):
i = Input(shape, name="input_layer")
n = Conv2D(nb_filters, conv_size, padding="same", kernel_initializer=init,
name="block1_conv1")(i)
n = activation(name="block1_{}1".format(activation.__name__))(n)
n = BatchNormalization(name="block1_bn1")(n)
n = Conv2D(nb_filters, conv_size, padding="same", kernel_initializer=init,
name="block1_conv2")(n)
n = activation(name="block1_{}2".format(activation.__name__))(n)
n = BatchNormalization(name="block1_bn2")(n)
o = Conv2D(output_channels, 1, activation="softmax", name="conv_out")(n)
return Model(inputs=i, outputs=o)
def talos_model():
model = u_net(SHAPE, nb_filters=p["nb_filters"], activation=p["act"])
model.compile(optimizer=p["opt"](lr=1e-4))
history = model.fit(x=X, y=Y)
return model, history
scan_object = ta.Scan(x=X, y=Y, model=talos_model, params=p)
parameter dictionary
# fit params
from keras.optimizers import Adam
from keras.layers.advanced_activations import ReLU
from keras_radam import RAdam
from keras_contrib.layers.advanced_activations.swish import Swish
p = {
"nb_filters": [12, 16, 32],
"act": [Swish, ReLU],
"opt": [RAdam, Adam]
}
- My bug report includes an input model
- My bug report includes a parameter dictionary
- My bug report includes a
Scan()
command - My bug report question includes a link to a sample of the data
I chose to leave out sample data because it is not relevant to the issue at hand. For the same reason, I chose to create a Minimal Working Example rather than pasting the entire model, which is quite complicated and does not help to locate the issue.
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (2 by maintainers)
Awesome! I’ll be looking forward to seeing the development and maybe pitch in, if I manage to find the time between finishing my phd dissertation and getting my data out of my models 😉
@bjtho08 definitely, that’s a great idea 😃 I’ll change the title of the issue to correspond this scope.