mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Feat/open instruct example (#381)
* added open-instruct * fixed hooks * GRPO --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
parent
1511c5e301
commit
9234aa77bf
5 changed files with 629 additions and 0 deletions
34
examples/open-instruct/README.md
Normal file
34
examples/open-instruct/README.md
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# Reasoning Gym Open-Instruct Example
|
||||
|
||||
This example demonstrates how to use [open-instruct](https://github.com/allenai/open-instruct) with **reasoning-gym** for training language models through reinforcement learning.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Before running the training script, you may want to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
```
|
||||
|
||||
- `CUDA_VISIBLE_DEVICES=0` specifies which GPUs to use (in this case, GPUs 0 )
|
||||
- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` enables dynamic GPU memory allocation
|
||||
|
||||
## Running the Training Script
|
||||
|
||||
To start training, run:
|
||||
|
||||
```bash
|
||||
bash grpo_config.sh
|
||||
```
|
||||
|
||||
This will train a DeepSeek-R1-Distill-Qwen-1.5B model using GRPO (Group Relative Policy Optimization) on reasoning tasks.
|
||||
|
||||
## Key Features
|
||||
|
||||
- Uses vLLM for efficient inference
|
||||
- Supports multi-GPU training
|
||||
- Implements reward functions for both answer correctness and response formatting
|
||||
- Includes evaluation during training
|
||||
- Supports model checkpointing and pushing to Hugging Face Hub
|
||||
- Integrates with Weights & Biases for experiment tracking
|
||||
34
examples/open-instruct/grpo_config.sh
Executable file
34
examples/open-instruct/grpo_config.sh
Executable file
|
|
@ -0,0 +1,34 @@
|
|||
exp_name="0302_qwen2.5_math_grpo_fast1_${RANDOM}"
|
||||
python src/grpo_trainer.py \
|
||||
--model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||
--beta 0.0 \
|
||||
--num_unique_prompts_rollout 16 \
|
||||
--num_samples_per_prompt_rollout 8 \
|
||||
--output_dir ./open_instruct_checkpoints/$exp_name \
|
||||
--save_freq 40 \
|
||||
--kl_estimator kl3 \
|
||||
--learning_rate 5e-7 \
|
||||
--max_token_length 256 \
|
||||
--max_prompt_token_length 128 \
|
||||
--response_length 256 \
|
||||
--pack_length 512 \
|
||||
--stop_strings '"</answer>"' \
|
||||
--chat_template_name r1_simple_chat_postpend_think \
|
||||
--temperature 1.0 \
|
||||
--total_episodes 100000 \
|
||||
--deepspeed_stage 3 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--num_mini_batches 1\
|
||||
--num_learners_per_node 1 \
|
||||
--num_epochs 1 \
|
||||
--vllm_tensor_parallel_size 1 \
|
||||
--vllm_num_engines 1 \
|
||||
--vllm_enforce_eager true \
|
||||
--lr_scheduler_type linear \
|
||||
--seed 1 \
|
||||
--num_evals 100 \
|
||||
--dataset_name chain_sum \
|
||||
--gradient_checkpointing \
|
||||
--with_tracking \
|
||||
--single_gpu_mode false \
|
||||
--gather_whole_model false
|
||||
0
examples/open-instruct/src/__init__.py
Normal file
0
examples/open-instruct/src/__init__.py
Normal file
457
examples/open-instruct/src/grpo_trainer.py
Normal file
457
examples/open-instruct/src/grpo_trainer.py
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict, dataclass
|
||||
from queue import Queue
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
from open_instruct.ground_truth_utils import soft_format_reward_func
|
||||
from open_instruct.grpo_fast import (
|
||||
Args,
|
||||
ModelGroup,
|
||||
PolicyTrainerRayProcess,
|
||||
Timer,
|
||||
calculate_runtime_args,
|
||||
collate_fn,
|
||||
vllm_generate_thread,
|
||||
)
|
||||
from open_instruct.model_utils import ModelConfig, push_folder_to_hub
|
||||
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
|
||||
from open_instruct.utils import ArgumentParserPlus, get_wandb_tags, is_beaker_job, maybe_get_beaker_config
|
||||
from open_instruct.vllm_utils2 import create_vllm_engines
|
||||
from ray.util.placement_group import placement_group
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from utils import ReasoningGymDataset, list_preserving_collate, pack_sequences
|
||||
from vllm import SamplingParams
|
||||
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOArgs(Args):
|
||||
dataset_name: str = "chain_sum"
|
||||
seed: int = 42
|
||||
size: int = 10000
|
||||
eval_seed: int = 42
|
||||
eval_size: int = 100
|
||||
|
||||
|
||||
def data_preparation_thread(
|
||||
reward_fn: Callable,
|
||||
inference_results_Q: Queue,
|
||||
packed_sequences_Q: Queue,
|
||||
queries_prompt_Q: Queue,
|
||||
args: Args,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_training_steps: int,
|
||||
reasoning_gym_dataset: ReasoningGymDataset,
|
||||
):
|
||||
for training_step in range(1, num_training_steps + 1):
|
||||
# Get next batch of prompts and responses
|
||||
observations = queries_prompt_Q.get()
|
||||
queries, items = observations
|
||||
|
||||
if args.num_samples_per_prompt_rollout > 1:
|
||||
queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)]
|
||||
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
|
||||
responses, finish_reasons = inference_results_Q.get()
|
||||
for i in range(len(finish_reasons)):
|
||||
if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id:
|
||||
responses[i].append(tokenizer.eos_token_id)
|
||||
|
||||
with Timer("📦 [Data Preparation Thread] Packing sequences"):
|
||||
packed_sequences = pack_sequences(
|
||||
queries=queries,
|
||||
responses=responses,
|
||||
pack_length=args.pack_length,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
|
||||
|
||||
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
|
||||
decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons)
|
||||
|
||||
with Timer("💰 [Data Preparation Thread] Calculating rewards"):
|
||||
scores, reward_metrics = reward_fn(decoded_responses, items, reasoning_gym_dataset)
|
||||
|
||||
with Timer("🎆 [Data Preparation Thread] Calculating advantages"):
|
||||
# Calculate advantages
|
||||
scores = np.array(scores)
|
||||
print(f"[Data Preparation Thread] {len(scores)=}")
|
||||
scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout)
|
||||
global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
|
||||
global_mean_grouped_rewards = np.repeat(
|
||||
global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
|
||||
)
|
||||
global_std_grouped_rewards = scores_per_prompt.std(axis=-1)
|
||||
global_std_grouped_rewards = np.repeat(
|
||||
global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
|
||||
)
|
||||
global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8)
|
||||
|
||||
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
|
||||
# and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed)
|
||||
lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32)
|
||||
lookup_advantages[1:] = global_advantages
|
||||
packed_advantages = [
|
||||
torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32)
|
||||
for packed_mask in packed_sequences.response_masks
|
||||
]
|
||||
packed_sequences.advantages = packed_advantages
|
||||
|
||||
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
|
||||
B = (
|
||||
len(packed_sequences.query_responses) // args.world_size
|
||||
) # essentially doing `drop_last=True`, which is fine.
|
||||
collated_data = []
|
||||
for i in range(args.world_size):
|
||||
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)]
|
||||
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)]
|
||||
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
|
||||
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
|
||||
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
|
||||
|
||||
# Shuffle the batch and collate the data
|
||||
b_inds = np.random.permutation(len(per_device_packed_query_responses))
|
||||
collated_query_responses = []
|
||||
collated_attention_masks = []
|
||||
collated_position_ids = []
|
||||
collated_response_masks = []
|
||||
collated_advantages = []
|
||||
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
|
||||
micro_range = b_inds[j : j + args.per_device_train_batch_size]
|
||||
collated_query_responses.append(
|
||||
collate_fn(
|
||||
[per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id
|
||||
)
|
||||
)
|
||||
collated_attention_masks.append(
|
||||
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_position_ids.append(
|
||||
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_response_masks.append(
|
||||
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_advantages.append(
|
||||
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_data.append(
|
||||
{
|
||||
"collated_query_responses": collated_query_responses,
|
||||
"collated_attention_masks": collated_attention_masks,
|
||||
"collated_position_ids": collated_position_ids,
|
||||
"collated_advantages": collated_advantages,
|
||||
"collated_response_masks": collated_response_masks,
|
||||
}
|
||||
)
|
||||
sequence_lengths = np.array([len(response) for response in responses])
|
||||
metrics = {
|
||||
"scores": np.array(scores).mean(),
|
||||
"val/sequence_lengths": sequence_lengths.mean(),
|
||||
"val/sequence_lengths_min": sequence_lengths.min(),
|
||||
"val/sequence_lengths_max": sequence_lengths.max(),
|
||||
"val/stop_rate": stop_rate,
|
||||
**reward_metrics,
|
||||
}
|
||||
|
||||
if args.save_traces:
|
||||
traces = {
|
||||
"scores": scores.tolist(),
|
||||
"finish_reasons": finish_reasons,
|
||||
"responses": responses,
|
||||
"queries": queries,
|
||||
"ground_truths": items["answer"],
|
||||
"datasets": reasoning_gym_dataset.dataset_name,
|
||||
"training_step": training_step,
|
||||
**reward_metrics,
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f:
|
||||
json.dump(traces, f)
|
||||
f.write("\n")
|
||||
packed_sequences_Q.put(
|
||||
{
|
||||
"packed_sequences": packed_sequences, # for debugging purposes
|
||||
"collated_data": collated_data,
|
||||
"metrics": metrics,
|
||||
"responses_count": len(responses),
|
||||
"num_new_tokens": num_new_tokens,
|
||||
"B": B,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def reward_fn(responses, items, reasoning_gym_dataset):
|
||||
metrics = {}
|
||||
formatted_responses = [extract_answer(response) for response in responses]
|
||||
correctness_rewards = [
|
||||
reasoning_gym_dataset.data.score_answer(formatted_response, entry)
|
||||
for formatted_response, entry in zip(formatted_responses, items)
|
||||
]
|
||||
format_rewards = soft_format_reward_func(responses)
|
||||
rewards = correctness_rewards + format_rewards
|
||||
metrics["correctness_rewards"] = correctness_rewards
|
||||
metrics["format_rewards"] = format_rewards
|
||||
metrics["rewards"] = rewards
|
||||
return rewards, metrics
|
||||
|
||||
|
||||
def main(args: Args, model_config: ModelConfig, reward_fn: Callable):
|
||||
calculate_runtime_args(args)
|
||||
|
||||
# Setup experiment tracking and seeds
|
||||
all_configs = {}
|
||||
beaker_config = None
|
||||
if is_beaker_job():
|
||||
beaker_config = maybe_get_beaker_config()
|
||||
all_configs.update(vars(beaker_config))
|
||||
all_configs.update(**asdict(args), **asdict(model_config))
|
||||
if args.with_tracking:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=args.wandb_project_name,
|
||||
entity=args.wandb_entity,
|
||||
sync_tensorboard=True,
|
||||
config=all_configs,
|
||||
name=args.run_name,
|
||||
save_code=True,
|
||||
tags=[args.exp_name] + get_wandb_tags(),
|
||||
)
|
||||
|
||||
# Setup tokenizer and get datasets
|
||||
tokenizer_revision = (
|
||||
model_config.model_revision if model_config.tokenizer_revision is None else model_config.tokenizer_revision
|
||||
)
|
||||
tokenizer_name = (
|
||||
model_config.tokenizer_name if model_config.tokenizer_name is not None else model_config.model_name_or_path
|
||||
)
|
||||
if tokenizer_revision != model_config.model_revision:
|
||||
# Warn user if tokenizer and model use different revisions; this is an unusual
|
||||
# use case.
|
||||
warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different
|
||||
from the model revision `{model_config.model_revision}`."""
|
||||
print(warning)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision)
|
||||
reasoning_gym_dataset = ReasoningGymDataset(
|
||||
args.dataset_name, args.seed, args.size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
)
|
||||
eval_dataset = ReasoningGymDataset(
|
||||
args.dataset_name, args.eval_seed, args.eval_size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
)
|
||||
|
||||
bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node]
|
||||
pg = placement_group(bundles, strategy="PACK")
|
||||
ray.get(pg.ready())
|
||||
inits = []
|
||||
policy_group = ModelGroup(
|
||||
pg,
|
||||
PolicyTrainerRayProcess,
|
||||
args.num_learners_per_node,
|
||||
args.single_gpu_mode,
|
||||
)
|
||||
|
||||
wandb_url = wandb.run.get_url() if args.with_tracking else None
|
||||
inits.extend(
|
||||
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
|
||||
for model in policy_group.models
|
||||
)
|
||||
max_len = args.max_prompt_token_length + args.response_length
|
||||
vllm_engines = create_vllm_engines(
|
||||
args.vllm_num_engines,
|
||||
args.vllm_tensor_parallel_size,
|
||||
args.vllm_enforce_eager,
|
||||
model_config.model_name_or_path,
|
||||
model_config.model_revision,
|
||||
args.seed,
|
||||
args.vllm_enable_prefix_caching,
|
||||
max_len,
|
||||
args.vllm_gpu_memory_utilization,
|
||||
args.single_gpu_mode,
|
||||
pg=pg if args.single_gpu_mode else None,
|
||||
)
|
||||
ray.get(inits)
|
||||
print("======== ✅ all models and vLLM engines initialized =========")
|
||||
|
||||
ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models])
|
||||
print("======== ✅ model update group setup successfully =========")
|
||||
|
||||
# Setup training
|
||||
generation_config = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=1.0,
|
||||
max_tokens=args.response_length,
|
||||
include_stop_str_in_output=True,
|
||||
n=args.num_samples_per_prompt_rollout,
|
||||
stop=args.stop_strings,
|
||||
)
|
||||
eval_generation_config = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=args.response_length,
|
||||
include_stop_str_in_output=True,
|
||||
n=1, # since we are doing greedy sampling, don't need to generate more
|
||||
stop=args.stop_strings,
|
||||
)
|
||||
|
||||
iter_dataloader = DataLoader(
|
||||
reasoning_gym_dataset, batch_size=args.num_unique_prompts_rollout, collate_fn=list_preserving_collate
|
||||
)
|
||||
inference_results_Q = Queue(maxsize=1)
|
||||
param_prompt_Q = Queue(maxsize=1)
|
||||
evaluation_inference_results_Q = Queue(maxsize=1)
|
||||
packed_sequences_Q = Queue(maxsize=1)
|
||||
queries_prompt_Q = Queue(maxsize=1)
|
||||
num_eval_samples = 50
|
||||
|
||||
eval_prompt_token_ids = None
|
||||
eval_ground_truths = None
|
||||
if eval_dataset is not None:
|
||||
model_eval_input, eval_items = next(iter(eval_dataset))
|
||||
eval_prompt_token_ids = model_eval_input
|
||||
resume_training_step = 1
|
||||
thread = threading.Thread(
|
||||
target=vllm_generate_thread,
|
||||
args=(
|
||||
vllm_engines,
|
||||
generation_config,
|
||||
eval_generation_config,
|
||||
inference_results_Q,
|
||||
param_prompt_Q,
|
||||
args.num_training_steps,
|
||||
eval_prompt_token_ids,
|
||||
evaluation_inference_results_Q,
|
||||
args.eval_freq,
|
||||
resume_training_step,
|
||||
),
|
||||
)
|
||||
|
||||
thread.start()
|
||||
print("======== ✅ vllm generate thread starts =========")
|
||||
|
||||
packing_thread = threading.Thread(
|
||||
target=data_preparation_thread,
|
||||
args=(
|
||||
reward_fn,
|
||||
inference_results_Q,
|
||||
packed_sequences_Q,
|
||||
queries_prompt_Q,
|
||||
args,
|
||||
tokenizer,
|
||||
args.num_training_steps,
|
||||
reasoning_gym_dataset,
|
||||
),
|
||||
)
|
||||
|
||||
packing_thread.start()
|
||||
print("======== ✅ data preparation thread starts =========")
|
||||
model_inputs, items = next(iter(iter_dataloader))
|
||||
queries_prompt_Q.put((model_inputs, items))
|
||||
param_prompt_Q.put((None, model_inputs))
|
||||
|
||||
episode = 0
|
||||
num_total_tokens = 0
|
||||
start_time = time.time()
|
||||
for training_step in range(resume_training_step, args.num_training_steps + 1):
|
||||
episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
|
||||
|
||||
if training_step != 1:
|
||||
model_inputs, items = next(iter(iter_dataloader))
|
||||
with Timer("🔄 Loading weights using shared memory"):
|
||||
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
|
||||
|
||||
queries_prompt_Q.put((model_inputs, items))
|
||||
param_prompt_Q.put((None, model_inputs))
|
||||
|
||||
with Timer("[Main Thread] 📦 Getting packed sequences from thread"):
|
||||
packed_data = packed_sequences_Q.get()
|
||||
data_thread_metrics = packed_data["metrics"]
|
||||
B = packed_data["B"]
|
||||
collated_data = packed_data["collated_data"]
|
||||
num_total_tokens += packed_data["num_new_tokens"]
|
||||
|
||||
if B == 0:
|
||||
print("[Main Thread] 🤡 After packing, there is not enough data to train")
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Train the model
|
||||
with Timer("[Main Thread] 🗡️ Training"):
|
||||
metrics_list: List[dict[str, float]] = ray.get(
|
||||
[
|
||||
policy_group.models[i].train.remote(
|
||||
**collated_data[i],
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
num_mini_batches=args.num_mini_batches,
|
||||
)
|
||||
for i in range(args.world_size)
|
||||
]
|
||||
)
|
||||
average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]}
|
||||
metrics = {
|
||||
"episode": episode,
|
||||
"training_step": training_step,
|
||||
"val/num_total_tokens": num_total_tokens,
|
||||
"epoch": episode / len(reasoning_gym_dataset),
|
||||
"tokens_per_second": num_total_tokens / (time.time() - start_time),
|
||||
**data_thread_metrics,
|
||||
**average_metrics,
|
||||
}
|
||||
|
||||
if args.save_freq > 0 and training_step % args.save_freq == 0:
|
||||
with Timer("[Main Thread] 🗡️ Saving model"):
|
||||
checkpoint_dir = f"{args.output_dir}_checkpoints"
|
||||
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
|
||||
print(f"Saving model at step {training_step} to {step_dir}")
|
||||
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
|
||||
|
||||
print(f"Saving final model at step {training_step} to {args.output_dir}")
|
||||
with Timer("[Main Thread] 🗡️ Saving model"):
|
||||
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
|
||||
|
||||
thread.join()
|
||||
print("======== ✅ vllm generate thread ends =========")
|
||||
packing_thread.join()
|
||||
print("======== ✅ data preparation thread ends =========")
|
||||
ray.shutdown()
|
||||
|
||||
if (
|
||||
args.try_auto_save_to_beaker
|
||||
and is_beaker_job()
|
||||
and len(beaker_config.beaker_dataset_id_urls) > 0
|
||||
and args.output_dir.rstrip("/") != "/output"
|
||||
):
|
||||
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
|
||||
print("finished training")
|
||||
|
||||
accelerator = Namespace()
|
||||
accelerator.is_main_process = True # hack
|
||||
if args.push_to_hub:
|
||||
print("Pushing model to hub")
|
||||
push_folder_to_hub(
|
||||
accelerator,
|
||||
args.output_dir,
|
||||
args.hf_repo_id,
|
||||
args.hf_repo_revision,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParserPlus((GRPOArgs, ModelConfig))
|
||||
args, model_config = parser.parse_args_into_dataclasses()
|
||||
assert isinstance(args, GRPOArgs)
|
||||
assert isinstance(model_config, ModelConfig)
|
||||
main(args, model_config, reward_fn)
|
||||
104
examples/open-instruct/src/utils.py
Normal file
104
examples/open-instruct/src/utils.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS
|
||||
|
||||
|
||||
def list_preserving_collate(batch):
|
||||
"""
|
||||
Custom collate function that preserves lists instead of converting to tensors.
|
||||
"""
|
||||
token_ids = [item[0] for item in batch]
|
||||
items = [item[1] for item in batch]
|
||||
return token_ids, items
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(self, dataset_name, seed, size, tokenizer, developer_role, developer_prompt):
|
||||
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
|
||||
self.tokenizer = tokenizer
|
||||
self.developer_role = developer_role
|
||||
self.developer_prompt = developer_prompt
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size
|
||||
|
||||
def __getitem__(self, index):
|
||||
chat_message = [{"role": self.developer_role, "content": self.developer_prompt}]
|
||||
item = self.data[index]
|
||||
chat_message.append({"role": "user", "content": item["question"]})
|
||||
prompt_text = self.tokenizer.apply_chat_template(chat_message, tokenize=True, add_generation_prompt=True)
|
||||
prompt_text = [token for token in prompt_text if token != self.tokenizer.pad_token_id]
|
||||
return prompt_text, item
|
||||
|
||||
|
||||
def pack_sequences(
|
||||
queries: List[List[int]],
|
||||
responses: List[List[int]],
|
||||
pack_length: int,
|
||||
pad_token_id: int,
|
||||
) -> "PackedSequences":
|
||||
# assert padding token does not exist in queries and responses
|
||||
query_responses = []
|
||||
attention_masks = []
|
||||
response_masks = []
|
||||
num_actions = []
|
||||
packed_seq_lens = []
|
||||
cur_data = []
|
||||
cur_response_mask = []
|
||||
cur_num_actions = []
|
||||
cur_packed_seq_lens = []
|
||||
cur_attention_mask = []
|
||||
offset = 0
|
||||
|
||||
for i in range(len(queries)):
|
||||
query = queries[i]
|
||||
response = responses[i]
|
||||
# remove padding (but using vllm so this should not be needed, but just in case)
|
||||
query = [t for t in query if t != pad_token_id]
|
||||
response = [t for t in response if t != pad_token_id]
|
||||
query_response = query + response
|
||||
|
||||
if len(query_response) + len(cur_data) > pack_length:
|
||||
query_responses.append(cur_data)
|
||||
response_masks.append(cur_response_mask)
|
||||
attention_masks.append(cur_attention_mask)
|
||||
num_actions.append(cur_num_actions)
|
||||
packed_seq_lens.append(cur_packed_seq_lens)
|
||||
cur_data = []
|
||||
cur_response_mask = []
|
||||
cur_attention_mask = []
|
||||
cur_num_actions = []
|
||||
cur_packed_seq_lens = []
|
||||
offset = i
|
||||
|
||||
cur_data.extend(query_response)
|
||||
cur_num_actions.append(len(response))
|
||||
cur_packed_seq_lens.append(len(query_response))
|
||||
cur_response_mask.extend([0 for _ in range(len(query))] + [i + 1 for _ in range(len(response))])
|
||||
cur_attention_mask.extend([i + 1 - offset for _ in range(len(query_response))])
|
||||
|
||||
if len(cur_data) > 0:
|
||||
query_responses.append(cur_data)
|
||||
response_masks.append(cur_response_mask)
|
||||
attention_masks.append(cur_attention_mask)
|
||||
num_actions.append(cur_num_actions)
|
||||
packed_seq_lens.append(cur_packed_seq_lens)
|
||||
|
||||
attention_masks_list = [torch.tensor(t) for t in attention_masks]
|
||||
|
||||
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
|
||||
|
||||
return PackedSequences(
|
||||
query_responses=[torch.tensor(t) for t in query_responses],
|
||||
attention_masks=attention_masks_list,
|
||||
position_ids=[reset_position_ids(t.unsqueeze(0)).squeeze(0) for t in attention_masks_list],
|
||||
response_masks=[torch.tensor(t) for t in response_masks],
|
||||
original_responses=responses,
|
||||
num_actions=[torch.tensor(t) for t in num_actions],
|
||||
packed_seq_lens=[torch.tensor(t) for t in packed_seq_lens],
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue