Inner loop incompatible with weight_norm
See original GitHub issueHi, thanks for your work on this library!
Using a weight normalized network in higher’s inner loop raises the following error:
load from omniglot.npy.
DB: train (1200, 20, 1, 28, 28) test (423, 20, 1, 28, 28)
Traceback (most recent call last):
File "maml-omniglot.py", line 271, in <module>
main()
File "maml-omniglot.py", line 108, in main
train(db, net, device, meta_opt, epoch, log)
File "maml-omniglot.py", line 146, in train
spt_logits = fnet(x_spt[i])
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/higher/patch.py", line 347, in _patched_forward
return self.boxed_forward(*args, **kwargs)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/higher/patch.py", line 288, in patched_forward
return true_forward(self, *args, **kwargs)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/higher/patch.py", line 288, in patched_forward
return true_forward(self, *args, **kwargs)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/functional.py", line 1370, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of device type cuda but got device type cpu for argument #3 'mat2' in call to _th_addmm
I can reproduce this by simply modifying the maml-omniglot example to weight_normalize the final linear layer (pasted below). The error only appears in the higher inner loop, I can evaluate the network on input data outside the inner loop with no error. I am on Ubuntu 16.0.4, Python 3.7.0, pytorch 1.3.0, and cuda 10.0.
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
"""
import argparse
import time
import typing
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import weight_norm
import higher
from support.omniglot_loaders import OmniglotNShot
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
device = torch.device('cuda')
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network that will be
# automatically monkey-patched by higher later.
# Before higher, models could *not* be created like this
# and the parameters needed to be manually updated and copied
# for the updates.
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
weight_norm(nn.Linear(64, args.n_way))).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
log = []
for epoch in range(100):
train(db, net, device, meta_opt, epoch, log)
test(db, net, device, epoch, log)
plot(log)
def train(db, net, device, meta_opt, epoch, log):
net.train()
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
net.train()
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:10 (4 by maintainers)
Top Results From Across the Web
A Correspondence Between Normalization Strategies ... - NCBI
We showed that two state-of-the-art normalization methods (BatchNorm and WeightNorm), as well as a new normalization algorithm inspired by synaptic scaling, ...
Read more >tfp.bijectors.MaskedAutoregressiveFlow | TensorFlow Probability
Notice that the inverse does not need a for-loop. This is because in the forward pass each calculation of shift and log_scale is...
Read more >ICLR 2022 Conference - OpenReview
... stabilizing closed-loop controller for a dynamical system. ... Inductive Bias of Multi-Channel Linear Convolutional Networks with Bounded Weight Norm.
Read more >2. PopART Python API - Graphcore Documents
If enabled, casts any tensor of unsupported data types to supported data types when lowering to ... Inner : Inner loop instrumentation, graph...
Read more >On Priors for Bayesian Neural Networks - eScholarship.org
The O(N) dependence has moved from the inner loop, the proposal step, ... by plo ing the KL divergence from the prior and...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This does help. I’ve read through the weight_norm code from pytorch, and you are correct that this is something which higher isn’t patching. We could write a hacky fix specifically for weight-norm, I think, but I would prefer a more general solution that caters to similar use cases. I will need to think through this properly and probably talk to some people from the pytorch team. I will attempt to look into this in the next two weeks, but it’s going require some effort.
Hello. I’ve returned from leave and allocated some time to look into this issue over the next two weeks. Hopefully we’ll make some progress and report back, or come back to you with questions.