question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

suggestion: save the training history in model.train_history_

See original GitHub issue

For example, if we are using verbose and validation:

'''
Demonstration of validation_split
'''
model.fit(X_train, y_train, nb_epoch=3, batch_size=16, validation_split=0.1, show_accuracy=True, verbose=1)
# outputs
'''
Train on 37800 samples, validate on 4200 samples
Epoch 0
37800/37800 [==============================] - 7s - loss: 0.0385 - acc.: 0.7258 - val. loss: 0.0160 - val. acc.: 0.9136
Epoch 1
37800/37800 [==============================] - 8s - loss: 0.0140 - acc.: 0.9265 - val. loss: 0.0109 - val. acc.: 0.9383
Epoch 2
10960/37800 [=======>......................] - ETA: 4s - loss: 0.0109 - acc.: 0.9420
'''

But after the fitting is finished, we could only get the final model performance. So, my suggestion is that, save the performances inside a new variable (inside the class), for example:

model.fit(X_train, y_train, nb_epoch=3, batch_size=16, validation_split=0.1, show_accuracy=True, verbose=1)

model.train_history_
# outputs
'''
{
  'epoch': [0, 1, 2],
  'loss': [0.0385, 0.0140, 0.0109],
  'acc': [0.7258, 0.9256, 0.9420],
  'val_loss': [0.0160, 0.0169, 0.0170],
  'val_acc': [0.9136, 0.9383, 0.9400]
}
'''

So that we can analyze the history, like drawing the lines of the changes of loss and val_loss, to try to select the best epochs to prevent the overfit.

Issue Analytics

  • State:closed
  • Created 8 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

4reactions
ctorrezcommented, Apr 12, 2017

Hi, I implemented this wrapper class for tracking the training history.

from keras.models import Sequential
from keras.models import load_model

from collections import defaultdict

import pickle


def _merge_dict(dict_list):
    dd = defaultdict(list)    
    for d in dict_list:
        for key, value in d.items():
            if not hasattr(value, '__iter__'):
                value = (value,)
            [dd[key].append(v) for v in value]
    return dict(dd)

def save(obj, name):
    try:
        filename = open(name + ".pickle","wb")
        pickle.dump(obj, filename)
        filename.close()
        return(True)
    except:
        return(False)

def load(name):
    filename = open(name + ".pickle","rb")
    obj = pickle.load(filename)
    filename.close()    
    return(obj)

def load_model_w(name):
    model_k = load_model(name)
    history = load(name)
    model = Sequential_wrapper(model_k)
    model.history = history
    return(model)

class Sequential_wrapper():
    """
    %s
    """%Sequential.__doc__
    
    def __init__(self, model=Sequential()):
        self.history = {}
        self.model = model
        
        # method shortcuts
        methods = dir(self.model)
        for method in methods:
            if method.startswith('_'): continue
            if method in ['model','fit','save']: continue
            try:
                exec('self.%s = self.model.%s' % (method,method))
            except:
                pass
    
    def _update_history(self,history):
        if len(self.history)==0:
            self.history = history
        else:
            self.history = _merge_dict([self.history,history])
    
    def fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None,
            validation_split=0.0, validation_data=None, shuffle=True,
            class_weight=None, sample_weight=None,
            initial_epoch=0, **kwargs):
        """
        %s
        """%self.model.fit.__doc__
        h = self.model.fit(x, y, batch_size, epochs, verbose, callbacks,
                     validation_split, validation_data, shuffle,
                     class_weight, sample_weight,
                     initial_epoch, **kwargs)
        self._update_history(h.history)
        return h
    
    def save(self, filepath, overwrite=True):
        """
        %s
        """%self.model.save.__doc__
        save(self.history,filepath)
        self.model.save(filepath, overwrite)
1reaction
fyearscommented, May 9, 2015

Well, I am just a newbie of keras and deep learning. I just follow the example and have never used train yet. So I do not know what we happen.

But I have looked your source code. It seems that at least the design of keras provides some capability for scikit-learn, right? Because at least the Sequential() class has .fit(), .describe(), .predict(), .predict_proba(). And in scikit-learn, it is common to save some kind of .train_history_ inside the class after a fit(), instead of returning a value.

And I do not think it is a trouble to fit serval times. After all, naturally you have a different training history in a different fit, and it’s developers’ responsibility to save out the training history before they fit it again.

Last but not least, I am using a custom sub-class to provide numerical information. val_loss and val_acc work as expected, but loss and acc don’t, since loss and acc is actually average of something inProgbar`, and I could not figure out the correct information. Would you provide some help? I think it is useful to save some time for future development to solve this issue.

It is almost the same as the original Sequential(), but I add a line train_history = [] in the middle, and I add lines for train_history in the end of each epoch. I always use show_accuracy and do_validation so I do not insert the codes inside the if else.

from keras.models import *
class MySequential(Sequential):
    def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1,
            validation_split=0., validation_data=None, shuffle=True, show_accuracy=False):
        y = standardize_y(y)

        do_validation = False
        if validation_data:
            try:
                X_val, y_val = validation_data
            except:
                raise Exception("Invalid format for validation data; provide a tuple (X_val, y_val).")
            do_validation = True
            y_val = standardize_y(y_val)
            if verbose:
                print("Train on %d samples, validate on %d samples" % (len(y), len(y_val)))
        else:
            if 0 < validation_split < 1:
                # If a validation split size is given (e.g. validation_split=0.2)
                # then split X into smaller X and X_val,
                # and split y into smaller y and y_val.
                do_validation = True
                split_at = int(len(X) * (1 - validation_split))
                (X, X_val) = (X[0:split_at], X[split_at:])
                (y, y_val) = (y[0:split_at], y[split_at:])
                if verbose:
                    print("Train on %d samples, validate on %d samples" % (len(y), len(y_val)))

        index_array = np.arange(len(X))
        train_history = []
        for epoch in range(nb_epoch):
            if verbose:
                print('Epoch', epoch)
                progbar = Progbar(target=len(X), verbose=verbose)
            if shuffle:
                np.random.shuffle(index_array)

            batches = make_batches(len(X), batch_size)
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                if shuffle:
                    batch_ids = index_array[batch_start:batch_end]
                else:
                    batch_ids = slice(batch_start, batch_end)
                X_batch = X[batch_ids]
                y_batch = y[batch_ids]

                if show_accuracy:
                    loss, acc = self._train_with_acc(X_batch, y_batch)
                    log_values = [('loss', loss), ('acc.', acc)]
                else:
                    loss = self._train(X_batch, y_batch)
                    log_values = [('loss', loss)]

                # validation
                if do_validation and (batch_index == len(batches) - 1):
                    if show_accuracy:
                        val_loss, val_acc = self.test(X_val, y_val, accuracy=True)
                        log_values += [('val. loss', val_loss), ('val. acc.', val_acc)]
                    else:
                        val_loss = self.test(X_val, y_val)
                        log_values += [('val. loss', val_loss)]

                # logging
                if verbose:
                    progbar.update(batch_end, log_values)

            train_history.append([log_values])
        self.train_history_ = train_history
Read more comments on GitHub >

github_iconTop Results From Across the Web

python - keras: how to save the training history attribute of the ...
What I use is the following: with open('/trainHistoryDict', 'wb') as file_pi: pickle.dump(history.history, file_pi). In this way I save the ...
Read more >
Display Deep Learning Model Training History in Keras
Access Model Training History in Keras​​ It records training metrics for each epoch. This includes the loss and the accuracy (for classification ...
Read more >
keras: how to save the training history - Intellipaat Community
You can save the model history by this: with open('/trainHistoryDict', 'wb') as file_pi: pickle.dump(history.history, file_pi). Happy Learning.
Read more >
Effective Model Saving and Resuming Training in PyTorch
This blog post explores how to do proper model saving in PyTorch framework that helps in resuming training later on.
Read more >
4. Model Training Patterns - Machine Learning Design ...
Then take the same network and train it on the full training dataset. ... Models like recurrent neural networks incorporate history of previous...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found