question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Batched Augmentations Fail With `torch.autocast()` When `0 < p < 1`

See original GitHub issue

Describe the bug

When running batched augmentations in mixed precision, kornia bugs out with the following error:

RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Interestingly, this bug doesn’t happen when p=0.0 or p=1.0, but only when it’s between these. My hunch is there’s some erroneous dangling hidden state that’s causing this.

Reproduction steps

import torch
import kornia as K

aug = K.augmentation.RandomGaussianBlur((3,3), (0.1, 3), p=0.0)
x = torch.rand(2,3,384,640).cuda()

for i in range(100):
    with torch.autocast("cuda", dtype=torch.half):
        aug(x)
print("Passed with p=0.0")


aug = K.augmentation.RandomGaussianBlur((3,3), (0.1, 3), p=1.0)
for i in range(100):
    with torch.autocast("cuda", dtype=torch.half):
        aug(x)
print("Passed with p=1.0")


aug = K.augmentation.RandomGaussianBlur((3,3), (0.1, 3), p=0.5)
for i in range(100):
    with torch.autocast("cuda", dtype=torch.half):
        aug(x)
print("Passed with p=0.5")

Expected behavior

Run augmentations without any errors

Environment

Inside a Docker environment, built from the base NVIDIA NGC image nvcr.io/nvidia/pytorch:21.12-py3

PyTorch version: 1.11.0a0+b6df043
Is debug build: False
CUDA used to build PyTorch: 11.5
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.21.3
Libc version: glibc-2.31

Python version: 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.13.0-44-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.5.50
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090

Nvidia driver version: 510.73.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.2
[pip3] nvidia-dlprof-pytorch-nvtx==1.8.0
[pip3] pytorch-lightning==1.6.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.11.0a0+b6df043
[pip3] torch-tensorrt==1.1.0a0
[pip3] torchmetrics==0.7.0
[pip3] torchtext==0.12.0a0
[pip3] torchvision==0.11.0a0
[conda] magma-cuda110             2.5.2                         5    local
[conda] mkl                       2019.5                      281    conda-forge
[conda] mkl-include               2019.5                      281    conda-forge
[conda] numpy                     1.21.2                   pypi_0    pypi
[conda] nvidia-dlprof-pytorch-nvtx 1.8.0                    pypi_0    pypi
[conda] pytorch-lightning         1.6.0                    pypi_0    pypi
[conda] pytorch-quantization      2.1.2                    pypi_0    pypi
[conda] torch                     1.11.0a0+b6df043          pypi_0    pypi
[conda] torch-tensorrt            1.1.0a0                  pypi_0    pypi
[conda] torchmetrics              0.7.0                    pypi_0    pypi
[conda] torchtext                 0.12.0a0                 pypi_0    pypi
[conda] torchvision               0.11.0a0                 pypi_0    pypi

Additional context

This issue was originally opened in pytorch-lightning (https://github.com/PyTorchLightning/pytorch-lightning/issues/13228) , but as they pointed out, the bug is in kornia.

This may potentially be the same issue as https://github.com/kornia/kornia/issues/1477, which was closed thinking the bug was in lightning.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:6 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
rsomani95commented, Jun 7, 2022

@edgarriba Fair enough, thanks for the quick response. I’d love to be able to help but I’m out of my depth with how pytorch is casting types. If I make any progress / workaround, I’ll definitely update on this thread.

0reactions
t-vicommented, Oct 5, 2022

I didn’t look at it in detail, but @JanSellner 's approach of being more diligent with casting as needed sees like the right way to do it.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Automatic Mixed Precision — PyTorch Tutorials 1.12.1+cu102 ...
Mixed precision tries to match each op to its appropriate datatype, which can reduce your network's runtime and memory footprint.
Read more >
Quick tour - Hugging Face
We're on a journey to advance and democratize artificial intelligence through open source and open science.
Read more >
Automatic Mixed Precision Using PyTorch - Paperspace Blog
In this overview of Automatic Mixed Precision (AMP) training with PyTorch, we demonstrate how the technique works, walking step-by-step ...
Read more >
Raw PyTorch loop (expert)
import torch from torch import nn from torch.utils.data import DataLoader, Dataset class ... model.train() for epoch in range(args.num_epochs): for batch in ...
Read more >
Automatic Mixed Precision Training for Deep Learning using ...
To use autocasting, we need to use the torch.cuda.amp.autocast() package. ... X = images self.y = labels # apply augmentations if tfms ==...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found