question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Feature Request] Early stop the training if there is no improvement (no new best model) after consecutive evaluations

See original GitHub issue

🚀 Feature

Create a new callback to allow stop the training once the last evaluations has not found any new best model.

Motivation

I’m working in a problem in which I have to experiment different scenarios for the same environment. It is costly to find the best maximum number of time steps for each scenario. If I use the same number for all scenarios, this can be insufficient for some scenarios and more than necessary for others.

So it could be interesting to experiment considering a maximum budget (total_timesteps) for the worst case scenario, but be able to early stop the training in scenarios with stabilized learning (no improvement after many evaluations). This approach would save time in the experiments without jeopardizing any scenario .

Pitch

The idea is to have a new callback used with EvalCallback that allows to stop the training before the total_timesteps specified in the learn method.

This callback would have two parameters:

  • a: Maximum number of consecutive evaluations without a new best model.
  • b: Number of evaluations before start to count evaluations without improvements.

After the first b evaluations, the callback would start to count consecutive evaluations without improvement. If this count becomes greater than a the training would be stopped.

Alternatives

I have implemented this feature in my project creating a extended version of EvalCallback and the proposed callback StopTrainingOnNoModelImprovement.

class StopTrainingOnNoModelImprovement(BaseCallback):
    """
    Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
    
    It is possible to define a minimum number of evaluations before start to count evaluations without improvement.   
    
    It must be used with the ``ExtendedEvalCallback``.
    
    :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
    :param min_evals: Number of evaluations before start to count evaluations without improvements.
    :param verbose:
    """

    def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
        super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
        self.max_no_improvement_evals = max_no_improvement_evals
        self.min_evals = min_evals
        self.last_best_mean_reward = -np.inf
        self.no_improvement_evals = 0

    def _on_step(self) -> bool:
        assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used " "with an ``ExtendedEvalCallback``"
        
        continue_training = True
        
        if self.n_calls > self.min_evals:
            if self.parent.best_mean_reward > self.last_best_mean_reward:
                self.no_improvement_evals = 0                
            else:
                self.no_improvement_evals += 1
                if self.no_improvement_evals > self.max_no_improvement_evals:
                    continue_training = False        
        
        self.last_best_mean_reward = self.parent.best_mean_reward
                
        if self.verbose > 0 and not continue_training:
            print(
                f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
            )
        
        return continue_training
class ExtendedEvalCallback(EvalCallback):
    """
    Extends Eval Callback by adding a new child callback called after each evaluation.
    """
    def __init__(
        self,
        eval_env: Union[gym.Env, VecEnv],
        callback_on_new_best: Optional[BaseCallback] = None,
        callback_after_eval: Optional[BaseCallback] = None,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        log_path: str = None,
        best_model_save_path: str = None,
        deterministic: bool = True,
        render: bool = False,
        verbose: int = 1,
        warn: bool = True,
    ):
        super(ExtendedEvalCallback, self).__init__(
            eval_env,
            callback_on_new_best=callback_on_new_best,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            log_path=log_path,
            best_model_save_path=best_model_save_path,
            deterministic=deterministic,
            render=render,
            verbose=verbose,
            warn=warn)

        self.callback_after_eval = callback_after_eval
        # Give access to the parent
        if self.callback_after_eval is not None:
            self.callback_after_eval.parent = self
    
    def _init_callback(self) -> None:
        super(ExtendedEvalCallback, self)._init_callback()
        if self.callback_after_eval is not None:
            self.callback_after_eval.init_callback(self.model)
    
    def _on_step(self) -> bool:
        continue_training = super(ExtendedEvalCallback, self)._on_step()
        
        if continue_training:
            if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
                # Trigger callback if needed
                if self.callback_after_eval is not None:
                    return self.callback_after_eval.on_step()
        return continue_training

Additional context

If you think this can be useful for a greater audience, I could open a PR to include this feature in the library. But if it is too specific, at least the code posted here can be useful for other people.

In case of open a PR, ideally EvalCallback would be changed instead of creating an extended version. But, probably, it would be necessary to discuss some design issues, as I’m not sure if I have used the best approach to consider two child callbacks in EvalCallback.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
caburucommented, Jul 20, 2021

Are you still planning to do a PR?

Yes, sorry for the delay. I’m struggling with an article deadline and my thesis. I believe I can do it next month.

1reaction
chenyingongcommented, Apr 6, 2021

Very glad to hear that you will add this new feature. I actually also created a similar early stop callback for my project (the problem I am dealing with usually requires more than 10,000,000+ training steps, so early stop should be quite helpful). Since I am new to deep learning, I will definitely use your implementation rather than mine. Looking forward to that!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Use Early Stopping to Halt the Training of Neural Networks At ...
Often, the first sign of no further improvement may not be the best time to stop training. This is because the model may...
Read more >
Use Early Stopping to halt the training of neural ... - AICorespot
The EarlyStopping callback will cease training after being triggered, but the model at the conclusion of training might not be the model with...
Read more >
How to tell Keras stop training based on loss value?
I found the answer. I looked into Keras sources and find out code for EarlyStopping. I made my own callback, based on it:...
Read more >
tf.keras.callbacks.EarlyStopping | TensorFlow v2.11.0
This callback will stop the training when there is no improvement in # the loss for three consecutive epochs. model = tf.keras.models.
Read more >
Training & evaluation with the built-in methods - Keras
The metrics argument should be a list -- your model can have any number of metrics. If your model has multiple outputs, you...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found