Issues: Dynamic Batching
See original GitHub issue(Posting a couple issues to get features upstreamed from OpenNMT-py, cc @da03)
For the transformer model, we need some improvements to dynamic batching. In particular the batch_size_fn interface has some issues. Let me give an example.
We need batches of 4096 tokens (including padding).
- We can’t really do this with with batch_size_fn because while it lets us count the total number of tokens, it doesn’t let us account for padding (max size in the batch). One bad example either causes tons of padding, or a huge batch and an OOM.
Our current terrible hack:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src) + 2)
max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt) + 1)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
- Iterator uses this line to buffer data for batching:
for p in batch(data, batch_size * 100, batch_size_fn):
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L271
Unfortunately even though there is a 100 here, it doesn’t help because if we are counting padding in batch_size_fn
then on long example will make every other sentence take a ton of space. Think we need control.
Our current hack (don’t use batch_size_fn for buffering):
for p in torchtext.data.batch(data, self.batch_size * 100):
- Minor: Batching use
sort
for two different purposes. One to find the batches themselves, and the other for the order in which the batch is created. I would like to be able to have abatch_construction_sort
to find sentences of the same length and then anbatch_sort
for each in batch. For example: in MT I would like to sort by a weighted src x tgt len in batch_construction (to minimize padding), but then have the batch itself sorted by src len to make cudnn work.
Issue Analytics
- State:
- Created 6 years ago
- Comments:12 (4 by maintainers)
Top GitHub Comments
@patelrajnath I’ll instead suggest that you have a look at
huggingface
who has a similar implementation (which I just found out recently), which you can use to compare or get inspiration: https://github.com/huggingface/transformers/blob/c2cd02ac625cd0ab64cf42124ad71ce9158fb67c/src/transformers/trainer_pt_utils.py#L501@patelrajnath for my usecase I’ve implemented a subclass of
torch.utils.data.Sampler
that simple generates lists of indices, corresponding to dynamically generated batches, based on the amount of tokens within the batch.For each call of
__iter__
I do the following:Tuple[int, Tuple[int, int]]
which mapsDataset
indices to thesrc_len, tgt_len
tuples.src_len
thentgt_len
batch_size * max(src_len) + batch * max(tgt_len)
is add your desired max batch tokens.