question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Interpreting Fine-tuned Bert model using LIME

See original GitHub issue

Thanks for this amazing work. I am trying to interpret Fine-tuned BERT model using Transformer framework. It seems there is tokenization issue, when I try to use LIME with BERT. Here is the error that i am getting:

Traceback (most recent call last):
  File "src/predict.py", line 351, in <module>
    exp = explainer.explain_instance(s, prediction.predictor, num_features=6)
  File "/home/ramesh/.virtualenvs/transformer-env/lib/python3.6/site-packages/lime/lime_text.py", line 417, in explain_instance
    distance_metric=distance_metric)
  File "/home/ramesh/.virtualenvs/transformer-env/lib/python3.6/site-packages/lime/lime_text.py", line 484, in __data_labels_distances
    labels = classifier_fn(inverse_data)
  File "src/predict.py", line 297, in predictor
    input_ids, input_mask, segment_ids = self.convert_text_to_features(text)
  File "src/predict.py", line 135, in convert_text_to_features
    tokens_a = self.tokenizer.tokenize(text_a)
  File "/home/ramesh/.virtualenvs/transformer-env/lib/python3.6/site-packages/transformers/tokenization_utils.py", line 649, in tokenize
    tokenized_text = split_on_tokens(added_tokens, text)
  File "/home/ramesh/.virtualenvs/transformer-env/lib/python3.6/site-packages/transformers/tokenization_utils.py", line 637, in split_on_tokens
    if sub_text not in self.added_tokens_encoder \
TypeError: unhashable type: 'list'

Here is my code:

def predictor(self, text):

        max_seq_length=128
        input_ids, input_mask, segment_ids = self.convert_text_to_features(text)
        self.model.to(self.device)

        with torch.no_grad():
            outputs = self.model(input_ids, input_mask, segment_ids)

        logits = outputs[0]
        logits = F.softmax(logits, dim=1)

        return logits.numpy()

def convert_text_to_features(self, text_a, text_b=None):

        features = []
        cls_token = self.tokenizer.cls_token
        sep_token = self.tokenizer.sep_token
        cls_token_at_end = False
        sequence_a_segment_id = 0
        sequence_b_segment_id = 1
        cls_token_segment_id = 1
        pad_token_segment_id = 0
        mask_padding_with_zero = True
        pad_token = 0
        tokens_a = self.tokenizer.tokenize(text_a)
        tokens_b = None

        self._truncate_seq_pair(tokens_a, self.max_seq_length - 2)

        tokens = tokens_a + [sep_token]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if tokens_b:
            tokens += tokens_b + [sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)


        tokens = [cls_token] + tokens
        segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
        padding_length = self.max_seq_length - len(input_ids)


        input_ids = input_ids + ([pad_token] * padding_length)
        input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

        assert len(input_ids) == self.max_seq_length
        assert len(input_mask) == self.max_seq_length
        assert len(segment_ids) == self.max_seq_length

        input_ids = torch.tensor([input_ids], dtype=torch.long).to(self.device)
        input_mask = torch.tensor([input_mask], dtype=torch.long).to(self.device)
        segment_ids = torch.tensor([segment_ids], dtype=torch.long).to(self.device)
        return input_ids, input_mask, segment_ids


if __name__ == '__main__':

    model_path = "models/mrpc"
    bert_model_class = "bert"
    prediction = Prediction(bert_model_class, model_path, lower_case=True, seq_length=128)
    label_names = [0, 1]
    explainer = LimeTextExplainer(class_names=label_names)
    train_df = pd.read_csv("data/train.tsv", sep = '\t')

    for example in train_df["string"]:
        exp = explainer.explain_instance(example, prediction.predictor, num_features=6)
        print(exp.as_list())

I have checked this issue356, but still i cannot figure out my problem.

Any leads will be appreciated.

Thank you 😃

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:12 (1 by maintainers)

github_iconTop GitHub Comments

20reactions
rameshjescommented, Jul 17, 2020

Thanks for your reply. I figured it out, if anyone is interested in interpreting BERT using LIME. Here is the correct example 😃

class Prediction:

    def __init__(self, bert_model_class, model_path, lower_case, seq_length):

        self.model, self.tokenizer, self.model_config = \
                    self.load_model(bert_model_class, model_path, lower_case)
        self.max_seq_length = seq_length
        self.device = "cpu"
        self.model.to("cpu")

    def load_model(self, bert_model_class, model_path, lower_case):

        config_class, model_class, tokenizer_class = MODEL_CLASSES[bert_model_class]
        config = config_class.from_pretrained(model_path)
        tokenizer = tokenizer_class.from_pretrained(model_path, do_lower_case=lower_case)
        model = model_class.from_pretrained(model_path, config=config)

        return model, tokenizer, config

    def predict_label(self, text_a, text_b):

        self.model.to(self.device)

        input_ids, input_mask, segment_ids = self.convert_text_to_features(text_a, text_b)
        with torch.no_grad():
            outputs = self.model(input_ids, segment_ids, input_mask)

        logits = outputs[0]
        logits = F.softmax(logits, dim=1)
        # print(logits)
        logits_label = torch.argmax(logits, dim=1)
        label = logits_label.detach().cpu().numpy()

        # print("logits label ", logits_label)
        logits_confidence = logits[0][logits_label]
        label_confidence_ = logits_confidence.detach().cpu().numpy()
        # print("logits confidence ", label_confidence_)

        return label, label_confidence_


    def _truncate_seq_pair(self, tokens_a, max_length):
        """Truncates a sequence pair in place to the maximum length."""

        # This is a simple heuristic which will always truncate the longer sequence
        # one token at a time. This makes more sense than truncating an equal percent
        # of tokens from each, since if one sequence is very short then each token
        # that's truncated likely contains more information than a longer sequence.
        while True:
            total_length = len(tokens_a)
            if total_length <= max_length:
                break
            if len(tokens_a) > max_length:
                tokens_a.pop()

    def convert_text_to_features(self, text_a, text_b=None):

        features = []
        cls_token = self.tokenizer.cls_token
        sep_token = self.tokenizer.sep_token
        cls_token_at_end = False
        sequence_a_segment_id = 0
        sequence_b_segment_id = 1
        cls_token_segment_id = 1
        pad_token_segment_id = 0
        mask_padding_with_zero = True
        pad_token = 0
        tokens_a = self.tokenizer.tokenize(text_a)
        tokens_b = None

        self._truncate_seq_pair(tokens_a, self.max_seq_length - 2)

        tokens = tokens_a + [sep_token]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if tokens_b:
            tokens += tokens_b + [sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)


        tokens = [cls_token] + tokens
        segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
        #
        # # Zero-pad up to the sequence length.
        padding_length = self.max_seq_length - len(input_ids)


        input_ids = input_ids + ([pad_token] * padding_length)
        input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

        assert len(input_ids) == self.max_seq_length
        assert len(input_mask) == self.max_seq_length
        assert len(segment_ids) == self.max_seq_length

        input_ids = torch.tensor([input_ids], dtype=torch.long).to(self.device)
        input_mask = torch.tensor([input_mask], dtype=torch.long).to(self.device)
        segment_ids = torch.tensor([segment_ids], dtype=torch.long).to(self.device)

        return input_ids, input_mask, segment_ids

    def predictor(self, text):

        examples = []
        print(text)
        for example in text:
            examples.append(self.convert_text_to_features(example))

        results = []
        for example in examples:

            with torch.no_grad():
                outputs = self.model(example[0], example[1], example[2])
            logits = outputs[0]
            logits = F.softmax(logits, dim = 1)
            results.append(logits.cpu().detach().numpy()[0])

        results_array = np.array(results)

        return results_array



if __name__ == '__main__':

    model_path = "models/mrpc"
    bert_model_class = "bert"
    prediction = Prediction(bert_model_class, model_path,
                                lower_case = True, seq_length = 512)
    label_names = [0, 1]
    explainer = LimeTextExplainer(class_names=label_names)
    train_df = pd.read_csv("data/train.tsv", sep = '\t')

    train_ls = train_df["string"].tolist()

    for example in train_ls:

        exp = explainer.explain_instance(example, prediction.predictor)
        words = exp.as_list()

5reactions
Elizabithi1-devcommented, Oct 18, 2020

what is “MODEL_CLASSES” in your code?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Applying LIME interpretation on my fine-tuned BERT for ...
Applying LIME interpretation on my fine-tuned BERT for sequence classification model? ; in <module> exp = explainer.explain_instance(example, ...
Read more >
Using LIME to explain the predictions from a BERT model, it ...
To be clear, I have a BERT model that I'm fine-tuning for a downstream binary classification task. I've frozen the BERT model itself...
Read more >
Interpreting BERT Models (Part 1) - Captum
In this notebook we demonstrate how to interpret Bert models using Captum library. In this particular case study we focus on a fine-tuned...
Read more >
BERT regression & LIME explainer - Hugging Face Forums
Greetings, I am looking to apply a LIME explainer to a fine-tuned BERT-model with a linear output layer. My training pipeline is vanilla, ......
Read more >
Interpreting Fine-tuned Bert model using LIME - Bountysource
I am trying to interpret Fine-tuned BERT model using Transformer framework. It seems there is tokenization issue, when I try to use LIME...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found