question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Loading Best Model from File

See original GitHub issue

The documentation clearly explains the procedure for loading the best model after hypereparameter optimization is complete.

models = tuner.get_best_models(num_models=2)

Also the metrics/ predictions can be obtained with: # Evaluate the best model. loss, accuracy = best_model.evaluate(x_val, y_val)

However, how do you load a pre-tuned model from file and how to get the best model to make predictions?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:9 (2 by maintainers)

github_iconTop GitHub Comments

9reactions
ppurwarcommented, Aug 2, 2019

I was able to comeup with a workaround. Use the attached code to import results in another script `

import os
import json
import kerastuner.engine.hyperparameters as hp_module
import kerastuner.engine.trial as trial_module
import kerastuner.engine.metrics_tracking as metrics_tracking
from kerastuner.abstractions.tensorflow import TENSORFLOW_UTILS as tf_utils
import tensorflow as tf
from tensorflow.keras.models import model_from_json

class SearchResults(object):
    def __init__(self, directory, project_name, objective):
        self.directory = directory
        self.project_name = project_name
        self.objective = objective
        
    def reload(self):
        """Populate `self.trials` and `self.oracle` state."""
        fname = os.path.join(self.directory, self.project_name, 'tuner.json')
        state_data = tf_utils.read_file(fname)
        state = json.loads(state_data)

        self.hyperparameters = hp_module.HyperParameters.from_config(
            state['hyperparameters'])
        self.best_metrics = metrics_tracking.MetricsTracker.from_config(
            state['best_metrics'])
        self.trials = [trial_module.Trial.load(f) for f in state['trials']]
        self.start_time = state['start_time']
    
    def _get_best_trials(self, num_trials=1):
        if not self.best_metrics.exists(self.objective):
            return []
        trials = []
        for x in self.trials:
            if x.score is not None:
                trials.append(x)
        if not trials:
            return []
        direction = self.best_metrics.directions[self.objective]
        sorted_trials = sorted(trials,
                               key=lambda x: x.score,
                               reverse=direction == 'max')
        return sorted_trials[:num_trials]
    
    def get_best_models(self, num_models = 1):
        best_trials = self._get_best_trials(num_models)
        models = []
        for trial in best_trials:
            hp = trial.hyperparameters.copy()
            # Get best execution.
            direction = self.best_metrics.directions[self.objective]
            executions = sorted(
                trial.executions,
                key=lambda x: x.per_epoch_metrics.get_best_value(
                    self.objective),
                reverse=direction == 'max')
            
            # Reload best checkpoint.
            ckpt = executions[0].best_checkpoint
            model_graph = ckpt + '-config.json'
            model_wts = ckpt + '-weights.h5'
            with open(model_graph, 'r') as f:
                model = model_from_json(f.read())
            model.load_weights(model_wts)
            models.append(model)
        return models

`

example usage:

    res = SearchResults(directory='./multiclass_classifier/training',
                    project_name='search_bs2000',
                    objective='val_accuracy'                    
                   )
    res.reload()
    model = res.get_best_models()[0]
5reactions
omalleyt12commented, Nov 5, 2019

@JakeTheWise it can and does

But usually when doing hyperparameter tuning, you’ll split the data into three sets: train, validation, and test

You’ll perform the hyperparameter search using the train set to train the model, and the validation set to evaluate hyperparameter performance

Then you evaluate the generalization ability on the test set with either:

  1. The best model found during the hyperparameter search as-is (as you suggested)
  2. Retraining using the train + validation data with the hyperparameters found during the search

get_best_models does (1)

Since more data is almost always better, (2) is likely to give you better performance on the test set (and in production), but requires additional training time

The idea of get_best_models is just to be a convenient way to access the models that were trained during the search

Read more comments on GitHub >

github_iconTop Results From Across the Web

Save and load models | TensorFlow Core
An entire model can be saved in two different file formats ( SavedModel and HDF5 ). The TensorFlow SavedModel format is the default...
Read more >
How to Save and Load Your Keras Deep Learning Model
In this post, you will discover how to save your Keras models to files and load them up again to make predictions.
Read more >
Saving and Loading the Best Model in PyTorch - DebuggerCafe
In this tutorial, you will learn about easily saving and loading the best model in PyTorch.
Read more >
Saving and Loading Models - PyTorch
To load the models, first initialize the models and optimizers, then load the dictionary locally using torch.load() . From here, you can easily...
Read more >
Loading a model from local with best checkpoint - Beginners
Now I have another file where I load the model and observe results on test data set. I want to be able to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found