question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

DeepLiftShap example doesnt work on Imagenet data

See original GitHub issue

My code is:

import numpy as np

import torch

from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel,
)

model = models.resnet50(pretrained=True).eval()
print(test_images.shape)
print(background.shape)

dl = DeepLiftShap(model)
attributions, delta = dl.attribute(inputs=test_images, baselines=background)
print('DeepLiftSHAP Attributions:', attributions)
print('Convergence Delta:', delta)

🐛 Log

torch.Size([10, 3, 224, 224]) torch.Size([30, 3, 224, 224])


TypeError Traceback (most recent call last) <ipython-input-8-33adec3e9593> in <module> 18 19 dl = DeepLiftShap(model) —> 20 attributions, delta = dl.attribute(inputs=test_images, baselines=background) 21 print(‘DeepLiftSHAP Attributions:’, attributions) 22 print(‘Convergence Delta:’, delta)

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func) 764 Literal[True, False], return_convergence_delta 765 ), –> 766 custom_attribution_func=custom_attribution_func, 767 ) 768 if return_convergence_delta:

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func) 320 self.model, (inputs, baselines), expanded_target, input_base_additional_args 321 ) –> 322 gradients = self.gradient_func(wrapped_forward_func, inputs,) 323 if custom_attribution_func is None: 324 attributions = tuple(

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args) 94 with torch.autograd.set_grad_enabled(True): 95 # runs forward pass —> 96 outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args) 97 assert outputs[0].numel() == 1, ( 98 “Target not provided when necessary, cannot”

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args) 490 forward_func_args = signature(forward_func).parameters 491 if len(forward_func_args) == 0: –> 492 output = forward_func() 493 return output if target is None else _select_targets(output, target) 494

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in forward_fn() 354 ) -> Callable: 355 def forward_fn(): –> 356 return _run_forward(forward_func, inputs, target, additional_forward_args) 357 358 if hasattr(forward_func, “device_ids”):

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args) 501 *(*inputs, *additional_forward_args) 502 if additional_forward_args is not None –> 503 else inputs 504 ) 505 return _select_targets(output, target)

~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs) 491 result = self._slow_forward(*input, **kwargs) 492 else: –> 493 result = self.forward(*input, **kwargs) 494 for hook in self._forward_hooks.values(): 495 hook_result = hook(self, input, result)

TypeError: forward() takes 2 positional arguments but 3 were given

Can some one help me out?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:13 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
vivekmigcommented, Oct 5, 2020

Hi @giangnguyen2412 , your usage looks generally correct, I don’t see any obvious issues. Just double check that the target=3 parameter corresponds to the output you want to interpret for each input. This is often chosen to be the true class index, but can be used for other classes to see pixel importance for an alternate / incorrect decision, this FAQ answer provides more context.

When visualizing, It may also be worthwhile to look at both positive / negative or absolute value of attributions in addition to only positive signed results.

For more examples of using other methods, you can take a look at this tutorial.

1reaction
vivekmigcommented, Oct 1, 2020

Hi @giangnguyen2412 , from a closer look at the stack trace, I think it’s possible you might be using torch version 1.1, which is not supported by Captum. Can you double check the version of PyTorch that you are using with print(torch.__version__)?

If it is prior to version 1.2, can you upgrade to a newer version of PyTorch and try again?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Interpretable Neural Networks - Towards Data Science
In a previous post, I discussed interpreting complex machine learning models using shap values. To summarize, for a particular feature, ...
Read more >
Image examples — SHAP latest documentation
These examples explain machine learning models applied to image data. They are all generated from Jupyter notebooks available on GitHub.
Read more >
Interpretability part 3: opening the black box with LIME and ...
In the paper, the authors have proposed a novel model-agnostic way of approximating the Shapely Values called Kernel SHAP (LIME + Shapely Values) ......
Read more >
Opportunities and Challenges in Explainable Artificial ... - arXiv
Here, each row starts with an original image from ImageNet followed by explanation map generated by gradient algorithms such as 1) saliency maps ......
Read more >
Visualizing the Impact of Feature Attribution Baselines
a convolutional neural network designed for the ImageNet dataset ... Why doesn't the attribution for “killer whale” highlight the black ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found