Multi-GPU scattering2d [torch]
See original GitHub issueHi everyone,
Close #640
In order for scattering2d torch implementation to fully leverage the fact that it inherits nn.Module and thus be parallelizable with nn.DataParallel, I believe the following lines shall be modified in kymatio-v2 branch:
to respectively simply:
self.register_single_filter(phi, n)
self.register_single_filter(v, n)
and:
phis = copy.deepcopy(self.phi)
psis = copy.deepcopy(self.psi)
Indeed, self.phi and self.psi being dicts, at each forward pass, the replica models built by the replicate.py function of DataParallel on each GPU device all share the same underlying self.phi and self.psi dicts. If we assign the named buffers directly to those dicts (in the first 2 lines I have mentioned), then as those buffers on the other hand are replicated separately on each GPU, this means ultimately that the self.phi[c] and self.psi[j][k] will only point to one GPU while inputs will be scattered on all GPUs and as such will lead to a TypeError: Input and filter must be on the same GPU.
The problem is similar for the other 2 lines and one workaround is thus for each replica model to have its own copy of the phi and psi dicts. Another workaround would be to pass as well a buffer dict in the scattering call:
and load the filters within the scattering core function (but would be less generic).
Proposed solution seems to work in multi-gpu by for instance slightly modifying following lines of cifar.py in examples/2d:
by:
if use_cuda:
scattering = torch.nn.DataParallel(scattering).cuda()
model = Scattering2dCNN(K, args.classifier)
if use_cuda:
model = torch.nn.DataParallel(Scattering2dCNN(K,args.classifier)).cuda()
# DataLoaders
Seems to work most of the time with 2 GPUs, a bit more randomly with 4 GPUs where I can get sometimes a Segmentation fault (core dumped) issue, which using faulthandler and faulthandler.enable() gives the following error:
Thread 0x00007f4b5c889700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/scattering2d/backend/torch_backend.py", line 231 in fft
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/scattering2d/core/scattering2d.py", line 23 in scattering2d
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/scattering2d/frontend/torch_frontend.py", line 126 in scattering
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/frontend/torch_frontend.py", line 20 in forward
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532 in __call__
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60 in _worker
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Current thread 0x00007f4b5d08a700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/scattering2d/backend/torch_backend.py", line 231 in fft
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/scattering2d/core/scattering2d.py", line 23 in scattering2d
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/scattering2d/frontend/torch_frontend.py", line 126 in scattering
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/kymatio/frontend/torch_frontend.py", line 20 in forward
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532 in __call__
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60 in _worker
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Thread 0x00007f4b5f7fe700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 296 in wait
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 224 in _feed
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Thread 0x00007f4b5ffff700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 296 in wait
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 224 in _feed
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Thread 0x00007f4b80ff9700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 296 in wait
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 224 in _feed
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Thread 0x00007f4b817fa700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 296 in wait
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 224 in _feed
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Thread 0x00007f4b81ffb700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/selectors.py", line 415 in select
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 920 in wait
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 414 in _poll
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 257 in poll
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/multiprocessing/queues.py", line 104 in get
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 25 in _pin_memory_loop
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 870 in run
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 926 in _bootstrap_inner
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 890 in _bootstrap
Thread 0x00007f4ceeacb700 (most recent call first):
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 1060 in _wait_for_tstate_lock
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/threading.py", line 1044 in join
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 77 in parallel_apply
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162 in parallel_apply
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152 in forward
File "/users/data/zarka/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532 in __call__
File "cifar.py", line 75 in train
File "cifar.py", line 177 in main
File "cifar.py", line 184 in <module>
Tested on Ubuntu 16.04 and 18.04, torch 1.4.0, torchvision 0.5.0 (got similar behaviors with 1.3.1 and 0.4.2)
Issue Analytics
- State:
- Created 4 years ago
- Comments:14
Top GitHub Comments
@MuawizChaudhary @eickenberg it’s fixed on my machine. Can you confirm?
Please close only an issue when it’s fixed… and otherwise refer to the fix in the issue…