Question: Integrated Gradient w/ Embedded Categorical Data
See original GitHub issueHi Everyone,
Question:
How can I apply integrated gradient to a dataset with numerical and embedded categorical data?
I am somewhat of a beginner with pytorch and the available resources are just not clicking with my use case. The ultimate goal is for me to plot the feature importance of a model, but I am stuck on calculating the attribution. Any help or guidance would be much appreciated.
What I’ve reviewed:
- Multimodal_VQA_Captum_Insights tutorial
- BERT tutorials
- https://github.com/pytorch/captum/issues/282
(These resources all have very different data structures(images/sentences) and are confusing for a beginner to translate to an easier tabular numerical/categorical dataset)
My Problem:
Model:
(all_embeddings): ModuleList(
(0): Embedding(3, 2)
(1): Embedding(2, 1)
(2): Embedding(2, 1)
(3): Embedding(2, 1)
)
(embedding_dropout): Dropout(p=0.4, inplace=False)
(batch_norm_num): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layers): Sequential(
(0): Linear(in_features=11, out_features=200, bias=True)
(1): ReLU(inplace=True)
(2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.4, inplace=False)
(4): Linear(in_features=200, out_features=100, bias=True)
(5): ReLU(inplace=True)
(6): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.4, inplace=False)
(8): Linear(in_features=100, out_features=50, bias=True)
(9): ReLU(inplace=True)
(10): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): Dropout(p=0.4, inplace=False)
(12): Linear(in_features=50, out_features=2, bias=True)
)
)
Categorical Data Example:
tensor([[0, 0, 1, 1],
[2, 0, 0, 1],
[0, 0, 1, 0],
[0, 0, 0, 0],
[2, 0, 1, 1]])
Numerical Data Example
tensor([[6.1900e+02, 4.2000e+01, 2.0000e+00, 0.0000e+00, 1.0000e+00, 1.0135e+05],
[6.0800e+02, 4.1000e+01, 1.0000e+00, 8.3808e+04, 1.0000e+00, 1.1254e+05],
[5.0200e+02, 4.2000e+01, 8.0000e+00, 1.5966e+05, 3.0000e+00, 1.1393e+05],
[6.9900e+02, 3.9000e+01, 1.0000e+00, 0.0000e+00, 2.0000e+00, 9.3827e+04],
[8.5000e+02, 4.3000e+01, 2.0000e+00, 1.2551e+05, 1.0000e+00, 7.9084e+04]])
Output Data Example
tensor([1, 0, 1, 0, 0])
My Failing Attempt at Attribution
interpretable_embedding = configure_interpretable_embedding_layer(model, 'all_embeddings')
cat_input_embedding = interpretable_embedding.indices_to_embeddings(categorical_train_data).unsqueeze(0)
#I received an error here "NotImplementedError"
ig = IntegratedGradients(model)
ig_attr_train = ig.attribute(inputs=(numerical_train_data, categorical_train_data), baselines=(numerical_train_data * 0.0, cat_input_embedding), target=train_outputs, n_steps=50)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:9 (5 by maintainers)
Top Results From Across the Web
The gradient of neural networks w.r.t one-hot encoded inputs
Suppose we trained a neural network f(x) with x one-hot encoded. Now I want to evaluate the importance of each character based on...
Read more >Tensorflow 2.0 Tutorial on Categorical Features Embedding
A comprehensive guide to categorical features embedding using Tensorflow 2.0 and a practical demo on how to train a neural network with it....
Read more >Survey on categorical data for neural networks - Gale
This survey investigates current techniques for representing qualitative data for use as input to neural networks. Techniques for using qualitative data in ...
Read more >Entity Embeddings of Categorical Variables - arXiv Vanity
In this paper we show how to use the entity embedding method to automatically learn the representation of categorical features in multi- ...
Read more >Categorical Embedding and Transfer Learning
The words/tokens of any language are categorical variables. Machine Learning algorithms are devoted to working with numbers so we have to ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Awesome, that is much cleaner. I was planning on refactoring once i understood it, but you’ve nailed it here. Thanks so so much @NarineK!
yeah, I think we can clean things up and make more modular with something like this:
Here is all you need for interpretability:
I didn’t specify baselines. Feel free to specify it too.