mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* tweak len reward * first inter-generalisation experiment config * update inter algorithmic config * default to empty config * fix typo * change config to match experiment script * long prompt fixes * algorithmic training config tweaks * imports * update algorithmic training cfgs * first logic composite config * fix dset name * tweaks * fix syllogisms dataset * rm temp print * initial algebra config * algebra cfg tweaks * add gc * add initial games cfg * rename games cfg * fix dset name * fix sokoban metadata * remove boxnet * games cfg tweak
134 lines
5 KiB
Python
134 lines
5 KiB
Python
"""Train an LLM using GRPO over Reasoning Gym procedural dataset(s)."""
|
|
|
|
from dataclasses import replace
|
|
|
|
import hydra
|
|
import ray
|
|
from omegaconf import OmegaConf
|
|
from trainers import RayGRPOTrainer
|
|
from utils import ReasoningGymDataset, make_dataset
|
|
|
|
import reasoning_gym
|
|
import reasoning_gym.utils
|
|
from reasoning_gym.coaching.curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig
|
|
from reasoning_gym.coaching.experiment import CurriculumExperiment
|
|
from reasoning_gym.composite import CompositeDataset, DatasetSpec
|
|
|
|
|
|
def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningGymDataset]:
|
|
"""Prepare training and validation datasets."""
|
|
dataset_size = config.reasoning_gym.dataset_size
|
|
developer_prompt_setting = config.reasoning_gym.developer_prompt
|
|
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS[developer_prompt_setting]
|
|
|
|
if config.curriculum.enabled:
|
|
curricula = config.curriculum.curricula
|
|
curriculum_config = CurriculumExperimentConfig(
|
|
curricula={
|
|
curriculum_name: CurriculumAttributeConfig(**curriculum_config)
|
|
for curriculum_name, curriculum_config in curricula.items()
|
|
}
|
|
)
|
|
|
|
train_data_source = CurriculumExperiment(
|
|
name=config.trainer.experiment_name, config=curriculum_config, size=dataset_size, seed=1
|
|
)
|
|
val_data_source = CompositeDataset(config=replace(train_data_source.composite.config, seed=2))
|
|
else:
|
|
dataset_specs = [
|
|
DatasetSpec(
|
|
name=name,
|
|
weight=ds.weight,
|
|
config=OmegaConf.to_container(ds.config, resolve=True) if "config" in ds else {},
|
|
)
|
|
for name, ds in config.reasoning_gym.datasets.items()
|
|
]
|
|
train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs)
|
|
val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs)
|
|
train_dataset = make_dataset(
|
|
tokenizer, train_data_source, developer_prompt, max_prompt_length=config.data.max_prompt_length
|
|
)
|
|
val_dataset = make_dataset(
|
|
tokenizer, val_data_source, developer_prompt, max_prompt_length=config.data.max_prompt_length
|
|
)
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
@ray.remote
|
|
def main_task(config):
|
|
from pprint import pprint
|
|
|
|
from verl.utils import hf_tokenizer
|
|
from verl.utils.fs import copy_local_path_from_hdfs
|
|
|
|
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
|
OmegaConf.resolve(config)
|
|
|
|
# download the checkpoint from hdfs
|
|
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
|
|
|
# instantiate tokenizer
|
|
tokenizer = hf_tokenizer(local_path)
|
|
|
|
# define worker classes
|
|
if config.actor_rollout_ref.actor.strategy == "fsdp":
|
|
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
from verl.single_controller.ray import RayWorkerGroup
|
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
|
|
|
|
ray_worker_group_cls = RayWorkerGroup
|
|
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
|
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
|
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
|
|
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
|
|
|
ray_worker_group_cls = NVMegatronRayWorkerGroup
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
|
|
|
role_worker_mapping = {
|
|
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
|
Role.Critic: ray.remote(CriticWorker),
|
|
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
|
|
}
|
|
|
|
global_pool_id = "global_pool"
|
|
resource_pool_spec = {
|
|
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
|
}
|
|
mapping = {
|
|
Role.ActorRollout: global_pool_id,
|
|
Role.Critic: global_pool_id,
|
|
Role.RefPolicy: global_pool_id,
|
|
}
|
|
|
|
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
|
|
|
train_dataset, val_dataset = prepare_datasets(config, tokenizer)
|
|
|
|
trainer = RayGRPOTrainer(
|
|
config=config,
|
|
tokenizer=tokenizer,
|
|
train_dataset=train_dataset,
|
|
val_dataset=val_dataset,
|
|
role_worker_mapping=role_worker_mapping,
|
|
resource_pool_manager=resource_pool_manager,
|
|
ray_worker_group_cls=ray_worker_group_cls,
|
|
max_output_length=config.data.max_response_length,
|
|
)
|
|
trainer.init_workers()
|
|
trainer.fit()
|
|
|
|
|
|
@hydra.main(config_path="configs", config_name="llama3.1_1b_grpo", version_base=None)
|
|
def main(config):
|
|
if not ray.is_initialized():
|
|
# this is for local ray cluster
|
|
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
|
|
ray.get(main_task.remote(config))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|