TypeError in forward() while tring to use ClassifierExplainer
See original GitHub issueHi all,
Trying to use this dashboard to explain a PyTorch NN binary classifier model, but I am encountering a TypeError in my forward() function when trying to call the ClassifierExplainer function on my data.
My NN script (Skorch wrapped) which is working and fits the model:
`def get_skorch_classifier(): X_train_m = X_train.astype(np.float32) y_train_m = y_train.astype(np.float32)
X_train_df = pd.DataFrame(X_train_m, columns=X.columns)
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.layer_1 = nn.Linear(298, 60)
self.layer_2 = nn.Linear(60, 60)
self.layer_3 = nn.Linear(60, 60)
self.layer_4 = nn.Linear(60, 60)
self.layer_out = nn.Linear(60, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
self.batchnorm1 = nn.BatchNorm1d(60, momentum=0.2)
self.batchnorm2 = nn.BatchNorm1d(60, momentum=0.2)
self.batchnorm3 = nn.BatchNorm1d(60, momentum=0.2)
self.batchnorm4 = nn.BatchNorm1d(60, momentum=0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
x = self.relu(self.layer_1(inputs))
x = self.batchnorm1(x)
x = self.dropout(x)
x = self.relu(self.layer_2(x))
x = self.batchnorm2(x)
x = self.dropout(x)
x = self.relu(self.layer_3(x))
x = self.batchnorm3(x)
x = self.dropout(x)
x = self.relu(self.layer_4(x))
x = self.batchnorm4(x)
#x = self.dropout(x)
x = self.layer_out(x)
#x = self.sigmoid(x)
return x
model = NeuralNetBinaryClassifier(MyModule, max_epochs=10, lr=0.01, optimizer=optim.Adam)
model.fit(X_train_m, torch.FloatTensor(y_train_m))
return model, X_train_df, y_train_m
model, Xm_df, ym = get_skorch_classifier()`
the next lines where the error ocures: `explainer = ClassifierExplainer(model, pd.DataFrame(X_test, columns=X.columns), y_test)
ExplainerDashboard(explainer, mode=‘inline’).run(port=8051)`
The Error:
WARNING: Parameter shap=‘guess’, but failed to guess the type of shap explainer to use for NeuralNetBinaryClassifier. Defaulting to the model agnostic shap.KernelExplainer (shap=‘kernel’). However this will be slow, so if your model is compatible with e.g. shap.TreeExplainer or shap.LinearExplainer then pass shap=‘tree’ or shap=‘linear’! WARNING: For shap=‘kernel’, shap interaction values can unfortunately not be calculated! Note: shap values for shap=‘kernel’ normally get calculated against X_background, but paramater X_background=None, so setting X_background=shap.sample(X, 50)… Generating self.shap_explainer = shap.KernelExplainer(model, X, link=‘identity’) Provided model function fails when applied to the provided data set.
TypeError Traceback (most recent call last) <ipython-input-24-dce008cfed2f> in <module>() ----> 1 explainer = ClassifierExplainer(model, pd.DataFrame(X_test, columns=X.columns), y_test) 2 ExplainerDashboard(explainer, mode=‘inline’).run(port=8051)
10 frames /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument ‘temp’
Where ‘temp’ is a column’s name in the X_test:
<html> <body>| temp | cond | ph | do | turb | tds | District_Adilabad, Andhra Pradesh | District_Agra, Uttar Pradesh | District_Agra, Uttaranchal | District_Ahmadabad, Gujarat | District_Ajmer, Rajasthan | District_Alappuzha, Kerala | District_Allahabad, Uttar Pradesh | District_Amravati, Maharashtra | District_Amravati, Tamil Nadu | District_Amritsar, Punjab | District_Anand, Gujarat | District_Anantapur, Andhra Pradesh | District_Anantnag (Kashmir South), Jammu & Kashmir | District_Angul, Orissa | District_Anuppur, Madhya Pradesh | District_Ariyalur, Tamil Nadu | District_Auraiya, Uttar Pradesh | District_Aurangabad, Maharashtra | District_Bagalkot, Karnataka | District_Balaghat, Madhya Pradesh | District_Banas Kantha, Gujarat | District_Bangalore Rural, Karnataka | District_Bangalore Urban, Karnataka | District_Banswara, Gujarat | District_Barddhaman, West Bengal | District_Barpeta, Assam | District_Barwani, Madhya Pradesh | District_Belgaum, Karnataka | District_Bellary, Andhra Pradesh | District_Bellary, Karnataka | District_Betul, Madhya Pradesh | District_Bhadrak, Orissa | District_Bharuch, Gujarat | District_Bhojpur, Bihar | … | District_Surendranagar, Gujarat | District_Tehri Garhwal, Uttaranchal | District_Thane, Maharashtra | District_Thanjavur, Tamil Nadu | District_Theni, Kerala | District_Thiruvananthapuram, Kerala | District_Thoothukudi, Tamil Nadu | District_Thrissur, Kerala | District_Tinsukia, Assam | District_Tiruchchirappalli, Tamil Nadu | District_Tirunelveli Kattabo, Tamil Nadu | District_Udaipur, Rajasthan | District_Ujjain, Madhya Pradesh | District_Una, Himachal Pradesh | District_Unnao, Uttar Pradesh | District_Uttar Kannand, Karnataka | District_Uttarkashi, Uttaranchal | District_Vadodara, Gujarat | District_Valsad, Dadra & Nagar Haveli | District_Valsad, Daman & Diu | District_Valsad, Gujarat | District_Varanasi, Uttar Pradesh | District_Vellore, Tamil Nadu | District_Villupuram, Pondicherry | District_Vizianagaram, Andhra Pradesh | District_Warangal, Andhra Pradesh | District_Wardha, Maharashtra | District_Wayanad, Kerala | District_West Nimar, Madhya Pradesh | District_West Tripura, Tripura | District_Yamuna Nagar, Haryana | Type1032_Canal | Type1032_Creek | Type1032_Drain | Type1032_Lake | Type1032_Other | Type1032_Pond | Type1032_River | Type1032_Tank | Type1032_Well – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – | – -0.257550 | 6.345038 | -0.711990 | -1.536070 | -0.437670 | -1.000000 | -0.083247 | -0.037257 | -0.031204 | -0.139416 | -0.019103 | -0.084715 | -0.036262 | -0.052261 | -0.033099 | -0.045244 | -0.081299 | -0.033099 | -0.004931 | -0.092921 | -0.023137 | -0.034186 | -0.024665 | -0.075814 | -0.052261 | -0.053646 | -0.045244 | -0.095935 | -0.072329 | -0.050597 | -0.017783 | -0.008542 | -0.035925 | -0.09398 | -0.045513 | -0.085006 | -0.038544 | -0.046574 | -0.079924 | -0.031976 | … | -0.073341 | -0.016358 | 4.85578 | -0.033099 | -0.011028 | -0.078682 | -0.060508 | -0.057817 | -0.025153 | -0.092787 | -0.090766 | -0.064241 | -0.080231 | -0.008542 | -0.04246 | -0.047611 | -0.004931 | -0.076138 | -0.036597 | -0.0541 | -0.08544 | -0.065 | -0.046311 | -0.031204 | -0.027907 | -0.047867 | -0.07113 | -0.053646 | -0.053874 | -0.053418 | -0.019103 | -0.021501 | 7.528068 | -0.061512 | -0.166331 | -0.121274 | -0.02702 | -2.809743 | -0.045513 | -0.221192 -0.257550 | 0.008256 | -0.531209 | -0.027831 | -0.429356 | -1.000000 | -0.083247 | -0.037257 | -0.031204 | -0.139416 | -0.019103 | -0.084715 | -0.036262 | -0.052261 | -0.033099 | -0.045244 | -0.081299 | -0.033099 | -0.004931 | -0.092921 | -0.023137 | -0.034186 | -0.024665 | -0.075814 | -0.052261 | -0.053646 | -0.045244 | -0.095935 | -0.072329 | -0.050597 | -0.017783 | -0.008542 | -0.035925 | -0.09398 | -0.045513 | -0.085006 | -0.038544 | -0.046574 | -0.079924 | -0.031976 | … | -0.073341 | -0.016358 | -0.20594 | -0.033099 | -0.011028 | -0.078682 | -0.060508 | -0.057817 | -0.025153 | -0.092787 | -0.090766 | -0.064241 | -0.080231 | -0.008542 | -0.04246 | -0.047611 | -0.004931 | -0.076138 | -0.036597 | -0.0541 | -0.08544 | -0.065 | -0.046311 | -0.031204 | -0.027907 | -0.047867 | -0.07113 | -0.053646 | -0.053874 | -0.053418 | -0.019103 | -0.021501 | -0.132836 | -0.061512 | -0.166331 | -0.121274 | -0.02702 | 0.355904 | -0.045513 | -0.221192 0.448392 | -0.114126 | 0.734259 | 0.596268 | -0.449309 | 0.129979 | 12.012508 | -0.037257 | -0.031204 | -0.139416 | -0.019103 | -0.084715 | -0.036262 | -0.052261 | -0.033099 | -0.045244 | -0.081299 | -0.033099 | -0.004931 | -0.092921 | -0.023137 | -0.034186 | -0.024665 | -0.075814 | -0.052261 | -0.053646 | -0.045244 | -0.095935 | -0.072329 | -0.050597 | -0.017783 | -0.008542 | -0.035925 | -0.09398 | -0.045513 | -0.085006 | -0.038544 | -0.046574 | -0.079924 | -0.031976 | … | -0.073341 | -0.016358 | -0.20594 | -0.033099 | -0.011028 | -0.078682 | -0.060508 | -0.057817 | -0.025153 | -0.092787 | -0.090766 | -0.064241 | -0.080231 | -0.008542 | -0.04246 | -0.047611 | -0.004931 | -0.076138 | -0.036597 | -0.0541 | -0.08544 | -0.065 | -0.046311 | -0.031204 | -0.027907 | -0.047867 | -0.07113 | -0.053646 | -0.053874 | -0.053418 | -0.019103 | -0.021501 | -0.132836 | -0.061512 | -0.166331 | -0.121274 | -0.02702 | 0.355904 | -0.045513 | -0.221192 -0.257550 | -0.059310 | 0.915041 | -0.599921 | -0.442658 | -1.000000 | -0.083247 | -0.037257 | -0.031204 | -0.139416 | -0.019103 | -0.084715 | -0.036262 | -0.052261 | -0.033099 | -0.045244 | -0.081299 | -0.033099 | -0.004931 | -0.092921 | -0.023137 | -0.034186 | -0.024665 | -0.075814 | -0.052261 | -0.053646 | -0.045244 | -0.095935 | -0.072329 | -0.050597 | -0.017783 | -0.008542 | -0.035925 | -0.09398 | -0.045513 | -0.085006 | -0.038544 | -0.046574 | -0.079924 | -0.031976 | … | -0.073341 | -0.016358 | -0.20594 | -0.033099 | -0.011028 | -0.078682 | -0.060508 | -0.057817 | -0.025153 | -0.092787 | -0.090766 | -0.064241 | -0.080231 | -0.008542 | -0.04246 | -0.047611 | -0.004931 | -0.076138 | -0.036597 | -0.0541 | -0.08544 | -0.065 | -0.046311 | -0.031204 | -0.027907 | -0.047867 | -0.07113 | -0.053646 | -0.053874 | -0.053418 | -0.019103 | -0.021501 | -0.132836 | -0.061512 | -0.166331 | -0.121274 | -0.02702 | 0.355904 | -0.045513 | -0.221192 -1.198806 | -0.177749 | 0.372697 | 0.440244 | 0.814380 | -0.492370 | -0.083247 | -0.037257 | -0.031204 | -0.139416 | -0.019103 | -0.084715 | 27.576829 | -0.052261 | -0.033099 | -0.045244 | -0.081299 | -0.033099 | -0.004931 | -0.092921 | -0.023137 | -0.034186 | -0.024665 | -0.075814 | -0.052261 | -0.053646 | -0.045244 | -0.095935 | -0.072329 | -0.050597 | -0.017783 | -0.008542 | -0.035925 | -0.09398 | -0.045513 | -0.085006 | -0.038544 | -0.046574 | -0.079924 | -0.031976 | … | -0.073341 | -0.016358 | -0.20594 | -0.033099 | -0.011028 | -0.078682 | -0.060508 | -0.057817 | -0.025153 | -0.092787 | -0.090766 | -0.064241 | -0.080231 | -0.008542 | -0.04246 | -0.047611 | -0.004931 | -0.076138 | -0.036597 | -0.0541 | -0.08544 | -0.065 | -0.046311 | -0.031204 | -0.027907 | -0.047867 | -0.07113 | -0.053646 | -0.053874 | -0.053418 | -0.019103 | -0.021501 | -0.132836 | -0.061512 | -0.166331 | -0.121274 | -0.02702 | 0.355904 | -0.045513 | -0.221192
5 rows × 298 columns
</body> </html>Any ideas how fix this?
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
Top GitHub Comments
Hi, try passing
shap='skorch'
to the explainer? (it should autodetect this, but doesn’t somehow)…Ah, great, so will just add
skorch.classifier.NeuralNetBinaryClassifier
to the skorch list… thanks!