question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Custom configuration how/where to save the best model with ModelCheckpoint

See original GitHub issue

In case when user would like to integrate ModelCheckpoint with a package for experiment tracking, e.g. mlflow, polyaxon, etc. In such case logging, model weights etc can be stored on a cloud storage, e.g.

exp_tracking.log_artifact(filepath)

Be default ModelCheckpoint is saving the model to the provided path dirname. Idea is to provide a flexibility to execute a custom code when model is saved and be eable to store everywhere we would like.

What do you think ?

cc @elanmart

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
vfdev-5commented, Apr 3, 2019

@Bibonaut thanks for sharing the code ! Looks nice! We can think to put it into contrib module.

1reaction
alxlampecommented, Apr 3, 2019

This sounds nice. I also did some hacking and just want to share the code, in case it can be useful for anyone of you. In my case, I needed a custom save method. I didn’t want to use torch.save(), because my own model class is still under development and I want to have compatibility between all its versions. My save method simply saves hyperparameters and weights from which the class is recreated when it is loaded.

I inherited my ModelSaver class from ignite.handlers.ModelCheckpoint and overloaded the _internal_save method. There are two more little features: Save the model on exception and when the training is completed. Sorry for the incomplete documentation, but at the moment, I have only little time.

Long story short… here is my code:

from ignite.engine import Events
import ignite
import os


class ModelSaver(ignite.handlers.ModelCheckpoint):
    """"
    Extends class`ignite.handlers.ModelCheckpoint with option to provide a custom save method,
    saving the final model after training ends and saving a model if an exception is raised during training
    """

    def __init__(self, *args, save_method=None, save_on_exception=True, save_on_completed=True, **kwargs):
        if not isinstance(save_on_exception, bool):
            raise TypeError(
                "Argument save_on_exception must be of type bool, got {] instead.".format(type(save_on_exception)))
        if not isinstance(save_on_completed, bool):
            raise TypeError(
                "Argument save_on_completed must be of type bool, got {] instead.".format(type(save_on_completed)))
        if save_method is not None and not callable(save_method):
            raise TypeError(
                "Argument save_method must be callable accepting two arguments: Model object to save and path")

        self._save_method = save_method
        self._save_on_completed = save_on_completed
        self._save_on_exception = save_on_exception

        super(ModelSaver, self).__init__(*args, **kwargs)

    def _internal_save(self, obj, path):
        if self._save_method is not None:
            self._save_method(obj, path)
        else:
            super(ModelSaver, self)._internal_save(obj, path)

    def _on_exception(self, engine, exception, to_save):
        for name, obj in to_save.items():
            fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, "_on_exception")
            path = os.path.join(self._dirname, fname)
            if os.path.exists(path):
                os.remove(path)
            self._save(obj=obj, path=path)

    def _on_completed(self, engine, to_save):
        for name, obj in to_save.items():
            fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, "_on_completed")
            path = os.path.join(self._dirname, fname)
            if os.path.exists(path):
                os.remove(path)
            self._save(obj=obj, path=path)

    def attach(self, engine, model_dict):
        """
                Attaches the model saver to an engine object

                Args:
                    engine (Engine): engine object
                    model_dict (dict): A dict mapping names to objects, e.g. {'mymodel': model}
        """
        engine.add_event_handler(Events.EPOCH_COMPLETED, self, model_dict)
        engine.add_event_handler(Events.COMPLETED, self._on_completed, model_dict)
        engine.add_event_handler(Events.EXCEPTION_RAISED, self._on_exception, model_dict)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Keras Callbacks and How to Save Your Model from Overtraining
In this article, you will learn how to use the ModelCheckpoint callback in Keras to save the best version of your model during...
Read more >
tf.keras.callbacks.ModelCheckpoint | TensorFlow v2.11.0
Whether only weights are saved, or the whole model is saved. Note: If you get WARNING:tensorflow:Can save best model only with <name> available, ......
Read more >
Saving best model in keras - Stack Overflow
EarlyStopping and ModelCheckpoint is what you need from Keras documentation. You should set save_best_only=True in ModelCheckpoint.
Read more >
Save Best Model using Checkpoint and Callbacks - YouTube
Model Checkpoint in Tensorflow | Save Best Model using Checkpoint and ... TensorFlow Tutorial 14 - Callbacks with Keras and Writing Custom ......
Read more >
How to Checkpoint Deep Learning Models in Keras
Checkpointing is set up to save the network weights only when there is an ... Checkpoint the weights for best model on validation...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found