diff --git a/training/README.md b/training/README.md new file mode 100644 index 00000000..6234ac3d --- /dev/null +++ b/training/README.md @@ -0,0 +1,66 @@ +# Reasoning Gym Model Training + +Training codebase for training LLMs using Reasoning Gym procedural dataset generators. + +### Requirements + +1. Prepare and activate a Python 3.11 virtual environment however you prefer. +2. Install Reasoning Gym: + +```bash +cd reasoning-gym/ +pip install -e . +``` + +3. Install training-specific Python package dependencies: + +```bash +pip install ray wandb +pip install torch==2.6.0 +pip install flash-attn --no-build-isolation +``` + +4. Install veRL (tested with HEAD c34206925e2a50fd452e474db857b4d488f8602d): + +```bash +git clone https://github.com/volcengine/verl.git +cd verl +pip install -e . +``` + +5. Install vLLM: + +```bash +pip install -U vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +``` + +6. Log in to HF and W&B: + +```bash +huggingface-cli login +wandb login +``` + +### Usage + +First, activate the virtual environment you prepared. + +Example GRPO training usage: + +```bash +python3 -u train_grpo.py --config-name llama3.1_1b_grpo \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + trainer.project_name=rg-test \ + trainer.experiment_name=verl_grpo_llama3.1_1b \ + trainer.n_gpus_per_node=2 $@ 2>&1 | tee verl_output.log +``` + +Then, having saved this as a bash script such as `train.sh`, run it: + +```bash +CUDA_VISIBLE_DEVICES=0,1 bash train.sh +``` + +CUDA_VISIBLE_DEVICES is set to 0,1 to use the first two GPUs on the machine (see `nvidia-smi` output). This can be adjusted as needed. `tensor_model_parallel_size` and `n_gpus_per_node` should also be set to the number of GPUs you are using. + +You can change all configuration options by either modifying the config YAML (in this case, `config/llama3.1_1b_grpo.yaml`) or providing them as arguments to the Python script. Note that the batch sizes set in the Llama 1B and Qwen 1.5B configs are as high as it was possible for me to set them for the puzzles dataset mix on 2xA6000 GPUs without OOMs. Depending on the hardware you use and the datasets you train on, you may need to adjust these. diff --git a/training/configs/llama3.1_1b_grpo.yaml b/training/configs/llama3.1_1b_grpo.yaml new file mode 100644 index 00000000..74200cad --- /dev/null +++ b/training/configs/llama3.1_1b_grpo.yaml @@ -0,0 +1,217 @@ +reasoning_gym: + dataset_size: 10000 + developer_prompt: DeepSeekZero + enable_curriculum_learning: False + datasets: # Used if enable_curriculum_learning is False + mini_sudoku: + weight: 0.33 + config: + min_empty: 6 + futoshiki: + weight: 0.33 + config: + max_board_size: 5 + sudoku: + weight: 0.34 + config: + min_empty: 20 + curricula: + leg_counting: + attribute_levels: + num_animals: 2 + weight: 1.0 + products: + attribute_levels: + num_terms: 4 + num_digits: 4 + weight: 1.0 + chain_sum: + attribute_levels: + num_terms: 4 + num_digits: 4 + weight: 1.0 + +reward: + format_reward: + enable: True + scaling_factor: 0.2 + length_reward: + enable: True + scaling_factor: 0.2 + +data: + tokenizer: null + train_files: train.parquet + val_files: test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 1024 + train_batch_size: 64 + val_batch_size: 64 + return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + model: + path: meta-llama/Llama-3.2-1B-Instruct + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 32 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 16 + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: True + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 160 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.6 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 160 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + use_fire_sampling: False + # number of responses (i.e. num sample times) + n: 8 # > 1 for grpo + val_kwargs: + do_sample: True + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + balance_batch: True + total_epochs: 10 + total_training_steps: null + project_name: rg-test + experiment_name: verl_grpo_llama3.1_1b + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 100 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: 100 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +# Reward model not used for GRPO +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} diff --git a/training/configs/qwen2.5_1.5b_grpo.yaml b/training/configs/qwen2.5_1.5b_grpo.yaml new file mode 100644 index 00000000..3ad49d60 --- /dev/null +++ b/training/configs/qwen2.5_1.5b_grpo.yaml @@ -0,0 +1,217 @@ +reasoning_gym: + dataset_size: 10000 + developer_prompt: DeepSeekZero + enable_curriculum_learning: False + datasets: # Used if enable_curriculum_learning is False + mini_sudoku: + weight: 0.33 + config: + min_empty: 6 + futoshiki: + weight: 0.33 + config: + max_board_size: 5 + sudoku: + weight: 0.34 + config: + min_empty: 20 + curricula: + leg_counting: + attribute_levels: + num_animals: 2 + weight: 1.0 + products: + attribute_levels: + num_terms: 4 + num_digits: 4 + weight: 1.0 + chain_sum: + attribute_levels: + num_terms: 4 + num_digits: 4 + weight: 1.0 + +reward: + format_reward: + enable: True + scaling_factor: 0.2 + length_reward: + enable: True + scaling_factor: 0.2 + +data: + tokenizer: null + train_files: train.parquet + val_files: test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 1024 + train_batch_size: 16 + val_batch_size: 16 + return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + model: + path: Qwen/Qwen2.5-1.5B-Instruct + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 16 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: True + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 160 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.6 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 160 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + use_fire_sampling: False + # number of responses (i.e. num sample times) + n: 8 # > 1 for grpo + val_kwargs: + do_sample: True + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + balance_batch: True + total_epochs: 10 + total_training_steps: null + project_name: rg-test + experiment_name: verl_grpo_llama3.1_1b + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 100 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: 100 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +# Reward model not used for GRPO +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} diff --git a/training/train_grpo.py b/training/train_grpo.py new file mode 100644 index 00000000..9eec6ff7 --- /dev/null +++ b/training/train_grpo.py @@ -0,0 +1,128 @@ +"""Train an LLM using GRPO over Reasoning Gym procedural dataset(s).""" + +from dataclasses import replace + +import hydra +import ray +from omegaconf import OmegaConf +from trainers import RayGRPOTrainer +from utils import ReasoningGymDataset, make_dataset + +import reasoning_gym +import reasoning_gym.utils +from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig +from reasoning_gym.coaching.experiment import CurriculumExperiment +from reasoning_gym.composite import CompositeDataset, DatasetSpec + + +def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningGymDataset]: + """Prepare training and validation datasets.""" + dataset_size = config.reasoning_gym.dataset_size + developer_prompt_setting = config.reasoning_gym.developer_prompt + developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS[developer_prompt_setting] + + if config.reasoning_gym.enable_curriculum_learning: + curricula = config.reasoning_gym.curricula + curriculum_config = CurriculumExperimentConfig( + curricula={ + curriculum_name: CurriculumAttributeConfig(**curriculum_config) + for curriculum_name, curriculum_config in curricula.items() + } + ) + curriculum_config.validate() + + train_data_source = CurriculumExperiment( + name=config.trainer.experiment_name, config=curriculum_config, size=dataset_size, seed=1 + ) + val_data_source = CompositeDataset(config=replace(train_data_source.composite.config, seed=2)) + else: + dataset_specs = [ + DatasetSpec(name=name, weight=ds.weight, config=OmegaConf.to_container(ds.config, resolve=True)) + for name, ds in config.reasoning_gym.datasets.items() + ] + train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs) + val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs) + + train_dataset = make_dataset(tokenizer, train_data_source, developer_prompt) + val_dataset = make_dataset(tokenizer, val_data_source, developer_prompt) + return train_dataset, val_dataset + + +@ray.remote +def main_task(config): + from pprint import pprint + + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_local_path_from_hdfs + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = NVMegatronRayWorkerGroup + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + train_dataset, val_dataset = prepare_datasets(config, tokenizer) + + trainer = RayGRPOTrainer( + config=config, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + max_output_length=config.data.max_response_length, + ) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="configs", config_name="llama3.1_1b_grpo", version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) + ray.get(main_task.remote(config)) + + +if __name__ == "__main__": + main() diff --git a/training/trainers/__init__.py b/training/trainers/__init__.py new file mode 100644 index 00000000..8509e769 --- /dev/null +++ b/training/trainers/__init__.py @@ -0,0 +1,3 @@ +from .ray_grpo_trainer import RayGRPOTrainer + +__all__ = ["RayGRPOTrainer"] diff --git a/training/trainers/ray_grpo_trainer.py b/training/trainers/ray_grpo_trainer.py new file mode 100644 index 00000000..a1d17824 --- /dev/null +++ b/training/trainers/ray_grpo_trainer.py @@ -0,0 +1,189 @@ +# Adapted version of Bytedance code: +# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py + +import re + +import torch +from omegaconf import OmegaConf, open_dict +from torchdata.stateful_dataloader import StatefulDataLoader +from utils import ReasoningGymDataset +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.dataset.rl_dataset import collate_fn + +from reasoning_gym.utils import extract_answer + + +class RayGRPOTrainer(RayPPOTrainer): + def __init__( + self, + config, + tokenizer, + train_dataset: ReasoningGymDataset, + val_dataset: ReasoningGymDataset, + role_worker_mapping: dict, + resource_pool_manager, + ray_worker_group_cls, + max_output_length: int = 1024, + ): + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.max_output_length = max_output_length + + self.format_reward_scaling_factor = config.reward.format_reward.scaling_factor + self.length_reward_scaling_factor = config.reward.length_reward.scaling_factor + + train_reward_fn = lambda data: self._score_output(data, num_examine=0) + val_reward_fn = lambda data: self._score_output(data, num_examine=1) + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + train_reward_fn, + val_reward_fn, + ) + + def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + num_printed = 0 + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch["prompts"] # tokenized prompts + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + prompt_str = self.tokenizer.decode(valid_prompt_ids) + response_str = self.tokenizer.decode(valid_response_ids) + sequences_str = prompt_str + response_str + + index = data_item.non_tensor_batch["index"] + + reward = score = self._compute_correctness_score( + solution_str=response_str, + index=index, + ) + + if self.config.reward.format_reward.enable: + format_reward = self._compute_format_reward(response_str) + reward += format_reward + else: + format_reward = 0.0 + + if self.config.reward.length_reward.enable: + length_reward = self._compute_length_reward(response_str, score) + reward += length_reward + else: + length_reward = 0.0 + + reward_tensor[i, valid_response_length - 1] = reward + + if num_printed < num_examine: + print( + f"reward={reward} (score={score}, format={format_reward}, length={length_reward}), seq={sequences_str}" + ) + num_printed += 1 + + return reward_tensor + + def _compute_format_reward(self, solution_str: str) -> float: + """Reward use of exactly one correctly structured and block.""" + scaling_factor = self.format_reward_scaling_factor + # check and blocks are present + pattern = r"\s*.*?\s*.*?" + if not re.match(pattern, solution_str, re.DOTALL): + return 0.0 + # check exactly one properly structured block and one block + think_matches = list(re.finditer(r"(.*?)", solution_str, re.DOTALL)) + answer_matches = list(re.finditer(r"(.*?)", solution_str, re.DOTALL)) + if len(think_matches) != 1 or len(answer_matches) != 1: + return 0.0 + # check for or inside + think_content = think_matches[0].group(1) + if "" in think_content or "" in think_content: + return 0.0 + # check for nested or inside + answer_content = answer_matches[0].group(1) + if "" in answer_content or "" in answer_content: + return 0.0 + return 1.0 * scaling_factor + + def _compute_length_reward( + self, + solution_str: str, + correctness_score: float, + max_score: float = 1.0, + ) -> float: + """ + Reward shorter solutions for perfect answers, longer solutions for imperfect answers. + The scaling factor for this should be set far below 1.0, to avoid dominating the reward signal over correctness. + """ + epsilon = 1e-6 + scaling_factor = self.length_reward_scaling_factor + generation_len = len(solution_str) + progress = min(generation_len / self.max_output_length, 1.0) + if correctness_score < max_score - epsilon: + # for imperfect answers, incentivise longer ones + length_reward = (max_score - correctness_score) * progress + else: + # for perfect answers, penalise longer ones + length_reward = -progress + return length_reward * scaling_factor + + def _compute_correctness_score(self, solution_str: str, index: int) -> float: + found_answer = extract_answer(solution_str, tag_name="answer") + data = self.train_dataset.data + entry = data[index] + if self.train_dataset.experiment: + experiment = self.train_dataset.experiment + return experiment.score_answer_with_id(found_answer, entry["metadata"]["entry_id"]) + else: + return data.score_answer(found_answer, entry=entry) + + def _create_dataloader(self): + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps diff --git a/training/utils/__init__.py b/training/utils/__init__.py new file mode 100644 index 00000000..a81a20cd --- /dev/null +++ b/training/utils/__init__.py @@ -0,0 +1,3 @@ +from .datasets import ReasoningGymDataset, make_dataset + +__all__ = ["ReasoningGymDataset", "make_dataset"] diff --git a/training/utils/datasets.py b/training/utils/datasets.py new file mode 100644 index 00000000..80da6b96 --- /dev/null +++ b/training/utils/datasets.py @@ -0,0 +1,87 @@ +from typing import Optional + +import verl.utils.torch_functional as verl_F +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer +from verl.utils.model import compute_position_id_with_mask + +from reasoning_gym.coaching.experiment import Experiment +from reasoning_gym.dataset import ProceduralDataset + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + procedural_dataset: Optional[ProceduralDataset] = None, + experiment: Optional[Experiment] = None, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + max_prompt_length: int = 2048, + truncation: str = "error", ## ['left', 'right', 'error'] + ): + assert procedural_dataset or experiment, "One of `procedural_dataset` or `experiment` must be provided" + assert ( + procedural_dataset is None or experiment is None + ), "Only one of `procedural_dataset` or `experiment` may be provided" + + self.tokenizer = tokenizer + self.data = procedural_dataset or experiment.composite + self.experiment = experiment + self.developer_prompt = developer_prompt + self.developer_role = developer_role + self.max_prompt_length = max_prompt_length + self.truncation = truncation + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index): + row_dict = self.data[index].copy() + q = row_dict["question"] + + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": q}) + + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt, + tokenizer=self.tokenizer, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["data_source"] = "reasoning_gym" + row_dict["input_ids"] = input_ids[0] + row_dict["attention_mask"] = attention_mask[0] + row_dict["position_ids"] = position_ids[0] + row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False) + row_dict["raw_prompt"] = chat + row_dict["index"] = index + return row_dict + + +def make_dataset( + tokenizer, + data_source: Experiment | ProceduralDataset, + developer_prompt: str, +) -> ReasoningGymDataset: + """ + Create ReasoningGymDataset object using either a ProceduralDataset or Experiment as the underlying data source. + """ + kwargs = { + "tokenizer": tokenizer, + "developer_prompt": developer_prompt, + } + if isinstance(data_source, Experiment): + kwargs["experiment"] = data_source + else: + kwargs["procedural_dataset"] = data_source + return ReasoningGymDataset(**kwargs)