PyTorchLightningPruningCallback not pruning
See original GitHub issueI’ve only been using pytorch lightning and optuna for a few days (kudos, it literally took less than 30 minutes to implement with optuna, I’ve been using hyperopt and optuna has a much nicer API).
It seems to me though that PyTorchLightningPruningCallback
implements on_epoch_end
when it should in fact implement on_validation_end
. Otherwise, the MedianPruner
doesn’t appear to get invoked.
Expected behavior
PyTorchLightningPruningCallback
to be invoked on each epoch, if you rename the PyTorchLightningPruningCallback.on_epoch_end
to PyTorchLightningPruningCallback.on_validation_end
the MedianPruner
is invoked.
Alternatively I just removed the early_stop_callback
- then pruning does work, but I’m not sure what strategy it’s using, I’m pretty sure it’s not the optuna.pruners.MedianPruner
, I guess it’s a default pytorch_lightning pruner?
Environment
- Optuna version: 1.5.0
- Python version: 3.7
- OS: Ubuntu 18.04
- (Optional) Other libraries and their versions:
pytorch-lightning 0.8.1
Additional context (optional)
I simplified my objective (the example doesn’t log to tensorboard, I found that by commenting out much of the code it started logging correctly). I’m training a TCN with only 1 batch per epoch. I also modified the example so as to not pass the trial object into my LightningModule
def objective(trial):
metrics_callback = MetricsCallback()
trainer = pl.Trainer(
callbacks=[metrics_callback],
early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="avg_val_loss"),
)
layers = trial.suggest_int("layers", 1, 8)
filters = trial.suggest_int("filters", 1, 32)
kernel_size = trial.suggest_int("kernel_size", 2, 3)
dropout = trial.suggest_uniform("dropout", 0.2, 0.5)
learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)
model = TCN(PATH, layers, filters, kernel_size, dropout, learning_rate)
trainer.fit(model)
return metrics_callback.metrics[-1]["avg_val_loss"].item()
pretty sure the main code is essentially the same as the example
hparams = vars(args)
tcn = TCN(**hparams)
pruner = optuna.pruners.MedianPruner()
trainer = pl.Trainer.from_argparse_args(args, pruner=pruner)
trainer.fit(tcn)
Issue Analytics
- State:
- Created 3 years ago
- Comments:22 (2 by maintainers)
Top GitHub Comments
of course, no problem. Let me try to get a minimal working example and I can make a new issue. Thanks
has this issue been resolved? I see some pruning behavior in the simple example for