Early experiments with Multi-GPU training and PyTorch Lightning
See original GitHub issueI spent a chunk of the weekend experimenting with multi-GPU training using PyTorch lightning. I’m starting up a discussion here to document the steps I took and to hopefully start brainstorming how we can bring multi-GPU training to DeepChem.
- I set up a multi-GPU instance on Sagemaker following the guide at https://forum.deepchem.io/t/making-a-sagemaker-deepchem-development-environment/530. Once the multi-GPU instance was up, I could verify that I could access 8 K80 GPUs by running
nvidia-smi
- I started with PyTorch lightning following this tutorial https://towardsdatascience.com/trivial-multi-node-training-with-pytorch-lightning-ff75dfb809bd. Parts of the tutorial were out of date but I was able to use the docs to figure out needed changes https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html.
- The API was very easy; I just set
gpus=4
in theTrainer
call and it worked out of box. I was able to verify that multiple GPUs were being used by callingnvidia-smi
. - GPU utilization was quite low (5% per GPU). I was just using the tutorial code and made no attempts to optimize, but I suspect that getting good utilization may take some effort.
Here are a first few thoughts on how to get multi-GPU to DeepChem:
- The “simplest” strategy would be to start using PyTorch Lightning in
TorchModel
. The challenge here is that PyTorch Lightning has a good amount of “magic” required. Instead ofnn.Module
, you need to usepl.LightningModule
. It’s not clear how well other frameworks like DGL/PyG work with PyTorch Lightning yet and we depend heavily on these frameworks for graphconv primitives. PyTorch Lightning is also under rapid development and not entirely API stable. I found a lot of broken tutorials and discussions on forums explaining that flags/arguments had changed. - As one potential alternative, we could try using lower level distributed primitives from PyTorch directly (see https://pytorch.org/tutorials/beginner/dist_overview.html) for distributed training. I haven’t yet experimented with these myself so I’m not sure how easy/hard this would be.
It’s not yet clear to me what the right path to multi-GPU is for DeepChem. We have many models using many backend frameworks. We’re shifting towards PyTorch as our mainstay but even there we have PyG/DGL dependencies. The ideal implementation would be to upgrade TorchModel
to support distributed training so all DeepChem Torch models are distributed out of box but this will take some careful work. I think this will be a very powerful feature for us since many new models for the sciences (AlphaFold-2, ChemBERTa, ProteinBERT, etc) all depend on large scaling training.
Here are a couple suggested questions for us to think about:
- Are there any serious blockers to adopting PyTorch Lightning beyond API instability? (For example, some serious incompatibility with DGL/PyG)
- How does PyTorch Lightning implemented distributed training under the hood? Can we try to understand how Lightning does it and recreate similar infrastructure directly in
TorchModel
? - How serious an issue is GPU utilization? Will we need to make infrastructure upgrades to
DiskDataset
to really get the most out of multi-GPU?
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (5 by maintainers)
Top GitHub Comments
@ncfrey A tutorial along the lines you mention would be great! It could be a really nice way to scale DeepChem models right away without needing to add heavy duty new infrastructure. If you have bandwidth, a tutorial PR would be an awesome contribution 😃
I recently migrated completely to PL for distributed model training, so this is great!
I think to fully utilize PL, you simply need your data available as a
LightningDataModule
and you can inject your PyTorch model into aLightningModule
. I really like this framework because the PyTorch dataloaders and models remain exactly the same - you simply wrap them in the corresponding PL classes.Following that approach, to @peastman’s point, it isn’t even necessary to directly support PL as a dependency. There could be a tutorial that shows how to take any DeepChem PyTorch model and dataset, wrap them in the PL style, and do distributed training. I am doing this already so it would be easy to put together.