# This example is an adapted version of Bytedance's code: # https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py from typing import Optional import hydra import ray import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict from torch.utils.data import DataLoader, Dataset from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.utils.dataset.rl_dataset import collate_fn from verl.utils.model import compute_position_id_with_mask import reasoning_gym import reasoning_gym.utils from reasoning_gym.utils import extract_answer class ReasoningGymDataset(Dataset): def __init__( self, tokenizer: PreTrainedTokenizer, dataset_name: str, seed: int, size: int, developer_prompt: Optional[str] = None, developer_role: str = "system", max_prompt_length: int = 2048, truncation: str = "error", ## ['left', 'right', 'error'] return_raw_chat: bool = False, ): self.tokenizer = tokenizer self.dataset_name = dataset_name self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size) self.developer_prompt = developer_prompt self.developer_role = developer_role self.max_prompt_length = max_prompt_length self.truncation = truncation self.return_raw_chat = return_raw_chat def __len__(self) -> int: return len(self.data) def __getitem__(self, index): row_dict = self.data[index].copy() q = row_dict["question"] chat = [] if self.developer_prompt is not None: chat.append({"role": self.developer_role, "content": self.developer_prompt}) chat.append({"role": "user", "content": q}) prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( prompt=prompt, tokenizer=self.tokenizer, max_length=self.max_prompt_length, pad_token_id=self.tokenizer.pad_token_id, left_pad=True, truncation=self.truncation, ) position_ids = compute_position_id_with_mask(attention_mask) row_dict["data_source"] = "reasoning_gym/" + self.dataset_name row_dict["input_ids"] = input_ids[0] row_dict["attention_mask"] = attention_mask[0] row_dict["position_ids"] = position_ids[0] # encode prompts without chat template if self.return_raw_chat: row_dict["raw_prompt"] = chat.tolist() # add index for each prompt # index = row_dict.get("extra_info", {}).get("index", 0) row_dict["index"] = index return row_dict class RayPPOTrainerCustom(RayPPOTrainer): def __init__( self, config, tokenizer, role_worker_mapping: dict, resource_pool_manager, ray_worker_group_cls, dataset_name: str = "chain_sum", dataset_size: int = 10000, ): self.dataset_name = dataset_name self.dataset_size = dataset_size developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] self.train_dataset = ReasoningGymDataset( tokenizer=tokenizer, dataset_name=self.dataset_name, seed=1, size=self.dataset_size, developer_prompt=developer_prompt, ) self.val_dataset = ReasoningGymDataset( tokenizer=tokenizer, dataset_name=self.dataset_name, seed=2, size=self.dataset_size, developer_prompt=developer_prompt, ) train_reward_fn = lambda data: self._score_output(data, num_examine=0) val_reward_fn = lambda data: self._score_output(data, num_examine=1) super().__init__( config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, train_reward_fn, val_reward_fn, ) def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) num_printed = 0 for i in range(len(data)): data_item = data[i] # DataProtoItem prompt_ids = data_item.batch["prompts"] # tokenized prompts prompt_length = prompt_ids.shape[-1] valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] response_ids = data_item.batch["responses"] valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode sequences = torch.cat((valid_prompt_ids, valid_response_ids)) sequences_str = self.tokenizer.decode(sequences) index = data_item.non_tensor_batch["index"] score = self._compute_score( solution_str=sequences_str, index=index, ) reward_tensor[i, valid_response_length - 1] = score if num_printed < num_examine: print(f"reward={score}, seq={sequences_str}") num_printed += 1 return reward_tensor def _compute_score(self, solution_str: str, index: int) -> float: found_answer = extract_answer(solution_str, tag_name="answer") entry = self.train_dataset.data[index] reward = self.train_dataset.data.score_answer(found_answer, entry=entry) # print(f"found answer={found_answer}; reward: {reward};") return reward def _create_dataloader(self): self.train_dataloader = DataLoader( dataset=self.train_dataset, batch_size=self.config.data.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn, ) self.val_dataloader = DataLoader( dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=True, drop_last=True, collate_fn=collate_fn, ) assert len(self.train_dataloader) >= 1 assert len(self.val_dataloader) >= 1 print(f"Size of train dataloader: {len(self.train_dataloader)}") print(f"Size of val dataloader: {len(self.val_dataloader)}") # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps print(f"Total training steps: {self.total_training_steps}") OmegaConf.set_struct(self.config, True) with open_dict(self.config): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps @ray.remote def main_task(config): # print initial config from pprint import pprint from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) # instantiate tokenizer tokenizer = hf_tokenizer(local_path) # define worker classes if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: raise NotImplementedError from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role role_worker_mapping = { Role.ActorRollout: ray.remote(ActorRolloutRefWorker), Role.Critic: ray.remote(CriticWorker), Role.RefPolicy: ray.remote(ActorRolloutRefWorker), } global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } mapping = { Role.ActorRollout: global_pool_id, Role.Critic: global_pool_id, Role.RefPolicy: global_pool_id, } resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) trainer = RayPPOTrainerCustom( config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, ) trainer.init_workers() trainer.fit() @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) def main(config): if not ray.is_initialized(): # this is for local ray cluster ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) ray.get(main_task.remote(config)) if __name__ == "__main__": main()