question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Why not accumulate loss and then take derivative in MAML?

See original GitHub issue

Why do you not do this:

def inner_loop2():
    n_inner_iter = 5
    inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

    qry_losses = []
    qry_accs = []
    meta_opt.zero_grad()
    meta_loss = 0
    for i in range(task_num):
        with higher.innerloop_ctx(
            net, inner_opt, copy_initial_weights=False
        ) as (fnet, diffopt):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            # higher is able to automatically keep copies of
            # your network's parameters as they are being updated.
            for _ in range(n_inner_iter):
                spt_logits = fnet(x_spt[i])
                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                diffopt.step(spt_loss)

            # The final set of adapted parameters will induce some
            # final loss and accuracy on the query dataset.
            # These will be used to update the model's meta-parameters.
            qry_logits = fnet(x_qry[i])
            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
            qry_losses.append(qry_loss.detach())
            qry_acc = (qry_logits.argmax(
                dim=1) == y_qry[i]).sum().item() / querysz
            qry_accs.append(qry_acc)

            # Update the model's meta-parameters to optimize the query
            # losses across all of the tasks sampled in this batch.
            # This unrolls through the gradient steps.
            #qry_loss.backward()
            meta_loss += qry_loss

    qry_losses = sum(qry_losses) / task_num
    qry_losses.backward()
    meta_opt.step()
    qry_accs = 100. * sum(qry_accs) / task_num
    i = epoch + float(batch_idx) / n_train_iter
    iter_time = time.time() - start_time

instead of what you have:

def inner_loop1():
    n_inner_iter = 5
    inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

    qry_losses = []
    qry_accs = []
    meta_opt.zero_grad()
    for i in range(task_num):
        with higher.innerloop_ctx(
            net, inner_opt, copy_initial_weights=False
        ) as (fnet, diffopt):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            # higher is able to automatically keep copies of
            # your network's parameters as they are being updated.
            for _ in range(n_inner_iter):
                spt_logits = fnet(x_spt[i])
                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                diffopt.step(spt_loss)

            # The final set of adapted parameters will induce some
            # final loss and accuracy on the query dataset.
            # These will be used to update the model's meta-parameters.
            qry_logits = fnet(x_qry[i])
            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
            qry_losses.append(qry_loss.detach())
            qry_acc = (qry_logits.argmax(
                dim=1) == y_qry[i]).sum().item() / querysz
            qry_accs.append(qry_acc)

            # Update the model's meta-parameters to optimize the query
            # losses across all of the tasks sampled in this batch.
            # This unrolls through the gradient steps.
            qry_loss.backward()

    meta_opt.step()
    qry_losses = sum(qry_losses) / task_num
    qry_accs = 100. * sum(qry_accs) / task_num
    i = epoch + float(batch_idx) / n_train_iter
    iter_time = time.time() - start_time

https://github.com/facebookresearch/higher/blob/e45c1a059e39a16fa016d37bc15397824c65547c/examples/maml-omniglot.py#L130


https://stackoverflow.com/questions/62394411/why-not-accumulate-query-loss-and-then-take-derivative-in-maml-with-pytorch-and

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

3reactions
egrefencommented, Jun 17, 2020

The answer to your stack overflow question is the right answer. For any x and theta, grad(sum(x), theta) = sum(grad(x, theta)), or in latex:

\nabla_\theta \sum_i f_\theta(x_i) = \sum_i \nabla_\theta f_\theta(x_i)
0reactions
renesax14commented, Jun 30, 2020

The answer to your stack overflow question is the right answer. For any x and theta, grad(sum(x), theta) = sum(grad(x, theta)), or in latex:

\nabla_\theta \sum_i f_\theta(x_i) = \sum_i \nabla_\theta f_\theta(x_i)

also, why don’t you do:

meta_loss += qrt_loss.detach()

and instead you track them all in a list?

No particular reason. I guess this allows you to plot task losses separately if you wanted, or measure mean/variance rather than just mean?

No, what I mean is why not use .item() or .data.

Why .detach() in particular

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why not accumulate query loss and then take derivative in ...
@stackoverflowuser2010 I'm very familiar with MAML at this point. Honestly, I just read the original paper. I think I've read it a couple...
Read more >
How to Train MAML(Model-Agnostic Meta-Learning)
MAML reduces the second-order derivative cost by completely ignoring it. This could impair the final generalize performance in some cases. Solution: Derivative- ...
Read more >
How to train your MAML - arXiv
If the annealing is not used, we found that the final loss might be higher than with the original formulation. Second Order Derivative...
Read more >
An Interactive Introduction to Model-Agnostic Meta-Learning ...
Interactive introduction to model-agnostic meta-learning (MAML), a research field that attempts to equip conventional machine learning architectures with ...
Read more >
How to train your MAML - OpenReview
In this paper, we propose various modifications to MAML that not only ... of the losses from different step; 2) anneal the second...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found