reasoning-gym/training/train_grpo.py
Oliver Stanley 224532f12a
first inter-domain generalisation experiments (#412)
* tweak len reward

* first inter-generalisation experiment config

* update inter algorithmic config

* default to empty config

* fix typo

* change config to match experiment script

* long prompt fixes

* algorithmic training config tweaks

* imports

* update algorithmic training cfgs

* first logic composite config

* fix dset name

* tweaks

* fix syllogisms dataset

* rm temp print

* initial algebra config

* algebra cfg tweaks

* add gc

* add initial games cfg

* rename games cfg

* fix dset name

* fix sokoban metadata

* remove boxnet

* games cfg tweak
2025-04-14 21:06:40 +01:00

134 lines
5 KiB
Python

"""Train an LLM using GRPO over Reasoning Gym procedural dataset(s)."""
from dataclasses import replace
import hydra
import ray
from omegaconf import OmegaConf
from trainers import RayGRPOTrainer
from utils import ReasoningGymDataset, make_dataset
import reasoning_gym
import reasoning_gym.utils
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
from reasoning_gym.coaching.experiment import CurriculumExperiment
from reasoning_gym.composite import CompositeDataset, DatasetSpec
def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningGymDataset]:
"""Prepare training and validation datasets."""
dataset_size = config.reasoning_gym.dataset_size
developer_prompt_setting = config.reasoning_gym.developer_prompt
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS[developer_prompt_setting]
if config.curriculum.enabled:
curricula = config.curriculum.curricula
curriculum_config = CurriculumExperimentConfig(
curricula={
curriculum_name: CurriculumAttributeConfig(**curriculum_config)
for curriculum_name, curriculum_config in curricula.items()
}
)
train_data_source = CurriculumExperiment(
name=config.trainer.experiment_name, config=curriculum_config, size=dataset_size, seed=1
)
val_data_source = CompositeDataset(config=replace(train_data_source.composite.config, seed=2))
else:
dataset_specs = [
DatasetSpec(
name=name,
weight=ds.weight,
config=OmegaConf.to_container(ds.config, resolve=True) if "config" in ds else {},
)
for name, ds in config.reasoning_gym.datasets.items()
]
train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs)
val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs)
train_dataset = make_dataset(
tokenizer, train_data_source, developer_prompt, max_prompt_length=config.data.max_prompt_length
)
val_dataset = make_dataset(
tokenizer, val_data_source, developer_prompt, max_prompt_length=config.data.max_prompt_length
)
return train_dataset, val_dataset
@ray.remote
def main_task(config):
from pprint import pprint
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
train_dataset, val_dataset = prepare_datasets(config, tokenizer)
trainer = RayGRPOTrainer(
config=config,
tokenizer=tokenizer,
train_dataset=train_dataset,
val_dataset=val_dataset,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
max_output_length=config.data.max_response_length,
)
trainer.init_workers()
trainer.fit()
@hydra.main(config_path="configs", config_name="llama3.1_1b_grpo", version_base=None)
def main(config):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(config))
if __name__ == "__main__":
main()