use StatefulDataLoader in veRL examples (#378)

This commit is contained in:
Oliver Stanley 2025-03-17 06:28:10 +00:00 committed by GitHub
parent c760da6b1b
commit 1c6f2d01ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 12 additions and 9 deletions

View file

@ -8,7 +8,8 @@ import ray
import torch
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
@ -182,7 +183,7 @@ class RayPPOTrainerCustom(RayPPOTrainer):
return reward
def _create_dataloader(self):
self.train_dataloader = DataLoader(
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
shuffle=True,
@ -190,7 +191,7 @@ class RayPPOTrainerCustom(RayPPOTrainer):
collate_fn=collate_fn,
)
self.val_dataloader = DataLoader(
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,