[tune] pbt - checkpointing trials and stopping
See original GitHub issueSystem information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
- Ray version: 0.8.0
- Python version: 3.6.9
Experiment
I have been experimenting with PBT for the training of a convnet in pytorch. Everything is working fine but I am a bit frustrated with the checkpointing. My _save
and _restore
methods are similar to https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_convnet_example.py
I am running an experiment with 4 samples, each one on a single gpu. My stopping condition is an accuracy threshold.
Questions 1 - Is it possible to stop all the trials as soon as one reaches the stopping condition? For example, in the training below, the blue trial reached the stopping condition and the orange should be killed.
2 - The top performers models are saved in temp directories during training. How is it possible for me to recover them if the training crashes / I want to kill it early? I understand that I could checkpoint every epoch with checkpoint_freq
but it seems a bit sub-optimal. Is there a way to save the current best model in the trial local directory?
Issue Analytics
- State:
- Created 4 years ago
- Comments:12 (12 by maintainers)
Top GitHub Comments
see https://ray.readthedocs.io/en/latest/tune-usage.html#custom-stopping-criteria, pass in a stateful custom stopping criteria function as in the example
Set
checkpoint_score_attr
to whatever metric you want to use to determine how to score the checkpoints. Setkeep_checkpoints_num
so that the worst checkpoints are deleted.Thanks for the reply @ujvl Your suggestion should work. I’m thinking about two things:
What’s the right behavior when user only specified
checkpoint_score_attr
andkeep_checkpoints_num
? Right now we don’t do anything unlesscheckpoint_freq
is set. I’m not sure if it’s right.What’s the right behavior when users specify all three parameters,
checkpoint_score_attr
,keep_checkpoints_num
, andcheckpoint_freq
. Right now it will strictly respectcheckpoint_freq
. This is fine for me.