question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Feat] Improve device switching for transforms

See original GitHub issue

🚀 Feature

Improve device switching for transforms

Motivation

Currently tensors are stored as is within transforms (I’m using kornia.Rotate as example here, but this same is valid for many other transforms). This, makes the switching between devices inconvenient, since torch.nn.Module.to() does not affect them:

import torch
import kornia

angle = torch.tensor(30.0).view(1, -1)

transform = kornia.Rotate(angle)
print(transform.angle.device)

transform = transform.to("cuda")
print(transform.angle.device)
cpu
cpu

Pitch

Register all tensors as buffers:

from torch import nn


class RotateMock(nn.Module):
    def __init__(self, angle: torch.Tensor) -> None:
        super().__init__()
        self.register_buffer("angle", angle)

transform = RotateMock(angle)
print(transform.angle.device)

transform = transform.to("cuda")
print(transform.angle.device)
cpu
cuda:0

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:1
  • Comments:9 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
pmeiercommented, Oct 19, 2020

Nope. I’ll see if I have time this week to send a PR. Otherwise I’ll get back to you.

1reaction
shijianjiancommented, Sep 9, 2020

It shall be enough.

import torch
from torch import nn


class RotateMock(nn.Module):
    def __init__(self, angle: torch.Tensor) -> None:
        super().__init__()
        self.angle = nn.Parameter(angle)
    def forward(self, input):
        return input * self.angle

r = RotateMock(torch.tensor(0.5, requires_grad=True))
res = r(torch.tensor(2.))

criterion = torch.nn.L1Loss() 
optimizer = torch.optim.SGD(r.parameters(), lr=10)

loss = criterion(res, torch.tensor(5.))

loss.backward()
optimizer.step()

print(r.angle, r.angle.grad)
Parameter containing:
tensor(20.5000, requires_grad=True) tensor(-2.)
Read more comments on GitHub >

github_iconTop Results From Across the Web

arXiv:2101.03961v3 [cs.LG] 16 Jun 2022
These models improve the pre-training speed of a strongly tuned T5-XXL baseline by. 4x. 2. Switch Transformer. The guiding design principle for ...
Read more >
Impedance Matching Devices - Mini-Circuits Blog
Explore impedance matching techniques and understand the differences between the various types of impedance matching devices, ...
Read more >
A Review of Resistive Switching Devices: Performance ...
A brief introduction to the RS mechanisms and materials is provided, followed by a detailed discussion of the performance improvement methods ...
Read more >
Bandwidth Improvement of MMIC Single-Pole-Double-Throw ...
In this paper, we propose a new configuration for improving the isolation bandwidth of MMIC single-pole-double-throw (SPDT) passive high-electron-mobility ...
Read more >
Advances of RRAM Devices: Resistive Switching Mechanisms ...
With the transformation of resistance states, RRAM devices can complete the data storage process based on '0' or '1'.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found