diff --git a/examples/veRL/main_ppo_custom_reward_server.py b/examples/veRL/main_ppo_custom_reward_server.py index b14fafdc..f5757fb0 100644 --- a/examples/veRL/main_ppo_custom_reward_server.py +++ b/examples/veRL/main_ppo_custom_reward_server.py @@ -9,6 +9,25 @@ import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict from torch.utils.data import DataLoader, Dataset +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .main_ppo_custom_reward_server import RayPPOTrainerCustom + +class EpochTrackingDataLoader(DataLoader): + """DataLoader that tracks epochs based on trainer's global_steps""" + + def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs): + super().__init__(dataset, *args, **kwargs) + self.trainer = trainer + self.steps_per_epoch = len(self) # Number of batches per epoch + + def __iter__(self): + # Calculate current epoch from global_steps + current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch + # Update dataset's epoch counter + self.dataset.epoch = current_epoch + return super().__iter__() from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer @@ -234,16 +253,18 @@ class RayPPOTrainerCustom(RayPPOTrainer): return reward_tensor def _create_dataloader(self): - self.train_dataloader = DataLoader( + self.train_dataloader = EpochTrackingDataLoader( dataset=self.train_dataset, + trainer=self, batch_size=self.config.data.train_batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn, ) - self.val_dataloader = DataLoader( + self.val_dataloader = EpochTrackingDataLoader( dataset=self.val_dataset, + trainer=self, batch_size=len(self.val_dataset), shuffle=False, drop_last=True,