Issue training Longformer
See original GitHub issueHello, apologies if this is the wrong place to ask for help, I’m currently trying to fine-tune longformer on a text classification task. My script is below.
When I use for param in model.longformer.encoder.parameters(): param.requires_grad = False
to not train the encoder layer but just the classification head and the embeddings, training works as expected. When I don’t freeze the encoder layers, the model doesn’t train at all, and when I try to do inference on it, it gives constant output, regardless of what data I put in. I’ve been reading all the papers to find what I have done wrong, can anyone point me in the right direction? Thank you so much for your help! Tom
import pandas as pd
from transformers import AdamW, LongformerTokenizerFast, TrainingArguments, Trainer,LongformerForSequenceClassification
import torch
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy using sklearn's function
acc = accuracy_score(labels, preds)
f1 = f1_score(labels,preds)
precision = precision_score(labels,preds)
recall = recall_score(labels,preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
class SupremeDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def main():
# Setup logging:
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logging.info("*** Data processing ***")
logging.info("importing data")
data_train = pd.read_csv("../../../shared/benchmarking/supreme_train.csv").dropna()
data_val = pd.read_csv("../../../shared/benchmarking/supreme_val.csv").dropna()
tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096')
logging.info("tokenizing data")
train_encodings = tokenizer(list(data_train.content_decode),truncation=True,padding=True,return_tensors="pt")
val_encodings = tokenizer(list(data_val.content_decode),truncation=True,padding=True,return_tensors="pt")
train_encodings['global_attention_mask'] = torch.zeros_like(train_encodings['input_ids'])
val_encodings['global_attention_mask'] = torch.zeros_like(val_encodings['input_ids'])
train_encodings['global_attention_mask'][train_encodings['input_ids']==0] = 1
val_encodings['global_attention_mask'][val_encodings['input_ids']==0] = 1
train_labels = data_train.label.tolist()
val_labels = data_val.label.tolist()
logging.info("creating datasets")
train_dataset = SupremeDataset(train_encodings, train_labels)
val_dataset = SupremeDataset(val_encodings, val_labels)
logging.info("*** Training ***")
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=3, # total number of training epochs
per_device_train_batch_size=1, # batch size per device during training
per_device_eval_batch_size=1, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./logs', # directory for storing logs
logging_steps=200,
do_eval=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
evaluation_strategy = "steps",
)
logging.info("loading model")
model = LongformerForSequenceClassification.from_pretrained('allenai/longformer-base-4096')
for param in model.longformer.encoder.parameters():
param.requires_grad = False
logging.info("loading trainer")
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset,
compute_metrics = compute_metrics # evaluation dataset
)
logging.info("starting training")
trainer.train()
torch.save(model, 'supremecourt_fullmodel.pt')
if __name__ == "__main__":
main()
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:7 (3 by maintainers)
Top Results From Across the Web
Longformer - Hugging Face
Constructs a Longformer tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. This tokenizer has been trained to treat spaces ...
Read more >Train a Longformer for Detecting Hyperpartisan News Content
This blog proposes a solution to the pressing problem of hyperpartisan news. If extreme left/extreme right articles can be identified and ...
Read more >Text classification with the Longformer - Jesus Leal
The authors pretrained two variations of the model a base model (with 12 layers) and a large model (30 layers). Both models were...
Read more >Longformer — The Long Document Transformer | Jarvislabs.ai
Understanding LongFormer and its building blocks like Sliding window attention ... My notes from the training procedure, training objective, ...
Read more >arXiv:2210.05529v1 [cs.CL] 11 Oct 2022
Longformer models and partially pre-trained. HATs. In several long document downstream ... tational problem, researchers have introduced ef-.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hmm, at first sounds to me this sounds like the classic overfitting to one class, I’m not so sure whether this is due to using Longformer.
Some tips:
Hope this is somewhat helpful!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.