Avoid view in no_grad mode in GeneralizedRCNNTransform.batch_images
See original GitHub issue🚀 Feature
Thank you very much for this great project! I’m one of the maintainers of the Adversarial Robustness Toolbox (ART) and torchvision
has been very useful for ART.
Motivation
ART has a lot of tools for adversarial machine learning that require the gradients of the loss at the model input and in the case of this issue for torchvision.models.detection.fasterrcnn_resnet50_fpn
. Since torchvision>=0.8
we are observing the following RunTimeError when calculating loss gradients towards the input tensor, which prevents us from using torch>=1.7
for the related tools in ART:
batch_shape = [len(images)] + max_size
batched_imgs = images[0].new_full(batch_shape, 0)
for img, pad_img in zip(images, batched_imgs):
> pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
E RuntimeError: A view was created in no_grad mode and is being modified inplace with grad mode enabled. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
.../python3.7/site-packages/torchvision/models/detection/transform.py:224: RuntimeError
Pitch
Would it be possible to change https://github.com/pytorch/vision/blob/ef711591a5db69d36f904ab5c39dec13627a58ad/torchvision/models/detection/transform.py#L223-L224 into
for i in range(batched_imgs.shape[0]):
batched_imgs[i][: images[i].shape[0], : images[i].shape[1], : images[i].shape[2]].copy_(images[i])
to avoid the RunTimeError?
If yes, I would be happy to open a pull request with the proposed solution.
Issue Analytics
- State:
- Created 2 years ago
- Comments:12 (10 by maintainers)
Top GitHub Comments
Hi @datumbox Thank you very much! And no worries at all, I see you have lot of other pull requests going on in parallel.
@NicolasHug I think supporting requires_grad for inputs is reasonable, as @datumbox mentioned it is useful for some applications to backprop through the input as well.
@beat-buesser just send a PR fixing this and we will be happy to review / merge it