add deps for veRL experiment in README

This commit is contained in:
Andreas Koepf 2025-02-01 21:27:33 +00:00
parent e671b97ab4
commit 3f24df31dc
2 changed files with 37 additions and 20 deletions

19
examples/veRL/README.md Normal file
View file

@ -0,0 +1,19 @@
### env setup
```
conda create --name verl python=3.12 -y
conda activate verl
pip install flash-attn --no-build-isolation
pip install vllm==0.7.0 ray wandb
```
### clone and install veRL
tested with verl HEAD a65c9157bc0b85b64cd753de19f94e80a11bd871
```
git clone https://github.com/volcengine/verl.git
cd verl
pip install -e .
```

View file

@ -20,24 +20,22 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o
"""
from typing import Optional
from omegaconf import OmegaConf, open_dict
import reasoning_gym
from reasoning_gym.utils import extract_answer
import reasoning_gym.utils
from verl import DataProto
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer
import ray
import hydra
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.model import compute_position_id_with_mask
from verl.utils.dataset.rl_dataset import collate_fn
import ray
import torch
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.dataset.rl_dataset import collate_fn
from verl.utils.model import compute_position_id_with_mask
import reasoning_gym
import reasoning_gym.utils
from reasoning_gym.utils import extract_answer
class RewardManager:
@ -262,12 +260,12 @@ class RayPPOTrainerCustom(RayPPOTrainer):
@ray.remote
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
from transformers import AutoTokenizer
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
@ -283,15 +281,15 @@ def main_task(config, compute_score=None):
# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup