question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Backpropagation through Translation

See original GitHub issue

Hi, Thank you for the wonderful package ! I have been trying to learn translation (x, y) parameters in the following manner:

`class DTranslation(nn.Module): def init(self, x_translation, y_translation): super(DTranslation, self).init() self.translations = torch.stack([x_translation, y_translation], 1) self.angle = torch.tensor([0])

def forward(self, input):
    _, _, h, w = input.shape
    if self.angle.shape[0] != input.shape[0]:
        angle = self.angle.repeat(input.shape[0])
    else:
        angle = self.angle
        
    if self.translations.shape[0] != input.shape[0]:
        translations = self.translations.repeat([input.shape[0], 1])
    else:
        translations = self.translations
    
    translations[:, 0] *= h
    translations[:, 1] *= w
    
    # define the rotation center
    center = torch.ones(2)
    center[..., 0] = input.shape[3] / 2  # x
    center[..., 1] = input.shape[2] / 2  # y
    center = center.repeat(input.shape[0], 1)

    # define the scale factor
    scale = torch.ones(input.shape[0])

    # compute the transformation matrix
    M = kornia.get_rotation_matrix2d(center, -angle, scale)
    
    # Translate
    M[..., 2] += translations  # tx/ty
    
    # apply the transformation to original image
    out = kornia.warp_affine(input, M, dsize=(h, w))
    
    return out

tx = torch.tensor([0.3], dtype=torch.float32) tx_p = Parameter(tx, requires_grad=True)

ty = torch.tensor([-0.2], dtype=torch.float32) ty_p = Parameter(ty, requires_grad=True)

translation = DTranslation(x_translation=tx_p, y_translation=ty_p)

criterion = nn.MSELoss() optimizer = optim.Adam([tx_p, ty_p], lr=1)

for x, y in dataloader: optimizer.zero_grad() loss = criterion(x, translation(x)) loss.backward() optimizer.step()`

The first backward call passes, but the second one fails:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time. When I follow the instructions and instead I use: loss.backward(retain_graph=True) I receive the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). I also tried to avoid the inplace operation: M[..., 2] += translations and use instead: shape = list(M.shape) shape[-1] -= 1 M = M + torch.cat([torch.zeros(shape), translations.unsqueeze(-1)], 2) but got the same errors.

Maybe another inplace operation causes this ? e.g. https://github.com/kornia/kornia/blob/5a736409a9a133da27c3dfa581bba2bc71f27286/kornia/geometry/conversions.py#L122

Or is it something else ?

It is worth mentioning that I do manage to backpropagated through rotation angle and shear (x, y) but only translation seems to be the problem, i.e. when I comment out M[..., 2] += translations No errors occur, but of course I cannot learn the translation parameters either.

Any thoughts ?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:11 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
shijianjiancommented, Sep 22, 2020

Hi @NivNayman, impressive gif you have made!

We are currently actively looking for tutorials/examples using Kornia, especially on those demonstrated differentiabilities in a world use cases. I think your example can be a very impressive one. Let me know if you are okay with making a jupyter notebook!

1reaction
NivNaymancommented, Sep 21, 2020

Wonderful ! It works now: learn_x_translation_0 30000001192092896_y_translation_-0 20000000298023224_entropy_0 01_step_

Indeed the only problem was placing translations = torch.stack([self.x_translation, self.y_translation], 1) at the init() rather than in the forward().

FYI all the inplace operations doesn’t interfere with the differentiation as those are all differentiable, i,e, both: translations[:, 0] *= h translations[:, 1] *= w And: M[..., 2] += translations

Closing the issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Backpropagation-Based Decoding for Multimodal Machine ...
Our experiments show that backpropagation-based decoding coupled with transformer-based models can produce reasonable translations among all ...
Read more >
A Gentle Introduction to Backpropagation Through Time
Backpropagation Through Time, or BPTT, is the training algorithm used to update weights in recurrent neural networks like LSTMs.
Read more >
Backpropagation-Based Decoding for Multimodal ... - NCBI - NIH
Machine translation can be done by using images as a bridge between language pairs. We show that translations between German and Japanese is ......
Read more >
9.7. Backpropagation Through Time - Dive into Deep Learning
Applying backpropagation in RNNs is called backpropagation through time (Werbos, 1990). This procedure requires us to expand (or unroll) the computational graph ...
Read more >
Backpropagation - Wikipedia
In machine learning, backpropagation (backprop, BP) is a widely used algorithm for training feedforward artificial neural networks.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found