mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
use StatefulDataLoader in veRL examples (#378)
This commit is contained in:
parent
c760da6b1b
commit
1c6f2d01ee
3 changed files with 12 additions and 9 deletions
|
|
@ -8,7 +8,8 @@ import ray
|
|||
import torch
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data import Dataset
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from transformers import PreTrainedTokenizer
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
|
|
@ -182,7 +183,7 @@ class RayPPOTrainerCustom(RayPPOTrainer):
|
|||
return reward
|
||||
|
||||
def _create_dataloader(self):
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataloader = StatefulDataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=self.config.data.train_batch_size,
|
||||
shuffle=True,
|
||||
|
|
@ -190,7 +191,7 @@ class RayPPOTrainerCustom(RayPPOTrainer):
|
|||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
self.val_dataloader = DataLoader(
|
||||
self.val_dataloader = StatefulDataLoader(
|
||||
dataset=self.val_dataset,
|
||||
batch_size=len(self.val_dataset),
|
||||
shuffle=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue