diff --git a/reasoning_gym/arc/arc_1d.py b/reasoning_gym/arc/arc_1d.py index 7e399f20..ce594ec4 100644 --- a/reasoning_gym/arc/arc_1d.py +++ b/reasoning_gym/arc/arc_1d.py @@ -58,27 +58,27 @@ class Arc1DDataset(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = Random(self.seed + idx) + rng = Random(self.seed + idx) # Select random task - task_name = item_rng.choice(self.task_names) + task_name = rng.choice(self.task_names) task_func, task_kwargs = self.ARC_1D_TASKS[task_name] # Generate training examples train_examples = [] - size = item_rng.randint(self.config.min_size, self.config.max_size) + size = rng.randint(self.config.min_size, self.config.max_size) for _ in range(self.config.num_train): example = None while example is None: - example = task_func(item_rng, size, **task_kwargs) + example = task_func(rng, size, **task_kwargs) train_examples.append(example) # Generate test example test_example = None while test_example is None: - test_example = task_func(item_rng, size, **task_kwargs) + test_example = task_func(rng, size, **task_kwargs) # Format question question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n" diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index cbfee4b1..495a79c5 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset -from .chain_sum import ChainSum, ChainSumConfig +from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset @@ -14,13 +14,13 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset -from .products import Products, ProductsConfig +from .products import ProductsConfig, ProductsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", "BasicArithmeticDatasetConfig", - "ChainSum", + "ChainSumDataset", "ChainSumConfig", "CalendarArithmeticConfig", "CalendarArithmeticDataset", @@ -36,7 +36,7 @@ __all__ = [ "PowerFunctionDataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", - "Products", + "ProductsDataset", "ProductsConfig", "GSMSymbolicDatasetConfig", "GSMSymbolicDataset", diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 9ec096ee..156314d9 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -78,17 +78,17 @@ class BasicArithmeticDataset(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = Random(self.seed + idx) + rng = Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) if self.config.allow_parentheses: - expression, result = self._generate_complex_task(item_rng, num_terms, num_digits) + expression, result = self._generate_complex_task(rng, num_terms, num_digits) else: - expression, result = self._generate_simple_task(item_rng, num_terms, num_digits) + expression, result = self._generate_simple_task(rng, num_terms, num_digits) - question = self._format_question(item_rng, expression) + question = self._format_question(rng, expression) return { "question": question, diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index bf12211c..3a052590 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -122,9 +122,9 @@ class CalendarArithmeticDataset(ProceduralDataset): self.tasks = [self.task_handlers[task] for task in self.config.tasks] def __getitem__(self, idx: int) -> dict: - item_rng = random.Random(self.seed + idx) - task = item_rng.choice(self.tasks) - question, answer, metadata = task(item_rng) + rng = random.Random(self.seed + idx) + task = rng.choice(self.tasks) + question, answer, metadata = task(rng) return { "question": question, "answer": str(answer), diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 969df820..6d2e43e6 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -32,7 +32,7 @@ class ChainSumConfig: assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" -class ChainSum(ProceduralDataset): +class ChainSumDataset(ProceduralDataset): """Generates simple arithmetic tasks using only + and - operators""" def __init__(self, config: ChainSumConfig): @@ -51,16 +51,16 @@ class ChainSum(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits - expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { "question": f"{expression} =", @@ -143,4 +143,4 @@ class ChainSumCurriculum(BaseCurriculum): # Register the dataset -register_dataset("chain_sum", ChainSum, ChainSumConfig) +register_dataset("chain_sum", ChainSumDataset, ChainSumConfig) diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 62ae291c..742696ec 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -26,7 +26,7 @@ class ProductsConfig: assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" -class Products(ProceduralDataset): +class ProductsDataset(ProceduralDataset): """Generates multiplication tasks with configurable number of terms""" def __init__(self, config: ProductsConfig): @@ -45,16 +45,16 @@ class Products(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits - expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { "question": f"{expression} =", @@ -127,4 +127,4 @@ class ProductsCurriculum(BaseCurriculum): # Register the dataset -register_dataset("products", Products, ProductsConfig) +register_dataset("products", ProductsDataset, ProductsConfig) diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index 1b296d02..f4011ba7 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -82,14 +82,14 @@ class TimeIntervalsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single time interval calculation task""" - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) # Randomly choose task type from config - task_type = item_rng.choice(self.config.task_types) + task_type = rng.choice(self.config.task_types) - start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type) + start_time, end_time, format_str, expected_format = self._generate_times(rng, task_type) - template = item_rng.choice(self.TEMPLATES) + template = rng.choice(self.TEMPLATES) question = template.format(start=start_time, end=end_time, format=expected_format) # Calculate the actual difference diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index 36b0185c..eedbaae3 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum @@ -18,8 +18,8 @@ def test_chain_sum_config_validation(): def test_chain_sum_deterministic(): """Test that dataset generates same items with same seed""" config = ChainSumConfig(seed=42, size=10) - dataset1 = ChainSum(config) - dataset2 = ChainSum(config) + dataset1 = ChainSumDataset(config) + dataset2 = ChainSumDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -28,7 +28,7 @@ def test_chain_sum_deterministic(): def test_chain_sum_items(): """Test basic properties of generated items""" config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -57,7 +57,7 @@ def test_chain_sum_number_ranges(): size=50, seed=42, ) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -71,7 +71,7 @@ def test_chain_sum_number_ranges(): # Test 1-digit numbers config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -88,7 +88,7 @@ def test_chain_sum_negation(): config = ChainSumConfig( min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True ) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) # Track if we see both positive and negative numbers has_positive = False @@ -112,7 +112,7 @@ def test_chain_sum_negation(): def test_chain_sum_iteration(): """Test that iteration respects dataset size""" config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing - dataset = ChainSum(config) + dataset = ChainSumDataset(config) # Test manual iteration items = [] diff --git a/tests/test_coaching.py b/tests/test_coaching.py index 83b56768..3b1b97d7 100644 --- a/tests/test_coaching.py +++ b/tests/test_coaching.py @@ -5,7 +5,7 @@ from pathlib import Path import pytest -from reasoning_gym.arithmetic.chain_sum import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.leg_counting import LegCountingConfig from reasoning_gym.coaching import Coach, GroupedScores from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec @@ -14,7 +14,7 @@ from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSp def test_coach_with_chain_sum(): # Create a small ChainSum dataset config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) coach = Coach(dataset) # Simulate an agent working on tasks @@ -208,7 +208,7 @@ def test_coach_score_logging(tmp_path): # Create dataset and coach with logging config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) coach = Coach(dataset, score_log=log_file) # Score a few answers diff --git a/tests/test_products.py b/tests/test_products.py index aac77a98..34ff1623 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic import Products, ProductsConfig +from reasoning_gym.arithmetic import ProductsConfig, ProductsDataset from reasoning_gym.arithmetic.products import ProductsCurriculum @@ -18,8 +18,8 @@ def test_products_config_validation(): def test_products_deterministic(): """Test that dataset generates same items with same seed""" config = ProductsConfig(seed=42, size=10) - dataset1 = Products(config) - dataset2 = Products(config) + dataset1 = ProductsDataset(config) + dataset2 = ProductsDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -28,7 +28,7 @@ def test_products_deterministic(): def test_products_items(): """Test basic properties of generated items""" config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) - dataset = Products(config) + dataset = ProductsDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -57,7 +57,7 @@ def test_products_number_ranges(): size=50, seed=42, ) - dataset = Products(config) + dataset = ProductsDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -68,7 +68,7 @@ def test_products_number_ranges(): # Test 1-digit numbers config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) - dataset = Products(config) + dataset = ProductsDataset(config) for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -80,7 +80,7 @@ def test_products_number_ranges(): def test_products_iteration(): """Test that iteration respects dataset size""" config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing - dataset = Products(config) + dataset = ProductsDataset(config) # Test manual iteration items = [] @@ -101,18 +101,18 @@ def test_products_iteration(): def test_products_scoring(): """Test that scoring works correctly""" config = ProductsConfig(min_terms=2, max_terms=2, size=10, seed=42) - dataset = Products(config) + dataset = ProductsDataset(config) # Test scoring with exact match item = dataset[0] assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0" - + # Test scoring with wrong answer assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01" - + # Test scoring with partial match (answer contained in response) assert dataset.score_answer(f"The answer is {item['answer']}", item) == 0.5, "Partial match should score 0.5" - + # Test scoring with None assert dataset.score_answer(None, item) == 0.0, "None should score 0.0"