mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add EpochTrackingDataLoader to track epochs using trainer's global_steps
This commit is contained in:
parent
8f14c3aba7
commit
5f16d54ebe
1 changed files with 23 additions and 2 deletions
|
|
@ -9,6 +9,25 @@ import torch
|
|||
import verl.utils.torch_functional as verl_F
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .main_ppo_custom_reward_server import RayPPOTrainerCustom
|
||||
|
||||
class EpochTrackingDataLoader(DataLoader):
|
||||
"""DataLoader that tracks epochs based on trainer's global_steps"""
|
||||
|
||||
def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs):
|
||||
super().__init__(dataset, *args, **kwargs)
|
||||
self.trainer = trainer
|
||||
self.steps_per_epoch = len(self) # Number of batches per epoch
|
||||
|
||||
def __iter__(self):
|
||||
# Calculate current epoch from global_steps
|
||||
current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch
|
||||
# Update dataset's epoch counter
|
||||
self.dataset.epoch = current_epoch
|
||||
return super().__iter__()
|
||||
from transformers import PreTrainedTokenizer
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
|
|
@ -234,16 +253,18 @@ class RayPPOTrainerCustom(RayPPOTrainer):
|
|||
return reward_tensor
|
||||
|
||||
def _create_dataloader(self):
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataloader = EpochTrackingDataLoader(
|
||||
dataset=self.train_dataset,
|
||||
trainer=self,
|
||||
batch_size=self.config.data.train_batch_size,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
self.val_dataloader = DataLoader(
|
||||
self.val_dataloader = EpochTrackingDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
trainer=self,
|
||||
batch_size=len(self.val_dataset),
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue