Posing orthogonality condition on convolution weights in v0.3.0
See original GitHub issueHi,
I am struggling to impose orthogonality constraints on the weights of convolutional layers as was also asked in Issue 10, where we unfold the weight matrix (in, out, k ,k) as as matrix (in, out * k * k) on which we pose orthogonality. If I try to run the code from @lezcano’s answer, I get:
line 28, in _register_manifold tensor.copy_(X) RuntimeError: The size of tensor a (3) must match the size of tensor b (20) at non-singleton dimension 3
Trying to run the other answer from @bokveizen, I get:
raise InManifoldError(X, self) geotorch.exceptions.InManifoldError: Tensor not contained in FlattenedStiefel(n=180, k=40, triv=linalg_matrix_exp, transposed). Got: tensor([[[[ ...
I tried rewriting the in_manifold
check to run self.forward() first, but the code still gives RuntimeError.
What works is rewinding to the commit from a year ago, however, I would like to reproduce the same functionality in geotorch@v0.3.0.
Thanks for help!
Code from the answer of @bokveizen in Issue 10:
from numpy import prod
import torch
import torch.nn as nn
import geotorch
def size_flattened(size, dim):
size = list(size)
size_dim = size[dim]
size[dim] = 1
return (size_dim, prod(size))
class FlattenedStiefel(geotorch.Stiefel):
def __init__(self, size, triv="expm"):
# We asume that you want to flatten the dimensions [1:n]
# See the comment in forward for why we keep the dim=1 and the dim=0
super().__init__(size_flattened(size, 0), triv)
# size = (out, in, k, k) so we transpose it
self.size = size
# size = list(size)
# size[0], size[1] = size[1], size[0]
# self.size = tuple(size)
def forward(self, X):
# The weight of a CNN of with params (in, out, k ,k)
# is of size (out, in, k, k), so we transpose it before flattening it
# X = X.T
X = X.flatten(1)
X = super().forward(X)
X = X.view(self.size)
# return X.T
return X
def initialize_(self, X, check_in_manifold=True):
# X = X.T
X = X.flatten(1)
X = super().initialize_(X, check_in_manifold)
X = X.view(self.size)
# return X.T
return X
def sample(self, distribution="uniform", init_=None):
X = super().sample(distribution, init_)
X = X.view(self.size)
# return X.T
return X
def flattened_orthogonal(module, tensor_name="weight", triv="expm"):
return geotorch.constraints._register_manifold(module, tensor_name, FlattenedStiefel, triv)
layer = nn.Conv2d(20, 40, 3, 3) # Make the kernels orthogonal
flattened_orthogonal(layer, "weight")
print(layer)
W = layer.weight # W has size (40, 20, 3, 3)
# W = W.T.flatten(1) # W has size (20, 360) with orthogonal rows
# W = W.T # W has size (360, 20) with orthogonal columns
# # Check that W.T @ W = Id
# print(torch.allclose(W.T @ W, torch.eye(40), atol=1e-4))
W = W.flatten(1)
print(torch.allclose(W @ W.T, torch.eye(40), atol=1e-4))
Issue Analytics
- State:
- Created a year ago
- Comments:9 (6 by maintainers)
Top GitHub Comments
also, hi Simon! Hope you’re doing great in Belgium working with Absil 😃
Great that you made it work in the end! Closing this one