question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Detaching attributions from computation graph

See original GitHub issue

Hi! I’m interested in regularizing the model via some attribution-based loss, as described in a previous post. As a baseline, I would like to train a model without such regularization (i.e., using only task loss), but using the attributions to compute some evaluation metric.

The desired workflow for this baseline is as follows:

  1. Compute the gradient-based attributions (e.g., using IG) for the model, but without keeping the gradients used to obtain the attributions beyond Step 1. That is, all I want is the final attribution vector, before starting with an empty computation graph for Step 2.
  2. Perform a forward pass through the model, use the task labels to compute the loss, then backprop this loss only via gradients computed in Step 2. Crucially, the attributions from Step 1 are not used here in Step 2. In other words, when I perform the backward pass here, I don’t want there to be any connection to the computation in Step 1.
  3. Use the attributions obtained from Step 1 to compute the evaluation metric.

So far, I’ve tried: (a) doing detach() and with torch.no_grad() in Step 1, (b) perform Step 1 on a deepcopy of the model, and © removing Step 3. However, the train loss in Step 2 is somehow still being affected by the attribution computation from Step 1.

I’d appreciate any advice on how to resolve this. Thanks!

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
NarineKcommented, Oct 21, 2021

Thank you for the example, @aarzchan! Do you have Dropout or Batch Norm in the mode ? When you use Dropout, for instance, it is randomly choosing which neurons to drop out if your model runs in train mode. That’s why you might be seeing different losses. In the eval mode the randomization is turned off that’s why you don’t see the same effects. These are some contemplation that I have, I don’t know if you are using Dropout or batch norm.

1reaction
NarineKcommented, Oct 22, 2021

GradientShap uses randomization. It selects baseline randomly and in addition to that it also randomly selects data points between input and baseline that’s the only big difference compared to IG. Attribution results of GradientShap won’t be deterministic if the seeds aren’t fixed but it shouldn’t effect the loss. If during eval there is no randomization I wonder if model’s forward still depends on some seed value.

https://github.com/pytorch/captum/blob/master/captum/attr/_core/gradient_shap.py#L412

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to Detach specific components in the loss? - autograd
I'm a little confused how to detach certain model from the loss computation graph. If I have 3 models that generate an output:...
Read more >
Why by changing tensor using detach method make ...
When calling c = y.detach() you effectively detach c from the computation graph, while y remains attached. However, c shares the same data ......
Read more >
Evaluating Attribution for Graph Neural Networks
If multiple attributions are valid (e.g., a subgraph is present twice in a graph), we take the maximum attribution value of all possible...
Read more >
TensorFlow Graph Optimizations
Redundant computation removal through constant folding, CSE, ... Whole graph analysis to identify and remove hidden identity and other unnecessary ops (e.g..
Read more >
clone_module(module, memo=None) - learn2learn
Detaches all parameters/buffers of a previously cloned module from its computational graph. Note: detach works in-place, so it does not return a copy....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found