DeepLiftShap example doesnt work on Imagenet data
See original GitHub issueMy code is:
import numpy as np
import torch
from captum.attr import (
GradientShap,
DeepLift,
DeepLiftShap,
IntegratedGradients,
LayerConductance,
NeuronConductance,
NoiseTunnel,
)
model = models.resnet50(pretrained=True).eval()
print(test_images.shape)
print(background.shape)
dl = DeepLiftShap(model)
attributions, delta = dl.attribute(inputs=test_images, baselines=background)
print('DeepLiftSHAP Attributions:', attributions)
print('Convergence Delta:', delta)
🐛 Log
torch.Size([10, 3, 224, 224]) torch.Size([30, 3, 224, 224])
TypeError Traceback (most recent call last) <ipython-input-8-33adec3e9593> in <module> 18 19 dl = DeepLiftShap(model) —> 20 attributions, delta = dl.attribute(inputs=test_images, baselines=background) 21 print(‘DeepLiftSHAP Attributions:’, attributions) 22 print(‘Convergence Delta:’, delta)
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func) 764 Literal[True, False], return_convergence_delta 765 ), –> 766 custom_attribution_func=custom_attribution_func, 767 ) 768 if return_convergence_delta:
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func) 320 self.model, (inputs, baselines), expanded_target, input_base_additional_args 321 ) –> 322 gradients = self.gradient_func(wrapped_forward_func, inputs,) 323 if custom_attribution_func is None: 324 attributions = tuple(
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args) 94 with torch.autograd.set_grad_enabled(True): 95 # runs forward pass —> 96 outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args) 97 assert outputs[0].numel() == 1, ( 98 “Target not provided when necessary, cannot”
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args) 490 forward_func_args = signature(forward_func).parameters 491 if len(forward_func_args) == 0: –> 492 output = forward_func() 493 return output if target is None else _select_targets(output, target) 494
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in forward_fn() 354 ) -> Callable: 355 def forward_fn(): –> 356 return _run_forward(forward_func, inputs, target, additional_forward_args) 357 358 if hasattr(forward_func, “device_ids”):
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/captum/attr/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args) 501 *(*inputs, *additional_forward_args) 502 if additional_forward_args is not None –> 503 else inputs 504 ) 505 return _select_targets(output, target)
~/miniconda3/envs/advanced_ml/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs) 491 result = self._slow_forward(*input, **kwargs) 492 else: –> 493 result = self.forward(*input, **kwargs) 494 for hook in self._forward_hooks.values(): 495 hook_result = hook(self, input, result)
TypeError: forward() takes 2 positional arguments but 3 were given
Can some one help me out?
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (5 by maintainers)
Hi @giangnguyen2412 , your usage looks generally correct, I don’t see any obvious issues. Just double check that the
target=3
parameter corresponds to the output you want to interpret for each input. This is often chosen to be the true class index, but can be used for other classes to see pixel importance for an alternate / incorrect decision, this FAQ answer provides more context.When visualizing, It may also be worthwhile to look at both positive / negative or absolute value of attributions in addition to only positive signed results.
For more examples of using other methods, you can take a look at this tutorial.
Hi @giangnguyen2412 , from a closer look at the stack trace, I think it’s possible you might be using torch version 1.1, which is not supported by Captum. Can you double check the version of PyTorch that you are using with
print(torch.__version__)
?If it is prior to version 1.2, can you upgrade to a newer version of PyTorch and try again?