feat: Add EpochTrackingDataLoader to track epochs using trainer's global_steps

This commit is contained in:
Andreas Koepf (aider) 2025-02-22 21:11:55 +00:00
parent 8f14c3aba7
commit 5f16d54ebe

View file

@ -9,6 +9,25 @@ import torch
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .main_ppo_custom_reward_server import RayPPOTrainerCustom
class EpochTrackingDataLoader(DataLoader):
"""DataLoader that tracks epochs based on trainer's global_steps"""
def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
self.trainer = trainer
self.steps_per_epoch = len(self) # Number of batches per epoch
def __iter__(self):
# Calculate current epoch from global_steps
current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch
# Update dataset's epoch counter
self.dataset.epoch = current_epoch
return super().__iter__()
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
@ -234,16 +253,18 @@ class RayPPOTrainerCustom(RayPPOTrainer):
return reward_tensor
def _create_dataloader(self):
self.train_dataloader = DataLoader(
self.train_dataloader = EpochTrackingDataLoader(
dataset=self.train_dataset,
trainer=self,
batch_size=self.config.data.train_batch_size,
shuffle=False,
drop_last=True,
collate_fn=collate_fn,
)
self.val_dataloader = DataLoader(
self.val_dataloader = EpochTrackingDataLoader(
dataset=self.val_dataset,
trainer=self,
batch_size=len(self.val_dataset),
shuffle=False,
drop_last=True,