diff --git a/examples/veRL/.gitignore b/examples/veRL/.gitignore new file mode 100644 index 00000000..c54a47c0 --- /dev/null +++ b/examples/veRL/.gitignore @@ -0,0 +1,3 @@ +outputs/ +wandb/ +verl_output.log diff --git a/examples/veRL/config/ppo_trainer.yaml b/examples/veRL/config/ppo_trainer.yaml new file mode 100644 index 00000000..b294a7cb --- /dev/null +++ b/examples/veRL/config/ppo_trainer.yaml @@ -0,0 +1,167 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: 1312 + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + grad_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + grad_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number + max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: [ 'console', 'wandb' ] + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/examples/veRL/launch_on_4gpu.sh b/examples/veRL/launch_on_4gpu.sh new file mode 100755 index 00000000..0a51f68c --- /dev/null +++ b/examples/veRL/launch_on_4gpu.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +export N_GPUS=4 +export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct +export ROLLOUT_TP_SIZE=2 +export EXPERIMENT_NAME=chain_sum_llama +export VLLM_ATTENTION_BACKEND=XFORMERS + +bash ./train.sh diff --git a/examples/veRL/main_ppo_custom_reward.py b/examples/veRL/main_ppo_custom_reward.py new file mode 100644 index 00000000..ef28ecf5 --- /dev/null +++ b/examples/veRL/main_ppo_custom_reward.py @@ -0,0 +1,353 @@ +# This example is a modified version of: +# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py + + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +from typing import Optional +from omegaconf import OmegaConf, open_dict +import reasoning_gym +from reasoning_gym.utils import extract_answer + +import reasoning_gym.utils +from verl import DataProto +import torch +from torch.utils.data import Dataset, DataLoader +from transformers import PreTrainedTokenizer + +import ray +import hydra + + +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.model import compute_position_id_with_mask +from verl.utils.dataset.rl_dataset import collate_fn +import verl.utils.torch_functional as verl_F + + +class RewardManager: + """The reward manager.""" + + def __init__(self, tokenizer, num_examine, compute_score) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + self.compute_score = compute_score + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if "rm_scores" in data.batch.keys(): + return data.batch["rm_scores"] + + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + already_print_data_sources = {} + + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch["prompts"] + + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + + data_source = data_item.non_tensor_batch["data_source"] + ground_truth = data_item.non_tensor_batch["answer"] + index = data_item.non_tensor_batch["index"] + + score = self.compute_score( + data_source=data_source, + solution_str=sequences_str, + ground_truth=ground_truth, + index=index, + ) + reward_tensor[i, valid_response_length - 1] = score + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print(sequences_str) + + return reward_tensor + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) + + ray.get(main_task.remote(config)) + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + dataset_name: str, + tokenizer: PreTrainedTokenizer, + seed: int, + size: int, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + max_prompt_length: int = 2048, + truncation: str = "error", ## ['left', 'right', 'error'] + return_raw_chat: bool = False, + ): + self.tokenizer = tokenizer + self.dataset_name = dataset_name + self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size) + self.developer_prompt = developer_prompt + self.developer_role = developer_role + self.max_prompt_length = max_prompt_length + self.truncation = truncation + self.return_raw_chat = return_raw_chat + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index): + row_dict = self.data[index].copy() + q = row_dict["question"] + + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": q}) + + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt, + tokenizer=self.tokenizer, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["data_source"] = "reasoning_gym/" + self.dataset_name + row_dict["input_ids"] = input_ids[0] + row_dict["attention_mask"] = attention_mask[0] + row_dict["position_ids"] = position_ids[0] + + # encode prompts without chat template + if self.return_raw_chat: + row_dict["raw_prompt"] = chat.tolist() + + # add index for each prompt + # index = row_dict.get("extra_info", {}).get("index", 0) + row_dict["index"] = index + + return row_dict + + +class RayPPOTrainerCustom(RayPPOTrainer): + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict, + resource_pool_manager, + ray_worker_group_cls, + reward_fn=None, + val_reward_fn=None, + dataset_name: str = "chain_sum", + dataset_size: int = 10000, + ): + self.dataset_name = dataset_name + self.dataset_size = dataset_size + + developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] + self.train_dataset = ReasoningGymDataset( + dataset_name=self.dataset_name, + tokenizer=tokenizer, + seed=1, + size=self.dataset_size, + developer_prompt=developer_prompt, + ) + + self.val_dataset = ReasoningGymDataset( + dataset_name=self.dataset_name, + tokenizer=tokenizer, + seed=2, + size=self.dataset_size, + developer_prompt=developer_prompt, + ) + + reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0, compute_score=self._compute_score) + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + reward_fn, + val_reward_fn, + ) + + def _compute_score(self, data_source, solution_str, ground_truth, index) -> float: + print("Solution:", solution_str, ground_truth, index, data_source) + found_answer = extract_answer(solution_str, tag_name="answer") + entry = self.train_dataset.data[index] + return self.train_dataset.data.score_answer(found_answer, entry=entry) + + def _create_dataloader(self): + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + +@ray.remote +def main_task(config, compute_score=None): + from verl.utils.fs import copy_local_path_from_hdfs + from transformers import AutoTokenizer + + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_tokenizer + + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy == "fsdp": + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # Note that we always use function-based RM for validation + val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPPOTrainerCustom( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/examples/veRL/train.sh b/examples/veRL/train.sh new file mode 100755 index 00000000..92ed0b84 --- /dev/null +++ b/examples/veRL/train.sh @@ -0,0 +1,30 @@ +#!/bin/bash +python3 -u main_ppo_custom_reward.py \ +data.train_files=$DATA_DIR/train.parquet \ +data.val_files=$DATA_DIR/test.parquet \ +data.train_batch_size=256 \ +data.val_batch_size=1312 \ +data.max_prompt_length=256 \ +data.max_response_length=1024 \ +actor_rollout_ref.model.path=$BASE_MODEL \ +actor_rollout_ref.actor.optim.lr=1e-6 \ +actor_rollout_ref.actor.ppo_mini_batch_size=128 \ +actor_rollout_ref.actor.ppo_micro_batch_size=8 \ +actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \ +actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ +actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ +actor_rollout_ref.ref.log_prob_micro_batch_size=4 \ +critic.optim.lr=1e-5 \ +critic.model.path=$BASE_MODEL \ +critic.ppo_micro_batch_size=8 \ +algorithm.kl_ctrl.kl_coef=0.001 \ +trainer.logger=['wandb'] \ ++trainer.val_before_train=False \ +trainer.default_hdfs_dir=null \ +trainer.n_gpus_per_node=$N_GPUS \ +trainer.nnodes=1 \ +trainer.save_freq=100 \ +trainer.test_freq=100 \ +trainer.project_name=verl_chain_sum \ +trainer.experiment_name=$EXPERIMENT_NAME \ +trainer.total_epochs=15 2>&1 | tee verl_output.log