DGL is not compatible with pytorch_lightning in the official SAGE examples since version 0.7.2
See original GitHub issue🐛 Bug
DGL is not compatible with pytorch_lightning in the official examples since version 0.7.2
To Reproduce
Following the examples in
https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_lightning.py
I want to train the graph model with pytorch_lightning
with a single GPU, however it report a DGLError in DGL v0.7.2.
My Guess
I found that it is because DGL v0.7.2 adds a new feature to make the NodeDataLoader
to call sampler.set_epoch
only when use_ddp
is enable.
line link
if self.use_ddp:
if self.use_scalar_batcher:
self.scalar_batcher.set_epoch(epoch)
else:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
However, in pytorch_lightning
, though I’m not famaliar with its internal implementation, I have some guesses from the code pieces below.
line link
if self.trainer.train_dataloader is not None and callable(
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
It seems that it assumes the sampler which has the set_epoch
method as the distributed sampler ( which is true in pytorch environment, but not true in DGL). So it alway calls set_epoch
from the DGL sampler even when use_ddp=False
.
the compromise solution
If I replace the offical dgl.dataloading.NodeDataLoader
with the inherited loader which overwrites the set_epoch
method. It will work.
class MyLoader(dgl.dataloading.NodeDataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def set_epoch(self, epoch):
if self.use_scalar_batcher:
self.scalar_batcher.set_epoch(epoch)
else:
self.dist_sampler.set_epoch(epoch)
So I am curious that whether to check use_ddp=True
is neccessary?
Environment
- DGL Version (e.g., 0.7.2):
- pytorch_lightning v1.5.2
- torch v1.10.0
Detailed Error Information
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1193
1194 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1195 self._dispatch()
1196
1197 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
1272 self.training_type_plugin.start_predicting(self)
1273 else:
-> 1274 self.training_type_plugin.start_training(self)
1275
1276 def run_stage(self):
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
200 def start_training(self, trainer: "pl.Trainer") -> None:
201 # double dispatch to initiate the training loop
--> 202 self._results = trainer.run_stage()
203
204 def start_evaluating(self, trainer: "pl.Trainer") -> None:
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
1282 if self.predicting:
1283 return self._run_predict()
-> 1284 return self._run_train()
1285
1286 def _pre_training_routine(self):
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1312 self.fit_loop.trainer = self
1313 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1314 self.fit_loop.run()
1315
1316 def _run_evaluate(self) -> _EVALUATE_OUTPUT:
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
142 while not self.done:
143 try:
--> 144 self.on_advance_start(*args, **kwargs)
145 self.advance(*args, **kwargs)
146 self.on_advance_end()
~/.conda/envs/ptorch/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py in on_advance_start(self)
214 ):
215 # set seed for distributed sampler (enables shuffling for each epoch)
--> 216 self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
217
218 # changing gradient according accumulation_scheduler
~/.conda/envs/ptorch/lib/python3.8/site-packages/dgl/dataloading/pytorch/dataloader.py in set_epoch(self, epoch)
557 self.dist_sampler.set_epoch(epoch)
558 else:
--> 559 raise DGLError('set_epoch is only available when use_ddp is True.')
560
561 class EdgeDataLoader:
DGLError: set_epoch is only available when use_ddp is True.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:3
- Comments:10 (2 by maintainers)
@decoherencer While I’m working on cleaning up the existing PyTorch Lightning example (and extending it to multiple GPUs), you can run the existing examples in
examples/pytorch/graphsage/advanced/train_lightning.py
. I confirmed that it works with PyTorch Lightning 1.5.10.Is this included in the v0.8 release? Thanks