Time Abstraction
See original GitHub issueš Time Abstraction
Motivation
There are various measures of time during training, and we need a common steppable abstraction to handle conversion between units. In the CV community, it is common to track time in terms of samples in batches. However, in NLP, it is more common to track time in terms of tokens and the duration of the training process. Here, we propose a time tracking solution.
Implementation
After discussion with @abhi-mosaic and @moinnadeem, we are leaning towards the following design:
-
1.
Time
objects will simplify time arithmatic. ATime
object consists of an integer and a unit, which will be one ofepochs
,batches
,samples
, ortokens
. Via overloaded functions,Time
objects will support comparisons, addition, and subtraction against other Time objects of same units, and for backwards compatibility, raw integers (though, in that case, aUserWarning
will be emitted). They will also have getters to get the underlying value as an integer and unit. -
2. A
Timer
object, attached to the trainerās state, will trackepochs
,batches
,samples
, andtokens
. Types areTime
objects, except for tokens which may beNone
(for non-NLP training jobs). The timer object will have getters for each of these fields and a single update function that the training loop will call to update the timer at the end of every batch ā e.g.timer.update(samples=X, token=Y)
. -
3. To determine the number of samples and number of tokens, a dataset can provide
get_batch_size(batch)
andget_num_tokens(batch)
. If not specified, the defaultget_batch_size()
will be used, and tokens will NOT be tracked. -
4. Datasets can optionally provide
__len__
andget_num_tokens()
. By pytorch convention,__len__
should be the number of samples in the dataset.get_num_tokens
can either return a constant number, perform some sort of computation upon initialization to determine the number of tokens in the dataset, or (by default) return None if the number of tokens is unknown. -
5. The
max_epochs
property in the trainer hparams will be replaced withmax_duration
, where duration can be specified in terms ofepochs
,steps
,samples
, ortokens
. -
6. The trainer will have a function
trainer.get_elapsed_duration()
that will query the timer object and return a float on[0,1]
representing how much of the training process has been completed (relative to themax_duration
parameter). -
7. The timing module (NOT the timer object) will have a static method like:
convert(time_string, desired_unit, dataset_num_samples: Optional[int] = None, dataset_num_tokens: Optional[int] = None, max_training_duration: Optional[str] = None, batch_size: Optional[int] = None): pass
This static method performs a static conversion between the specified time string and desired unit. Depending on the conversion being performed,
dataset_num_samples
,dataset_num_tokens
,max_training_duration
, and/orbatch_size
will need to be provided. These parameters must be explicitly provided to emphasize that this function is a static conversion, done at the time of conversion, and may be inaccurate if these parameters later change (e.g. an algorithm changes the training duration). The follow conversions are allowed. 1. epochs <-> batches, ifdataset_num_samples
andbatch_size
are defined 1. epochs <-> samples, ifdataset_num_samples
is defined 1. batches <-> samples, ifbatch_size
is defined 1. epochs <-> tokens, ifdataset_num_tokens
is defined. 1. duration <-> unit ofmax_duration
: You can convert a duration string (e.g. ā0.1durā) into the unit (e.g.ep
) ofmax_duration
(e.g. ā90epā) ā e.g. would return9
1 duration <-> other units. If a unit other than that ofmax_duration
is specified, then the conversion will attempt to use one or more of the above conversions to perform it.
- We will rewrite all schedulers to query the time object and perform a closed-form calculation to determine the learning rate, using
timer.get_elapsed_duration
andtimer.get_num_XXX
calls, so they are compatible with datasets of unknown size or tokens. However, this can be done later, and for the time being,timer.convert
calls can be used to properly initialize schedulers upon creation.
TODO
- PR 1: Build out the timer, and use the timer to track progress in the training loop. Update the state object. Should be a non-breaking change.
- PR 2: Update the rest of the codebase to support timing strings (e.g. in schedulers, checkpoint interval, flush intervals, etcā¦). If needed, use
timer.convert
to be compatible with existing pytorch components. - PR 3: Create our own drop-in replacements for the pytorch schedulers that do not depend on
timer.convert
. - PR 4 (can be concurrent, or maybe should be done with PR 3): Update the algorithms. Try to avoid using the timer in the functional form.
See also
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (8 by maintainers)
I donāt have a ton to add atm.
The only thing that comes to mind is that it looks like the scale schedule algorithm modifies
max_epochs
(to becomemax_duration
) in astate
object, which has two implications to me:state
ās (and nothparam
ās)max_duration
time
ortimer
objects are used prior toscale_schedule
being called, and if so, will it cause problems thatscale_schedule
has changedmax_duration
between calls/uses oftime
/timer
?Big +1 here
Re.
batch_size
, I think we are planning ahead for variable-batch size algorithms, which are already used in NLP for warmups. So it would be safer to query the current batch size at each step rather than hard-code it at the start.For schedulers, I think the main concern is that using
scheduler.step
assumes things about how Time passes, and places time state within the Scheduler (which can fall out of sync), whereas the cleaner way would be to treat the scheduler as a stateless function that returns the decay factor given the current Time, something likescheduler.get_factor(timer)
.My hope is that our reimplementations for the common schedulers will actually be cleaner than Pytorchs, almost like one-line functions. And making a custom Scheduler should also be pretty easy. Something like:
And YAMLs can look like:
whereas our current Trainer is only capable of handling this:
What do you think @A-Jacobson ?