From 8811598a1799bcca34dfb12eb4bd5837d1e69544 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 00:49:31 -0600 Subject: [PATCH 01/18] Add useful instructions to question template of some datasets --- .../algebra/intermediate_integration.py | 9 ++++++++- reasoning_gym/algebra/polynomial_equations.py | 17 ++++++++++++----- .../algebra/polynomial_multiplication.py | 11 ++++++++--- reasoning_gym/algebra/simple_integration.py | 9 ++++++++- reasoning_gym/algorithmic/number_sorting.py | 9 ++++++++- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py index 6335b6b7..34e144df 100644 --- a/reasoning_gym/algebra/intermediate_integration.py +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -76,6 +76,12 @@ class IntermediateIntegrationDataset(ProceduralDataset): "Calculate the antiderivative: ∫ {integrand} dx", "Evaluate the indefinite integral: ∫ {integrand} dx", ] + self.added_instruction = """ + \n\n + In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems + ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. + ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. + """ def _get_outer_constant(self, rng: random.Random) -> int: """Helper to generate signed outer constant from config""" @@ -222,9 +228,10 @@ class IntermediateIntegrationDataset(ProceduralDataset): answer = sympy.integrate(integrand, x) answer_str = str(answer) + " + C" + question = rng.choice(self.prompt_template).format(integrand=integrand) + self.added_instruction return { - "question": rng.choice(self.prompt_template).format(integrand=integrand), + "question": question, "answer": answer_str, "metadata": { "integrand": str(integrand), diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index eec45285..f3e9d2da 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -62,6 +62,15 @@ class PolynomialEquationsDataset(ProceduralDataset): "Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0", "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", ] + self.added_instruction = """ + \n\n + In solving the equations, please abide by the following instruction: + ## 1. All answers should be inserted in square brackets. For example [], [-0.3773, 0.4005] etc. + ## 2. In cases where your answer is b = [2 + sqrt(4560)] / 172 and b = [2 - sqrt(4560)] / 172. Since b can be 2 numbers Resolve your answer like this instead, [-0.3773, 0.4005] + ## 3. If there are no real values of i that satisfy the equation, report your answer as empty square bracket, [] + ## 4. If there are 2 answers, resolve the answers as floats and fill the 2 numbers in square bracket, if 3 answers, fill it with 3 answers. + ## 5. Resolve all numbers in square brackets as floats. Round the floats higher than 4 decimal place(d.p) down to 4 d.p. + """ super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: @@ -89,15 +98,13 @@ class PolynomialEquationsDataset(ProceduralDataset): for sol in solutions: if sol.is_real: # Evaluate symbolic solution to a floating approximation - real_solutions.append(float(sol.evalf())) + real_solutions.append(round(float(sol.evalf()), 4)) real_solutions.sort() answer_str = str(real_solutions) + question = rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded) + self.added_instruction return { - "question": rng.choice(self._prompt_templates).format( - variable=variable, - polynomial_expanded=polynomial_expanded, - ), + "question": question, "answer": answer_str, "metadata": { "polynomial_expr": str(polynomial_expanded), diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index 9a74679f..d0c472b3 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -61,6 +61,12 @@ class PolynomialMultiplicationDataset(ProceduralDataset): "Simplify this expression: {polynomial_expr}", "Calculate the following: {polynomial_expr}", ] + self.added_instruction = """ + \n\n + In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems + ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. + ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. + """ super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: @@ -79,11 +85,10 @@ class PolynomialMultiplicationDataset(ProceduralDataset): polynomial_expr = sp.prod(polynomials) product = sp.expand(polynomial_expr) + question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction return { - "question": rng.choice(self._prompt_templates).format( - polynomial_expr=polynomial_expr, - ), + "question": question, "answer": product, "metadata": { "polynomial_expr": str(polynomial_expr), diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py index a8ca3be2..d056bc7c 100644 --- a/reasoning_gym/algebra/simple_integration.py +++ b/reasoning_gym/algebra/simple_integration.py @@ -41,6 +41,12 @@ class SimpleIntegrationDataset(ProceduralDataset): "Calculate the antiderivative: ∫ {integrand} dx", "Evaluate the indefinite integral: ∫ {integrand} dx", ] + self.added_instruction = """ + \n\n + In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems + ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. + ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. + """ super().__init__(config=config, seed=config.seed, size=config.size) def _generate_coefficient(self, rng: random.Random) -> Fraction: @@ -69,9 +75,10 @@ class SimpleIntegrationDataset(ProceduralDataset): rng = random.Random(self.seed + idx) symbol, polynomial = self._generate_polynomial(rng) derivative = sympy.diff(polynomial, symbol) + question = rng.choice(self._prompt_templates).format(integrand=derivative) + self.added_instruction return { - "question": rng.choice(self._prompt_templates).format(integrand=derivative), + "question": question, "answer": str(polynomial) + " + C", "metadata": { "integrand": str(derivative), diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index d922aa74..39dfda99 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -34,6 +34,12 @@ class NumberSortingDataset(ProceduralDataset): def __init__(self, config: NumberSortingConfig): super().__init__(config=config, seed=config.seed, size=config.size) + self.added_instruction = """ + \n\n + Please follow the instruction below: + # 1. Let all your answers be a list of numbers. Instead of reporting your answer as -69, -13, 1, 7, 11, 43, 59, 61, use ['-69', '-13', '1', '7', '11', '43', '59', '61'] instead + # 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61'] + """ def _format_number(self, num: float, decimals: int) -> str: """Format number with specified decimal places""" @@ -78,9 +84,10 @@ class NumberSortingDataset(ProceduralDataset): is_ascending = rng.choice([True, False]) direction = "ascending" if is_ascending else "descending" answer = asc_answer if is_ascending else desc_answer + question = f"Sort these numbers in {direction} order: {', '.join(number_strs)}" + self.added_instruction return { - "question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", + "question": question, "answer": str(answer), "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer}, } From 0adde9853f0a1d512c1fde7dd99931f2521c0a21 Mon Sep 17 00:00:00 2001 From: Rich Jones Date: Thu, 13 Feb 2025 13:25:38 +0100 Subject: [PATCH 02/18] json gol --- reasoning_gym/games/game_of_life.py | 20 +++++++++++++++----- tests/test_game_of_life.py | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/reasoning_gym/games/game_of_life.py b/reasoning_gym/games/game_of_life.py index c8cdc0d1..b77f3fd6 100644 --- a/reasoning_gym/games/game_of_life.py +++ b/reasoning_gym/games/game_of_life.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass from random import Random from typing import Dict, List, Optional, Tuple @@ -31,7 +32,7 @@ class GameOfLifeDataset(ProceduralDataset): def __init__(self, config: GameOfLifeConfig): self._prompt_templates = [ - "What will this Game of Life board look like after {simulation_steps} steps of simulation?\n\n{board}" + "What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}." ] super().__init__(config=config, seed=config.seed, size=config.size) @@ -63,7 +64,9 @@ class GameOfLifeDataset(ProceduralDataset): ) board_str = str(board[0]) - result_str = str(evolved[-1]) + final_step = evolved[-1] + final_step_list = final_step.tolist() + result_str = json.dumps(final_step_list, separators=(",", ":")) return { "question": rng.choice(self._prompt_templates).format( @@ -93,10 +96,17 @@ class GameOfLifeDataset(ProceduralDataset): if answer == None: return 0.0 - if answer.replace("\n", "") != entry["answer"].replace("\n", ""): + + try: + ans_arr = json.loads(answer) + correct_arr = json.loads(entry["answer"]) + + if correct_arr != ans_arr: + return 0.01 + else: + return 1.0 # Yay + except Exception as e: return 0.01 - else: - return 1.0 # Yay register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig) diff --git a/tests/test_game_of_life.py b/tests/test_game_of_life.py index df0f133d..10ec5c7a 100644 --- a/tests/test_game_of_life.py +++ b/tests/test_game_of_life.py @@ -7,7 +7,7 @@ def test_game_of_life(): """Test basic properties and solution of generated items""" # Easy - config = GameOfLifeConfig(seed=42, size=1, grid_size_x=20, grid_size_y=20, filled_cells=10, simulation_steps=1) + config = GameOfLifeConfig(seed=42, size=10, grid_size_x=20, grid_size_y=20, filled_cells=200, simulation_steps=1) dataset = GameOfLifeDataset(config) for item in dataset: From 3ead141db5cede761576f6bc93a54597ce7d2a73 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Thu, 13 Feb 2025 17:50:17 +0100 Subject: [PATCH 03/18] feat: Add PowerFunctionConfig and PowerFunctionDataset to arithmetic module exports --- reasoning_gym/arithmetic/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index cc94bf56..4c10e3b2 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -31,6 +31,8 @@ __all__ = [ "LCMDataset", "LegCountingConfig", "LegCountingDataset", + "PowerFunctionConfig", + "PowerFunctionDataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", "GSMSymbolicDatasetConfig", From bdcaeff42a6e549e1063366e0b75e8e294fc3a92 Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Thu, 13 Feb 2025 17:50:19 +0100 Subject: [PATCH 04/18] feat: Add ProductsDataset with configurable terms and digits --- reasoning_gym/arithmetic/__init__.py | 3 + reasoning_gym/arithmetic/products.py | 130 +++++++++++++++++++++++++++ tests/test_products.py | 125 ++++++++++++++++++++++++++ 3 files changed, 258 insertions(+) create mode 100644 reasoning_gym/arithmetic/products.py create mode 100644 tests/test_products.py diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 4c10e3b2..cbfee4b1 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -14,6 +14,7 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset +from .products import Products, ProductsConfig from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ @@ -35,6 +36,8 @@ __all__ = [ "PowerFunctionDataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", + "Products", + "ProductsConfig", "GSMSymbolicDatasetConfig", "GSMSymbolicDataset", "TimeIntervalsConfig", diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py new file mode 100644 index 00000000..62ae291c --- /dev/null +++ b/reasoning_gym/arithmetic/products.py @@ -0,0 +1,130 @@ +import random +from dataclasses import dataclass +from typing import Optional + +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class ProductsConfig: + """Configuration for products task generation""" + + min_terms: int = 2 + max_terms: int = 2 + min_digits: int = 1 + max_digits: int = 5 + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.min_terms > 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_digits > 0, "min_digits must be positive" + assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" + + +class Products(ProceduralDataset): + """Generates multiplication tasks with configurable number of terms""" + + def __init__(self, config: ProductsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single multiplication task + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the formatted multiplication expression + - answer: str, the ground truth result + - metadata: dict with generation parameters + """ + # Create deterministic RNG from base seed and idx + item_rng = random.Random(self.seed + idx) + + num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + + # Calculate value ranges based on number of digits + min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit + max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits + + expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + + return { + "question": f"{expression} =", + "answer": str(result), + "metadata": { + "difficulty": { + "num_terms": num_terms, + "num_digits": num_digits, + }, + "expression": expression, + }, + } + + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: + """Generate a multiplication task + + Args: + rng: Random number generator + num_terms: Number of terms in the expression + min_value: Minimum value for generated numbers + max_value: Maximum value for generated numbers + + Returns: + Tuple of (expression string, result integer) + """ + # Generate random numbers within the specified range + constants = [rng.randint(min_value, max_value) for _ in range(num_terms)] + + # Build expression and compute result + expression_parts = [] + result = constants[0] + + expression_parts.append(str(constants[0])) + for i in range(1, len(constants)): + expression_parts.append("*") + expression_parts.append(str(constants[i])) + result *= constants[i] + + expression = " ".join(expression_parts) + return expression, result + + +class ProductsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ProductsCurriculum.__name__, ProductsConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 terms + description="Maximum number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, # Ensure at least 2 terms + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 3, 4], + default_level=0, # Start with 1-digit numbers + description="Number of digits in each operand", + attr_type=AttributeType.APPEND, + min_value=1, # Ensure numbers are at least 1 digit + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + ) + + +# Register the dataset +register_dataset("products", Products, ProductsConfig) diff --git a/tests/test_products.py b/tests/test_products.py new file mode 100644 index 00000000..a569209c --- /dev/null +++ b/tests/test_products.py @@ -0,0 +1,125 @@ +import pytest + +from reasoning_gym.arithmetic import Products, ProductsConfig +from reasoning_gym.arithmetic.products import ProductsCurriculum + + +def test_products_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = ProductsConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = ProductsConfig(min_terms=3, max_terms=2) + config.validate() + + +def test_products_deterministic(): + """Test that dataset generates same items with same seed""" + config = ProductsConfig(seed=42, size=10) + dataset1 = Products(config) + dataset2 = Products(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_products_items(): + """Test basic properties of generated items""" + config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) + dataset = Products(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify only * is used + expression = item["metadata"]["expression"] + assert all(op in ["*", " "] or op.isdigit() for op in expression) + + # Verify the answer matches the expression + answer = eval(expression) # Safe here as we control the expression + assert str(answer) == item["answer"] + + +def test_products_number_ranges(): + """Test that generated numbers respect digit constraints""" + # Test 3-digit numbers + config = ProductsConfig( + min_terms=2, + max_terms=2, # Fix to 2 terms for easier testing + min_digits=3, # Should generate numbers >= 100 + max_digits=3, # Should generate numbers <= 999 + size=50, + seed=42, + ) + dataset = Products(config) + + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits" + + # Test 1-digit numbers + config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) + dataset = Products(config) + for i in range(len(dataset)): + item = dataset[i] + expression = item["metadata"]["expression"] + numbers = [int(n) for n in expression.split() if n.isdigit()] + for num in numbers: + assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit" + + +def test_products_iteration(): + """Test that iteration respects dataset size""" + config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing + dataset = Products(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_products_curriculum(): + curriculum = ProductsCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ProductsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + + # test incrementing attribute levels for num_terms & num_digits attributes + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") + + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3 + + # test decrementing attribute level for num_digits again + curriculum.decrement_attr_level("num_digits") + + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1 + assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3 From ce305366279ea7c1c447adbe8c49b878e62216c7 Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Thu, 13 Feb 2025 17:52:32 +0100 Subject: [PATCH 05/18] test: Add scoring tests for Products dataset --- tests/test_products.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_products.py b/tests/test_products.py index a569209c..aac77a98 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -98,6 +98,25 @@ def test_products_iteration(): assert first_items == second_items, "Multiple iterations should yield same items" +def test_products_scoring(): + """Test that scoring works correctly""" + config = ProductsConfig(min_terms=2, max_terms=2, size=10, seed=42) + dataset = Products(config) + + # Test scoring with exact match + item = dataset[0] + assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0" + + # Test scoring with wrong answer + assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01" + + # Test scoring with partial match (answer contained in response) + assert dataset.score_answer(f"The answer is {item['answer']}", item) == 0.5, "Partial match should score 0.5" + + # Test scoring with None + assert dataset.score_answer(None, item) == 0.0, "None should score 0.0" + + def test_products_curriculum(): curriculum = ProductsCurriculum() From 5410bb78a0818b871a464171862af1c43c358d11 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Thu, 13 Feb 2025 17:59:02 +0100 Subject: [PATCH 06/18] add ProductsDataset (multiplication tasks) --- reasoning_gym/arc/arc_1d.py | 10 ++++----- reasoning_gym/arithmetic/__init__.py | 8 +++---- reasoning_gym/arithmetic/basic_arithmetic.py | 12 +++++----- .../arithmetic/calendar_arithmetic.py | 6 ++--- reasoning_gym/arithmetic/chain_sum.py | 12 +++++----- reasoning_gym/arithmetic/products.py | 12 +++++----- reasoning_gym/arithmetic/time_intervals.py | 8 +++---- tests/test_chain_sum.py | 16 +++++++------- tests/test_coaching.py | 6 ++--- tests/test_products.py | 22 +++++++++---------- 10 files changed, 56 insertions(+), 56 deletions(-) diff --git a/reasoning_gym/arc/arc_1d.py b/reasoning_gym/arc/arc_1d.py index 7e399f20..ce594ec4 100644 --- a/reasoning_gym/arc/arc_1d.py +++ b/reasoning_gym/arc/arc_1d.py @@ -58,27 +58,27 @@ class Arc1DDataset(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = Random(self.seed + idx) + rng = Random(self.seed + idx) # Select random task - task_name = item_rng.choice(self.task_names) + task_name = rng.choice(self.task_names) task_func, task_kwargs = self.ARC_1D_TASKS[task_name] # Generate training examples train_examples = [] - size = item_rng.randint(self.config.min_size, self.config.max_size) + size = rng.randint(self.config.min_size, self.config.max_size) for _ in range(self.config.num_train): example = None while example is None: - example = task_func(item_rng, size, **task_kwargs) + example = task_func(rng, size, **task_kwargs) train_examples.append(example) # Generate test example test_example = None while test_example is None: - test_example = task_func(item_rng, size, **task_kwargs) + test_example = task_func(rng, size, **task_kwargs) # Format question question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n" diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index cbfee4b1..495a79c5 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset -from .chain_sum import ChainSum, ChainSumConfig +from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsDataset from .dice import DiceConfig, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset @@ -14,13 +14,13 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset -from .products import Products, ProductsConfig +from .products import ProductsConfig, ProductsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", "BasicArithmeticDatasetConfig", - "ChainSum", + "ChainSumDataset", "ChainSumConfig", "CalendarArithmeticConfig", "CalendarArithmeticDataset", @@ -36,7 +36,7 @@ __all__ = [ "PowerFunctionDataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", - "Products", + "ProductsDataset", "ProductsConfig", "GSMSymbolicDatasetConfig", "GSMSymbolicDataset", diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 9ec096ee..156314d9 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -78,17 +78,17 @@ class BasicArithmeticDataset(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = Random(self.seed + idx) + rng = Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) if self.config.allow_parentheses: - expression, result = self._generate_complex_task(item_rng, num_terms, num_digits) + expression, result = self._generate_complex_task(rng, num_terms, num_digits) else: - expression, result = self._generate_simple_task(item_rng, num_terms, num_digits) + expression, result = self._generate_simple_task(rng, num_terms, num_digits) - question = self._format_question(item_rng, expression) + question = self._format_question(rng, expression) return { "question": question, diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index bf12211c..3a052590 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -122,9 +122,9 @@ class CalendarArithmeticDataset(ProceduralDataset): self.tasks = [self.task_handlers[task] for task in self.config.tasks] def __getitem__(self, idx: int) -> dict: - item_rng = random.Random(self.seed + idx) - task = item_rng.choice(self.tasks) - question, answer, metadata = task(item_rng) + rng = random.Random(self.seed + idx) + task = rng.choice(self.tasks) + question, answer, metadata = task(rng) return { "question": question, "answer": str(answer), diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 969df820..6d2e43e6 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -32,7 +32,7 @@ class ChainSumConfig: assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" -class ChainSum(ProceduralDataset): +class ChainSumDataset(ProceduralDataset): """Generates simple arithmetic tasks using only + and - operators""" def __init__(self, config: ChainSumConfig): @@ -51,16 +51,16 @@ class ChainSum(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits - expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { "question": f"{expression} =", @@ -143,4 +143,4 @@ class ChainSumCurriculum(BaseCurriculum): # Register the dataset -register_dataset("chain_sum", ChainSum, ChainSumConfig) +register_dataset("chain_sum", ChainSumDataset, ChainSumConfig) diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 62ae291c..742696ec 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -26,7 +26,7 @@ class ProductsConfig: assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits" -class Products(ProceduralDataset): +class ProductsDataset(ProceduralDataset): """Generates multiplication tasks with configurable number of terms""" def __init__(self, config: ProductsConfig): @@ -45,16 +45,16 @@ class Products(ProceduralDataset): - metadata: dict with generation parameters """ # Create deterministic RNG from base seed and idx - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) - num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) - num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + num_digits = rng.randint(self.config.min_digits, self.config.max_digits) # Calculate value ranges based on number of digits min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits - expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) + expression, result = self._generate_task(rng, num_terms, min_value, max_value) return { "question": f"{expression} =", @@ -127,4 +127,4 @@ class ProductsCurriculum(BaseCurriculum): # Register the dataset -register_dataset("products", Products, ProductsConfig) +register_dataset("products", ProductsDataset, ProductsConfig) diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index 1b296d02..f4011ba7 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -82,14 +82,14 @@ class TimeIntervalsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single time interval calculation task""" - item_rng = random.Random(self.seed + idx) + rng = random.Random(self.seed + idx) # Randomly choose task type from config - task_type = item_rng.choice(self.config.task_types) + task_type = rng.choice(self.config.task_types) - start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type) + start_time, end_time, format_str, expected_format = self._generate_times(rng, task_type) - template = item_rng.choice(self.TEMPLATES) + template = rng.choice(self.TEMPLATES) question = template.format(start=start_time, end=end_time, format=expected_format) # Calculate the actual difference diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index 36b0185c..eedbaae3 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum @@ -18,8 +18,8 @@ def test_chain_sum_config_validation(): def test_chain_sum_deterministic(): """Test that dataset generates same items with same seed""" config = ChainSumConfig(seed=42, size=10) - dataset1 = ChainSum(config) - dataset2 = ChainSum(config) + dataset1 = ChainSumDataset(config) + dataset2 = ChainSumDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -28,7 +28,7 @@ def test_chain_sum_deterministic(): def test_chain_sum_items(): """Test basic properties of generated items""" config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -57,7 +57,7 @@ def test_chain_sum_number_ranges(): size=50, seed=42, ) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -71,7 +71,7 @@ def test_chain_sum_number_ranges(): # Test 1-digit numbers config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -88,7 +88,7 @@ def test_chain_sum_negation(): config = ChainSumConfig( min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True ) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) # Track if we see both positive and negative numbers has_positive = False @@ -112,7 +112,7 @@ def test_chain_sum_negation(): def test_chain_sum_iteration(): """Test that iteration respects dataset size""" config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing - dataset = ChainSum(config) + dataset = ChainSumDataset(config) # Test manual iteration items = [] diff --git a/tests/test_coaching.py b/tests/test_coaching.py index 83b56768..3b1b97d7 100644 --- a/tests/test_coaching.py +++ b/tests/test_coaching.py @@ -5,7 +5,7 @@ from pathlib import Path import pytest -from reasoning_gym.arithmetic.chain_sum import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.leg_counting import LegCountingConfig from reasoning_gym.coaching import Coach, GroupedScores from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec @@ -14,7 +14,7 @@ from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSp def test_coach_with_chain_sum(): # Create a small ChainSum dataset config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) coach = Coach(dataset) # Simulate an agent working on tasks @@ -208,7 +208,7 @@ def test_coach_score_logging(tmp_path): # Create dataset and coach with logging config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSum(config) + dataset = ChainSumDataset(config) coach = Coach(dataset, score_log=log_file) # Score a few answers diff --git a/tests/test_products.py b/tests/test_products.py index aac77a98..34ff1623 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arithmetic import Products, ProductsConfig +from reasoning_gym.arithmetic import ProductsConfig, ProductsDataset from reasoning_gym.arithmetic.products import ProductsCurriculum @@ -18,8 +18,8 @@ def test_products_config_validation(): def test_products_deterministic(): """Test that dataset generates same items with same seed""" config = ProductsConfig(seed=42, size=10) - dataset1 = Products(config) - dataset2 = Products(config) + dataset1 = ProductsDataset(config) + dataset2 = ProductsDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -28,7 +28,7 @@ def test_products_deterministic(): def test_products_items(): """Test basic properties of generated items""" config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) - dataset = Products(config) + dataset = ProductsDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -57,7 +57,7 @@ def test_products_number_ranges(): size=50, seed=42, ) - dataset = Products(config) + dataset = ProductsDataset(config) for i in range(len(dataset)): item = dataset[i] @@ -68,7 +68,7 @@ def test_products_number_ranges(): # Test 1-digit numbers config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) - dataset = Products(config) + dataset = ProductsDataset(config) for i in range(len(dataset)): item = dataset[i] expression = item["metadata"]["expression"] @@ -80,7 +80,7 @@ def test_products_number_ranges(): def test_products_iteration(): """Test that iteration respects dataset size""" config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing - dataset = Products(config) + dataset = ProductsDataset(config) # Test manual iteration items = [] @@ -101,18 +101,18 @@ def test_products_iteration(): def test_products_scoring(): """Test that scoring works correctly""" config = ProductsConfig(min_terms=2, max_terms=2, size=10, seed=42) - dataset = Products(config) + dataset = ProductsDataset(config) # Test scoring with exact match item = dataset[0] assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0" - + # Test scoring with wrong answer assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01" - + # Test scoring with partial match (answer contained in response) assert dataset.score_answer(f"The answer is {item['answer']}", item) == 0.5, "Partial match should score 0.5" - + # Test scoring with None assert dataset.score_answer(None, item) == 0.0, "None should score 0.0" From 9ae8a29ff2d134d5918822bb42435507592066e0 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Thu, 13 Feb 2025 19:06:20 +0100 Subject: [PATCH 07/18] use json formatting for initial state of game-of-life board --- reasoning_gym/games/game_of_life.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/reasoning_gym/games/game_of_life.py b/reasoning_gym/games/game_of_life.py index b77f3fd6..2b5e369b 100644 --- a/reasoning_gym/games/game_of_life.py +++ b/reasoning_gym/games/game_of_life.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from random import Random -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional import cellpylib as cpl @@ -12,8 +12,8 @@ from ..factory import ProceduralDataset, register_dataset class GameOfLifeConfig: """Configuration for sudoku puzzle generation""" - grid_size_x: int = 20 - grid_size_y: int = 20 + grid_size_x: int = 10 + grid_size_y: int = 10 filled_cells: int = 100 # actually a max simulation_steps: int = 1 seed: Optional[int] = None @@ -60,10 +60,15 @@ class GameOfLifeDataset(ProceduralDataset): # Simulate the result to get the answer evolved = cpl.evolve2d( - board, timesteps=self.config.simulation_steps + 1, apply_rule=cpl.game_of_life_rule, memoize="recursive" + board, + timesteps=self.config.simulation_steps + 1, + apply_rule=cpl.game_of_life_rule, + memoize="recursive", ) - board_str = str(board[0]) + rows = [json.dumps(board[0, i].tolist(), separators=(",", ":")) for i in range(board.shape[1])] + board_str = "[" + ", \n ".join(rows) + "]" + final_step = evolved[-1] final_step_list = final_step.tolist() result_str = json.dumps(final_step_list, separators=(",", ":")) From 12af794381867e8cfadb5ffd594284ab1ec3d087 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 12 Feb 2025 15:18:51 +0100 Subject: [PATCH 08/18] string insertion --- reasoning_gym/algorithmic/__init__.py | 3 + reasoning_gym/algorithmic/string_insertion.py | 98 +++++++++++++++++++ tests/test_string_insertion.py | 94 ++++++++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 reasoning_gym/algorithmic/string_insertion.py create mode 100644 tests/test_string_insertion.py diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 875ab539..ebe4b397 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -28,6 +28,7 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from .string_insertion import StringInsertionConfig, StringInsertionDataset __all__ = [ "SpellBackwardConfig", @@ -75,4 +76,6 @@ __all__ = [ "ABDataset", "CountPrimesConfig", "CountPrimesDataset", + "StringInsertionConfig", + "StringInsertionDataset", ] diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py new file mode 100644 index 00000000..fb1d35e0 --- /dev/null +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -0,0 +1,98 @@ +"""Insert into string according to a pattern + +https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_insertion.py +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + + +QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: +1. If there is a substring ABCD in the string, insert the character A after the substring. +2. If there is a substring BCDE in the string, insert the character B after the substring. +3. If there is a substring CDEA in the string, insert the character C after the substring. +4. If there is a substring DEAB in the string, insert the character D after the substring. +5. If there is a substring EABC in the string, insert the character E after the substring. + +Once you have inserted a character, you have to skip over the substring and the inserted character and continue the search from the next character. + +Example +- Input: DDABCDEEDEAB +- Output: DDABCDAEEDEABD +- Explanation: + - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) + - First, we insert A after ABCD. + - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. + - Lastly, we insert D after DEAB. + +Given the following string, provide the answer after inserting the characters according to the pattern: {string} +""" + + +@dataclass +class StringInsertionConfig: + """Configuration for String Insertion dataset generation""" + + min_string_length: int = 5 # Minimum string length + max_string_length: int = 20 # Maximum string length + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 5 <= self.min_string_length, "Minimum string length should be at least 5" + assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + +class StringInsertionDataset(ProceduralDataset): + """Generates String Insertion exercises with configurable difficulty""" + + def __init__(self, config: StringInsertionConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.vocabulary = ['A', 'B', 'C', 'D', 'E'] + self.insertion_rules = [ + ("ABCD", "A"), + ("BCDE", "B"), + ("CDEA", "C"), + ("DEAB", "D"), + ("EABC", "E"), + ] + + def _get_answer(self, string: str) -> str: + """Apply insertion rules to a string""" + output = [] + i = 0 + while i < len(string): + inserted = False + for pattern, char in self.insertion_rules: + substring = string[i:i+len(pattern)] + if substring == pattern: + output.append(substring + char) + i += len(pattern) + inserted = True + break + if not inserted: + output.append(string[i]) + i += 1 + return "".join(output) + + def __getitem__(self, idx: int) -> dict: + """Generate a single String Insertion question""" + rng = Random(self.seed + idx) + + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) + string = [rng.choice(self.vocabulary) for _ in range(string_length)] + + answer = self._get_answer(string) + + return { + "question": QUESTION_TEMPLATE.format(string=string), + "answer": str(answer), + "metadata": {"string": string, "solution": answer}, + } + + +register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig) diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py new file mode 100644 index 00000000..746ff5c5 --- /dev/null +++ b/tests/test_string_insertion.py @@ -0,0 +1,94 @@ +"""Tests for String Insertion questions generation""" + +import pytest + +from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, StringInsertionDataset + + +def test_string_insertion_config_validation(): + """Test that invalid configs raise appropriate errors""" + + for field in ["min_string_length", "max_string_length"]: + for i in range(-1, 5): + with pytest.raises(AssertionError): + config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid + config.validate() + + +def test_string_insertion_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = StringInsertionConfig(seed=42, size=10) + dataset1 = StringInsertionDataset(config) + dataset2 = StringInsertionDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_string_insertion_dataset_items(): + """Test basic properties of generated items""" + config = StringInsertionConfig(min_string_length=5, max_string_length=30, size=10, seed=42) + dataset = StringInsertionDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "string" in item["metadata"] + assert "solution" in item["metadata"] + + string = item["metadata"]["string"] + solution = item["metadata"]["solution"] + + # Verify string dimensions + assert 5 <= len(string) <= 30 + assert len(string) <= len(solution) + + +def test_string_insertion_dataset_iteration(): + """Test that iteration respects dataset size""" + config = StringInsertionConfig(size=5, seed=42) + dataset = StringInsertionDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_string_insertion_answer(): + """Test the _get_rotated method""" + config = StringInsertionConfig(seed=42) + dataset = StringInsertionDataset(config) + + # No pattern match + assert dataset._get_answer("AAAAAAA") == "AAAAAAA" + assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" + assert dataset._get_answer("ADEACA") == "ADEACA" + + # Insert A after ABCD + assert dataset._get_answer("ABCDE") == "ABCDAE" + + # Insert B after BCDE + assert dataset._get_answer("AEBCDEC") == "AEBCDEBC" + + # Insert C after CDEA + assert dataset._get_answer("BBACDEAC") == "BBACDEACC" + + # Insert D after DEAB + assert dataset._get_answer("BAAABDEAB") == "BAAABDEABD" + + # Insert E after EABC + assert dataset._get_answer("EABCBCBC") == "EABCEBCBC" + + # Multiple insertions + assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" + + # No reuse of newly inserted characters + assert dataset._get_answer("ABCDBCD") == "ABCDABCD" \ No newline at end of file From 6f9036631adb37ad858f8b55819e436d0f96c6f0 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 12 Feb 2025 17:26:23 +0100 Subject: [PATCH 09/18] lint --- reasoning_gym/algorithmic/__init__.py | 2 +- reasoning_gym/algorithmic/string_insertion.py | 12 ++++++------ tests/test_string_insertion.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index ebe4b397..1b528794 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -25,10 +25,10 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset +from .string_insertion import StringInsertionConfig, StringInsertionDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .string_insertion import StringInsertionConfig, StringInsertionDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index fb1d35e0..b217ed76 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -9,7 +9,6 @@ from typing import Optional from ..factory import ProceduralDataset, register_dataset - QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: 1. If there is a substring ABCD in the string, insert the character A after the substring. 2. If there is a substring BCDE in the string, insert the character B after the substring. @@ -22,7 +21,7 @@ Once you have inserted a character, you have to skip over the substring and the Example - Input: DDABCDEEDEAB - Output: DDABCDAEEDEABD -- Explanation: +- Explanation: - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) - First, we insert A after ABCD. - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. @@ -37,7 +36,7 @@ class StringInsertionConfig: """Configuration for String Insertion dataset generation""" min_string_length: int = 5 # Minimum string length - max_string_length: int = 20 # Maximum string length + max_string_length: int = 20 # Maximum string length size: int = 500 # Virtual dataset size seed: Optional[int] = None @@ -47,12 +46,13 @@ class StringInsertionConfig: assert 5 <= self.min_string_length, "Minimum string length should be at least 5" assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + class StringInsertionDataset(ProceduralDataset): """Generates String Insertion exercises with configurable difficulty""" def __init__(self, config: StringInsertionConfig): super().__init__(config=config, seed=config.seed, size=config.size) - self.vocabulary = ['A', 'B', 'C', 'D', 'E'] + self.vocabulary = ["A", "B", "C", "D", "E"] self.insertion_rules = [ ("ABCD", "A"), ("BCDE", "B"), @@ -68,7 +68,7 @@ class StringInsertionDataset(ProceduralDataset): while i < len(string): inserted = False for pattern, char in self.insertion_rules: - substring = string[i:i+len(pattern)] + substring = string[i : i + len(pattern)] if substring == pattern: output.append(substring + char) i += len(pattern) @@ -82,7 +82,7 @@ class StringInsertionDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single String Insertion question""" rng = Random(self.seed + idx) - + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) string = [rng.choice(self.vocabulary) for _ in range(string_length)] diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py index 746ff5c5..12225954 100644 --- a/tests/test_string_insertion.py +++ b/tests/test_string_insertion.py @@ -7,13 +7,13 @@ from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, St def test_string_insertion_config_validation(): """Test that invalid configs raise appropriate errors""" - + for field in ["min_string_length", "max_string_length"]: for i in range(-1, 5): with pytest.raises(AssertionError): - config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid + config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid config.validate() - + def test_string_insertion_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -67,7 +67,7 @@ def test_string_insertion_answer(): config = StringInsertionConfig(seed=42) dataset = StringInsertionDataset(config) - # No pattern match + # No pattern match assert dataset._get_answer("AAAAAAA") == "AAAAAAA" assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" assert dataset._get_answer("ADEACA") == "ADEACA" @@ -91,4 +91,4 @@ def test_string_insertion_answer(): assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" # No reuse of newly inserted characters - assert dataset._get_answer("ABCDBCD") == "ABCDABCD" \ No newline at end of file + assert dataset._get_answer("ABCDBCD") == "ABCDABCD" From 90a923c78ea8f35efb55ae890bbafd342693a707 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Thu, 13 Feb 2025 03:51:01 +0000 Subject: [PATCH 10/18] updated async impl and added r1 --- eval/r1/eval.py | 146 ++++++++++++++++++---------------- eval/r1/yaml/algorithmic.yaml | 3 +- 2 files changed, 77 insertions(+), 72 deletions(-) diff --git a/eval/r1/eval.py b/eval/r1/eval.py index 737707c7..3dbc39b1 100644 --- a/eval/r1/eval.py +++ b/eval/r1/eval.py @@ -1,4 +1,5 @@ import argparse +import asyncio import json import logging import os @@ -6,10 +7,9 @@ from dataclasses import asdict from datetime import datetime from typing import Any, Dict, List -import requests +import aiohttp from eval_config import EvalConfig -from requests.exceptions import RequestException -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential +from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential import reasoning_gym from reasoning_gym.utils import extract_answer @@ -30,9 +30,9 @@ class OpenRouterEvaluator: "X-Title": os.getenv("OR_APP_NAME", "Model Evaluation"), "Content-Type": "application/json", } + self.semaphore = asyncio.Semaphore(10) # Control concurrency def save_results(self, results: List[Dict[str, Any]], dataset, dataset_name) -> Dict[str, Any]: - file_name = f"{self.output_dir}/{dataset_name}.json" total_score = sum(r["score"] for r in results) @@ -45,7 +45,7 @@ class OpenRouterEvaluator: "total_examples": len(results), "timestamp": datetime.now().isoformat(), "config": asdict(dataset.config), - "results": results, # save results to allow for performance recalculation + "results": results, } with open(file_name, "w") as f: @@ -53,87 +53,93 @@ class OpenRouterEvaluator: return metrics def prepare_messages(self, prompt: str) -> List[Dict[str, str]]: - messages = [ - {"role": self.config.developer_role, "content": self.config.developer_prompt}, - {"role": "user", "content": prompt}, - ] + return { + "model": self.model, + "messages": [ + {"role": self.config.developer_role, "content": self.config.developer_prompt}, + {"role": "user", "content": prompt}, + ], + "provider": {"order": ["Nebius"], "allow_fallbacks": False}, + } + + async def get_model_response(self, session: aiohttp.ClientSession, prompt: str) -> str: payload = { "model": self.model, - "messages": messages, - "provider": {"order": ["Nebius"], "allow_fallbacks": False}, - } # make sure only one provider is used + "messages": [ + {"role": self.config.developer_role, "content": self.config.developer_prompt}, + {"role": "user", "content": prompt}, + ], + } - return payload + async for attempt in AsyncRetrying( + stop=stop_after_attempt(20), + wait=wait_exponential(multiplier=1, min=1, max=60), + retry=retry_if_exception_type( + (aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError, ValueError) + ), + ): + with attempt: + async with session.post(self.base_url, json=payload) as response: + data = await response.json() - @retry( - retry=retry_if_exception_type(RequestException), - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=60), - ) - def get_model_response(self, prompt: str) -> str: - """Get response from the model via OpenRouter API.""" + if not data: + raise ValueError("Empty response") - payload = self.prepare_messages(prompt) - try: - response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RequestException( - f"API request failed: {str(e)}", {"endpoint": self.base_url, "model": self.model} - ) from e - return response.json()["choices"][0]["message"]["content"] + if not data.get("choices"): + raise ValueError("Missing choices in response") - def evaluate_datasets(self) -> List[Dict[str, Any]]: - """Evaluate model on multiple datasets with their respective configurations.""" + return data["choices"][0]["message"]["content"] + + raise Exception("Failed to get valid response after retries") + + async def process_entry(self, session: aiohttp.ClientSession, dataset: Any, entry: Any) -> Dict[str, Any]: + """Process a single entry with concurrency control.""" + async with self.semaphore: + response = await self.get_model_response(session, entry["question"]) + model_answer = extract_answer(response) + score = dataset.score_answer(answer=model_answer, entry=entry) + print(f"Question: {entry['question']}") + + return { + "question": entry["question"], + "expected_answer": str(entry["answer"]), + "model_answer": model_answer, + "score": score, + "metadata": str(entry["metadata"]), + } + + async def evaluate_dataset(self, session: aiohttp.ClientSession, dataset_name: str) -> Dict[str, Any]: + """Evaluate a single dataset asynchronously.""" + self.logger.info(f"\nEvaluating dataset: {dataset_name}") + dataset = reasoning_gym.create_dataset( + dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed + ) + + tasks = [self.process_entry(session, dataset, entry) for entry in dataset] + results = await asyncio.gather(*tasks) + return self.save_results(results, dataset, dataset_name) + + async def evaluate_datasets(self) -> List[Dict[str, Any]]: + """Main async evaluation entry point.""" all_results = [] - - for dataset_name in self.config.datasets: - self.logger.info(f"\nEvaluating dataset: {dataset_name}") - - # Create dataset with its specific configuration - dataset = reasoning_gym.create_dataset( - dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed - ) - results = [] - - for i, entry in enumerate(dataset): - print(f"On example {i+1} of {len(dataset)}") - response = self.get_model_response(entry["question"]) - model_answer = extract_answer(response) - - score = dataset.score_answer(answer=model_answer, entry=entry) - - result = { - "question": entry["question"], - "expected_answer": str(entry["answer"]), - "model_answer": model_answer, - "score": score, - "metadata": str(entry["metadata"]), - } - results.append(result) - - metrics = self.save_results(results, dataset, dataset_name) - - all_results.append({"metrics": metrics, "results": results}) - - return all_results + async with aiohttp.ClientSession(headers=self.headers) as session: + return await asyncio.gather(*(self.evaluate_dataset(session, name) for name in self.config.datasets)) -def main(): +async def async_main(): parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets") parser.add_argument("--yaml", required=True, help="Path to YAML configuration file") - args = parser.parse_args() + config = EvalConfig.from_yaml(args.yaml) + evaluator = OpenRouterEvaluator(model=config.model, config=config) + results = await evaluator.evaluate_datasets() + output_dir = f"{config.eval_dir}/{config.category}" os.makedirs(output_dir, exist_ok=True) - - evaluator = OpenRouterEvaluator(model=config.model, config=config) - all_results = evaluator.evaluate_datasets() - with open(f"{output_dir}/summary.json", "w") as f: - json.dump(all_results, f, indent=2) + json.dump(results, f, indent=2) if __name__ == "__main__": - main() + asyncio.run(async_main()) diff --git a/eval/r1/yaml/algorithmic.yaml b/eval/r1/yaml/algorithmic.yaml index c1c043ce..5d0d630a 100644 --- a/eval/r1/yaml/algorithmic.yaml +++ b/eval/r1/yaml/algorithmic.yaml @@ -1,9 +1,8 @@ model: deepseek/deepseek-r1 category: algorithmic datasets: - - base_conversion - binary_matrix - - caesar _cipher + - caesar_cipher - group_anagrams - isomorphic_strings - letter_counting From 6ec029c0301622665e628ca1389d785d85701bb2 Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Thu, 13 Feb 2025 11:48:16 +0100 Subject: [PATCH 11/18] test: Add test to verify perfect score for PolynomialEquationsDataset --- tests/test_polynomial_equations.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index e7caf654..4d2ec50c 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -138,3 +138,13 @@ def test_polynomial_solutions_score_answer(oracle_answer, predicted_answer, expe actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer}) assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats + + +def test_polynomial_perfect_score(): + """Test that scoring an item's own answer gives a perfect score""" + cfg = PolynomialEquationsConfig(seed=42, size=10) + ds = PolynomialEquationsDataset(cfg) + + for item in ds: + score = ds.score_answer(item["answer"], item) + assert score == pytest.approx(1.0, rel=1e-6) From 3493703c331f0df71fbb87b3ee118a9aef010007 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 21:21:11 -0600 Subject: [PATCH 12/18] Fix conflict during rebasing --- reasoning_gym/algebra/polynomial_equations.py | 43 +++++++++++-------- tests/test_polynomial_equations.py | 5 +-- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index f3e9d2da..a65cdff9 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -79,28 +79,32 @@ class PolynomialEquationsDataset(ProceduralDataset): Returns: A dict with: - - question: str (e.g. "Solve the polynomial equation: 2*x^2 - 3*x + 1 = 0") - - answer: str (the sorted list of real solutions, e.g. "[0.5, 1.0]") + - question: str (e.g. "Solve the polynomial equation: 2*x**2 - 3*x + 1 = 0") + - answer: str (the sorted list of real solutions, e.g. "0.5, 1.0") - metadata: dict with details (polynomial_expr, degree, etc.) """ rng = random.Random(self.seed + idx) + for _ in range(8): + # Get variable and generate polynomial equation in standard form + variable = self._get_variable(rng) + degree = rng.randint(self.config.min_degree, self.config.max_degree) + polynomial_expr = self._generate_polynomial_expr(rng, variable, degree) + polynomial_expanded = expand(polynomial_expr) - # Get variable and generate polynomial equation in standard form - variable = self._get_variable(rng) - degree = rng.randint(self.config.min_degree, self.config.max_degree) - polynomial_expr = self._generate_polynomial_expr(rng, variable, degree) - polynomial_expanded = expand(polynomial_expr) + # Solve the polynomial = 0 + # We filter real solutions only + solutions = solve(Eq(polynomial_expanded, 0), variable, dict=False) + real_solutions = [] + for sol in solutions: + if sol.is_real: + # Evaluate symbolic solution to a floating approximation + real_solutions.append(round(float(sol.evalf()), 4)) - # Solve the polynomial = 0 - # We filter real solutions only - solutions = solve(Eq(polynomial_expanded, 0), variable, dict=False) - real_solutions = [] - for sol in solutions: - if sol.is_real: - # Evaluate symbolic solution to a floating approximation - real_solutions.append(round(float(sol.evalf()), 4)) - real_solutions.sort() - answer_str = str(real_solutions) + if len(real_solutions) > 0: + real_solutions.sort() + break + + answer_str = ", ".join(str(x) for x in real_solutions) question = rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded) + self.added_instruction return { @@ -116,7 +120,7 @@ class PolynomialEquationsDataset(ProceduralDataset): def _get_variable(self, rng: random.Random) -> str: """Get a random lowercase variable name""" - return rng.choice(string.ascii_lowercase) + return rng.choice("abcdefghklmnopqrstuvwxyz") # remove ij to avoid confusion with complex numbers def _generate_polynomial_expr(self, rng: random.Random, variable: Symbol, degree: int): """ @@ -209,6 +213,9 @@ class PolynomialEquationsDataset(ProceduralDataset): oracle_solutions = self._parse_score_to_list(entry["answer"]) # Parse oracle solutions predicted_solutions = self._parse_score_to_list(answer) # Parse predicted solutions + if len(oracle_solutions) == 0 and len(predicted_solutions) == 0: + return 1.0 + total_reward = 0.0 matched_solutions = 0 extra_solutions = 0 diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index 4d2ec50c..420187de 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -144,7 +144,6 @@ def test_polynomial_perfect_score(): """Test that scoring an item's own answer gives a perfect score""" cfg = PolynomialEquationsConfig(seed=42, size=10) ds = PolynomialEquationsDataset(cfg) - + for item in ds: - score = ds.score_answer(item["answer"], item) - assert score == pytest.approx(1.0, rel=1e-6) + assert ds.score_answer(item["answer"], item) == 1.0 From b7a721fce0d80fb1d36a6b3104ec085f9fb1c2e4 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 21:24:05 -0600 Subject: [PATCH 13/18] Fix more conflict --- reasoning_gym/algorithmic/__init__.py | 6 +- .../algorithmic/string_manipulation.py | 199 ++++++++++++++ tests/test_string_manipulation.py | 257 ++++++++++++++++++ 3 files changed, 459 insertions(+), 3 deletions(-) create mode 100644 reasoning_gym/algorithmic/string_manipulation.py create mode 100644 tests/test_string_manipulation.py diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 1b528794..3dbbe0d2 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -25,7 +25,7 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset -from .string_insertion import StringInsertionConfig, StringInsertionDataset +from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset @@ -76,6 +76,6 @@ __all__ = [ "ABDataset", "CountPrimesConfig", "CountPrimesDataset", - "StringInsertionConfig", - "StringInsertionDataset", + "StringManipulationConfig", + "StringManipulationDataset", ] diff --git a/reasoning_gym/algorithmic/string_manipulation.py b/reasoning_gym/algorithmic/string_manipulation.py new file mode 100644 index 00000000..b382921f --- /dev/null +++ b/reasoning_gym/algorithmic/string_manipulation.py @@ -0,0 +1,199 @@ +"""Manipulate a string according to a set of rules + +https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_deletion_and_modification.py +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Your job is to repeatedly transform a string according to a set of rules until no further transformations can be performed, or a state is repeated. + +Evaluate the following rules in order, and apply the first applicable rule to the string: +{rules} + +Once you have applied a rule, repeat the process with the new string until no further transformations can be performed (i.e. the string doesn't change), or a state is repeated. +If a state is repeated, the process is terminated, and the repeated state is discarded (i.e. is not considered as the final answer) and the state before the repeated state is considered as the final answer. + +Example: +- Input: + - String: abbac + - Rules: + 1. If the string prefix is 'ab', replace it with 'ca'. + 2. If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end. + 3. If the string ends with 'aa', replace it with 'cc'. +- Output: bbbacc +- Explanation: + - In the first iteration, rule 1 is applied to the string abbac, resulting in cabac + - In the second interation, rule 1 doesn't apply, but rule 2 is applied to the string cabac, resulting in bbbacc + - In the third iteration, none of the rules (1, 2, 3) apply, so the process is terminated, and the final answer is bbbacc + +Transform the following string according to the above list of rules: +{string} +""" + + +@dataclass +class StringManipulationConfig: + """Configuration for String Insertion dataset generation""" + + min_string_length: int = 5 # Minimum string length + max_string_length: int = 20 # Maximum string length + min_num_rules: int = 3 # Minimum number of rules/transforms + max_num_rules: int = 8 # Maximum number of rules/transforms + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 5 <= self.min_string_length, "Minimum string length should be at least 5" + assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + assert 3 <= self.min_num_rules, "Minimum number of rules should be at least 3" + assert self.min_num_rules <= self.max_num_rules, "Minimum number of rules should be less than maximum" + + +class StringManipulationDataset(ProceduralDataset): + """Generates String Insertion exercises with configurable difficulty""" + + def __init__(self, config: StringManipulationConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.vocabulary = ["a", "b", "c"] + self.rules = [ + ( + "If the string prefix is 'ab', replace it with 'ca'.", + lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0), + ), + ( + "If the string suffix is 'ac', replace it with 'cb'.", + lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0), + ), + ( + "If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.", + lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0), + ), + ( + "If the string suffix is 'bb', delete the last two characters.", + lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0), + ), + ( + "If the string prefix is 'cb', replace it with 'aa' and delete the last character.", + lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0), + ), + ( + "If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.", + lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0), + ), + ( + "If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.", + lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0), + ), + ( + "If the string prefix is 'aa', remove the first character.", + lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0), + ), + ( + "If the string contains 'abc', replace the first occurrence with 'cab'.", + lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0), + ), + ( + "If the string contains 'bca', delete the first occurrence entirely.", + lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0), + ), + ( + "If the string ends with 'ba', replace it with 'ab'.", + lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0), + ), + ( + "If the string starts with 'cc', remove the first two characters.", + lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0), + ), + ( + "If the string contains 'acb', replace the first occurrence with its reverse ('bca').", + lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0), + ), + ( + "If the string ends with 'ca', remove the last character.", + lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0), + ), + ( + "If the string starts with 'bb', remove the second character.", + lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0), + ), + ( + "If the string ends with 'aa', replace it with 'cc'.", + lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0), + ), + ( + "If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.", + lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0), + ), + ( + "If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.", + lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0), + ), + ( + "If the string length is greater than 15, remove the middle character.", + lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0), + ), + ( + "If the string starts with 'ac', replace the first two characters with 'zz'.", + lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0), + ), + ] + + def _apply_rule(self, string: str, selected_rules: list[tuple[str, callable]]) -> tuple[str, int]: + """ + Apply the first applicable rule from the list of selected rules. + Returns a tuple containing the modified string and the rule index (1-based) that was applied. + If no rule is applicable, returns (s, 0). + """ + for _, rule_fn in selected_rules: + new_string, op_idx = rule_fn(string) + if op_idx != 0: + return new_string, op_idx + return string, 0 + + def _get_all_transforms(self, string: str, selected_rules: list[tuple[str, callable]]) -> list[str]: + """ + Repeatedly apply transformation rules to a string until no further transformations can be performed, + or a state is repeated. If a state is repeated, the process is terminated, and the state is not added to the list. + Returns a list of string states from the initial string to the final state (i.e. the desired answer). + """ + states = [string] + while True: + new_string, op_idx = self._apply_rule(states[-1], selected_rules) + if op_idx == 0 or new_string in states: + break + states.append(new_string) + return states + + def __getitem__(self, idx: int) -> dict: + """Generate a single String Insertion question""" + rng = Random(self.seed + idx) + + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) + string = "".join(rng.choice(self.vocabulary) for _ in range(string_length)) + + num_rules = rng.randint(self.config.min_num_rules, self.config.max_num_rules) + selected_rules = rng.sample(self.rules, num_rules) + rules_str = "\n".join(f"{i+1}. {rule}" for i, (rule, _) in enumerate(selected_rules)) + + states = self._get_all_transforms(string, selected_rules) + answer = states[-1] + + return { + "question": QUESTION_TEMPLATE.format(string=string, rules=rules_str), + "answer": str(answer), + "metadata": { + "string": string, + "solution": answer, + "states": states, + "selected_rules": [rule for rule, _ in selected_rules], + }, + } + + +register_dataset("string_manipulation", StringManipulationDataset, StringManipulationConfig) diff --git a/tests/test_string_manipulation.py b/tests/test_string_manipulation.py new file mode 100644 index 00000000..f62a7acd --- /dev/null +++ b/tests/test_string_manipulation.py @@ -0,0 +1,257 @@ +"""Tests for String Manipulation questions generation""" + +import pytest + +from reasoning_gym.algorithmic.string_manipulation import StringManipulationConfig, StringManipulationDataset + + +def test_string_manipulation_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_string_length=4) # Minimum string length should be at least 5 + config.validate() + + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_string_length=10, max_string_length=7) # Max must be greater than min + config.validate() + + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_num_rules=2) # Min number of rules should be at least 3 + config.validate() + + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_num_rules=5, max_num_rules=3) # Max must be greater than min + config.validate() + + +def test_string_manipulation_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = StringManipulationConfig(seed=42, size=10) + dataset1 = StringManipulationDataset(config) + dataset2 = StringManipulationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_string_manipulation_dataset_items(): + """Test basic properties of generated items""" + config = StringManipulationConfig( + min_string_length=7, max_string_length=25, min_num_rules=5, max_num_rules=12, size=10, seed=42 + ) + dataset = StringManipulationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "string" in item["metadata"] + assert "states" in item["metadata"] + # assert "selected_rules" in item["metadata"] + assert "solution" in item["metadata"] + + string = item["metadata"]["string"] + solution = item["metadata"]["solution"] + states = item["metadata"]["states"] + selected_rules = item["metadata"]["selected_rules"] + + # Verify dimensions + assert config.min_string_length <= len(string) <= config.max_string_length + assert config.min_num_rules <= len(selected_rules) <= config.max_num_rules + assert len(states) >= 1 + assert solution == states[-1] + + +def test_string_manipulation_dataset_iteration(): + """Test that iteration respects dataset size""" + config = StringManipulationConfig(size=5, seed=42) + dataset = StringManipulationDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_string_manipulation_answer(): + """Test the method for getting the answer""" + config = StringManipulationConfig(seed=42) + dataset = StringManipulationDataset(config) + + rules = [ + ( + "If the string prefix is 'ab', replace it with 'ca'.", + lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbab", rules)[-1] == "cabbab" + + rules = [ + ( + "If the string suffix is 'ac', replace it with 'cb'.", + lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0), + ), + ] + assert dataset._get_all_transforms("abbbac", rules)[-1] == "abbbcb" + + rules = [ + ( + "If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.", + lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0), + ), + ] + assert dataset._get_all_transforms("bcabbb", rules)[-1] == "abbbaa" + + rules = [ + ( + "If the string suffix is 'bb', delete the last two characters.", + lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0), + ), + ] + assert dataset._get_all_transforms("abbbabb", rules)[-1] == "abbba" + + rules = [ + ( + "If the string prefix is 'cb', replace it with 'aa' and delete the last character.", + lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0), + ) + ] + assert dataset._get_all_transforms("cbabbb", rules)[-1] == "aaabb" + + rules = [ + ( + "If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.", + lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0), + ) + ] + assert dataset._get_all_transforms("caabbb", rules)[-1] == "bbabbbc" + + rules = [ + ( + "If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.", + lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbcc", rules)[-1] == "aabbbb" + + rules = [ + ( + "If the string prefix is 'aa', remove the first character.", + lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0), + ) + ] + assert dataset._get_all_transforms("aabbb", rules)[-1] == "abbb" + + rules = [ + ( + "If the string contains 'abc', replace the first occurrence with 'cab'.", + lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("ababcb", rules)[-1] == "cababb" # 'ababcb' -> 'abcabb' -> 'cababb' + + rules = [ + ( + "If the string contains 'bca', delete the first occurrence entirely.", + lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbcab", rules)[-1] == "abb" + + rules = [ + ( + "If the string ends with 'ba', replace it with 'ab'.", + lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbba", rules)[-1] == "abbbab" + + rules = [ + ( + "If the string starts with 'cc', remove the first two characters.", + lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0), + ) + ] + assert dataset._get_all_transforms("ccabbb", rules)[-1] == "abbb" + + rules = [ + ( + "If the string contains 'acb', replace the first occurrence with its reverse ('bca').", + lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab" + + rules = [ + ( + "If the string contains 'acb', replace the first occurrence with its reverse ('bca').", + lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab" + + rules = [ + ( + "If the string ends with 'ca', remove the last character.", + lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbca", rules)[-1] == "abbbc" + + rules = [ + ( + "If the string starts with 'bb', remove the second character.", + lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0), + ) + ] + assert dataset._get_all_transforms("bbabcbb", rules)[-1] == "babcbb" + + rules = [ + ( + "If the string ends with 'aa', replace it with 'cc'.", + lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abccbaa", rules)[-1] == "abccbcc" + + rules = [ + ( + "If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.", + lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0), + ) + ] + assert dataset._get_all_transforms("abacab", rules)[-1] == "abab" + assert dataset._get_all_transforms("caabab", rules)[-1] == "caabab" + + rules = [ + ( + "If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.", + lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0), + ) + ] + assert dataset._get_all_transforms("abab", rules)[-1] == "ababab" + assert dataset._get_all_transforms("abbab", rules)[-1] == "abbab" + + rules = [ + ( + "If the string length is greater than 15, remove the middle character.", + lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0), + ) + ] + assert ( + dataset._get_all_transforms("bccbcbbbcbbbbcccc", rules)[-1] == "bccbcbbbbbbcccc" + ) # bccbcbbbcbbbbcccc -> "bccbcbbbbbbbcccc" -> "bccbcbbbbbbcccc" + + rules = [ + ( + "If the string starts with 'ac', replace the first two characters with 'zz'.", + lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0), + ) + ] + assert dataset._get_all_transforms("acab", rules)[-1] == "zzab" From fc5fb8553701082a031fdf817d9263dc2779ca7d Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 21:26:54 -0600 Subject: [PATCH 14/18] Fix more conflict --- reasoning_gym/algebra/polynomial_equations.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index a65cdff9..88aac334 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -62,15 +62,13 @@ class PolynomialEquationsDataset(ProceduralDataset): "Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0", "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", ] - self.added_instruction = """ - \n\n - In solving the equations, please abide by the following instruction: - ## 1. All answers should be inserted in square brackets. For example [], [-0.3773, 0.4005] etc. - ## 2. In cases where your answer is b = [2 + sqrt(4560)] / 172 and b = [2 - sqrt(4560)] / 172. Since b can be 2 numbers Resolve your answer like this instead, [-0.3773, 0.4005] - ## 3. If there are no real values of i that satisfy the equation, report your answer as empty square bracket, [] - ## 4. If there are 2 answers, resolve the answers as floats and fill the 2 numbers in square bracket, if 3 answers, fill it with 3 answers. - ## 5. Resolve all numbers in square brackets as floats. Round the floats higher than 4 decimal place(d.p) down to 4 d.p. - """ + self.added_instruction = """In solving the equations, please abide by the following instruction: +## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc. +## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005". +## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "". +## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers. +## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p. +""" super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: From 58e03067f96cb5c32d0b7fb55780b9ad64bbd853 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 21:14:59 -0600 Subject: [PATCH 15/18] Fix formatting of added instructions --- play.py | 6 ++++++ reasoning_gym/algebra/intermediate_integration.py | 9 ++++----- reasoning_gym/algebra/polynomial_equations.py | 3 ++- reasoning_gym/algebra/polynomial_multiplication.py | 9 ++++----- reasoning_gym/algebra/simple_integration.py | 9 ++++----- reasoning_gym/algorithmic/number_sorting.py | 9 ++++----- tests/test_polynomial_equations.py | 2 +- 7 files changed, 25 insertions(+), 22 deletions(-) create mode 100644 play.py diff --git a/play.py b/play.py new file mode 100644 index 00000000..d7b42410 --- /dev/null +++ b/play.py @@ -0,0 +1,6 @@ +import reasoning_gym +data = reasoning_gym.create_dataset('polynomial_equations', size=3, seed=42) +for i, x in enumerate(data): + print(f"{i}: question={x['question']}\n") + print(f"{i}: answer={x['answer']}\n") + print('metadata:', x['metadata']) \ No newline at end of file diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py index 34e144df..84e01d9b 100644 --- a/reasoning_gym/algebra/intermediate_integration.py +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -77,11 +77,10 @@ class IntermediateIntegrationDataset(ProceduralDataset): "Evaluate the indefinite integral: ∫ {integrand} dx", ] self.added_instruction = """ - \n\n - In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems - ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. - ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. - """ +In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems +## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. +""" def _get_outer_constant(self, rng: random.Random) -> int: """Helper to generate signed outer constant from config""" diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index 88aac334..dad8a0b0 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -62,7 +62,8 @@ class PolynomialEquationsDataset(ProceduralDataset): "Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0", "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", ] - self.added_instruction = """In solving the equations, please abide by the following instruction: + self.added_instruction = """ +In solving the equations, please abide by the following instruction: ## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc. ## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005". ## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "". diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index d0c472b3..67a914e8 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -62,11 +62,10 @@ class PolynomialMultiplicationDataset(ProceduralDataset): "Calculate the following: {polynomial_expr}", ] self.added_instruction = """ - \n\n - In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems - ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. - ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. - """ +In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems +## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. +""" super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py index d056bc7c..8dfa775b 100644 --- a/reasoning_gym/algebra/simple_integration.py +++ b/reasoning_gym/algebra/simple_integration.py @@ -42,11 +42,10 @@ class SimpleIntegrationDataset(ProceduralDataset): "Evaluate the indefinite integral: ∫ {integrand} dx", ] self.added_instruction = """ - \n\n - In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems - ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. - ## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. - """ +In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems +## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. +""" super().__init__(config=config, seed=config.seed, size=config.size) def _generate_coefficient(self, rng: random.Random) -> Fraction: diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index 39dfda99..f906d230 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -35,11 +35,10 @@ class NumberSortingDataset(ProceduralDataset): def __init__(self, config: NumberSortingConfig): super().__init__(config=config, seed=config.seed, size=config.size) self.added_instruction = """ - \n\n - Please follow the instruction below: - # 1. Let all your answers be a list of numbers. Instead of reporting your answer as -69, -13, 1, 7, 11, 43, 59, 61, use ['-69', '-13', '1', '7', '11', '43', '59', '61'] instead - # 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61'] - """ +Please follow the instruction below: +## 1. Let all your answers be a list of numbers. Instead of reporting your answer as -69, -13, 1, 7, 11, 43, 59, 61, use ['-69', '-13', '1', '7', '11', '43', '59', '61'] instead +## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61'] +""" def _format_number(self, num: float, decimals: int) -> str: """Format number with specified decimal places""" diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index 420187de..e4e72b18 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -112,7 +112,7 @@ def test_polynomial_solutions_evaluation(): evaluated_value = poly_expr.subs(x, solution) # Ensure the evaluated value is close to zero (numerical stability threshold) - assert abs(evaluated_value) < 1e-6, ( + assert abs(evaluated_value) < 1e-5, ( f"Solution {solution} does not satisfy the polynomial {poly_str}. " f"Evaluated value: {evaluated_value}" ) From 94d4bc03fc31aacf981f722f4b7ab09f14b5b426 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 22:39:05 -0600 Subject: [PATCH 16/18] Remove play file and format with pre-commit --- .pre-commit-config.yaml | 2 +- play.py | 6 ------ reasoning_gym/algebra/polynomial_equations.py | 7 +++++-- reasoning_gym/algebra/polynomial_multiplication.py | 2 +- 4 files changed, 7 insertions(+), 10 deletions(-) delete mode 100644 play.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49987591..b921ca29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: rev: 24.1.1 hooks: - id: black - language_version: python3.11 + language_version: python3.13 - repo: https://github.com/pycqa/isort rev: 5.13.2 diff --git a/play.py b/play.py deleted file mode 100644 index d7b42410..00000000 --- a/play.py +++ /dev/null @@ -1,6 +0,0 @@ -import reasoning_gym -data = reasoning_gym.create_dataset('polynomial_equations', size=3, seed=42) -for i, x in enumerate(data): - print(f"{i}: question={x['question']}\n") - print(f"{i}: answer={x['answer']}\n") - print('metadata:', x['metadata']) \ No newline at end of file diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index dad8a0b0..cd4842ee 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -102,9 +102,12 @@ In solving the equations, please abide by the following instruction: if len(real_solutions) > 0: real_solutions.sort() break - + answer_str = ", ".join(str(x) for x in real_solutions) - question = rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded) + self.added_instruction + question = ( + rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded) + + self.added_instruction + ) return { "question": question, diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index 67a914e8..6076c32a 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -64,7 +64,7 @@ class PolynomialMultiplicationDataset(ProceduralDataset): self.added_instruction = """ In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems ## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. -## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. +## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. """ super().__init__(config=config, seed=config.seed, size=config.size) From c762bae9f0e3a12294862ea9828c8b9f2cfeda4e Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 22:52:16 -0600 Subject: [PATCH 17/18] Remove pre-commit config --- .gitignore | 3 +++ .pre-commit-config.yaml | 21 --------------------- 2 files changed, 3 insertions(+), 21 deletions(-) delete mode 100644 .pre-commit-config.yaml diff --git a/.gitignore b/.gitignore index d1e0d496..ec14783f 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ htmlcov/ # Jupyter Notebook .ipynb_checkpoints/ .virtual_documents/ + +# Pre-commit config +.pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index b921ca29..00000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,21 +0,0 @@ -repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-yaml - - id: check-added-large-files - -- repo: https://github.com/psf/black - rev: 24.1.1 - hooks: - - id: black - language_version: python3.13 - -- repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - name: isort (python) -exclude: GALLERY.md From ad318c5b8ed36865ebc2377fe2c040c1aebde6e8 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 13 Feb 2025 23:15:10 -0600 Subject: [PATCH 18/18] Remove pre-commit config in gitignore --- .gitignore | 3 --- .pre-commit-config.yaml | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.gitignore b/.gitignore index ec14783f..d1e0d496 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,3 @@ htmlcov/ # Jupyter Notebook .ipynb_checkpoints/ .virtual_documents/ - -# Pre-commit config -.pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..49987591 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 24.1.1 + hooks: + - id: black + language_version: python3.11 + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) +exclude: GALLERY.md