Unable to use predict() for sentence classification task
See original GitHub issueI have been trying to use the predict method for making predictions on a data sample using pre-trained XLNet model that has been fine-tuned on a small sample of the data specific to my task. Upon execution, the IndexError is thrown saying the list index is out of range (see below).
Code Snippet:
model = ClassificationModel('bert', 'bert-base-uncased', use_cuda=True)
model.train_model(train_df)
valdn_set = valdn_df.drop("labels", axis=1).values.tolist()
pred_labels, _ = model.predict(valdn_set)
Error:
Running loss: 0.709356Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 32768.0 | 0/1 [00:00<?, ?it/s]
Running loss: 0.677126Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 16384.0 | 8/748 [00:06<10:31, 1.17it/s]
Running loss: 0.457629Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8192.0 | 25/748 [00:21<10:11, 1.18it/s]
Running loss: 0.558106Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0 | 104/748 [01:27<09:05, 1.18it/s]
Current iteration: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 748/748 [10:32<00:00, 1.18it/s]
Epoch: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [10:33<00:00, 633.39s/it]
Training of bert model complete. Saved to outputs/.
Traceback (most recent call last):
File "zero_threshold.py", line 83, in <module>
main()
File "zero_threshold.py", line 64, in main
pred_labels, model_outputs = model.predict(valdn_set)
File "/.local/lib/python3.6/site-packages/simpletransformers/classification/classification_model.py", line 821, in predict
eval_examples = [InputExample(i, text[0], text[1], 0) for i, text in enumerate(to_predict)]
File "/.local/lib/python3.6/site-packages/simpletransformers/classification/classification_model.py", line 821, in <listcomp>
eval_examples = [InputExample(i, text[0], text[1], 0) for i, text in enumerate(to_predict)]
IndexError: list index out of range
As per the documentation, predict()
expects a list of strings to be sent to the model, which I did convert my dataframe column to as in the above code. I looked into code lines #812 to #820 (attached below) in simpletransformers/classification/classification_model.py
to understand what was causing the issue and I noticed that the code was failing due to the code behavior upon isinstance()
being evaluated.
simpletransformers/classification/classification_model.py:
if multi_label:
eval_examples = [ InputExample(i, text, None, [0 for i in range(self.num_labels)]) for i, text in enumerate(to_predict) ]
else:
if isinstance(to_predict[0], list):
eval_examples = [InputExample(i, text[0], text[1], 0) for i, text in enumerate(to_predict)]
else:
eval_examples = [InputExample(i, text, None, 0) for i, text in enumerate(to_predict)]
From what I understood from the existing code: the first element of the to_predict
variable (a list) is checked to see if it is a list. If it is a list, then it implies that the data is representative of a sentence classification task where the data being fed to the model is just a single sentence. If it is not a list, then it indicates data representing a sentence pair classification task where the data being fed to the model is more than a single sentence. Am I correct in understanding the logic behind the evaluation of the condition based on isinstance()
?
If the above explanation I provided is correct, then I think the statements to be executed based on the if condition need to be reversed as follows:
if isinstance(to_predict[0], list):
''' do when True (indicating data for sentence classification) '''
eval_examples = [InputExample(i, text, None, 0) for i, text in enumerate(to_predict)]
else:
''' do when False (indicating data for sentence pair classification) '''
eval_examples = [InputExample(i, text[0], text[1], 0) for i, text in enumerate(to_predict)]
Am I correct in thinking this through or am I missing something thatβs implied? Please help. Thanks in advance.
Issue Analytics
- State:
- Created 4 years ago
- Comments:11 (4 by maintainers)
Top GitHub Comments
Yes, all set - predict() has been working, I just had to set βuse_cached_eval_featuresβ to False. Thanks!
Yeah, this can happen if cached features are used.
Starting to wonder if cache should be turned off by default. I think I kept it on by default before implementing multiprocessing for feature conversion as it could take hours to convert. Maybe caching is more trouble than it is worth now (at least when turned on by default).