Feat/open instruct example (#381)

* added open-instruct

* fixed hooks

* GRPO

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
joesharratt1229 2025-03-17 22:20:11 +00:00 committed by GitHub
parent 1511c5e301
commit 9234aa77bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 629 additions and 0 deletions

View file

@ -0,0 +1,34 @@
# Reasoning Gym Open-Instruct Example
This example demonstrates how to use [open-instruct](https://github.com/allenai/open-instruct) with **reasoning-gym** for training language models through reinforcement learning.
## Environment Setup
Before running the training script, you may want to set the following environment variables:
```bash
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
```
- `CUDA_VISIBLE_DEVICES=0` specifies which GPUs to use (in this case, GPUs 0 )
- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` enables dynamic GPU memory allocation
## Running the Training Script
To start training, run:
```bash
bash grpo_config.sh
```
This will train a DeepSeek-R1-Distill-Qwen-1.5B model using GRPO (Group Relative Policy Optimization) on reasoning tasks.
## Key Features
- Uses vLLM for efficient inference
- Supports multi-GPU training
- Implements reward functions for both answer correctness and response formatting
- Includes evaluation during training
- Supports model checkpointing and pushing to Hugging Face Hub
- Integrates with Weights & Biases for experiment tracking

View file

@ -0,0 +1,34 @@
exp_name="0302_qwen2.5_math_grpo_fast1_${RANDOM}"
python src/grpo_trainer.py \
--model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--beta 0.0 \
--num_unique_prompts_rollout 16 \
--num_samples_per_prompt_rollout 8 \
--output_dir ./open_instruct_checkpoints/$exp_name \
--save_freq 40 \
--kl_estimator kl3 \
--learning_rate 5e-7 \
--max_token_length 256 \
--max_prompt_token_length 128 \
--response_length 256 \
--pack_length 512 \
--stop_strings '"</answer>"' \
--chat_template_name r1_simple_chat_postpend_think \
--temperature 1.0 \
--total_episodes 100000 \
--deepspeed_stage 3 \
--per_device_train_batch_size 1 \
--num_mini_batches 1\
--num_learners_per_node 1 \
--num_epochs 1 \
--vllm_tensor_parallel_size 1 \
--vllm_num_engines 1 \
--vllm_enforce_eager true \
--lr_scheduler_type linear \
--seed 1 \
--num_evals 100 \
--dataset_name chain_sum \
--gradient_checkpointing \
--with_tracking \
--single_gpu_mode false \
--gather_whole_model false

View file

View file

@ -0,0 +1,457 @@
import json
import os
import shutil
import threading
import time
from argparse import Namespace
from dataclasses import asdict, dataclass
from queue import Queue
from typing import Callable, List
import numpy as np
import pandas as pd
import ray
import torch
from open_instruct.ground_truth_utils import soft_format_reward_func
from open_instruct.grpo_fast import (
Args,
ModelGroup,
PolicyTrainerRayProcess,
Timer,
calculate_runtime_args,
collate_fn,
vllm_generate_thread,
)
from open_instruct.model_utils import ModelConfig, push_folder_to_hub
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
from open_instruct.utils import ArgumentParserPlus, get_wandb_tags, is_beaker_job, maybe_get_beaker_config
from open_instruct.vllm_utils2 import create_vllm_engines
from ray.util.placement_group import placement_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, PreTrainedTokenizer
from utils import ReasoningGymDataset, list_preserving_collate, pack_sequences
from vllm import SamplingParams
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
@dataclass
class GRPOArgs(Args):
dataset_name: str = "chain_sum"
seed: int = 42
size: int = 10000
eval_seed: int = 42
eval_size: int = 100
def data_preparation_thread(
reward_fn: Callable,
inference_results_Q: Queue,
packed_sequences_Q: Queue,
queries_prompt_Q: Queue,
args: Args,
tokenizer: PreTrainedTokenizer,
num_training_steps: int,
reasoning_gym_dataset: ReasoningGymDataset,
):
for training_step in range(1, num_training_steps + 1):
# Get next batch of prompts and responses
observations = queries_prompt_Q.get()
queries, items = observations
if args.num_samples_per_prompt_rollout > 1:
queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)]
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
responses, finish_reasons = inference_results_Q.get()
for i in range(len(finish_reasons)):
if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id:
responses[i].append(tokenizer.eos_token_id)
with Timer("📦 [Data Preparation Thread] Packing sequences"):
packed_sequences = pack_sequences(
queries=queries,
responses=responses,
pack_length=args.pack_length,
pad_token_id=tokenizer.pad_token_id,
)
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons)
with Timer("💰 [Data Preparation Thread] Calculating rewards"):
scores, reward_metrics = reward_fn(decoded_responses, items, reasoning_gym_dataset)
with Timer("🎆 [Data Preparation Thread] Calculating advantages"):
# Calculate advantages
scores = np.array(scores)
print(f"[Data Preparation Thread] {len(scores)=}")
scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout)
global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
global_mean_grouped_rewards = np.repeat(
global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
)
global_std_grouped_rewards = scores_per_prompt.std(axis=-1)
global_std_grouped_rewards = np.repeat(
global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
)
global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8)
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
# and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed)
lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32)
lookup_advantages[1:] = global_advantages
packed_advantages = [
torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32)
for packed_mask in packed_sequences.response_masks
]
packed_sequences.advantages = packed_advantages
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
B = (
len(packed_sequences.query_responses) // args.world_size
) # essentially doing `drop_last=True`, which is fine.
collated_data = []
for i in range(args.world_size):
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)]
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)]
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
# Shuffle the batch and collate the data
b_inds = np.random.permutation(len(per_device_packed_query_responses))
collated_query_responses = []
collated_attention_masks = []
collated_position_ids = []
collated_response_masks = []
collated_advantages = []
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
micro_range = b_inds[j : j + args.per_device_train_batch_size]
collated_query_responses.append(
collate_fn(
[per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id
)
)
collated_attention_masks.append(
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0)
)
collated_position_ids.append(
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0)
)
collated_response_masks.append(
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0)
)
collated_advantages.append(
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
)
collated_data.append(
{
"collated_query_responses": collated_query_responses,
"collated_attention_masks": collated_attention_masks,
"collated_position_ids": collated_position_ids,
"collated_advantages": collated_advantages,
"collated_response_masks": collated_response_masks,
}
)
sequence_lengths = np.array([len(response) for response in responses])
metrics = {
"scores": np.array(scores).mean(),
"val/sequence_lengths": sequence_lengths.mean(),
"val/sequence_lengths_min": sequence_lengths.min(),
"val/sequence_lengths_max": sequence_lengths.max(),
"val/stop_rate": stop_rate,
**reward_metrics,
}
if args.save_traces:
traces = {
"scores": scores.tolist(),
"finish_reasons": finish_reasons,
"responses": responses,
"queries": queries,
"ground_truths": items["answer"],
"datasets": reasoning_gym_dataset.dataset_name,
"training_step": training_step,
**reward_metrics,
}
os.makedirs(args.output_dir, exist_ok=True)
with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f:
json.dump(traces, f)
f.write("\n")
packed_sequences_Q.put(
{
"packed_sequences": packed_sequences, # for debugging purposes
"collated_data": collated_data,
"metrics": metrics,
"responses_count": len(responses),
"num_new_tokens": num_new_tokens,
"B": B,
}
)
def reward_fn(responses, items, reasoning_gym_dataset):
metrics = {}
formatted_responses = [extract_answer(response) for response in responses]
correctness_rewards = [
reasoning_gym_dataset.data.score_answer(formatted_response, entry)
for formatted_response, entry in zip(formatted_responses, items)
]
format_rewards = soft_format_reward_func(responses)
rewards = correctness_rewards + format_rewards
metrics["correctness_rewards"] = correctness_rewards
metrics["format_rewards"] = format_rewards
metrics["rewards"] = rewards
return rewards, metrics
def main(args: Args, model_config: ModelConfig, reward_fn: Callable):
calculate_runtime_args(args)
# Setup experiment tracking and seeds
all_configs = {}
beaker_config = None
if is_beaker_job():
beaker_config = maybe_get_beaker_config()
all_configs.update(vars(beaker_config))
all_configs.update(**asdict(args), **asdict(model_config))
if args.with_tracking:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=all_configs,
name=args.run_name,
save_code=True,
tags=[args.exp_name] + get_wandb_tags(),
)
# Setup tokenizer and get datasets
tokenizer_revision = (
model_config.model_revision if model_config.tokenizer_revision is None else model_config.tokenizer_revision
)
tokenizer_name = (
model_config.tokenizer_name if model_config.tokenizer_name is not None else model_config.model_name_or_path
)
if tokenizer_revision != model_config.model_revision:
# Warn user if tokenizer and model use different revisions; this is an unusual
# use case.
warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different
from the model revision `{model_config.model_revision}`."""
print(warning)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision)
reasoning_gym_dataset = ReasoningGymDataset(
args.dataset_name, args.seed, args.size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
)
eval_dataset = ReasoningGymDataset(
args.dataset_name, args.eval_seed, args.eval_size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
)
bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node]
pg = placement_group(bundles, strategy="PACK")
ray.get(pg.ready())
inits = []
policy_group = ModelGroup(
pg,
PolicyTrainerRayProcess,
args.num_learners_per_node,
args.single_gpu_mode,
)
wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
for model in policy_group.models
)
max_len = args.max_prompt_token_length + args.response_length
vllm_engines = create_vllm_engines(
args.vllm_num_engines,
args.vllm_tensor_parallel_size,
args.vllm_enforce_eager,
model_config.model_name_or_path,
model_config.model_revision,
args.seed,
args.vllm_enable_prefix_caching,
max_len,
args.vllm_gpu_memory_utilization,
args.single_gpu_mode,
pg=pg if args.single_gpu_mode else None,
)
ray.get(inits)
print("======== ✅ all models and vLLM engines initialized =========")
ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models])
print("======== ✅ model update group setup successfully =========")
# Setup training
generation_config = SamplingParams(
temperature=args.temperature,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=args.num_samples_per_prompt_rollout,
stop=args.stop_strings,
)
eval_generation_config = SamplingParams(
temperature=0.0,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=1, # since we are doing greedy sampling, don't need to generate more
stop=args.stop_strings,
)
iter_dataloader = DataLoader(
reasoning_gym_dataset, batch_size=args.num_unique_prompts_rollout, collate_fn=list_preserving_collate
)
inference_results_Q = Queue(maxsize=1)
param_prompt_Q = Queue(maxsize=1)
evaluation_inference_results_Q = Queue(maxsize=1)
packed_sequences_Q = Queue(maxsize=1)
queries_prompt_Q = Queue(maxsize=1)
num_eval_samples = 50
eval_prompt_token_ids = None
eval_ground_truths = None
if eval_dataset is not None:
model_eval_input, eval_items = next(iter(eval_dataset))
eval_prompt_token_ids = model_eval_input
resume_training_step = 1
thread = threading.Thread(
target=vllm_generate_thread,
args=(
vllm_engines,
generation_config,
eval_generation_config,
inference_results_Q,
param_prompt_Q,
args.num_training_steps,
eval_prompt_token_ids,
evaluation_inference_results_Q,
args.eval_freq,
resume_training_step,
),
)
thread.start()
print("======== ✅ vllm generate thread starts =========")
packing_thread = threading.Thread(
target=data_preparation_thread,
args=(
reward_fn,
inference_results_Q,
packed_sequences_Q,
queries_prompt_Q,
args,
tokenizer,
args.num_training_steps,
reasoning_gym_dataset,
),
)
packing_thread.start()
print("======== ✅ data preparation thread starts =========")
model_inputs, items = next(iter(iter_dataloader))
queries_prompt_Q.put((model_inputs, items))
param_prompt_Q.put((None, model_inputs))
episode = 0
num_total_tokens = 0
start_time = time.time()
for training_step in range(resume_training_step, args.num_training_steps + 1):
episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
if training_step != 1:
model_inputs, items = next(iter(iter_dataloader))
with Timer("🔄 Loading weights using shared memory"):
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
queries_prompt_Q.put((model_inputs, items))
param_prompt_Q.put((None, model_inputs))
with Timer("[Main Thread] 📦 Getting packed sequences from thread"):
packed_data = packed_sequences_Q.get()
data_thread_metrics = packed_data["metrics"]
B = packed_data["B"]
collated_data = packed_data["collated_data"]
num_total_tokens += packed_data["num_new_tokens"]
if B == 0:
print("[Main Thread] 🤡 After packing, there is not enough data to train")
continue
# ------------------------------------------------------------------------------------------------
# Train the model
with Timer("[Main Thread] 🗡️ Training"):
metrics_list: List[dict[str, float]] = ray.get(
[
policy_group.models[i].train.remote(
**collated_data[i],
pad_token_id=tokenizer.pad_token_id,
num_mini_batches=args.num_mini_batches,
)
for i in range(args.world_size)
]
)
average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]}
metrics = {
"episode": episode,
"training_step": training_step,
"val/num_total_tokens": num_total_tokens,
"epoch": episode / len(reasoning_gym_dataset),
"tokens_per_second": num_total_tokens / (time.time() - start_time),
**data_thread_metrics,
**average_metrics,
}
if args.save_freq > 0 and training_step % args.save_freq == 0:
with Timer("[Main Thread] 🗡️ Saving model"):
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
print(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
thread.join()
print("======== ✅ vllm generate thread ends =========")
packing_thread.join()
print("======== ✅ data preparation thread ends =========")
ray.shutdown()
if (
args.try_auto_save_to_beaker
and is_beaker_job()
and len(beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")
accelerator = Namespace()
accelerator.is_main_process = True # hack
if args.push_to_hub:
print("Pushing model to hub")
push_folder_to_hub(
accelerator,
args.output_dir,
args.hf_repo_id,
args.hf_repo_revision,
)
if __name__ == "__main__":
parser = ArgumentParserPlus((GRPOArgs, ModelConfig))
args, model_config = parser.parse_args_into_dataclasses()
assert isinstance(args, GRPOArgs)
assert isinstance(model_config, ModelConfig)
main(args, model_config, reward_fn)

View file

@ -0,0 +1,104 @@
from typing import List
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
import reasoning_gym
from reasoning_gym.utils import SYSTEM_PROMPTS
def list_preserving_collate(batch):
"""
Custom collate function that preserves lists instead of converting to tensors.
"""
token_ids = [item[0] for item in batch]
items = [item[1] for item in batch]
return token_ids, items
class ReasoningGymDataset(Dataset):
def __init__(self, dataset_name, seed, size, tokenizer, developer_role, developer_prompt):
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
self.tokenizer = tokenizer
self.developer_role = developer_role
self.developer_prompt = developer_prompt
def __len__(self):
return self.data.size
def __getitem__(self, index):
chat_message = [{"role": self.developer_role, "content": self.developer_prompt}]
item = self.data[index]
chat_message.append({"role": "user", "content": item["question"]})
prompt_text = self.tokenizer.apply_chat_template(chat_message, tokenize=True, add_generation_prompt=True)
prompt_text = [token for token in prompt_text if token != self.tokenizer.pad_token_id]
return prompt_text, item
def pack_sequences(
queries: List[List[int]],
responses: List[List[int]],
pack_length: int,
pad_token_id: int,
) -> "PackedSequences":
# assert padding token does not exist in queries and responses
query_responses = []
attention_masks = []
response_masks = []
num_actions = []
packed_seq_lens = []
cur_data = []
cur_response_mask = []
cur_num_actions = []
cur_packed_seq_lens = []
cur_attention_mask = []
offset = 0
for i in range(len(queries)):
query = queries[i]
response = responses[i]
# remove padding (but using vllm so this should not be needed, but just in case)
query = [t for t in query if t != pad_token_id]
response = [t for t in response if t != pad_token_id]
query_response = query + response
if len(query_response) + len(cur_data) > pack_length:
query_responses.append(cur_data)
response_masks.append(cur_response_mask)
attention_masks.append(cur_attention_mask)
num_actions.append(cur_num_actions)
packed_seq_lens.append(cur_packed_seq_lens)
cur_data = []
cur_response_mask = []
cur_attention_mask = []
cur_num_actions = []
cur_packed_seq_lens = []
offset = i
cur_data.extend(query_response)
cur_num_actions.append(len(response))
cur_packed_seq_lens.append(len(query_response))
cur_response_mask.extend([0 for _ in range(len(query))] + [i + 1 for _ in range(len(response))])
cur_attention_mask.extend([i + 1 - offset for _ in range(len(query_response))])
if len(cur_data) > 0:
query_responses.append(cur_data)
response_masks.append(cur_response_mask)
attention_masks.append(cur_attention_mask)
num_actions.append(cur_num_actions)
packed_seq_lens.append(cur_packed_seq_lens)
attention_masks_list = [torch.tensor(t) for t in attention_masks]
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
return PackedSequences(
query_responses=[torch.tensor(t) for t in query_responses],
attention_masks=attention_masks_list,
position_ids=[reset_position_ids(t.unsqueeze(0)).squeeze(0) for t in attention_masks_list],
response_masks=[torch.tensor(t) for t in response_masks],
original_responses=responses,
num_actions=[torch.tensor(t) for t in num_actions],
packed_seq_lens=[torch.tensor(t) for t in packed_seq_lens],
)