diff --git a/examples/veRL/basic_curriculum/ppo_curriculum.py b/examples/veRL/basic_curriculum/ppo_curriculum.py index 4c583e13..15bc0fa9 100644 --- a/examples/veRL/basic_curriculum/ppo_curriculum.py +++ b/examples/veRL/basic_curriculum/ppo_curriculum.py @@ -8,7 +8,8 @@ import ray import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset +from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer @@ -182,7 +183,7 @@ class RayPPOTrainerCustom(RayPPOTrainer): return reward def _create_dataloader(self): - self.train_dataloader = DataLoader( + self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.train_batch_size, shuffle=True, @@ -190,7 +191,7 @@ class RayPPOTrainerCustom(RayPPOTrainer): collate_fn=collate_fn, ) - self.val_dataloader = DataLoader( + self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=True, diff --git a/examples/veRL/chain_sum/main_ppo_custom_reward.py b/examples/veRL/chain_sum/main_ppo_custom_reward.py index 6c5863f3..dc6418c6 100644 --- a/examples/veRL/chain_sum/main_ppo_custom_reward.py +++ b/examples/veRL/chain_sum/main_ppo_custom_reward.py @@ -7,7 +7,8 @@ import ray import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset +from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer @@ -170,7 +171,7 @@ class RayPPOTrainerCustom(RayPPOTrainer): return reward def _create_dataloader(self): - self.train_dataloader = DataLoader( + self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.train_batch_size, shuffle=True, @@ -178,7 +179,7 @@ class RayPPOTrainerCustom(RayPPOTrainer): collate_fn=collate_fn, ) - self.val_dataloader = DataLoader( + self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=True, diff --git a/examples/veRL/chain_sum/main_ppo_custom_reward_server.py b/examples/veRL/chain_sum/main_ppo_custom_reward_server.py index f97b25e9..5dc5cfa7 100644 --- a/examples/veRL/chain_sum/main_ppo_custom_reward_server.py +++ b/examples/veRL/chain_sum/main_ppo_custom_reward_server.py @@ -8,7 +8,8 @@ import ray import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset +from torchdata.stateful_dataloader import StatefulDataLoader from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer @@ -229,7 +230,7 @@ class RayPPOTrainerCustom(RayPPOTrainer): return reward_tensor def _create_dataloader(self): - self.train_dataloader = DataLoader( + self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.train_batch_size, shuffle=False, @@ -237,7 +238,7 @@ class RayPPOTrainerCustom(RayPPOTrainer): collate_fn=collate_fn, ) - self.val_dataloader = DataLoader( + self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=False,