# This example is an adapted version of Bytedance's code: # https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py from typing import Optional import hydra import numpy as np import ray import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict from torch.utils.data import Dataset from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.utils.dataset.rl_dataset import collate_fn as verl_collate_fn from verl.utils.model import compute_position_id_with_mask import reasoning_gym import reasoning_gym.utils from reasoning_gym.coaching.experiment import Experiment from reasoning_gym.composite import CompositeDataset, DatasetSpec from reasoning_gym.dataset import ProceduralDataset from reasoning_gym.utils import extract_answer class ReasoningGymDataset(Dataset): def __init__( self, tokenizer: PreTrainedTokenizer, procedural_dataset: Optional[ProceduralDataset] = None, experiment: Optional[Experiment] = None, developer_prompt: Optional[str] = None, developer_role: str = "system", max_prompt_length: int = 2048, truncation: str = "error", ## ['left', 'right', 'error'] ): assert procedural_dataset or experiment, "One of `procedural_dataset` or `experiment` must be provided" assert ( procedural_dataset is None or experiment is None ), "Only one of `procedural_dataset` or `experiment` may be provided" self.tokenizer = tokenizer self.data = procedural_dataset or experiment.composite self.experiment = experiment self.developer_prompt = developer_prompt self.developer_role = developer_role self.max_prompt_length = max_prompt_length self.truncation = truncation def __len__(self) -> int: return len(self.data) def __getitem__(self, index): row_dict = self.data[index].copy() q = row_dict["question"] chat = [] if self.developer_prompt is not None: chat.append({"role": self.developer_role, "content": self.developer_prompt}) chat.append({"role": "user", "content": q}) prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( prompt=prompt, tokenizer=self.tokenizer, max_length=self.max_prompt_length, pad_token_id=self.tokenizer.pad_token_id, left_pad=True, truncation=self.truncation, ) position_ids = compute_position_id_with_mask(attention_mask) item = {} item["index"] = index item["input_ids"] = input_ids[0] item["attention_mask"] = attention_mask[0] item["position_ids"] = position_ids[0] item["raw_prompt_ids"] = item["input_ids"].tolist() return item def make_dataset( tokenizer, data_source: Experiment | ProceduralDataset, developer_prompt: str, max_prompt_length: int = 2048, ) -> ReasoningGymDataset: """ Create ReasoningGymDataset object using either a ProceduralDataset or Experiment as the underlying data source. """ if isinstance(data_source, Experiment): return ReasoningGymDataset( tokenizer=tokenizer, experiment=data_source, developer_prompt=developer_prompt, developer_role="system", max_prompt_length=max_prompt_length, truncation="error", ) else: return ReasoningGymDataset( tokenizer=tokenizer, procedural_dataset=data_source, developer_prompt=developer_prompt, developer_role="system", max_prompt_length=max_prompt_length, truncation="error", ) def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningGymDataset]: """Prepare training and validation datasets.""" dataset_size = config.reasoning_gym.dataset_size developer_prompt_setting = config.reasoning_gym.developer_prompt developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS[developer_prompt_setting] dataset_specs = [ DatasetSpec( name=name, weight=ds.weight, config=OmegaConf.to_container(ds.config, resolve=True) if "config" in ds else {}, ) for name, ds in config.reasoning_gym.datasets.items() ] train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs) val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs) train_dataset = make_dataset( tokenizer, train_data_source, developer_prompt, max_prompt_length=config.data.max_prompt_length ) val_dataset = make_dataset( tokenizer, val_data_source, developer_prompt, max_prompt_length=config.data.max_prompt_length ) return train_dataset, val_dataset class RayPPOTrainerCustom(RayPPOTrainer): def __init__( self, config, tokenizer, role_worker_mapping: dict, resource_pool_manager, ray_worker_group_cls, train_dataset: ReasoningGymDataset, val_dataset: ReasoningGymDataset, dataset_name: str = "chain_sum", dataset_size: int = 10000, ): self.dataset_name = dataset_name self.dataset_size = dataset_size developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] self.train_dataset = train_dataset self.val_dataset = val_dataset def make_reward_fn(num_examine: int): def reward_fn(data: DataProto, return_dict: bool = False, **unused_kwargs): tensor = self._score_output(data, num_examine=num_examine) if return_dict: # wrap it so trainer can pull out extras return {"reward_tensor": tensor, "reward_extra_info": {}} return tensor return reward_fn train_reward_fn = make_reward_fn(num_examine=0) val_reward_fn = make_reward_fn(num_examine=1) super().__init__( config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, train_reward_fn, val_reward_fn, train_dataset=train_dataset, val_dataset=val_dataset, train_sampler=None, ) def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) num_printed = 0 for i in range(len(data)): data_item = data[i] # DataProtoItem prompt_ids = data_item.batch["prompts"] # tokenized prompts prompt_length = prompt_ids.shape[-1] valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] response_ids = data_item.batch["responses"] valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode prompt_str = self.tokenizer.decode(valid_prompt_ids) response_str = self.tokenizer.decode(valid_response_ids) sequences_str = prompt_str + response_str index = data_item.non_tensor_batch["index"] score = self._compute_score( solution_str=response_str, index=index, ) reward_tensor[i, valid_response_length - 1] = score if num_printed < num_examine: print(f"reward={score}, seq={sequences_str}") num_printed += 1 return reward_tensor def _compute_score(self, solution_str: str, index: int) -> float: found_answer = extract_answer(solution_str, tag_name="answer") entry = self.train_dataset.data[index] reward = self.train_dataset.data.score_answer(found_answer, entry=entry) return reward def _create_dataloader(self, train_dataset, val_dataset, collate_fn=None, sampler=None): if collate_fn is None: collate_fn = verl_collate_fn self.train_dataloader = StatefulDataLoader( dataset=train_dataset, batch_size=self.config.data.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn, ) self.val_dataloader = StatefulDataLoader( dataset=val_dataset, batch_size=self.config.data.val_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn, ) assert len(self.train_dataloader) >= 1 assert len(self.val_dataloader) >= 1 print(f"Size of train dataloader: {len(self.train_dataloader)}") print(f"Size of val dataloader: {len(self.val_dataloader)}") # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps print(f"Total training steps: {self.total_training_steps}") OmegaConf.set_struct(self.config, True) with open_dict(self.config): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps @ray.remote def main_task(config): # print initial config from pprint import pprint from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) # instantiate tokenizer tokenizer = hf_tokenizer(local_path) train_dataset, val_dataset = prepare_datasets(config, tokenizer) # define worker classes if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: raise NotImplementedError from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role role_worker_mapping = { Role.ActorRollout: ray.remote(ActorRolloutRefWorker), Role.Critic: ray.remote(CriticWorker), Role.RefPolicy: ray.remote(ActorRolloutRefWorker), } global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } mapping = { Role.ActorRollout: global_pool_id, Role.Critic: global_pool_id, Role.RefPolicy: global_pool_id, } resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) trainer = RayPPOTrainerCustom( config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, train_dataset=train_dataset, val_dataset=val_dataset, ) trainer.init_workers() trainer.fit() @hydra.main(config_path="config", config_name="grpo_trainer", version_base=None) def main(config): if not ray.is_initialized(): # this is for local ray cluster ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) ray.get(main_task.remote(config)) if __name__ == "__main__": main()