[tune] saving mechanism and PBT
See original GitHub issueSystem information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): OSX
- Ray installed from (source or binary): from pip
- Ray version: 0.7.2
- Python version: 3.6.8
Describe the problem
Hi everyone, I’ve been playing with the PBT algorithm and I have some feedback. I would also be happy to contribute to the lib for any of those behaviours, but I prefer to start with an issue.
After using PBT I think some behaviours could be relaxed a little bit to allow for more compatibility with different frameworks, in my case TF2. Mainly, it boils down to the log directory.
About PBT: The problem I have is that a trainable is kept “alive” when he is underperforming from the logdir point of view: since trainables are reused with PBT, logdirs are reused too and this doesn’t fit well with Tensorbaord folder mechanism.
Let’s say, Tr1
and Tr2
are my population,
At t_n
, Tr1
happens to perform worst than Tr2
, so he receives the config of Tr2
and gets mutated.
But Tr1
also keeps his own log dir folder because the class is reused and I can’t access the trial_runner
. When I log the learning rate, instead of seeing a new line branching from the point at which Tr2
is cloned in my Tensorboard, I might see a “jump” backward/forward of the learning rate in terms of iteration because I might clone a model which have less or more timestep than the current replaced model.
What I would prefer to see is:
At the exact point of cloning on any metric, I see a new line starting.
This can be done if I can set a new logdir along with the new config but I need to be able to do that up to to the Trial
If I understand tune
well enough.
I could bypass the centralised logging system of tune
but that’s not the point as I would have to rebuild some part of an existing mechanism.
On a side note:
I’m not sure if that’s ok to be able to clone from a less trained model.
I have found myself in an infinite loop where a model keeps underperforming and keeps being
cloned from less trained model blocking other models to be trained, this looks like a bug but I couldn’t pin down the culprit so far.
My understanding of PBT is that all agents must reach timesteps t_n
before killing/cloning/mutating any of the models to keep the flow forward and “continuous”.
Related to #3655
About trainable saving mechanism: In the trainable class, we have those 2 pieces of code.
def save(self, checkpoint_dir=None):
...
if isinstance(checkpoint, string_types):
if (not checkpoint.startswith(checkpoint_dir)
or checkpoint == checkpoint_dir):
raise ValueError(
"The returned checkpoint path must be within the "
"given checkpoint dir {}: {}".format(
checkpoint_dir, checkpoint))
if not os.path.exists(checkpoint):
raise ValueError(
"The returned checkpoint path does not exist: {}".format(
checkpoint))
def save_to_object(self):
...
checkpoint_prefix = self.save(tmpdir)
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
with open(path, "rb") as f:
data[os.path.basename(path)] = f.read()
I believe this is too restrictive because when you save using TF checkpoint mechanisme. You return a file prefix, which is only a prefix. No file exist at this exact prefix, so the save function won’t be happy. On the other hand, If I return a real file, now the second function won’t push into memory all the needed file for TF because it’s not anymore a prefix behaviour.
I think this can be relaxed by replacing the check if not os.path.exists(checkpoint):
in the save
function with a check based on a prefix behaviour: Is there any file starting with this prefix?
Source code / logs
import os
import random
import tensorflow as tf
import ray
from ray.tune import schedulers
from ray.tune import Trainable
from ray import tune
from ray.tune.logger import Logger
from ray.tune.result import TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL
ray.init()
config = {
'seed': 1,
'verbose': False,
'model': {
'h_size': 256,
'nb_h_layer': 3,
'activation': 'leakyrelu',
},
'training': {
'batch_size': 64,
'lr': tune.sample_from(lambda spec: random.uniform(5e-4, 5e-2)),
'optimizer': 'adam',
'reg_param': tune.sample_from(lambda spec: random.uniform(1e-2, 5e1)),
},
'eval': {
'batch_size': 256,
},
'tune': {
'resources_per_trial': {
'cpu': 2,
'gpu': 0,
},
"stop": {
TRAINING_ITERATION: 25,
},
'local_dir': os.path.dirname(os.path.realpath(__file__)) + '/pbt',
'num_samples': 8,
'checkpoint_freq': 0,
'checkpoint_at_end': True,
'verbose': 1,
},
}
tune_scheduler_config = {
'time_attr': TRAINING_ITERATION,
'metric': 'loss',
'mode': 'min',
'perturbation_interval': 5,
'quantile_fraction': .5,
'resample_probability': .25, # par param: 25% new, 75% (50% *1.2, 50% *.8)
'hyperparam_mutations': {
'training': {
'lr': lambda: random.uniform(5e-4, 1e-1),
'reg_param': lambda: random.uniform(1e-2, 5e1),
}
}
}
class PBTTrain(Trainable):
def __getitem__(self, item):
return getattr(self, item)
def _setup(self, config):
self.model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, batch_input_shape=[256, 1], activation='relu'),
tf.keras.layers.Dense(16, activation='tanh'),
tf.keras.layers.Dense(1),
])
self.lr = tf.Variable(config['training']['lr'])
self.reg_param = tf.Variable(config['training']['reg_param'])
self.optim = tf.keras.optimizers.SGD(self.lr)
self.ckpt = tf.train.Checkpoint(
model=self.model,
lr=self.lr,
)
self.new_config = None
def _train(self):
for i in range(100):
input = tf.random.normal([256, 1])
y_true = input**2
with tf.GradientTape() as tape:
output = self.model(input)
reg = 0.
for w in self.model.trainable_weights:
reg += tf.reduce_mean(w)
reg /= len(self.model.trainable_weights)
loss = tf.reduce_mean(tf.square(output - y_true)) + self.reg_param * reg
grads = tape.gradient(loss, self.model.trainable_weights)
self.optim.apply_gradients(zip(grads, self.model.trainable_weights))
# Eval
loss = 0.
for i in range(100):
input = tf.random.normal([256, 1])
y_true = input**2
output = self.model(input)
loss += tf.reduce_mean(tf.square(output - y_true))
loss /= 10
tune_dict = {
'loss': loss.numpy(),
'lr': self.lr.numpy(),
'reg_param': self.reg_param.numpy(),
}
return tune_dict
def _save(self, checkpoint_dir):
save_path_prefix = self.ckpt.write(os.path.join(checkpoint_dir, 'ckpt'))
# This is needed to comply with tune framework
open(save_path_prefix, 'a').close()
return save_path_prefix
def _restore(self, save_path_prefix):
self.ckpt.restore(save_path_prefix)
if self.new_config is not None:
for key, val in self.new_config.items():
self[key].assign(val)
self.new_config = None
return
def reset_config(self, new_config):
self.new_config = {
'lr': new_config['training']['lr'],
'reg_param': new_config['training']['reg_param'],
}
return True
# A simple TF2 logger implementation based on TFLogger
class TF2Logger(Logger):
def _init(self):
self._file_writer = tf.summary.create_file_writer(self.logdir)
def on_result(self, result):
with tf.device('/CPU:0'):
with self._file_writer.as_default():
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
tmp = result.copy()
for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
if k in tmp:
del tmp[k] # not useful to log these
for attr, value in tmp.items():
if type(value) in [int, float]:
tf.summary.scalar(attr, value, step=step)
self._file_writer.flush()
def flush(self):
self._file_writer.flush()
def close(self):
self._file_writer.close()
scheduler = schedulers.PopulationBasedTraining(**tune_scheduler_config)
trial_list = tune.run(
PBTTrain,
config=config,
scheduler=scheduler,
loggers=[TF2Logger],
**config['tune']
)
Issue Analytics
- State:
- Created 4 years ago
- Comments:11 (10 by maintainers)
Top GitHub Comments
Sorry, I didn’t have much time lately but I will prototype it and create a PR soon. I’m ok with your suggestions so far, just need to code to see if it fits well. 😊
I think that’s fair. One option would be modify PBT to include a
trial.close_logger(); trial.init_logger()
during exploration.I’m not quite sure about the dumping 1 eval datapoint; although this would not be hard (you can just log
trial.last_result
).Could you prototype this and see how this feels? Just want to make sure there’s a good user experience.