diff --git a/examples/open-instruct/README.md b/examples/open-instruct/README.md new file mode 100644 index 00000000..cda4ce5c --- /dev/null +++ b/examples/open-instruct/README.md @@ -0,0 +1,34 @@ +# Reasoning Gym Open-Instruct Example + +This example demonstrates how to use [open-instruct](https://github.com/allenai/open-instruct) with **reasoning-gym** for training language models through reinforcement learning. + +## Environment Setup + +Before running the training script, you may want to set the following environment variables: + +```bash +export CUDA_VISIBLE_DEVICES=0 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +``` + +- `CUDA_VISIBLE_DEVICES=0` specifies which GPUs to use (in this case, GPUs 0 ) +- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` enables dynamic GPU memory allocation + +## Running the Training Script + +To start training, run: + +```bash +bash grpo_config.sh +``` + +This will train a DeepSeek-R1-Distill-Qwen-1.5B model using GRPO (Group Relative Policy Optimization) on reasoning tasks. + +## Key Features + +- Uses vLLM for efficient inference +- Supports multi-GPU training +- Implements reward functions for both answer correctness and response formatting +- Includes evaluation during training +- Supports model checkpointing and pushing to Hugging Face Hub +- Integrates with Weights & Biases for experiment tracking diff --git a/examples/open-instruct/grpo_config.sh b/examples/open-instruct/grpo_config.sh new file mode 100755 index 00000000..3d574ec3 --- /dev/null +++ b/examples/open-instruct/grpo_config.sh @@ -0,0 +1,34 @@ +exp_name="0302_qwen2.5_math_grpo_fast1_${RANDOM}" +python src/grpo_trainer.py \ + --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --beta 0.0 \ + --num_unique_prompts_rollout 16 \ + --num_samples_per_prompt_rollout 8 \ + --output_dir ./open_instruct_checkpoints/$exp_name \ + --save_freq 40 \ + --kl_estimator kl3 \ + --learning_rate 5e-7 \ + --max_token_length 256 \ + --max_prompt_token_length 128 \ + --response_length 256 \ + --pack_length 512 \ + --stop_strings '""' \ + --chat_template_name r1_simple_chat_postpend_think \ + --temperature 1.0 \ + --total_episodes 100000 \ + --deepspeed_stage 3 \ + --per_device_train_batch_size 1 \ + --num_mini_batches 1\ + --num_learners_per_node 1 \ + --num_epochs 1 \ + --vllm_tensor_parallel_size 1 \ + --vllm_num_engines 1 \ + --vllm_enforce_eager true \ + --lr_scheduler_type linear \ + --seed 1 \ + --num_evals 100 \ + --dataset_name chain_sum \ + --gradient_checkpointing \ + --with_tracking \ + --single_gpu_mode false \ + --gather_whole_model false diff --git a/examples/open-instruct/src/__init__.py b/examples/open-instruct/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/open-instruct/src/grpo_trainer.py b/examples/open-instruct/src/grpo_trainer.py new file mode 100644 index 00000000..a8a86b90 --- /dev/null +++ b/examples/open-instruct/src/grpo_trainer.py @@ -0,0 +1,457 @@ +import json +import os +import shutil +import threading +import time +from argparse import Namespace +from dataclasses import asdict, dataclass +from queue import Queue +from typing import Callable, List + +import numpy as np +import pandas as pd +import ray +import torch +from open_instruct.ground_truth_utils import soft_format_reward_func +from open_instruct.grpo_fast import ( + Args, + ModelGroup, + PolicyTrainerRayProcess, + Timer, + calculate_runtime_args, + collate_fn, + vllm_generate_thread, +) +from open_instruct.model_utils import ModelConfig, push_folder_to_hub +from open_instruct.rl_utils2 import PackedSequences, reset_position_ids +from open_instruct.utils import ArgumentParserPlus, get_wandb_tags, is_beaker_job, maybe_get_beaker_config +from open_instruct.vllm_utils2 import create_vllm_engines +from ray.util.placement_group import placement_group +from torch.utils.data import DataLoader, Dataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, PreTrainedTokenizer +from utils import ReasoningGymDataset, list_preserving_collate, pack_sequences +from vllm import SamplingParams + +from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer + + +@dataclass +class GRPOArgs(Args): + dataset_name: str = "chain_sum" + seed: int = 42 + size: int = 10000 + eval_seed: int = 42 + eval_size: int = 100 + + +def data_preparation_thread( + reward_fn: Callable, + inference_results_Q: Queue, + packed_sequences_Q: Queue, + queries_prompt_Q: Queue, + args: Args, + tokenizer: PreTrainedTokenizer, + num_training_steps: int, + reasoning_gym_dataset: ReasoningGymDataset, +): + for training_step in range(1, num_training_steps + 1): + # Get next batch of prompts and responses + observations = queries_prompt_Q.get() + queries, items = observations + + if args.num_samples_per_prompt_rollout > 1: + queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)] + with Timer("🚀 [Data Preparation Thread] Getting response ids"): + responses, finish_reasons = inference_results_Q.get() + for i in range(len(finish_reasons)): + if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id: + responses[i].append(tokenizer.eos_token_id) + + with Timer("📦 [Data Preparation Thread] Packing sequences"): + packed_sequences = pack_sequences( + queries=queries, + responses=responses, + pack_length=args.pack_length, + pad_token_id=tokenizer.pad_token_id, + ) + num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) + + with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): + decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True) + stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons) + + with Timer("💰 [Data Preparation Thread] Calculating rewards"): + scores, reward_metrics = reward_fn(decoded_responses, items, reasoning_gym_dataset) + + with Timer("🎆 [Data Preparation Thread] Calculating advantages"): + # Calculate advantages + scores = np.array(scores) + print(f"[Data Preparation Thread] {len(scores)=}") + scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) + global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + global_mean_grouped_rewards = np.repeat( + global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0 + ) + global_std_grouped_rewards = scores_per_prompt.std(axis=-1) + global_std_grouped_rewards = np.repeat( + global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0 + ) + global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8) + + # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value + # and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed) + lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32) + lookup_advantages[1:] = global_advantages + packed_advantages = [ + torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) + for packed_mask in packed_sequences.response_masks + ] + packed_sequences.advantages = packed_advantages + + with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"): + B = ( + len(packed_sequences.query_responses) // args.world_size + ) # essentially doing `drop_last=True`, which is fine. + collated_data = [] + for i in range(args.world_size): + per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)] + per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)] + per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)] + per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)] + per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)] + + # Shuffle the batch and collate the data + b_inds = np.random.permutation(len(per_device_packed_query_responses)) + collated_query_responses = [] + collated_attention_masks = [] + collated_position_ids = [] + collated_response_masks = [] + collated_advantages = [] + for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size): + micro_range = b_inds[j : j + args.per_device_train_batch_size] + collated_query_responses.append( + collate_fn( + [per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id + ) + ) + collated_attention_masks.append( + collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0) + ) + collated_position_ids.append( + collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0) + ) + collated_response_masks.append( + collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0) + ) + collated_advantages.append( + collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0) + ) + collated_data.append( + { + "collated_query_responses": collated_query_responses, + "collated_attention_masks": collated_attention_masks, + "collated_position_ids": collated_position_ids, + "collated_advantages": collated_advantages, + "collated_response_masks": collated_response_masks, + } + ) + sequence_lengths = np.array([len(response) for response in responses]) + metrics = { + "scores": np.array(scores).mean(), + "val/sequence_lengths": sequence_lengths.mean(), + "val/sequence_lengths_min": sequence_lengths.min(), + "val/sequence_lengths_max": sequence_lengths.max(), + "val/stop_rate": stop_rate, + **reward_metrics, + } + + if args.save_traces: + traces = { + "scores": scores.tolist(), + "finish_reasons": finish_reasons, + "responses": responses, + "queries": queries, + "ground_truths": items["answer"], + "datasets": reasoning_gym_dataset.dataset_name, + "training_step": training_step, + **reward_metrics, + } + os.makedirs(args.output_dir, exist_ok=True) + with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f: + json.dump(traces, f) + f.write("\n") + packed_sequences_Q.put( + { + "packed_sequences": packed_sequences, # for debugging purposes + "collated_data": collated_data, + "metrics": metrics, + "responses_count": len(responses), + "num_new_tokens": num_new_tokens, + "B": B, + } + ) + + +def reward_fn(responses, items, reasoning_gym_dataset): + metrics = {} + formatted_responses = [extract_answer(response) for response in responses] + correctness_rewards = [ + reasoning_gym_dataset.data.score_answer(formatted_response, entry) + for formatted_response, entry in zip(formatted_responses, items) + ] + format_rewards = soft_format_reward_func(responses) + rewards = correctness_rewards + format_rewards + metrics["correctness_rewards"] = correctness_rewards + metrics["format_rewards"] = format_rewards + metrics["rewards"] = rewards + return rewards, metrics + + +def main(args: Args, model_config: ModelConfig, reward_fn: Callable): + calculate_runtime_args(args) + + # Setup experiment tracking and seeds + all_configs = {} + beaker_config = None + if is_beaker_job(): + beaker_config = maybe_get_beaker_config() + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(model_config)) + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + + # Setup tokenizer and get datasets + tokenizer_revision = ( + model_config.model_revision if model_config.tokenizer_revision is None else model_config.tokenizer_revision + ) + tokenizer_name = ( + model_config.tokenizer_name if model_config.tokenizer_name is not None else model_config.model_name_or_path + ) + if tokenizer_revision != model_config.model_revision: + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. + warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different + from the model revision `{model_config.model_revision}`.""" + print(warning) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision) + reasoning_gym_dataset = ReasoningGymDataset( + args.dataset_name, args.seed, args.size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"] + ) + eval_dataset = ReasoningGymDataset( + args.dataset_name, args.eval_seed, args.eval_size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"] + ) + + bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node] + pg = placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + inits = [] + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.num_learners_per_node, + args.single_gpu_mode, + ) + + wandb_url = wandb.run.get_url() if args.with_tracking else None + inits.extend( + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) + for model in policy_group.models + ) + max_len = args.max_prompt_token_length + args.response_length + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + args.vllm_enforce_eager, + model_config.model_name_or_path, + model_config.model_revision, + args.seed, + args.vllm_enable_prefix_caching, + max_len, + args.vllm_gpu_memory_utilization, + args.single_gpu_mode, + pg=pg if args.single_gpu_mode else None, + ) + ray.get(inits) + print("======== ✅ all models and vLLM engines initialized =========") + + ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models]) + print("======== ✅ model update group setup successfully =========") + + # Setup training + generation_config = SamplingParams( + temperature=args.temperature, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + n=args.num_samples_per_prompt_rollout, + stop=args.stop_strings, + ) + eval_generation_config = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + n=1, # since we are doing greedy sampling, don't need to generate more + stop=args.stop_strings, + ) + + iter_dataloader = DataLoader( + reasoning_gym_dataset, batch_size=args.num_unique_prompts_rollout, collate_fn=list_preserving_collate + ) + inference_results_Q = Queue(maxsize=1) + param_prompt_Q = Queue(maxsize=1) + evaluation_inference_results_Q = Queue(maxsize=1) + packed_sequences_Q = Queue(maxsize=1) + queries_prompt_Q = Queue(maxsize=1) + num_eval_samples = 50 + + eval_prompt_token_ids = None + eval_ground_truths = None + if eval_dataset is not None: + model_eval_input, eval_items = next(iter(eval_dataset)) + eval_prompt_token_ids = model_eval_input + resume_training_step = 1 + thread = threading.Thread( + target=vllm_generate_thread, + args=( + vllm_engines, + generation_config, + eval_generation_config, + inference_results_Q, + param_prompt_Q, + args.num_training_steps, + eval_prompt_token_ids, + evaluation_inference_results_Q, + args.eval_freq, + resume_training_step, + ), + ) + + thread.start() + print("======== ✅ vllm generate thread starts =========") + + packing_thread = threading.Thread( + target=data_preparation_thread, + args=( + reward_fn, + inference_results_Q, + packed_sequences_Q, + queries_prompt_Q, + args, + tokenizer, + args.num_training_steps, + reasoning_gym_dataset, + ), + ) + + packing_thread.start() + print("======== ✅ data preparation thread starts =========") + model_inputs, items = next(iter(iter_dataloader)) + queries_prompt_Q.put((model_inputs, items)) + param_prompt_Q.put((None, model_inputs)) + + episode = 0 + num_total_tokens = 0 + start_time = time.time() + for training_step in range(resume_training_step, args.num_training_steps + 1): + episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + + if training_step != 1: + model_inputs, items = next(iter(iter_dataloader)) + with Timer("🔄 Loading weights using shared memory"): + ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models]) + + queries_prompt_Q.put((model_inputs, items)) + param_prompt_Q.put((None, model_inputs)) + + with Timer("[Main Thread] 📦 Getting packed sequences from thread"): + packed_data = packed_sequences_Q.get() + data_thread_metrics = packed_data["metrics"] + B = packed_data["B"] + collated_data = packed_data["collated_data"] + num_total_tokens += packed_data["num_new_tokens"] + + if B == 0: + print("[Main Thread] 🤡 After packing, there is not enough data to train") + continue + + # ------------------------------------------------------------------------------------------------ + # Train the model + with Timer("[Main Thread] 🗡️ Training"): + metrics_list: List[dict[str, float]] = ray.get( + [ + policy_group.models[i].train.remote( + **collated_data[i], + pad_token_id=tokenizer.pad_token_id, + num_mini_batches=args.num_mini_batches, + ) + for i in range(args.world_size) + ] + ) + average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} + metrics = { + "episode": episode, + "training_step": training_step, + "val/num_total_tokens": num_total_tokens, + "epoch": episode / len(reasoning_gym_dataset), + "tokens_per_second": num_total_tokens / (time.time() - start_time), + **data_thread_metrics, + **average_metrics, + } + + if args.save_freq > 0 and training_step % args.save_freq == 0: + with Timer("[Main Thread] 🗡️ Saving model"): + checkpoint_dir = f"{args.output_dir}_checkpoints" + step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") + print(f"Saving model at step {training_step} to {step_dir}") + ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)]) + + print(f"Saving final model at step {training_step} to {args.output_dir}") + with Timer("[Main Thread] 🗡️ Saving model"): + ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)]) + + thread.join() + print("======== ✅ vllm generate thread ends =========") + packing_thread.join() + print("======== ✅ data preparation thread ends =========") + ray.shutdown() + + if ( + args.try_auto_save_to_beaker + and is_beaker_job() + and len(beaker_config.beaker_dataset_id_urls) > 0 + and args.output_dir.rstrip("/") != "/output" + ): + shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) + print("finished training") + + accelerator = Namespace() + accelerator.is_main_process = True # hack + if args.push_to_hub: + print("Pushing model to hub") + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((GRPOArgs, ModelConfig)) + args, model_config = parser.parse_args_into_dataclasses() + assert isinstance(args, GRPOArgs) + assert isinstance(model_config, ModelConfig) + main(args, model_config, reward_fn) diff --git a/examples/open-instruct/src/utils.py b/examples/open-instruct/src/utils.py new file mode 100644 index 00000000..ad6708ad --- /dev/null +++ b/examples/open-instruct/src/utils.py @@ -0,0 +1,104 @@ +from typing import List + +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +import reasoning_gym +from reasoning_gym.utils import SYSTEM_PROMPTS + + +def list_preserving_collate(batch): + """ + Custom collate function that preserves lists instead of converting to tensors. + """ + token_ids = [item[0] for item in batch] + items = [item[1] for item in batch] + return token_ids, items + + +class ReasoningGymDataset(Dataset): + def __init__(self, dataset_name, seed, size, tokenizer, developer_role, developer_prompt): + self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size) + self.tokenizer = tokenizer + self.developer_role = developer_role + self.developer_prompt = developer_prompt + + def __len__(self): + return self.data.size + + def __getitem__(self, index): + chat_message = [{"role": self.developer_role, "content": self.developer_prompt}] + item = self.data[index] + chat_message.append({"role": "user", "content": item["question"]}) + prompt_text = self.tokenizer.apply_chat_template(chat_message, tokenize=True, add_generation_prompt=True) + prompt_text = [token for token in prompt_text if token != self.tokenizer.pad_token_id] + return prompt_text, item + + +def pack_sequences( + queries: List[List[int]], + responses: List[List[int]], + pack_length: int, + pad_token_id: int, +) -> "PackedSequences": + # assert padding token does not exist in queries and responses + query_responses = [] + attention_masks = [] + response_masks = [] + num_actions = [] + packed_seq_lens = [] + cur_data = [] + cur_response_mask = [] + cur_num_actions = [] + cur_packed_seq_lens = [] + cur_attention_mask = [] + offset = 0 + + for i in range(len(queries)): + query = queries[i] + response = responses[i] + # remove padding (but using vllm so this should not be needed, but just in case) + query = [t for t in query if t != pad_token_id] + response = [t for t in response if t != pad_token_id] + query_response = query + response + + if len(query_response) + len(cur_data) > pack_length: + query_responses.append(cur_data) + response_masks.append(cur_response_mask) + attention_masks.append(cur_attention_mask) + num_actions.append(cur_num_actions) + packed_seq_lens.append(cur_packed_seq_lens) + cur_data = [] + cur_response_mask = [] + cur_attention_mask = [] + cur_num_actions = [] + cur_packed_seq_lens = [] + offset = i + + cur_data.extend(query_response) + cur_num_actions.append(len(response)) + cur_packed_seq_lens.append(len(query_response)) + cur_response_mask.extend([0 for _ in range(len(query))] + [i + 1 for _ in range(len(response))]) + cur_attention_mask.extend([i + 1 - offset for _ in range(len(query_response))]) + + if len(cur_data) > 0: + query_responses.append(cur_data) + response_masks.append(cur_response_mask) + attention_masks.append(cur_attention_mask) + num_actions.append(cur_num_actions) + packed_seq_lens.append(cur_packed_seq_lens) + + attention_masks_list = [torch.tensor(t) for t in attention_masks] + + from open_instruct.rl_utils2 import PackedSequences, reset_position_ids + + return PackedSequences( + query_responses=[torch.tensor(t) for t in query_responses], + attention_masks=attention_masks_list, + position_ids=[reset_position_ids(t.unsqueeze(0)).squeeze(0) for t in attention_masks_list], + response_masks=[torch.tensor(t) for t in response_masks], + original_responses=responses, + num_actions=[torch.tensor(t) for t in num_actions], + packed_seq_lens=[torch.tensor(t) for t in packed_seq_lens], + )