diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py index 6335b6b7..84e01d9b 100644 --- a/reasoning_gym/algebra/intermediate_integration.py +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -76,6 +76,11 @@ class IntermediateIntegrationDataset(ProceduralDataset): "Calculate the antiderivative: ∫ {integrand} dx", "Evaluate the indefinite integral: ∫ {integrand} dx", ] + self.added_instruction = """ +In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems +## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. +""" def _get_outer_constant(self, rng: random.Random) -> int: """Helper to generate signed outer constant from config""" @@ -222,9 +227,10 @@ class IntermediateIntegrationDataset(ProceduralDataset): answer = sympy.integrate(integrand, x) answer_str = str(answer) + " + C" + question = rng.choice(self.prompt_template).format(integrand=integrand) + self.added_instruction return { - "question": rng.choice(self.prompt_template).format(integrand=integrand), + "question": question, "answer": answer_str, "metadata": { "integrand": str(integrand), diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index 058d5dbb..cd4842ee 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -62,6 +62,14 @@ class PolynomialEquationsDataset(ProceduralDataset): "Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0", "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", ] + self.added_instruction = """ +In solving the equations, please abide by the following instruction: +## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc. +## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005". +## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "". +## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers. +## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p. +""" super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: @@ -89,19 +97,20 @@ class PolynomialEquationsDataset(ProceduralDataset): for sol in solutions: if sol.is_real: # Evaluate symbolic solution to a floating approximation - real_solutions.append(float(sol.evalf())) + real_solutions.append(round(float(sol.evalf()), 4)) if len(real_solutions) > 0: real_solutions.sort() break answer_str = ", ".join(str(x) for x in real_solutions) + question = ( + rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded) + + self.added_instruction + ) return { - "question": rng.choice(self._prompt_templates).format( - variable=variable, - polynomial_expanded=polynomial_expanded, - ), + "question": question, "answer": answer_str, "metadata": { "polynomial_expr": str(polynomial_expanded), diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index 9a74679f..6076c32a 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -61,6 +61,11 @@ class PolynomialMultiplicationDataset(ProceduralDataset): "Simplify this expression: {polynomial_expr}", "Calculate the following: {polynomial_expr}", ] + self.added_instruction = """ +In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems +## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. +""" super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: @@ -79,11 +84,10 @@ class PolynomialMultiplicationDataset(ProceduralDataset): polynomial_expr = sp.prod(polynomials) product = sp.expand(polynomial_expr) + question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction return { - "question": rng.choice(self._prompt_templates).format( - polynomial_expr=polynomial_expr, - ), + "question": question, "answer": product, "metadata": { "polynomial_expr": str(polynomial_expr), diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py index a8ca3be2..8dfa775b 100644 --- a/reasoning_gym/algebra/simple_integration.py +++ b/reasoning_gym/algebra/simple_integration.py @@ -41,6 +41,11 @@ class SimpleIntegrationDataset(ProceduralDataset): "Calculate the antiderivative: ∫ {integrand} dx", "Evaluate the indefinite integral: ∫ {integrand} dx", ] + self.added_instruction = """ +In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems +## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. +""" super().__init__(config=config, seed=config.seed, size=config.size) def _generate_coefficient(self, rng: random.Random) -> Fraction: @@ -69,9 +74,10 @@ class SimpleIntegrationDataset(ProceduralDataset): rng = random.Random(self.seed + idx) symbol, polynomial = self._generate_polynomial(rng) derivative = sympy.diff(polynomial, symbol) + question = rng.choice(self._prompt_templates).format(integrand=derivative) + self.added_instruction return { - "question": rng.choice(self._prompt_templates).format(integrand=derivative), + "question": question, "answer": str(polynomial) + " + C", "metadata": { "integrand": str(derivative), diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 3d7c8d0b..d4124b86 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -26,7 +26,6 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset -from .string_insertion import StringInsertionConfig, StringInsertionDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index d922aa74..f906d230 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -34,6 +34,11 @@ class NumberSortingDataset(ProceduralDataset): def __init__(self, config: NumberSortingConfig): super().__init__(config=config, seed=config.seed, size=config.size) + self.added_instruction = """ +Please follow the instruction below: +## 1. Let all your answers be a list of numbers. Instead of reporting your answer as -69, -13, 1, 7, 11, 43, 59, 61, use ['-69', '-13', '1', '7', '11', '43', '59', '61'] instead +## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61'] +""" def _format_number(self, num: float, decimals: int) -> str: """Format number with specified decimal places""" @@ -78,9 +83,10 @@ class NumberSortingDataset(ProceduralDataset): is_ascending = rng.choice([True, False]) direction = "ascending" if is_ascending else "descending" answer = asc_answer if is_ascending else desc_answer + question = f"Sort these numbers in {direction} order: {', '.join(number_strs)}" + self.added_instruction return { - "question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", + "question": question, "answer": str(answer), "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer}, } diff --git a/reasoning_gym/arc/arc_1d.py b/reasoning_gym/arc/arc_1d.py index 7e399f20..ce594ec4 100644 --- a/reasoning_gym/arc/arc_1d.py +++ b/reasoning_gym/arc/arc_1d.py @@ -58,27 +58,27 @@ class Arc1DDataset(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = Random(self.seed + idx) + rng = Random(self.seed + idx) # Select random task - task_name = item_rng.choice(self.task_names) + task_name = rng.choice(self.task_names) task_func, task_kwargs = self.ARC_1D_TASKS[task_name] # Generate training examples train_examples = [] - size = item_rng.randint(self.config.min_size, self.config.max_size) + size = rng.randint(self.config.min_size, self.config.max_size) for _ in range(self.config.num_train): example = None while example is None: - example = task_func(item_rng, size, **task_kwargs) + example = task_func(rng, size, **task_kwargs) train_examples.append(example) # Generate test example test_example = None while test_example is None: - test_example = task_func(item_rng, size, **task_kwargs) + test_example = task_func(rng, size, **task_kwargs) # Format question question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n" diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index cc94bf56..495a79c5 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset -from .chain_sum import ChainSum, ChainSumConfig +from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset @@ -14,12 +14,13 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset +from .products import ProductsConfig, ProductsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", "BasicArithmeticDatasetConfig", - "ChainSum", + "ChainSumDataset", "ChainSumConfig", "CalendarArithmeticConfig", "CalendarArithmeticDataset", @@ -31,8 +32,12 @@ __all__ = [ "LCMDataset", "LegCountingConfig", "LegCountingDataset", + "PowerFunctionConfig", + "PowerFunctionDataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", + "ProductsDataset", + "ProductsConfig", "GSMSymbolicDatasetConfig", "GSMSymbolicDataset", "TimeIntervalsConfig", diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 9ec096ee..156314d9 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -78,17 +78,17 @@ class BasicArithmeticDataset(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = Random(self.seed + idx) + rng = Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) if self.config.allow_parentheses: - expression, result = self._generate_complex_task(item_rng, num_terms, num_digits) + expression, result = self._generate_complex_task(rng, num_terms, num_digits) else: - expression, result = self._generate_simple_task(item_rng, num_terms, num_digits) + expression, result = self._generate_simple_task(rng, num_terms, num_digits) - question = self._format_question(item_rng, expression) + question = self._format_question(rng, expression) return { "question": question, diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index bf12211c..3a052590 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -122,9 +122,9 @@ class CalendarArithmeticDataset(ProceduralDataset): self.tasks = [self.task_handlers[task] for task in self.config.tasks] def __getitem__(self, idx: int) -> dict: - item_rng = random.Random(self.seed + idx) - task = item_rng.choice(self.tasks) - question, answer, metadata = task(item_rng) + rng = random.Random(self.seed + idx) + task = rng.choice(self.tasks) + question, answer, metadata = task(rng) return { "question": question, "answer": str(answer), diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 969df820..6d2e43e6 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -32,7 +32,7 @@ class ChainSumConfig: assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" -class ChainSum(ProceduralDataset): +class ChainSumDataset(ProceduralDataset): """Generates simple arithmetic tasks using only + and - operators""" def __init__(self, config: ChainSumConfig): @@ -51,16 +51,16 @@ class ChainSum(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits - expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { "question": f"{expression} =", @@ -143,4 +143,4 @@ class ChainSumCurriculum(BaseCurriculum): # Register the dataset -register_dataset("chain_sum", ChainSum, ChainSumConfig) +register_dataset("chain_sum", ChainSumDataset, ChainSumConfig) diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py new file mode 100644 index 00000000..742696ec --- /dev/null +++ b/reasoning_gym/arithmetic/products.py @@ -0,0 +1,130 @@ +import random +from dataclasses import dataclass +from typing import Optional + +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class ProductsConfig: + """Configuration for products task generation""" + + min_terms: int = 2 + max_terms: int = 2 + min_digits: int = 1 + max_digits: int = 5 + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + + +class ProductsDataset(ProceduralDataset): + """Generates multiplication tasks with configurable number of terms""" + + def __init__(self, config: ProductsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single multiplication task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted multiplication expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + # Create deterministic RNG from base seed and idx + rng = random.Random(self.seed + idx) + + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) + + # Calculate value ranges based on number of digits + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit + max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits + + expression, result = self._generate_task(rng, num_terms, min_value, max_value) + + return { + "question": f"{expression} =", + "answer": str(result), + "metadata": { + "difficulty": { + "num_terms": num_terms, + "num_digits": num_digits, + }, + "expression": expression, + }, + } + + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: + """Generate a multiplication task + + Args: + rng: Random number generator + num_terms: Number of terms in the expression + min_value: Minimum value for generated numbers + max_value: Maximum value for generated numbers + + Returns: + Tuple of (expression string, result integer) + """ + # Generate random numbers within the specified range + constants = [rng.randint(min_value, max_value) for _ in range(num_terms)] + + # Build expression and compute result + expression_parts = [] + result = constants[0] + + expression_parts.append(str(constants[0])) + for i in range(1, len(constants)): + expression_parts.append("*") + expression_parts.append(str(constants[i])) + result *= constants[i] + + expression = " ".join(expression_parts) + return expression, result + + +class ProductsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ProductsCurriculum.__name__, ProductsConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 terms + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, # Ensure at least 2 terms + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 3, 4], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure numbers are at least 1 digit + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + ) + + +# Register the dataset +register_dataset("products", ProductsDataset, ProductsConfig) diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index 1b296d02..f4011ba7 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -82,14 +82,14 @@ class TimeIntervalsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single time interval calculation task""" - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) # Randomly choose task type from config - task_type = item_rng.choice(self.config.task_types) + task_type = rng.choice(self.config.task_types) - start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type) + start_time, end_time, format_str, expected_format = self._generate_times(rng, task_type) - template = item_rng.choice(self.TEMPLATES) + template = rng.choice(self.TEMPLATES) question = template.format(start=start_time, end=end_time, format=expected_format) # Calculate the actual difference diff --git a/reasoning_gym/games/game_of_life.py b/reasoning_gym/games/game_of_life.py index c8cdc0d1..2b5e369b 100644 --- a/reasoning_gym/games/game_of_life.py +++ b/reasoning_gym/games/game_of_life.py @@ -1,6 +1,7 @@ +import json from dataclasses import dataclass from random import Random -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional import cellpylib as cpl @@ -11,8 +12,8 @@ from ..factory import ProceduralDataset, register_dataset class GameOfLifeConfig: """Configuration for sudoku puzzle generation""" - grid_size_x: int = 20 - grid_size_y: int = 20 + grid_size_x: int = 10 + grid_size_y: int = 10 filled_cells: int = 100 # actually a max simulation_steps: int = 1 seed: Optional[int] = None @@ -31,7 +32,7 @@ class GameOfLifeDataset(ProceduralDataset): def __init__(self, config: GameOfLifeConfig): self._prompt_templates = [ - "What will this Game of Life board look like after {simulation_steps} steps of simulation?\n\n{board}" + "What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}." ] super().__init__(config=config, seed=config.seed, size=config.size) @@ -59,11 +60,18 @@ class GameOfLifeDataset(ProceduralDataset): # Simulate the result to get the answer evolved = cpl.evolve2d( - board, timesteps=self.config.simulation_steps + 1, apply_rule=cpl.game_of_life_rule, memoize="recursive" + board, + timesteps=self.config.simulation_steps + 1, + apply_rule=cpl.game_of_life_rule, + memoize="recursive", ) - board_str = str(board[0]) - result_str = str(evolved[-1]) + rows = [json.dumps(board[0, i].tolist(), separators=(",", ":")) for i in range(board.shape[1])] + board_str = "[" + ", \n ".join(rows) + "]" + + final_step = evolved[-1] + final_step_list = final_step.tolist() + result_str = json.dumps(final_step_list, separators=(",", ":")) return { "question": rng.choice(self._prompt_templates).format( @@ -93,10 +101,17 @@ class GameOfLifeDataset(ProceduralDataset): if answer == None: return 0.0 - if answer.replace("\n", "") != entry["answer"].replace("\n", ""): + + try: + ans_arr = json.loads(answer) + correct_arr = json.loads(entry["answer"]) + + if correct_arr != ans_arr: + return 0.01 + else: + return 1.0 # Yay + except Exception as e: return 0.01 - else: - return 1.0 # Yay register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig) diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index 36b0185c..eedbaae3 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum @@ -18,8 +18,8 @@ def test_chain_sum_config_validation(): def test_chain_sum_deterministic(): """Test that dataset generates same items with same seed""" config = ChainSumConfig(seed=42, size=10) - dataset1 = ChainSum(config) - dataset2 = ChainSum(config) + dataset1 = ChainSumDataset(config) + dataset2 = ChainSumDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -28,7 +28,7 @@ def test_chain_sum_deterministic(): def test_chain_sum_items(): """Test basic properties of generated items""" config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -57,7 +57,7 @@ def test_chain_sum_number_ranges(): size=50, seed=42, ) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -71,7 +71,7 @@ def test_chain_sum_number_ranges(): # Test 1-digit numbers config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -88,7 +88,7 @@ def test_chain_sum_negation(): config = ChainSumConfig( min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True ) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) # Track if we see both positive and negative numbers has_positive = False @@ -112,7 +112,7 @@ def test_chain_sum_negation(): def test_chain_sum_iteration(): """Test that iteration respects dataset size""" config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing - dataset = ChainSum(config) + dataset = ChainSumDataset(config) # Test manual iteration items = [] diff --git a/tests/test_coaching.py b/tests/test_coaching.py index 83b56768..3b1b97d7 100644 --- a/tests/test_coaching.py +++ b/tests/test_coaching.py @@ -5,7 +5,7 @@ from pathlib import Path import pytest -from reasoning_gym.arithmetic.chain_sum import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.leg_counting import LegCountingConfig from reasoning_gym.coaching import Coach, GroupedScores from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec @@ -14,7 +14,7 @@ from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSp def test_coach_with_chain_sum(): # Create a small ChainSum dataset config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) coach = Coach(dataset) # Simulate an agent working on tasks @@ -208,7 +208,7 @@ def test_coach_score_logging(tmp_path): # Create dataset and coach with logging config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) coach = Coach(dataset, score_log=log_file) # Score a few answers diff --git a/tests/test_game_of_life.py b/tests/test_game_of_life.py index df0f133d..10ec5c7a 100644 --- a/tests/test_game_of_life.py +++ b/tests/test_game_of_life.py @@ -7,7 +7,7 @@ def test_game_of_life(): """Test basic properties and solution of generated items""" # Easy - config = GameOfLifeConfig(seed=42, size=1, grid_size_x=20, grid_size_y=20, filled_cells=10, simulation_steps=1) + config = GameOfLifeConfig(seed=42, size=10, grid_size_x=20, grid_size_y=20, filled_cells=200, simulation_steps=1) dataset = GameOfLifeDataset(config) for item in dataset: diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index 420187de..e4e72b18 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -112,7 +112,7 @@ def test_polynomial_solutions_evaluation(): evaluated_value = poly_expr.subs(x, solution) # Ensure the evaluated value is close to zero (numerical stability threshold) - assert abs(evaluated_value) < 1e-6, ( + assert abs(evaluated_value) < 1e-5, ( f"Solution {solution} does not satisfy the polynomial {poly_str}. " f"Evaluated value: {evaluated_value}" ) diff --git a/tests/test_products.py b/tests/test_products.py new file mode 100644 index 00000000..34ff1623 --- /dev/null +++ b/tests/test_products.py @@ -0,0 +1,144 @@ +import pytest + +from reasoning_gym.arithmetic import ProductsConfig, ProductsDataset +from reasoning_gym.arithmetic.products import ProductsCurriculum + + +def test_products_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = ProductsConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = ProductsConfig(min_terms=3, max_terms=2) + config.validate() + + +def test_products_deterministic(): + """Test that dataset generates same items with same seed""" + config = ProductsConfig(seed=42, size=10) + dataset1 = ProductsDataset(config) + dataset2 = ProductsDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_products_items(): + """Test basic properties of generated items""" + config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) + dataset = ProductsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify only * is used + expression = item["metadata"]["expression"] + assert all(op in ["*", " "] or op.isdigit() for op in expression) + + # Verify the answer matches the expression + answer = eval(expression) # Safe here as we control the expression + assert str(answer) == item["answer"] + + +def test_products_number_ranges(): + """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers + config = ProductsConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, # Should generate numbers >= 100 + max_digits=3, # Should generate numbers <= 999 + size=50, + seed=42, + ) + dataset = ProductsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) + dataset = ProductsDataset(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" + + +def test_products_iteration(): + """Test that iteration respects dataset size""" + config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing + dataset = ProductsDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_products_scoring(): + """Test that scoring works correctly""" + config = ProductsConfig(min_terms=2, max_terms=2, size=10, seed=42) + dataset = ProductsDataset(config) + + # Test scoring with exact match + item = dataset[0] + assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0" + + # Test scoring with wrong answer + assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01" + + # Test scoring with partial match (answer contained in response) + assert dataset.score_answer(f"The answer is {item['answer']}", item) == 0.5, "Partial match should score 0.5" + + # Test scoring with None + assert dataset.score_answer(None, item) == 0.0, "None should score 0.0" + + +def test_products_curriculum(): + curriculum = ProductsCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ProductsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + + # test incrementing attribute levels for num_terms & num_digits attributes + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") + + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3 + + # test decrementing attribute level for num_digits again + curriculum.decrement_attr_level("num_digits") + + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1 + assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3