Add GPU support
See original GitHub issue🚀 Feature
Add support to transform images in GPU devices.
Motivation
Some transforms would benefit a lot from hardware acceleration. I suspect the two main sources of improvement would be resampling (Resample
, RandomElasticDeformation
, RandomAffine
, RandomAnisotropy
, RandomMotion
) and Fourier transforms (RandomMotion
, RandomSpike
, RandomGhosting
).
Most users want to keep their GPU for training and not for preprocessing / augmentation, but if there are enough resources available, it’s nice to add GPU support to transforms.
Pitch
Supporting FFT is easy and seems to help, using the new torch.fft
module in PyTorch 1.7. I have added some support in the fourier
branch.
On CPU:
In [1]: import torchio as tio
...: t1 = tio.datasets.FPG().t1
...: t1.load()
...: transform = tio.RandomSpike()
...: %timeit transform(t1)
1.35 s ± 3.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
On GPU:
In [1]: import torchio as tio
...: t1 = tio.datasets.FPG().t1
...: t1.load()
...: transform = tio.RandomSpike()
...: %timeit transform(t1)
155 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Supporting resampling is more complex, because of the way PyTorch deals with coordinates etc. Ideally, we would be able to convert from a “world” affine transform to a PyTorch one.
Some discussions about converting to/from PyTorch conventions for affine transformations:
- Affine transformation matrix paramters conversion
- Unexpected behaviour for affine_grid and grid_sample with 3D inputs
- Generating pytorch’s theta from affine transform matrix
The steps for this transition to happen would be:
- Make sure everything works normally with tensors on GPU
- Make sure the run time for FFT transforms is improved
- Figure out how to resample medical images properly using PyTorch
- Make sure the run time for resampling transforms is improved
- Check if run time for other transforms gets better as well
- Test that everything works as before, on multiple PyTorch versions
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:12 (10 by maintainers)
Top GitHub Comments
If we ever start tackling this, here are a couple of sources that might be useful:
I think a good reason for doing this would be to support differentiable augmentation. This was done in StyleGAN2-ADA to train GANs using limited data. The augmentations are performed on the generated images before they are input to the discriminator, so they have to be differentiable.