question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

PyTorchLightningPruningCallback not pruning

See original GitHub issue

I’ve only been using pytorch lightning and optuna for a few days (kudos, it literally took less than 30 minutes to implement with optuna, I’ve been using hyperopt and optuna has a much nicer API).

It seems to me though that PyTorchLightningPruningCallback implements on_epoch_end when it should in fact implement on_validation_end. Otherwise, the MedianPruner doesn’t appear to get invoked.

Expected behavior

PyTorchLightningPruningCallback to be invoked on each epoch, if you rename the PyTorchLightningPruningCallback.on_epoch_end to PyTorchLightningPruningCallback.on_validation_end the MedianPruner is invoked.

Alternatively I just removed the early_stop_callback - then pruning does work, but I’m not sure what strategy it’s using, I’m pretty sure it’s not the optuna.pruners.MedianPruner, I guess it’s a default pytorch_lightning pruner?

Environment

  • Optuna version: 1.5.0
  • Python version: 3.7
  • OS: Ubuntu 18.04
  • (Optional) Other libraries and their versions:

pytorch-lightning 0.8.1

Additional context (optional)

I simplified my objective (the example doesn’t log to tensorboard, I found that by commenting out much of the code it started logging correctly). I’m training a TCN with only 1 batch per epoch. I also modified the example so as to not pass the trial object into my LightningModule

def objective(trial):
    metrics_callback = MetricsCallback()
    trainer = pl.Trainer(
        callbacks=[metrics_callback],
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="avg_val_loss"),
    )

    layers = trial.suggest_int("layers", 1, 8)
    filters = trial.suggest_int("filters", 1, 32)
    kernel_size = trial.suggest_int("kernel_size", 2, 3)
    dropout = trial.suggest_uniform("dropout", 0.2, 0.5)
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)
    model = TCN(PATH, layers, filters, kernel_size, dropout, learning_rate)
    trainer.fit(model)

    return metrics_callback.metrics[-1]["avg_val_loss"].item()

pretty sure the main code is essentially the same as the example

    hparams = vars(args)
    tcn = TCN(**hparams)
    pruner = optuna.pruners.MedianPruner()

    trainer = pl.Trainer.from_argparse_args(args, pruner=pruner)
    trainer.fit(tcn)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:22 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
tfurmstoncommented, May 25, 2021

of course, no problem. Let me try to get a minimal working example and I can make a new issue. Thanks

1reaction
vedalcommented, Apr 20, 2021

has this issue been resolved? I see some pruning behavior in the simple example for

python 3.8.5
pytorch-lightning==1.2.8
optuna==2.7.0

Read more comments on GitHub >

github_iconTop Results From Across the Web

optuna.integration.PyTorchLightningPruningCallback
PyTorch Lightning callback to prune unpromising trials. See the example if you want to add a pruning callback which observes accuracy. Parameters. trial...
Read more >
Optuna vs Hyperopt: Which Hyperparameter Optimization ...
Run Pruning. Not all hyperparameter configurations are created equal. For some of them you can tell very quickly that they will not produce...
Read more >
optuna/optuna - Gitter
Hi, this may be a stupid question. I'm trying to use multi-objective optimization in Optuna 2.5.0 and I'm not sure how to turn...
Read more >
optuna.integration.PyTorchLightningPruningCallback Example
None: no change in verbosity level (equivalent to verbose=1 by optuna-set default). * 0 or False: log only warnings. * 1 or True:...
Read more >
PyTorch Lightning V1.2.0- DeepSpeed, Pruning, Quantization ...
Quantization not only reduces the model size but also speeds up loading since operations on fixpoint are faster than on floating-point. Quantization Aware ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found