mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
fix: Move EpochTrackingDataLoader after ReasoningGymDataset to resolve undefined name error
This commit is contained in:
parent
5f16d54ebe
commit
8dc6cb5228
4 changed files with 24 additions and 30 deletions
|
|
@ -9,25 +9,6 @@ import torch
|
|||
import verl.utils.torch_functional as verl_F
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .main_ppo_custom_reward_server import RayPPOTrainerCustom
|
||||
|
||||
class EpochTrackingDataLoader(DataLoader):
|
||||
"""DataLoader that tracks epochs based on trainer's global_steps"""
|
||||
|
||||
def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs):
|
||||
super().__init__(dataset, *args, **kwargs)
|
||||
self.trainer = trainer
|
||||
self.steps_per_epoch = len(self) # Number of batches per epoch
|
||||
|
||||
def __iter__(self):
|
||||
# Calculate current epoch from global_steps
|
||||
current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch
|
||||
# Update dataset's epoch counter
|
||||
self.dataset.epoch = current_epoch
|
||||
return super().__iter__()
|
||||
from transformers import PreTrainedTokenizer
|
||||
from verl import DataProto
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
|
||||
|
|
@ -94,10 +75,7 @@ class ReasoningGymDataset(Dataset):
|
|||
if batch_idx not in self._batch_cache:
|
||||
base_index = batch_idx * self.batch_size
|
||||
response = self.client.get_batch(
|
||||
self.dataset_name,
|
||||
base_index=base_index,
|
||||
batch_size=self.batch_size,
|
||||
epoch=self.epoch
|
||||
self.dataset_name, base_index=base_index, batch_size=self.batch_size, epoch=self.epoch
|
||||
)
|
||||
self._batch_cache[batch_idx] = response.entries
|
||||
|
||||
|
|
@ -152,6 +130,22 @@ class ReasoningGymDataset(Dataset):
|
|||
return row_dict
|
||||
|
||||
|
||||
class EpochTrackingDataLoader(DataLoader):
|
||||
"""DataLoader that tracks epochs based on trainer's global_steps"""
|
||||
|
||||
def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs):
|
||||
super().__init__(dataset, *args, **kwargs)
|
||||
self.trainer = trainer
|
||||
self.steps_per_epoch = len(self) # Number of batches per epoch
|
||||
|
||||
def __iter__(self):
|
||||
# Calculate current epoch from global_steps
|
||||
current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch
|
||||
# Update dataset's epoch counter
|
||||
self.dataset.epoch = current_epoch
|
||||
return super().__iter__()
|
||||
|
||||
|
||||
class RayPPOTrainerCustom(RayPPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
levels=list(range(2, 8)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
|
|
@ -132,7 +132,7 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 4, 10],
|
||||
levels=list(range(1, 10)),
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class ProductsCurriculum(BaseCurriculum):
|
|||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
levels=list(range(2, 8)),
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
|
|
@ -124,7 +124,7 @@ class ProductsCurriculum(BaseCurriculum):
|
|||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 3, 4],
|
||||
levels=list(range(1, 10)),
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue