Support DistributedDataParallel and DataParallel, and publish Python package
See original GitHub issueFirst of all, thank you for the great package!
1. Support DistributedDataParallel and DataParallel
I’m working on large-scale experiments that takes pretty long for training, and wondering if this framework can support DataParallel
and DistributedDataParallel
.
The current example/train.py looks like supporting Dataparallel
as CustomDataParallel
, but returned the following error
Traceback (most recent call last):
File "examples/train.py", line 369, in <module>
main(sys.argv[1:])
File "examples/train.py", line 348, in main
args.clip_max_norm,
File "examples/train.py", line 159, in train_one_epoch
out_net = model(d)
File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 160, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in replicate
return replicate(module, device_ids, not torch.is_grad_enabled())
File "/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 140, in replicate
param_idx = param_indices[param]
KeyError: Parameter containing:
tensor([[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]],
[[-10., 0., 10.]]], device='cuda:0', requires_grad=True)
(pipenv run python examples/train.py --data ./dataset/ --batch-size 4 --cuda
on a machine with 3 GPUs)
When commenting out these two lines https://github.com/InterDigitalInc/CompressAI/blob/master/examples/train.py#L333-L334 , it looks working well
/home/yoshitom/.local/share/virtualenvs/yoshitom-lJAkl1qx/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported.
warnings.warn("Setting attributes on ParameterList is not supported.")
Train epoch 0: [0/5000 (0%)] Loss: 183.278 | MSE loss: 0.278 | Bpp loss: 2.70 | Aux loss: 5276.71
Train epoch 0: [40/5000 (1%)] Loss: 65.175 | MSE loss: 0.096 | Bpp loss: 2.70 | Aux loss: 5273.95
Train epoch 0: [80/5000 (2%)] Loss: 35.178 | MSE loss: 0.050 | Bpp loss: 2.69 | Aux loss: 5271.21
Train epoch 0: [120/5000 (2%)] Loss: 36.634 | MSE loss: 0.052 | Bpp loss: 2.68 | Aux loss: 5268.45
Train epoch 0: [160/5000 (3%)] Loss: 26.010 | MSE loss: 0.036 | Bpp loss: 2.68 | Aux loss: 5265.67
...
Could you please fix the issue and also support DistributedDataParallel
?
If you need more examples to identify the components causing this issue, let me know. I have a few more examples (error messages) for both DataParallel
and DistributedDataParallel
with different network architectures (containing CompressionModel
).
2. Publish Python package
It would be much more useful if you can publish this framework as a Python package so that we can install it with pip install compressai
Thank you!
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (7 by maintainers)
We’ll revisit DDP support at a later date.
When will the code be supported DistributedDataParallel?? looking forward ~