From 8dc6cb52281f290610c16e348d7e730fd0bfde1f Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Sat, 22 Feb 2025 21:12:15 +0000 Subject: [PATCH] fix: Move EpochTrackingDataLoader after ReasoningGymDataset to resolve undefined name error --- .../veRL/main_ppo_custom_reward_server.py | 40 ++++++++----------- reasoning_gym/arithmetic/chain_sum.py | 4 +- reasoning_gym/arithmetic/products.py | 4 +- tools/server/server.py | 6 +-- 4 files changed, 24 insertions(+), 30 deletions(-) diff --git a/examples/veRL/main_ppo_custom_reward_server.py b/examples/veRL/main_ppo_custom_reward_server.py index f5757fb0..4b0e31ba 100644 --- a/examples/veRL/main_ppo_custom_reward_server.py +++ b/examples/veRL/main_ppo_custom_reward_server.py @@ -9,25 +9,6 @@ import torch import verl.utils.torch_functional as verl_F from omegaconf import OmegaConf, open_dict from torch.utils.data import DataLoader, Dataset -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .main_ppo_custom_reward_server import RayPPOTrainerCustom - -class EpochTrackingDataLoader(DataLoader): - """DataLoader that tracks epochs based on trainer's global_steps""" - - def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs): - super().__init__(dataset, *args, **kwargs) - self.trainer = trainer - self.steps_per_epoch = len(self) # Number of batches per epoch - - def __iter__(self): - # Calculate current epoch from global_steps - current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch - # Update dataset's epoch counter - self.dataset.epoch = current_epoch - return super().__iter__() from transformers import PreTrainedTokenizer from verl import DataProto from verl.trainer.ppo.ray_trainer import RayPPOTrainer @@ -94,10 +75,7 @@ class ReasoningGymDataset(Dataset): if batch_idx not in self._batch_cache: base_index = batch_idx * self.batch_size response = self.client.get_batch( - self.dataset_name, - base_index=base_index, - batch_size=self.batch_size, - epoch=self.epoch + self.dataset_name, base_index=base_index, batch_size=self.batch_size, epoch=self.epoch ) self._batch_cache[batch_idx] = response.entries @@ -152,6 +130,22 @@ class ReasoningGymDataset(Dataset): return row_dict +class EpochTrackingDataLoader(DataLoader): + """DataLoader that tracks epochs based on trainer's global_steps""" + + def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs): + super().__init__(dataset, *args, **kwargs) + self.trainer = trainer + self.steps_per_epoch = len(self) # Number of batches per epoch + + def __iter__(self): + # Calculate current epoch from global_steps + current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch + # Update dataset's epoch counter + self.dataset.epoch = current_epoch + return super().__iter__() + + class RayPPOTrainerCustom(RayPPOTrainer): def __init__( self, diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 2072983c..a0fd2e1f 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -122,7 +122,7 @@ class ChainSumCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="num_terms", - levels=[2, 3, 4, 5], + levels=list(range(2, 8)), default_level=0, # Start with 2 terms description="Maximum number of terms in the expression", attr_type=AttributeType.APPEND, @@ -132,7 +132,7 @@ class ChainSumCurriculum(BaseCurriculum): ), RangeAttributeDefinition( name="num_digits", - levels=[1, 2, 4, 10], + levels=list(range(1, 10)), default_level=0, # Start with 1-digit numbers description="Number of digits in each operand", attr_type=AttributeType.APPEND, diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 8401be91..20289ccc 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -114,7 +114,7 @@ class ProductsCurriculum(BaseCurriculum): self._define_attributes( RangeAttributeDefinition( name="num_terms", - levels=[2, 3, 4, 5], + levels=list(range(2, 8)), default_level=0, # Start with 2 terms description="Maximum number of terms in the expression", attr_type=AttributeType.APPEND, @@ -124,7 +124,7 @@ class ProductsCurriculum(BaseCurriculum): ), RangeAttributeDefinition( name="num_digits", - levels=[1, 2, 3, 4], + levels=list(range(1, 10)), default_level=0, # Start with 1-digit numbers description="Number of digits in each operand", attr_type=AttributeType.APPEND, diff --git a/tools/server/server.py b/tools/server/server.py index 47e79cd7..35615942 100644 --- a/tools/server/server.py +++ b/tools/server/server.py @@ -76,12 +76,12 @@ def create_app(config: ServerConfig) -> FastAPI: def permute_index(idx: int, epoch_seed: int, dataset_size: int) -> int: """Generate a deterministic permuted index without materializing full permutation. - + Args: idx: Original index to permute epoch_seed: Seed for this epoch's permutation dataset_size: Size of the dataset - + Returns: Permuted index in range [0, dataset_size) """ @@ -107,7 +107,7 @@ def create_app(config: ServerConfig) -> FastAPI: dataset_size = len(experiment.dataset) base_seed = experiment.config.seed if experiment.config.seed is not None else 0 epoch_seed = base_seed + (epoch * dataset_size) - + entries = [] for i in range(base_index, base_index + batch_size): # Get permuted index for this position