Improve moving data / model to GPU using torchtext
See original GitHub issue🚀 Feature
Improve moving data and model to GPU if torchtext is used.
Motivation
Case 1:
Batch
object generated by torchtext.data.Iterator
doesn’t follow the rules described here https://github.com/PyTorchLightning/pytorch-lightning/blob/45d671a4a81788b9d97fd6b47763816926e58e95/pytorch_lightning/trainer/distrib_parts.py#L420
As the result data is not moved to GPU. torchtext.data.Iterator
is returned by method train_dataloader
. Take in mind that torchtext.data.Iterator
has a device argument that is not properly utilized by pytorch-ligthning.
@ptl.data_loader
def train_dataloader(self):
...
return Iterator(dataset=dataset, batch_size=self.batch_size, shuffle=False, device=DEVICE)
Partially reported here https://github.com/PyTorchLightning/pytorch-lightning/issues/226
Case 2
Using torchtext you can read pre-trained embeddings and create nn.Embedding
object as follows
def train_dataloader(self):
...
self.text_field.build_vocab(
dataset,
vectors=Vectors("/data/embeddings/glove/glove.840B.300d.txt"),
)
self.embeddings = nn.Embedding(
...
padding_idx=self.text_field.vocab.stoi[PAD_TOKEN],
_weight=self.text_field.vocab.vectors.to(DEVICE),
)
nn.Embedding
is clearly dependent on self.text_field.vocab
and this is in turn dependent on dataset
that is used by train_dataloader
. Currently any part of the model that is not created fully in __init__
of the ptl.LigthningModule
is not moved to the GPU. It requires still to have a global variable that determines a device i.e. DEVICE
. It makes Trainer(n_gpus=...)
useless.
Pitch
I would like not to worry about moving data to GPU using torchtext combined with pytorch-lightning.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:28 (24 by maintainers)
@awaelchli i would prefer a hook since that will be generally useful and will likely reduce the long-term maintenance burden
good ideas everyone!