Why not accumulate loss and then take derivative in MAML?
See original GitHub issueWhy do you not do this:
def inner_loop2():
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
meta_loss = 0
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
#qry_loss.backward()
meta_loss += qry_loss
qry_losses = sum(qry_losses) / task_num
qry_losses.backward()
meta_opt.step()
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
instead of what you have:
def inner_loop1():
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (8 by maintainers)
Top Results From Across the Web
Why not accumulate query loss and then take derivative in ...
@stackoverflowuser2010 I'm very familiar with MAML at this point. Honestly, I just read the original paper. I think I've read it a couple...
Read more >How to Train MAML(Model-Agnostic Meta-Learning)
MAML reduces the second-order derivative cost by completely ignoring it. This could impair the final generalize performance in some cases. Solution: Derivative- ...
Read more >How to train your MAML - arXiv
If the annealing is not used, we found that the final loss might be higher than with the original formulation. Second Order Derivative...
Read more >An Interactive Introduction to Model-Agnostic Meta-Learning ...
Interactive introduction to model-agnostic meta-learning (MAML), a research field that attempts to equip conventional machine learning architectures with ...
Read more >How to train your MAML - OpenReview
In this paper, we propose various modifications to MAML that not only ... of the losses from different step; 2) anneal the second...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
The answer to your stack overflow question is the right answer. For any x and theta,
grad(sum(x), theta) = sum(grad(x, theta))
, or in latex:No, what I mean is why not use
.item()
or.data
.Why
.detach()
in particular