Add minimal unsloth GRPO example

This commit is contained in:
Oliver 2025-02-20 22:03:40 +00:00
parent e25973b118
commit 8bd18f5387

View file

@ -0,0 +1,154 @@
"""
Minimal example using Unsloth and vLLM for efficient GRPO training of a model with (Q)LoRA.
Adapted from Unsloth's documentation examples.
"""
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
import argparse
import re
from torch.utils.data import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
import reasoning_gym
from reasoning_gym import utils
class ReasoningGymDataset(Dataset):
def __init__(self, dataset_name, seed, size, tokenizer, developer_prompt, developer_role="system") -> None:
super().__init__()
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
self.tokenizer = tokenizer
self.developer_role = developer_role
self.developer_prompt = developer_prompt
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item["question"]
chat = []
if self.developer_role is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "metadata": item}
def get_model_and_tokenizer(model_id, max_seq_length, lora_rank, quantize, gpu_memory_utilization) -> tuple:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=max_seq_length,
max_lora_rank=lora_rank,
gpu_memory_utilization=gpu_memory_utilization,
load_in_4bit=quantize,
fast_inference=True,
)
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
model = FastLanguageModel.get_peft_model(
model, r=lora_rank, target_modules=target_modules, lora_alpha=lora_rank, use_gradient_checkpointing="unsloth"
)
return model, tokenizer
class GRPOTrainerCustom(GRPOTrainer):
def __init__(self, model, args: GRPOConfig, tokenizer, train_dataset: Dataset):
super().__init__(
model,
reward_funcs=[self._accuracy_reward, self._format_reward],
args=args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
def _format_reward(self, completions, **kwargs):
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
matches = [re.match(regex, completion, flags=re.DOTALL) for completion in completions]
return [1.0 if match else 0.0 for match in matches]
def _accuracy_reward(self, completions, metadata, **kwargs):
answers = [utils.extract_answer(completion) for completion in completions]
return [self.train_dataset.data.score_answer(answer, entry=obj) for (answer, obj) in zip(answers, metadata)]
def train(model, tokenizer, dataset, training_args):
trainer = GRPOTrainerCustom(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset,
)
trainer.train()
def main(args):
model, tokenizer = get_model_and_tokenizer(
args.model_id, args.max_seq_length, args.lora_rank, args.quantize, args.gpu_memory_utilization
)
developer_prompt = utils.SYSTEM_PROMPTS["DeepSeekZero"]
dataset = ReasoningGymDataset(args.dataset_name, args.dataset_seed, args.dataset_size, tokenizer, developer_prompt)
training_args = GRPOConfig(
output_dir="outputs",
use_vllm=True,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=args.num_generations,
num_train_epochs=args.train_epochs,
save_steps=100,
max_grad_norm=0.1,
)
train(model, tokenizer, dataset, training_args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
parser.add_argument("--dataset-name", type=str)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--lora-rank", type=int, default=64)
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--num-generations", type=int, default=8)
parser.add_argument("--train-epochs", type=int, default=1)
parser.add_argument("--dataset-seed", type=int, default=42)
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.7)
args = parser.parse_args()
main(args)