hyperparameter search requirements/gpt2 metric
See original GitHub issueI am trying to fine tune gpt2 to respond to certain prompts in a specific way, so I am training it on strings like prompt + “someDivider” + output
I have about 1300 samples of training data, and I wanted to use hyperparameter_search to pick decent hyperparameters. I’m not sure if this requires a validation set, and if it does, what metric do I have to put? Do I even need a metric?
I’m also not sure what to do with the output of hyperparameter_search
I’ve tried doing research but I haven’t really gotten far on this issue. I am relatively new to training AI’s.
- ray/raytune: @richardliaw, @amogkam
- gpt2: @patrickvonplaten, @LysandreJik
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup,AutoTokenizer, DataCollatorForLanguageModeling, AutoConfig, Trainer, TrainingArguments,AutoModelForCausalLM
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv
from datasets import load_dataset,load_metric
import io
from google.colab import files
#version of gpt we use
model_version = 'gpt2-medium'
#create the dataset
raw_datasets = load_dataset('csv', data_files=['train.csv'])
print(raw_datasets)
print(raw_datasets["train"][1])
#initialize tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained(model_version)
tokenizer.pad_token = tokenizer.unk_token #prevents error where there is no token.
#function called by trainer to initialize model. I'm doing it this way so the hyperparameters can be tuned
def model_init():
return GPT2LMHeadModel.from_pretrained(model_version)
#helper for tokenizing everything
def tokenize_function(examples):
return tokenizer(examples["triplet"], truncation=True)
#tokenize all our data
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
#gets rid of original string data
tokenized_datasets=tokenized_datasets.remove_columns(["triplet"])
print(tokenized_datasets)
print(tokenized_datasets["train"]["input_ids"][1])
#collate data
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
#training args (you can control hyperprarameters from here)
training_args = TrainingArguments(output_dir="Finetuned",
overwrite_output_dir=True,
prediction_loss_only=True, #TODO if you are gonna do a validation set and compute metrics, this must be False but since we arent, I set to true
)
trainer = Trainer(
model_init=model_init,
args=training_args,
train_dataset=tokenized_datasets["train"],
data_collator=data_collator,
tokenizer=tokenizer,
)
# Automatically finds good hyperparameters, you can pass arguments into it but idk if I want to mess with them rn
trainer.hyperparameter_search()
trainer.train()
trainer.save_model()
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (1 by maintainers)
Top Results From Across the Web
Hyperparameter Search with Transformers and Ray Tune
A guest blog post by Richard Liaw from the Anyscale team. With cutting edge research implementations, thousands of trained models easily ...
Read more >Hyperparameter Search for HuggingFace Transformer Models
In this article, we will explore how to perform hyperparameter search for pre-trained HuggingFace transformer models, making use of Weights & Biases Sweeps....
Read more >Guide to fine-tuning Text Generation models: GPT-2, GPT-Neo ...
One point I wanted to discuss is that I haven't played at all with the hyperparameters. Add to that the prompt engineering methodology,...
Read more >Hyperparameter Optimization for Transformers: A guide
In this blog post, we'll show that basic grid search is not the most optimal, and in fact, the hyperparameters we choose can...
Read more >Overview of hyperparameter tuning | Vertex AI - Google Cloud
What hyperparameter tuning optimizes; How Vertex AI gets your metrics ... To use grid search, all parameters must be of type INTEGER ,...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Ya I figured it out, thank you
I believe so