diff --git a/.gitignore b/.gitignore index ce057fd8..d1e0d496 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ wheels/ *.egg-info/ .installed.cfg *.egg +.python-version # Virtual Environment venv/ diff --git a/GALLERY.md b/GALLERY.md index ff56e124..239d07a9 100644 --- a/GALLERY.md +++ b/GALLERY.md @@ -12,6 +12,7 @@ This gallery shows examples from all available datasets using their default conf - [calendar_arithmetic](#calendar_arithmetic) - [chain_sum](#chain_sum) - [color_cube_rotation](#color_cube_rotation) +- [complex_arithmetic](#complex_arithmetic) - [countdown](#countdown) - [course_schedule](#course_schedule) - [family_relationships](#family_relationships) @@ -23,6 +24,7 @@ This gallery shows examples from all available datasets using their default conf - [ransom_note](#ransom_note) - [gsm_symbolic](#gsm_symbolic) - [intermediate_integration](#intermediate_integration) +- [isomorphic_strings](#isomorphic_strings) - [largest_island](#largest_island) - [lcm](#lcm) - [leg_counting](#leg_counting) @@ -36,10 +38,12 @@ This gallery shows examples from all available datasets using their default conf - [number_sorting](#number_sorting) - [palindrome](#palindrome) - [polynomial_equations](#polynomial_equations) +- [polynomial_multiplication](#polynomial_multiplication) - [prime_factorization](#prime_factorization) - [propositional_logic](#propositional_logic) - [quantum_lock](#quantum_lock) - [rubiks_cube](#rubiks_cube) +- [self_reference](#self_reference) - [sentence_reordering](#sentence_reordering) - [simple_equations](#simple_equations) - [simple_geometry](#simple_geometry) @@ -50,6 +54,7 @@ This gallery shows examples from all available datasets using their default conf - [syllogism](#syllogism) - [time_intervals](#time_intervals) - [tower_of_hanoi](#tower_of_hanoi) +- [tsumego](#tsumego) - [word_ladder](#word_ladder) - [word_sequence_reversal](#word_sequence_reversal) - [word_sorting](#word_sorting) @@ -490,6 +495,39 @@ Metadata: {'initial_state': {'top': 'orange', 'right': 'cyan', 'front': 'violet' ```` +### complex_arithmetic +Generates complex number arithmetic problems. + +Default configuration: +```python +min_real = -10 +max_real = 10 +min_imag = -10 +max_imag = 10 +operations = ('+', '-', '*', '/') +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Add the complex numbers: (-10.0 - 2.0i) + (-3.0 - 3.0i) +Answer: -13.0 - 5.0i +Metadata: {'num1': (-10.0, -2.0), 'num2': (-3.0, -3.0), 'operation': '+', 'result': (-13, -5)} + +Example 2: +Question: Add the complex numbers: (-1.0 - 6.0i) + (4.0 + 1.0i) +Answer: 3.0 - 5.0i +Metadata: {'num1': (-1.0, -6.0), 'num2': (4.0, 1.0), 'operation': '+', 'result': (3, -5)} + +Example 3: +Question: Divide the complex numbers: (-7.0 - 79.0i) ÷ (-7.0 - 5.0i) +Answer: 6.0 + 7.0i +Metadata: {'num1': (-7.0, -79.0), 'num2': (-7.0, -5.0), 'operation': '/', 'result': (6, 7)} + +```` + ### countdown Generates Countdown Number Game tasks @@ -1122,6 +1160,99 @@ Metadata: {'integrand': '2*asin(x)', 'problem_type': 'by_parts', 'variable': 'x' ```` +### isomorphic_strings +Generates Isomorphic Strings exercises with configurable difficulty + +Default configuration: +```python +max_string_length = 10 +p_solvable = 0.5 +size = 500 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: Two strings are isomorphic if the characters in one string can be replaced to get the second string. + +All occurrences of a character must be replaced with another character while preserving the order of characters. + +No two characters may map to the same character, but a character may map to itself. + +Example 1: +Input: egg add +Output: True +Explanation: The strings s and t can be made identical by: + - Mapping 'e' to 'a'. + - Mapping 'g' to 'd'. + +Example 2: +Input: foo bar +Output: False +Explanation: + - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'. + +Return True if the following two strings are isomorphic, or False otherwise: +cc bw + +Answer: False +Metadata: {'words': ['cc', 'bw'], 'solution': False, 'solvable': False} + +Example 2: +Question: Two strings are isomorphic if the characters in one string can be replaced to get the second string. + +All occurrences of a character must be replaced with another character while preserving the order of characters. + +No two characters may map to the same character, but a character may map to itself. + +Example 1: +Input: egg add +Output: True +Explanation: The strings s and t can be made identical by: + - Mapping 'e' to 'a'. + - Mapping 'g' to 'd'. + +Example 2: +Input: foo bar +Output: False +Explanation: + - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'. + +Return True if the following two strings are isomorphic, or False otherwise: +nai oik + +Answer: True +Metadata: {'words': ['nai', 'oik'], 'solution': True, 'solvable': True} + +Example 3: +Question: Two strings are isomorphic if the characters in one string can be replaced to get the second string. + +All occurrences of a character must be replaced with another character while preserving the order of characters. + +No two characters may map to the same character, but a character may map to itself. + +Example 1: +Input: egg add +Output: True +Explanation: The strings s and t can be made identical by: + - Mapping 'e' to 'a'. + - Mapping 'g' to 'd'. + +Example 2: +Input: foo bar +Output: False +Explanation: + - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'. + +Return True if the following two strings are isomorphic, or False otherwise: +hogtytyof kgqwfwfgh + +Answer: True +Metadata: {'words': ['hogtytyof', 'kgqwfwfgh'], 'solution': True, 'solvable': True} + +```` + ### largest_island Generates Largest Island exercises with configurable difficulty @@ -1745,6 +1876,46 @@ Metadata: {'polynomial_expr': '71*n**3 - 2*n - 29', 'variable': 'n', 'degree': 3 ```` +### polynomial_multiplication +Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree]. + - The polynomial is formed by summing random terms of the form: coeff * x^exponent. + - Then we find "F = P_0 * ... * P_1" using Sympy. + +Default configuration: +```python +min_terms = 2 +max_terms = 4 +min_value = 1 +max_value = 100 +min_degree = 1 +max_degree = 3 +min_polynomials = 2 +max_polynomials = 3 +single_variable = (True,) +operators = ('+', '-') +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Calculate the following: (65*x - 72)*(105*x - 125) +Answer: 6825*x**2 - 15685*x + 9000 +Metadata: {'polynomial_expr': '(65*x - 72)*(105*x - 125)', 'single_variable': (True,), 'result': '6825*x**2 - 15685*x + 9000'} + +Example 2: +Question: Calculate the following: (-9*x**2 - 28*x)*(86*x**2 - 2*x - 13) +Answer: -774*x**4 - 2390*x**3 + 173*x**2 + 364*x +Metadata: {'polynomial_expr': '(-9*x**2 - 28*x)*(86*x**2 - 2*x - 13)', 'single_variable': (True,), 'result': '-774*x**4 - 2390*x**3 + 173*x**2 + 364*x'} + +Example 3: +Question: Calculate the following: (43 - 91*x)*(3*x**2 - 10*x)*(71*x**3 - 2*x - 29) +Answer: -19383*x**6 + 73769*x**5 - 29984*x**4 + 5839*x**3 - 29271*x**2 + 12470*x +Metadata: {'polynomial_expr': '(43 - 91*x)*(3*x**2 - 10*x)*(71*x**3 - 2*x - 29)', 'single_variable': (True,), 'result': '-19383*x**6 + 73769*x**5 - 29984*x**4 + 5839*x**3 - 29271*x**2 + 12470*x'} + +```` + ### prime_factorization Generates prime factorization tasks @@ -1943,6 +2114,56 @@ Metadata: {'cube_size': 3, 'scramble_steps': 3, 'scramble_moves': "U R' R'", 'ex ```` +### self_reference +Generates self-referential puzzles + +Default configuration: +```python +difficulty = 5 +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Given the truthfulness of these statements, please tell me the number of possible solutions: + - Statement 1: 'At least 1 of these 7 statements are true.' + - Statement 2: 'At most 3 of these 7 statements are false.' + - Statement 3: 'Exactly 4 of these 7 statements are true.' + - Statement 4: 'Exactly 3 of these 7 statements are false.' + - Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.' + - Statement 6: 'The number of true statements is a prime number.' + - Statement 7: 'The number of false statements is a composite number.' + +Answer: 4 + +Example 2: +Question: Given the truthfulness of these statements, please tell me the number of possible solutions: + - Statement 1: 'At least 4 of these 7 statements are true.' + - Statement 2: 'At most 5 of these 7 statements are false.' + - Statement 3: 'Exactly 7 of these 7 statements are true.' + - Statement 4: 'Exactly 1 of these 7 statements are false.' + - Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.' + - Statement 6: 'The number of true statements is a prime number.' + - Statement 7: 'The number of false statements is a composite number.' + +Answer: 4 + +Example 3: +Question: Given the truthfulness of these statements, please tell me the number of possible solutions: + - Statement 1: 'At least 2 of these 7 statements are true.' + - Statement 2: 'At most 5 of these 7 statements are false.' + - Statement 3: 'Exactly 0 of these 7 statements are true.' + - Statement 4: 'Exactly 3 of these 7 statements are false.' + - Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.' + - Statement 6: 'The number of true statements is a prime number.' + - Statement 7: 'The number of false statements is a composite number.' + +Answer: 2 + +```` + ### sentence_reordering Generates sentence reordering tasks from text spans @@ -2295,12 +2516,10 @@ Generates syllogism reasoning tasks Default configuration: ```python -terms = None allow_all = True allow_no = True allow_some = True allow_some_not = True -include_invalid = True invalid_ratio = 0.3 seed = 42 size = 500 @@ -2311,24 +2530,24 @@ Example tasks: Example 1: Question: Consider these statements: 1. No students are humans -2. No humans are chefs +2. All humans are chefs Does it logically follow that: -No students are chefs? +All students are chefs? (Answer Yes or No) -Answer: Yes -Metadata: {'premise1': 'No students are humans', 'premise2': 'No humans are chefs', 'conclusion': 'No students are chefs', 'is_valid': True} +Answer: No +Metadata: {'premise1': 'No students are humans', 'premise2': 'All humans are chefs', 'conclusion': 'All students are chefs', 'is_valid': False} Example 2: Question: Consider these statements: -1. Some children are not animals -2. Some animals are doctors +1. All children are animals +2. No animals are doctors Does it logically follow that: -All children are doctors? +Some children are not doctors? (Answer Yes or No) Answer: Yes -Metadata: {'premise1': 'Some children are not animals', 'premise2': 'Some animals are doctors', 'conclusion': 'All children are doctors', 'is_valid': True} +Metadata: {'premise1': 'All children are animals', 'premise2': 'No animals are doctors', 'conclusion': 'Some children are not doctors', 'is_valid': True} Example 3: Question: Consider these statements: @@ -2338,8 +2557,8 @@ Question: Consider these statements: Does it logically follow that: Some butterflies are not whales? (Answer Yes or No) -Answer: No -Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are not whales', 'is_valid': False} +Answer: Yes +Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are not whales', 'is_valid': True} ```` @@ -2442,6 +2661,96 @@ Metadata: {'num_disks': 6, 'num_pegs': 3, 'start_peg': 1, 'target_peg': 2, 'auxi ```` +### tsumego +Generates (one-move) Tsumego problems with configurable parameters + +Default configuration: +```python +min_board_size = 9 +max_board_size = 13 +max_stones = 15 +size = 10 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: I have a Go problem for you. Black moves next - can you capture some of the white stones? + + A B C D E F G H I + 9 X . . . X . . . . + 8 . . . . . . . . . + 7 . O . O . . X . . + 6 . . . X . . . . O + 5 O . X O X . . . . + 4 . X O O . O . . . + 3 . . X O X . . . . + 2 . . . X . . . . . + 1 . O . O . . X . . + +X - Black +O - White + +Specify your move in coordinates (e.g. 'C4' for column C, row 4) +Answer: E4 + +Metadata: {'difficulty': {'board_size': 9}, 'board': [['X', '.', '.', '.', 'X', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', '.', 'O', '.', '.', 'X', '.', '.'], ['.', '.', '.', 'X', '.', '.', '.', '.', 'O'], ['O', '.', 'X', 'O', 'X', '.', '.', '.', '.'], ['.', 'X', 'O', 'O', '.', 'O', '.', '.', '.'], ['.', '.', 'X', 'O', 'X', '.', '.', '.', '.'], ['.', '.', '.', 'X', '.', '.', '.', '.', '.'], ['.', 'O', '.', 'O', '.', '.', 'X', '.', '.']], 'solution': 'E4'} + +-------------------------------------------------- + +Example 2: +Question: Here's a Go challenge. Playing as Black, how can you capture as many white stones as possible? + + A B C D E F G H I + 9 . . O . . . . . . + 8 . X O . . . . . . + 7 X . X . . . . . . + 6 O O O X . . . . . + 5 X O O . . . . . . + 4 . X . . . . . . O + 3 . X . . . . X . . + 2 O . O . . . . . . + 1 . . . . O . . . . + +X - Black +O - White + +Specify your move in coordinates (e.g. 'C4' for column C, row 4) +Answer: B7 + +Metadata: {'difficulty': {'board_size': 9}, 'board': [['.', '.', 'O', '.', '.', '.', '.', '.', '.'], ['.', 'X', 'O', '.', '.', '.', '.', '.', '.'], ['X', '.', 'X', '.', '.', '.', '.', '.', '.'], ['O', 'O', 'O', 'X', '.', '.', '.', '.', '.'], ['X', 'O', 'O', '.', '.', '.', '.', '.', '.'], ['.', 'X', '.', '.', '.', '.', '.', '.', 'O'], ['.', 'X', '.', '.', '.', '.', 'X', '.', '.'], ['O', '.', 'O', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', 'O', '.', '.', '.', '.']], 'solution': 'B7'} + +-------------------------------------------------- + +Example 3: +Question: Tsumego time. Black to play and capture some stones. +Find the key move. + + A B C D E F G H I J K L +12 . . . . . . . . . . . . +11 . . X . . . . . . . . . +10 . . . . . . . . . . . . + 9 . . . . . . . . . . . . + 8 X . . . . X . . . X . . + 7 . X . . . . . . . . . . + 6 . O X X . . . . . . . O + 5 . X O O X . . . . . . . + 4 . O O . . . . . O . . O + 3 X . X . . . . . . . . . + 2 . . . . . . . . . . . . + 1 . . . . . . . . . . X . + +X - Black +O - White + +Specify your move in coordinates (e.g. 'C4' for column C, row 4) +Answer: D4 + +Metadata: {'difficulty': {'board_size': 12}, 'board': [['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['X', '.', '.', '.', '.', 'X', '.', '.', '.', 'X', '.', '.'], ['.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', 'X', 'X', '.', '.', '.', '.', '.', '.', '.', 'O'], ['.', 'X', 'O', 'O', 'X', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', 'O', '.', '.', '.', '.', '.', 'O', '.', '.', 'O'], ['X', '.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', 'X', '.']], 'solution': 'D4'} + +```` + ### word_ladder Generates word ladder transformation tasks diff --git a/README.md b/README.md index 0dd159b9..a899df26 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ metadata: {'animals': {'sheep': 2, 'dog': 2}, 'total_legs': 16} ... ``` -See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets with examples. +See the [Dataset Gallery](https://github.com/open-thought/reasoning-gym/blob/main/GALLERY.md) for a complete list of available datasets with examples. ## Task Overview @@ -72,6 +72,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3\*x + 2 = 14") - `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h\*\*4 + 4*h\**2 - 5*h = 0") +- `PolynomialMultiplicationDataset`: Generate polynomial multiplicatons (e.g. "(8x^3 + x + 2)\*(y - 3)") ### Arithmetic Tasks @@ -100,6 +101,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `WordLadderDataset`: Generate word ladder puzzles where one word is transformed into another by changing one letter at a time - `GroupAnagramsDataset`: Group anagrams together in a list of words - `RansomNoteDataset`: Check if a ransom note can be created from a given set of letters in a magazine +- `IsomorphicStrings`: Check if two strings are isomorphic (have the same character mapping) ### Code Tasks @@ -118,6 +120,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `SyllogismDataset`: Generates a [syllogism](https://en.wikipedia.org/wiki/Syllogism) reasoning dataset - `AliceInWonderlandDataset`: Generates [AIW](https://openreview.net/forum?id=Mkl7dzjYiW) (Alice In Wonderland) problems with a few variations - `ZebraDataset`: Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) of varying difficulty. +- `SelfReferenceDataset`: Generates self-referencing logic puzzles. ### Graph Tasks @@ -134,6 +137,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `MazeDataset`: Generate a maze with a start and a goal - `CountdownDataset`: Generate number game tasks where numbers and operators must be combined to reach a target value - `NQueensDataset`: Generate N-Queens puzzles with configurable board size and number of starting queens +- `TsumegoDataset`: Generate Tsumego capture puzzles with variable board sizes and stone placements ## Future Generator Ideas diff --git a/pyproject.toml b/pyproject.toml index c3cc31b7..794077d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "reasoning_gym" -version = "0.1.3" +version = "0.1.5" authors = [ { name = "Open-Thought community", email = "andreas.koepf@xamla.com" }, ] @@ -31,20 +31,20 @@ license = "Apache-2.0" license-files = ["LICENSE*"] [project.optional-dependencies] -test = [ - "pytest>=7.0.0", - "pytest-cov>=4.0.0", -] +test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"] [project.urls] "Homepage" = "https://github.com/open-thought/reasoning-gym" "Bug Tracker" = "https://github.com/open-thought/reasoning-gym/issues" -[tool.hatch.build.targets.wheel] -packages = ["reasoning_gym"] [tool.hatch.build] -include = ["reasoning_gym/**/*.txt"] +packages = ["reasoning_gym"] +include = [ + "reasoning_gym/**/*.py", + "reasoning_gym/**/*.txt", + "reasoning_gym/**/levels/*", +] [tool.black] line-length = 120 @@ -58,6 +58,4 @@ line_length = 120 [tool.pytest.ini_options] addopts = "-ra -q" -testpaths = [ - "tests", -] +testpaths = ["tests"] diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index 019873ff..ecca7f3f 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -5,7 +5,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin from . import algebra, algorithmic, arithmetic, code, cognition, data, games, geometry, graphs, logic from .factory import create_dataset, register_dataset -__version__ = "0.1.3" +__version__ = "0.1.5" __all__ = [ "algebra", "algorithmic", diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index fc7a867a..fc77b977 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,9 +1,13 @@ +from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset +from .polynomial_multiplication import PolynomialMultiplicationConfig, PolynomialMultiplicationDataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset __all__ = [ + "ComplexArithmeticConfig", + "ComplexArithmeticDataset", "IntermediateIntegrationConfig", "IntermediateIntegrationDataset", "PolynomialEquationsConfig", @@ -12,4 +16,6 @@ __all__ = [ "SimpleEquationsConfig", "SimpleIntegrationConfig", "SimpleIntegrationDataset", + "PolynomialMultiplicationConfig", + "PolynomialMultiplicationDataset", ] diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py new file mode 100644 index 00000000..7c749eaa --- /dev/null +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -0,0 +1,147 @@ +import cmath +import math +import random +from dataclasses import dataclass +from typing import Optional, Tuple + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class ComplexArithmeticConfig: + min_real: int = -10 + max_real: int = 10 + min_imag: int = -10 + max_imag: int = 10 + operations: Tuple[str, ...] = ("+", "-", "*", "/") + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.max_real >= self.min_real, "max_real must be >= min_real" + assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag" + assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator" + + +class ComplexArithmeticDataset(ProceduralDataset): + """Generates complex number arithmetic problems.""" + + def __init__(self, config: ComplexArithmeticConfig): + self._prompt_templates = { + "+": "Add the complex numbers: ({a}) + ({b})", + "-": "Subtract the complex numbers: ({a}) - ({b})", + "*": "Multiply the complex numbers: ({a}) × ({b})", + "/": "Divide the complex numbers: ({a}) ÷ ({b})", + } + super().__init__(config=config, seed=config.seed, size=config.size) + + def _generate_complex(self, rng: random.Random) -> complex: + """Generate a random complex number.""" + real = rng.randint(self.config.min_real, self.config.max_real) + imag = rng.randint(self.config.min_imag, self.config.max_imag) + return complex(real, imag) + + def _format_complex(self, z: complex) -> str: + """Format complex number with 2 decimal places.""" + real, imag = z.real, z.imag + if abs(imag) < 1e-10: + return f"{real:.2f}" + elif abs(real) < 1e-10: + return f"{imag:.2f}i" + else: + sign = "+" if imag >= 0 else "-" + return f"{real} {sign} {abs(imag)}i" + + def __getitem__(self, idx: int) -> dict: + rng = random.Random(self.seed + idx) + + # Choose random operation + op = rng.choice(self.config.operations) + + if op == "/": + # For division, first generate the quotient (a) and divisor (b) + # Then calculate the dividend (result) as a * b + a = self._generate_complex(rng) # This will be the final result + b = self._generate_complex(rng) + while b == 0: # Ensure non-zero divisor + b = self._generate_complex(rng) + result = a # Store the intended result + a = result * b # Calculate dividend to ensure whole number division + else: + # For other operations, generate numbers normally + a = self._generate_complex(rng) + b = self._generate_complex(rng) + + # Calculate result + if op == "+": + result = a + b + elif op == "-": + result = a - b + else: # op == "*" + result = a * b + + question = self._prompt_templates[op].format(a=self._format_complex(a), b=self._format_complex(b)) + + return { + "question": question, + "answer": self._format_complex(result), + "metadata": { + "num1": (a.real, a.imag), + "num2": (b.real, b.imag), + "operation": op, + "result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers + }, + } + + @staticmethod + def parse_string_to_complex(answer: str) -> complex: + try: + # Normalize the answer string by removing spaces and converting to lowercase + answer = answer.replace(" ", "").lower() + # Convert mathematical notation 'i' to Python's 'j' for complex numbers + answer = answer.replace("i", "j") + + # Handle real numbers (no imaginary part) + if "j" not in answer: + student_result = complex(float(answer)) + else: + # Handle cases like "j" or "2j" (implicit coefficient) + if answer[0] == "j": + # Convert "j" to "1j", "2j" remains unchanged + answer = "1" + answer + # Handle cases like "3j" where there's no explicit + or - before j + elif answer[-1] == "j" and not any(c in answer[:-1] for c in "+-"): + # Convert "3j" to "3+1j" + answer = answer.replace("j", "+1j") + + # Ensure the string has an imaginary part, even if zero + if "j" not in answer: + answer += "+0j" + + # Parse the normalized string into a complex number + student_result = complex(answer) + + except ValueError: + return None + + return student_result + + def score_answer(self, answer: str, metadata: dict) -> float: + """Score the answer using exponential distance-based scoring.""" + if answer is None: + return 0.0 + + try: + student_result = self.parse_string_to_complex(answer) + expected_result = complex(*metadata["result"]) + # Calculate distance-based score using exponential decay + distance = abs(student_result - expected_result) + score = min(1.0, math.exp(-distance)) # Add 'import math' at the top + return score + + except (ValueError, TypeError): + return 0.0 + + +register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py new file mode 100644 index 00000000..9bcadc66 --- /dev/null +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -0,0 +1,161 @@ +import random +import string +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import sympy as sp +from sympy import Eq, Symbol, expand, solve + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class PolynomialMultiplicationConfig: + """ + Configuration for polynomial multiplication task generation. + """ + + min_terms: int = 2 # Minimum number of polynomial terms + max_terms: int = 4 # Maximum number of polynomial terms + min_value: int = 1 # Minimum value for coefficients + max_value: int = 100 # Maximum value for coefficients + min_degree: int = 1 # Minimum polynomial degree + max_degree: int = 3 # Maximum polynomial degree + min_polynomials: int = 2 # Minimum number of polynomials being multiplied + max_polynomials: int = 3 # Maximum number of polynomials being multiplied + single_variable: bool = (True,) + operators: Tuple[str, ...] = ( + "+", + "-", + ) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.min_terms > 0, "min_terms must be positive." + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms." + + assert self.min_value > 0, "min_value must be positive." + assert self.max_value >= self.min_value, "max_value must be >= min_value." + + assert self.min_degree >= 1, "min_degree must be >= 1." + assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree." + + assert self.min_polynomials >= 2, "min_polynomials must be >= 2." + assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials." + + allowed_ops = {"+", "-"} + assert len(self.operators) > 0, "operators tuple cannot be empty." + assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}." + + +class PolynomialMultiplicationDataset(ProceduralDataset): + """ + Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree]. + - The polynomial is formed by summing random terms of the form: coeff * x^exponent. + - Then we find "F = P_0 * ... * P_1" using Sympy. + """ + + def __init__(self, config: PolynomialMultiplicationConfig): + self._prompt_templates = [ + "Simplify this expression: {polynomial_expr}", + "Calculate the following: {polynomial_expr}", + ] + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """ + Generate a single polynomial multiplication item. + + Returns: + A dict with: + - question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)") + - answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6") + - metadata: dict with details (polynomial_expr, single_variable) + """ + rng = random.Random(self.seed + idx) + number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials) + polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)] + + polynomial_expr = sp.prod(polynomials) + product = sp.expand(polynomial_expr) + + return { + "question": rng.choice(self._prompt_templates).format( + polynomial_expr=polynomial_expr, + ), + "answer": product, + "metadata": { + "polynomial_expr": str(polynomial_expr), + "single_variable": self.config.single_variable, + "result": str(product), + }, + } + + def _get_variable(self, rng: random.Random) -> str: + """Get a random lowercase variable name""" + if self.config.single_variable: + return "x" + return rng.choice(string.ascii_lowercase) + + def _generate_polynomial_expr(self, rng: random.Random): + """ + Randomly generate a polynomial expression of 'degree'. + We'll use the config parameters: + - min_terms, max_terms: how many total terms to combine + - min_value, max_value: range for coefficients + - operators: to decide sign flips or direct addition + + Args: + rng: Random number generator + + Returns: + Polynomial string + """ + variable = self._get_variable(rng) + degree = rng.randint(self.config.min_degree, self.config.max_degree) + + x = Symbol(variable) + + # Choose the number of terms and their respective degrees + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) + # Keep track of exponents, exponents can repeat or skip but we force the highest exponent + chosen_exponents = [degree] + # Fill the rest randomly in [0, degree] + for _ in range(num_terms - 1): + exp = rng.randint(0, degree) + chosen_exponents.append(exp) + + # Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign + polynomial_expr = 0 + for exp in chosen_exponents: + coeff = rng.randint(self.config.min_value, self.config.max_value) + # If '-' in operators, we can randomly flip the sign + if "-" in self.config.operators and rng.random() < 0.5: + coeff = -coeff + term_expr = coeff * (x**exp) + polynomial_expr += term_expr + + return polynomial_expr + + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + reward = 0.0 + if answer is not None: + try: + predicted_poly = sp.parse_expr(answer) + target_poly = sp.parse_expr(metadata["result"]) + + # Check if the difference simplifies to zero (i.e. they are equivalent). + if sp.simplify(predicted_poly - target_poly) == 0: + reward = 1.0 + elif answer.strip(): + reward = 0.05 + else: + reward = 0.01 + except Exception: + reward = 0.01 + return reward + + +register_dataset("polynomial_multiplication", PolynomialMultiplicationDataset, PolynomialMultiplicationConfig) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index e5326af5..4c33d08f 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -9,6 +9,7 @@ Algorithmic tasks for training reasoning capabilities: from .base_conversion import BaseConversionConfig, BaseConversionDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset +from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset @@ -51,4 +52,6 @@ __all__ = [ "GroupAnagramsDataset", "RansomNoteConfig", "RansomNoteDataset", + "IsomorphicStringsConfig", + "IsomorphicStringsDataset", ] diff --git a/reasoning_gym/algorithmic/isomorphic_strings.py b/reasoning_gym/algorithmic/isomorphic_strings.py new file mode 100644 index 00000000..3b4a59e5 --- /dev/null +++ b/reasoning_gym/algorithmic/isomorphic_strings.py @@ -0,0 +1,121 @@ +"""Check if two strings are isomorphic. + +Two strings are isomorphic if the characters in one string can be replaced to get the second string. + +A popular Leetcode problem: +https://leetcode.com/problems/isomorphic-strings/description/ +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Two strings are isomorphic if the characters in one string can be replaced to get the second string. + +All occurrences of a character must be replaced with another character while preserving the order of characters. + +No two characters may map to the same character, but a character may map to itself. + +Example 1: +Input: egg add +Output: True +Explanation: The strings s and t can be made identical by: + - Mapping 'e' to 'a'. + - Mapping 'g' to 'd'. + +Example 2: +Input: foo bar +Output: False +Explanation: + - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'. + +Return True if the following two strings are isomorphic, or False otherwise: +{s} {t} +""" + + +@dataclass +class IsomorphicStringsConfig: + """Configuration for Isomorphic Strings dataset generation""" + + max_string_length: int = 10 # Maximum length of the strings + p_solvable: float = 0.5 # Probability that the generated question is solvable + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 2 <= self.max_string_length, "max_string_length must be at least 2" + assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1" + + +class IsomorphicStringsDataset(ProceduralDataset): + """Generates Isomorphic Strings exercises with configurable difficulty""" + + def __init__(self, config: IsomorphicStringsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.letters = {chr(i) for i in range(ord("a"), ord("z") + 1)} + + def _check_isomorphic(self, s: str, t: str) -> bool: + """Check if two strings are isomorphic""" + if len(s) != len(t): + return False + + mapping, inverse_mapping = {}, {} # s -> t, t -> s + for i in range(len(s)): + if (s[i] in mapping and mapping[s[i]] != t[i]) or ( + t[i] in inverse_mapping and s[i] != inverse_mapping[t[i]] + ): + return False + mapping[s[i]] = t[i] + inverse_mapping[t[i]] = s[i] + + return True + + def _generate_inputs(self, rng: Random, solvable: bool) -> tuple[str, str]: + """Generate the two input strings""" + s, t = [], [] + mapping = {} + + # Generate a valid isomorphic pair first (leave one character for potential conflict) + for _ in range(rng.randint(1, self.config.max_string_length - 1)): + char_s = rng.choice(list(self.letters)) + if char_s not in mapping: + # Choose a random character that is not already mapped + char_t = rng.choice(list(self.letters - set(mapping.values()))) + mapping[char_s] = char_t + else: + # Use the existing mapping + char_t = mapping[char_s] + s.append(char_s) + t.append(char_t) + + if not solvable: + # Solution should be unsolvable, create conflict + letter = rng.choice(list(mapping.keys())) + conflict = rng.choice(list(self.letters - {mapping[letter]})) + insert_idx = rng.randint(0, len(s)) + s.insert(insert_idx, letter) + t.insert(insert_idx, conflict) + + return "".join(s), "".join(t) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Isomorphic Strings question""" + rng = Random(self.seed + idx) + + solvable = rng.random() < self.config.p_solvable + s, t = self._generate_inputs(rng, solvable) + answer = self._check_isomorphic(s, t) + + return { + "question": QUESTION_TEMPLATE.format(s=s, t=t), + "answer": str(answer), + "metadata": {"words": [s, t], "solution": answer, "solvable": solvable}, + } + + +register_dataset("isomorphic_strings", IsomorphicStringsDataset, IsomorphicStringsConfig) diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 958dcd01..295f6cdf 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -14,6 +14,7 @@ from .n_queens import NQueensDataset from .sokoban import SokobanConfig, SokobanDataset from .sudoku import SudokuConfig, SudokuDataset from .tower_of_hanoi import HanoiConfig, HanoiDataset +from .tsumego import TsumegoConfig, TsumegoDataset __all__ = [ "CountdownConfig", @@ -31,4 +32,6 @@ __all__ = [ "HanoiConfig", "HanoiDataset", "NQueensDataset", + "TsumegoConfig", + "TsumegoDataset", ] diff --git a/reasoning_gym/games/tsumego.py b/reasoning_gym/games/tsumego.py new file mode 100644 index 00000000..be1e4fd6 --- /dev/null +++ b/reasoning_gym/games/tsumego.py @@ -0,0 +1,305 @@ +"""Go problem (tsumego) generator""" + +""" +This module generates one-move Tsumego puzzles, which are Go problems focused on tactical capture scenarios. + +The puzzles generated here have the following characteristics: +- They are created on a board of configurable size (with a minimum and maximum board size). +- A number of stones are randomly placed on the board, subject to a maximum stone limit. +- A specific capture problem is then constructed by arranging white stones in a plus-shaped formation. +- Extra liberties surrounding this white group are filled with black stones, except for one key liberty. + This forces a situation where a single move by Black (at the remaining liberty) results in a capture. +- Puzzle generation is deterministic given a seed, which ensures reproducibility. + +These puzzles are intended to provide focused practice on reading and executing capturing moves in Go. + +TODO: Generate multi-step Tsumego problems. +""" + +import re +from dataclasses import dataclass +from random import Random +from typing import Any, Dict, List, Optional, Set, Tuple + +from ..factory import ProceduralDataset, register_dataset + +# Added constant to avoid repetition of adjacent directions +DIRECTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)] + + +@dataclass +class TsumegoConfig: + """Configuration for Tsumego problem generation""" + + min_board_size: int = 9 + max_board_size: int = 13 + max_stones: int = 15 + size: int = 100 + seed: Optional[int] = None + + def __post_init__(self): + """Validate configuration parameters""" + if self.min_board_size < 5: + raise ValueError("min_board_size must be at least 5") + if self.max_board_size > 19: + raise ValueError("max_board_size must be at most 19") + if self.min_board_size > self.max_board_size: + raise ValueError("min_board_size must be less than or equal to max_board_size") + if self.max_stones < 5: + raise ValueError("max_stones must be at least 5") + + +class TsumegoDataset(ProceduralDataset): + """Generates Tsumego problems with configurable parameters""" + + def __init__(self, config: TsumegoConfig): + self._prompt_templates = [ + "Tsumego time. Black to play and capture some stones.\nFind the key move.", + "I have a Go problem for you. Black moves next - can you capture some of the white stones?", + "Here's a Go challenge. Playing as Black, how can you capture as many white stones as possible?", + ] + self._ko_point = None + super().__init__(config=config, seed=config.seed, size=config.size) + + # New helper method for board copying + def _copy_board(self, board: List[List[str]]) -> List[List[str]]: + """Return a deep copy of the board.""" + return [row[:] for row in board] + + def _get_liberties(self, board: List[List[str]], row: int, col: int) -> Set[Tuple[int, int]]: + """Get empty adjacent points (liberties) for a stone""" + size = len(board) + liberties = set() + for dr, dc in DIRECTIONS: + r, c = row + dr, col + dc + if 0 <= r < size and 0 <= c < size and board[r][c] == ".": + liberties.add((r, c)) + return liberties + + def _get_group(self, board: List[List[str]], row: int, col: int) -> Set[Tuple[int, int]]: + """Get all stones in the same group (connected stones of same color)""" + size = len(board) + color = board[row][col] + if color == ".": + return set() + + group = {(row, col)} + queue = [(row, col)] + while queue: + r, c = queue.pop(0) + for dr, dc in DIRECTIONS: + nr, nc = r + dr, c + dc + if 0 <= nr < size and 0 <= nc < size and board[nr][nc] == color and (nr, nc) not in group: + group.add((nr, nc)) + queue.append((nr, nc)) + return group + + def _count_liberties(self, board: List[List[str]], group: Set[Tuple[int, int]]) -> int: + """Count total liberties for a group of stones""" + liberties = set() + for row, col in group: + liberties.update(self._get_liberties(board, row, col)) + return len(liberties) + + def _would_capture(self, board: List[List[str]], row: int, col: int, color: str) -> bool: + """Check if a move would capture any opponent stones""" + size = len(board) + opponent = "O" if color == "X" else "X" + + # Make a copy of the board and place the stone + board_copy = self._copy_board(board) + board_copy[row][col] = color + + checked = set() + for dr, dc in DIRECTIONS: + r, c = row + dr, col + dc + if 0 <= r < size and 0 <= c < size and board_copy[r][c] == opponent and (r, c) not in checked: + group = self._get_group(board_copy, r, c) + checked.update(group) + if self._count_liberties(board_copy, group) == 0: + return True + return False + + def _is_valid_move(self, board: List[List[str]], row: int, col: int, color: str) -> bool: + """Check if a move is legal (not suicide, unless it captures)""" + size = len(board) + if not (0 <= row < size and 0 <= col < size): + return False + if board[row][col] != ".": + return False + if (row, col) == self._ko_point: + return False + + # If the move captures opponent stones, it's valid + if self._would_capture(board, row, col, color): + return True + + board_copy = self._copy_board(board) + board_copy[row][col] = color + group = self._get_group(board_copy, row, col) + return self._count_liberties(board_copy, group) > 0 + + def _make_move(self, board: List[List[str]], row: int, col: int, color: str) -> bool: + """Make a move and update ko point. Returns True if move was valid.""" + if not self._is_valid_move(board, row, col, color): + return False + + self._ko_point = None + board[row][col] = color + opponent = "O" if color == "X" else "X" + captured_stones = [] + + for dr, dc in DIRECTIONS: + r, c = row + dr, col + dc + if 0 <= r < len(board) and 0 <= c < len(board) and board[r][c] == opponent: + group = self._get_group(board, r, c) + if self._count_liberties(board, group) == 0: + captured_stones.extend(group) + + if len(captured_stones) == 1 and len(self._get_group(board, row, col)) == 1: + self._ko_point = captured_stones[0] + + for r, c in captured_stones: + board[r][c] = "." + + return True + + def _generate_capture_problem(self, size: int, rng: Random) -> Tuple[List[List[str]], Tuple[int, int]]: + """Generate a capture problem""" + board = [["." for _ in range(size)] for _ in range(size)] + stones_placed = 0 + max_stones = self.config.max_stones - 4 # Reserve space for capture setup + + while stones_placed < max_stones: + row = rng.randint(0, size - 1) + col = rng.randint(0, size - 1) + color = "X" if rng.random() < 0.5 else "O" + if board[row][col] == "." and self._is_valid_move(board, row, col, color): + self._make_move(board, row, col, color) + stones_placed += 1 + + tries = 0 + formation_options = { + "plus": { + "white_offsets": [(0, 0), (-1, 0), (1, 0), (0, -1)], + "forced_move_offset": (0, 1), + "neighbor_offsets": [(0, 0), (-1, 0), (1, 0), (0, -1), (0, 1)], + }, + "L": { + "white_offsets": [(0, 0), (0, 1), (1, 0)], + "forced_move_offset": (1, 1), + "neighbor_offsets": [(0, 0), (0, 1), (1, 0), (1, 1)], + }, + "T": { + "white_offsets": [(0, -1), (0, 0), (0, 1), (1, 0)], + "forced_move_offset": (-1, 0), + "neighbor_offsets": [(0, -1), (0, 0), (0, 1), (1, 0), (-1, 0)], + }, + } + + while tries < 50: + row = rng.randint(1, size - 2) + col = rng.randint(1, size - 2) + formation_type = rng.choice(list(formation_options.keys())) + formation = formation_options[formation_type] + if all(board[row + dr][col + dc] == "." for dr, dc in formation["neighbor_offsets"]): + # Place white stones according to chosen formation + for dr, dc in formation["white_offsets"]: + board[row + dr][col + dc] = "O" + forced_move = (row + formation["forced_move_offset"][0], col + formation["forced_move_offset"][1]) + white_group = {(row + dr, col + dc) for dr, dc in formation["white_offsets"]} + extra_liberties = set() + for r, c in white_group: + extra_liberties |= self._get_liberties(board, r, c) + extra_liberties.discard(forced_move) + for r, c in extra_liberties: + board[r][c] = "X" + + # Add decoy stone to enhance puzzle difficulty + current_stone_count = sum(cell in "XO" for row in board for cell in row) + if current_stone_count < self.config.max_stones + 7: + center = (row, col) # using the base white stone as center + decoy_candidates = [] + for i in range(center[0] - 2, center[0] + 3): + for j in range(center[1] - 2, center[1] + 3): + if abs(i - center[0]) + abs(j - center[1]) == 2: + if 0 <= i < size and 0 <= j < size and board[i][j] == "." and (i, j) != forced_move: + decoy_candidates.append((i, j)) + if decoy_candidates: + decoy_pos = rng.choice(decoy_candidates) + decoy_color = "X" if rng.random() < 0.5 else "O" + board[decoy_pos[0]][decoy_pos[1]] = decoy_color + + if self._is_valid_move(board, forced_move[0], forced_move[1], "X"): + return board, forced_move + tries += 1 + raise RuntimeError("Failed to generate a capture problem") + + def _board_to_string(self, board: List[List[str]]) -> str: + """Convert board to string representation""" + size = len(board) + # Column labels + cols = " " + " ".join(chr(ord("A") + i) for i in range(size)) + "\n" + # Board with row numbers + rows = [f"{size-i:2d} {' '.join(row)}" for i, row in enumerate(board)] + return cols + "\n".join(rows) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Tsumego problem + + Returns: + dict with: + - "question": Problem description and board state + - "answer": Solution move(s) + - "metadata": Problem details and configuration + """ + rng = Random(self.seed + idx if self.seed is not None else None) + size = rng.randint(self.config.min_board_size, self.config.max_board_size) + + board, solution = self._generate_capture_problem(size, rng) + board_str = self._board_to_string(board) + solution_str = f"{chr(ord('A')+solution[1])}{size - solution[0]}" + self._ko_point = None + + return { + "question": ( + rng.choice(self._prompt_templates) + "\n\n" + board_str + "\n\n" + "X - Black\n" + "O - White\n\n" + "Specify your move in coordinates (e.g. 'C4' for column C, row 4)" + ), + "answer": solution_str, + "metadata": {"difficulty": {"board_size": size}, "board": board, "solution": solution_str}, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + """Score the answer against the solution""" + if answer is None: + return 0.0 + answer = answer.strip() + if not answer: + return 0.01 + metadata = entry["metadata"] + board_size = len(metadata["board"]) + expected_row, expected_col = metadata["solution"] # get solution from (row, col) tuple + + try: + # Assume letter-number format, e.g. "C4" + m = re.match(r"^([A-Za-z])(\d+)$", answer) + if not m: + return 0.01 + col_letter, row_str = m.group(1), m.group(2) + row = board_size - int(row_str) + col = ord(col_letter.upper()) - ord("A") + if (row, col) == (expected_row, expected_col): + return 1.0 + + if 0 <= row < board_size and 0 <= col < board_size: + return 0.05 + except Exception: + return 0.01 + return 0.01 + + +# Register the dataset +register_dataset("tsumego", TsumegoDataset, TsumegoConfig) diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index dfa1c7ad..c05c4dba 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -4,6 +4,7 @@ Logic tasks for training reasoning capabilities. from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset +from .self_reference import SelfReferenceConfig, SelfReferenceDataset from .syllogisms import SyllogismConfig, SyllogismDataset, Term from .zebra_puzzles import ZebraConfig, ZebraDataset @@ -18,4 +19,6 @@ __all__ = [ "Term", "ZebraConfig", "ZebraDataset", + "SelfReference", + "SelfReferenceDataset", ] diff --git a/reasoning_gym/logic/self_reference.py b/reasoning_gym/logic/self_reference.py new file mode 100644 index 00000000..d8155b4c --- /dev/null +++ b/reasoning_gym/logic/self_reference.py @@ -0,0 +1,373 @@ +from dataclasses import dataclass +from random import Random +from typing import Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + + +def is_prime(n): + """Return True if n is a prime number, False otherwise.""" + if n < 2: + return False + for i in range(2, int(n**0.5) + 1): + if n % i == 0: + return False + return True + + +def is_composite(n): + """ + Return True if n is composite. + (Composite means an integer greater than 1 that is not prime.) + """ + return n > 1 and not is_prime(n) + + +def generate_dynamic_puzzle(difficulty, rng): + """ + Dynamically generates a 7-statement self-referential puzzle. + + The seven statements (with parameters determined by this function) are: + + 1. "At least a of these 7 statements are true." + 2. "At most b of these 7 statements are false." + 3. "Exactly c of these 7 statements are true." + 4. "Exactly d of these 7 statements are false." + 5. "Either Statement 3 or Statement 4 is true, but not both." + 6. "The number of true statements is a prime number." + 7. "The number of false statements is a composite number." + + The idea is to choose an intended number T (1 ≤ T ≤ 6) of true statements + and then “plant” an intended solution. In our construction the truth values + for Statements 6 and 7 are forced by T (e.g. Statement 6 should be true exactly + when T is prime). For the first four statements the numeric parameters (a, b, c, d) + are chosen so that the statement evaluates correctly when compared to T. + + The difficulty parameter (an integer, e.g. 1 for easy up to 10 for hard) + influences how “borderline” the numeric choices are. At lower difficulty the numbers + are chosen with a clear gap; at higher difficulty they are chosen closer to T. + + Returns: + dict: A puzzle dictionary containing: + - 'n': number of statements (always 7 here), + - 'statements_text': a list of 7 strings (one per statement), + - 'parameters': a dict with the numeric parameters (for statements 1-4), + - 'intended_assignment': the intended truth values (list of 7 booleans), + - 'intended_T': the intended number of true statements. + """ + n = 7 + + # Choose an intended number of true statements, T, from 1 to 6 (nontrivial). + T = rng.choice(range(1, n)) + + # For the global statements (6 and 7), the intended truth is forced: + intended6 = is_prime(T) # Statement 6 must be true if T is prime. + intended7 = is_composite(n - T) # Statement 7 must be true if (# false) is composite. + + # Among statements 1-5, we need exactly k trues such that overall the total becomes T. + # Let k = T - (truth from statements 6 and 7). + forced_true_count = (1 if intended6 else 0) + (1 if intended7 else 0) + k = T - forced_true_count + # k must be between 0 and 5. + if not (0 <= k <= 5): + # If for some reason it is not in range, fall back to a known configuration (T=4). + T = 4 + intended6 = False + intended7 = False + k = 4 # so that overall T=4. + intended_assignment_15 = [True, True, True, True, False] + else: + # For statements 1-5, randomly choose which ones are intended true. + # We'll index these as 0..4 corresponding to statements 1..5. + intended_assignment_15 = [False] * 5 + if k > 0: + true_indices = set(rng.sample(range(5), k)) + for i in true_indices: + intended_assignment_15[i] = True + + # Now, for statements 1-4, choose numeric parameters based on whether the statement is + # intended to be true or false. We use the difficulty parameter to control the "margin." + # + # For statement 1: "At least a of these 7 statements are true." + # The condition is: T >= a. + def choose_at_least_param(T, intended, diff, rng): + # diff will be used as a margin factor: lower diff => wider gap. + if intended: # must have a <= T. + # At easy difficulty, choose a clearly below T (if possible). + low = 1 + high = T + # At lower difficulty, bias toward the lower end. + return rng.randint(low, high) + else: # must have a > T. + low = T + 1 + high = n # a can be at most n. + if low > high: + return n + return rng.randint(low, high) + + a_param = choose_at_least_param(T, intended_assignment_15[0], difficulty, rng) + + # For statement 2: "At most b of these 7 statements are false." + # F = n - T, so condition is: (n - T) <= b <=> T >= n - b. + def choose_at_most_param(T, intended, diff, rng): + if intended: # b must be >= n - T. + low = n - T + high = n + return rng.randint(low, high) + else: + # b must be < n - T. + low = 0 + high = max(n - T - 1, 0) + return rng.randint(low, high) + + b_param = choose_at_most_param(T, intended_assignment_15[1], difficulty, rng) + + # For statement 3: "Exactly c of these 7 statements are true." + def choose_exactly_true_param(T, intended, diff, rng): + if intended: + return T + else: + choices = [x for x in range(0, n + 1) if x != T] + return rng.choice(choices) + + c_param = choose_exactly_true_param(T, intended_assignment_15[2], difficulty, rng) + + # For statement 4: "Exactly d of these 7 statements are false." + # Condition: (n - T) == d. + def choose_exactly_false_param(T, intended, diff, rng): + false_count = n - T + if intended: + return false_count + else: + choices = [x for x in range(0, n + 1) if x != false_count] + return rng.choice(choices) + + d_param = choose_exactly_false_param(T, intended_assignment_15[3], difficulty, rng) + + # For statement 5: "Either Statement 3 or Statement 4 is true, but not both." + # We do not need a parameter here; the intended condition is that the truth values for + # statements 3 and 4 (which are positions 2 and 3 in our 0-indexed list) differ. + # The intended truth for statement 5 is taken from our assignment. + # (Later the verification function will check: solution[2] != solution[3].) + + # Build the intended assignment for all 7 statements. + # For statements 1-5, we use our generated intended_assignment_15. + intended_assignment = [ + intended_assignment_15[0], + intended_assignment_15[1], + intended_assignment_15[2], + intended_assignment_15[3], + intended_assignment_15[4], + intended6, + intended7, + ] + + # (If the total intended true count doesn't equal T, adjust statement 5.) + current_T = sum(intended_assignment) + if current_T != T: + # Since only statement 5 is free (its parameter wasn't numeric), + # force its intended truth to be what is needed. + intended_assignment[4] = T - (current_T - (1 if intended_assignment[4] else 0)) == 1 + + # Now build the text for each statement. + statements_text = [ + f"Statement 1: 'At least {a_param} of these 7 statements are true.'", + f"Statement 2: 'At most {b_param} of these 7 statements are false.'", + f"Statement 3: 'Exactly {c_param} of these 7 statements are true.'", + f"Statement 4: 'Exactly {d_param} of these 7 statements are false.'", + "Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.'", + "Statement 6: 'The number of true statements is a prime number.'", + "Statement 7: 'The number of false statements is a composite number.'", + ] + + return { + "n": n, + "statements_text": statements_text, + "parameters": { + "a": a_param, + "b": b_param, + "c": c_param, + "d": d_param, + }, + "intended_assignment": intended_assignment, + "intended_T": T, + "difficulty": difficulty, + } + + +def verify_solution_dynamic(puzzle, solution): + """ + Verifies a candidate solution for a dynamically generated puzzle. + + The rules are: + - If a statement is marked True, then its claim must hold. + - If a statement is marked False, then its claim must fail. + + The conditions are as follows: + 1. "At least a of these 7 statements are true." => (T >= a) + 2. "At most b of these 7 statements are false." => (F <= b) + 3. "Exactly c of these 7 statements are true." => (T == c) + 4. "Exactly d of these 7 statements are false." => (F == d) + 5. "Either Statement 3 or Statement 4 is true, but not both." => (solution[2] != solution[3]) + 6. "The number of true statements is a prime number." => is_prime(T) + 7. "The number of false statements is a composite number." => is_composite(F) + + Parameters: + puzzle (dict): The puzzle dictionary returned by generate_dynamic_puzzle. + solution (list of bool): A candidate assignment (length 7). + + Returns: + bool: True if candidate is self-consistent; False otherwise. + """ + n = puzzle["n"] + if len(solution) != n: + return False + T = sum(solution) + F = n - T + params = puzzle["parameters"] + + # Statement 1: "At least a of these 7 statements are true." + cond1 = T >= params["a"] + if solution[0] and not cond1: + return False + if not solution[0] and cond1: + return False + + # Statement 2: "At most b of these 7 statements are false." + cond2 = F <= params["b"] + if solution[1] and not cond2: + return False + if not solution[1] and cond2: + return False + + # Statement 3: "Exactly c of these 7 statements are true." + cond3 = T == params["c"] + if solution[2] and not cond3: + return False + if not solution[2] and cond3: + return False + + # Statement 4: "Exactly d of these 7 statements are false." + cond4 = F == params["d"] + if solution[3] and not cond4: + return False + if not solution[3] and cond4: + return False + + # Statement 5: "Either Statement 3 or Statement 4 is true, but not both." + cond5 = solution[2] != solution[3] + if solution[4] and not cond5: + return False + if not solution[4] and cond5: + return False + + # Statement 6: "The number of true statements is a prime number." + cond6 = is_prime(T) + if solution[5] and not cond6: + return False + if not solution[5] and cond6: + return False + + # Statement 7: "The number of false statements is a composite number." + cond7 = is_composite(F) + if solution[6] and not cond7: + return False + if not solution[6] and cond7: + return False + + return True + + +def print_puzzle_dynamic(puzzle): + """Prints the dynamically generated puzzle.""" + x = "" + for stmt in puzzle["statements_text"]: + x = x + " - " + stmt + "\n" + return x + + +def solve_puzzle_dynamic(puzzle): + """ + Searches all 2^7 possible truth assignments and returns those that + are self-consistent with the generated puzzle. + """ + n = puzzle["n"] + valid_solutions = [] + for i in range(2**n): + candidate = [(i >> j) & 1 == 1 for j in range(n)] + if verify_solution_dynamic(puzzle, candidate): + valid_solutions.append(candidate) + return valid_solutions + + +@dataclass +class SelfReferenceConfig: + """Configuration for SelfReference puzzle generation""" + + difficulty: int = 5 + seed: Optional[int] = None + size: int = 500 + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.difficulty <= 10, "difficulty must be between 1 and 10" + + +class SelfReferenceDataset(ProceduralDataset): + """Generates self-referential puzzles""" + + def __init__(self, config: SelfReferenceConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single SelfReference task + + Returns: + dict with keys: + - question: str, the task description + - answer: str, a solution string + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + + # Generate puzzle + puzzle = generate_dynamic_puzzle(self.config.difficulty, rng) + puzz_s = ( + "Given the truthfulness of these statements, please tell me the number of possible solutions: \n" + + print_puzzle_dynamic(puzzle) + ) + + # Solve puzzle + solutions = solve_puzzle_dynamic(puzzle) + for idx, sol in enumerate(solutions, start=1): + sol_str = ["True" if s else "False" for s in sol] + answer = len(solutions) + + return { + "question": puzz_s, + "answer": answer, + "metadata": {}, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """Determine if the solution provided solves the SelfReference task. + + The function awards 1.0 for a correct answer. + + Args: + answer (Optional[str]): The user's answer. + entry (Dict[str, any]): The original dataset entry containing the correct answer. + + Returns: + float: The computed score between 0.0 and 1.0. + """ + + if answer == None: + return 0.0 + if str(answer) != str(entry["answer"]): + return 0.1 + else: + return 1.0 # Yay + + +register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig) diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py index a5bbb219..37b87a6f 100644 --- a/reasoning_gym/logic/syllogisms.py +++ b/reasoning_gym/logic/syllogisms.py @@ -22,23 +22,21 @@ class Term: self.name = name self.plural = plural + def __repr__(self) -> str: + """Return string representation of the term""" + return f"Term({self.name}, {self.plural})" + @dataclass class SyllogismConfig: """Configuration for syllogism task generation""" - # Lists of terms to use in syllogisms - terms: List[Term] = None # Will be populated with defaults if None - # Control which quantifiers to use allow_all: bool = True allow_no: bool = True allow_some: bool = True allow_some_not: bool = True - # Whether to include invalid syllogisms as negative examples - include_invalid: bool = True - # Percentage of invalid examples if included (0.0 to 1.0) invalid_ratio: float = 0.3 @@ -101,7 +99,7 @@ class SyllogismDataset(ProceduralDataset): def __init__(self, config: SyllogismConfig): super().__init__(config=config, seed=config.seed, size=config.size) - self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms + self.terms = self.DEFAULT_TERMS def _get_allowed_quantifiers(self) -> List[Quantifier]: """Get list of allowed quantifiers based on config""" @@ -116,95 +114,126 @@ class SyllogismDataset(ProceduralDataset): quantifiers.append(Quantifier.SOME_NOT) return quantifiers + @staticmethod def _is_valid_syllogism( - self, - premise1: Tuple[Quantifier, Term, Term], - premise2: Tuple[Quantifier, Term, Term], - conclusion: Tuple[Quantifier, Term, Term], + premise1: Tuple[Quantifier, "Term", "Term"], + premise2: Tuple[Quantifier, "Term", "Term"], + conclusion: Tuple[Quantifier, "Term", "Term"], ) -> bool: """ - Check if a syllogism is logically valid using classical logic rules. - - Rules implemented: - 1. Universal Affirmative (ALL): - - If both premises are ALL, conclusion must be ALL - - ALL A are B + ALL B are C → ALL A are C (Barbara) - - 2. Universal Negative (NO): - - If one premise is NO and other is ALL, conclusion must be NO - - NO A are B + ALL C are B → NO A are C (Celarent) - - ALL A are B + NO C are B → NO A are C (Cesare) - - 3. Particular Affirmative (SOME): - - If one premise is SOME and other is ALL, conclusion must be SOME - - SOME A are B + ALL B are C → SOME A are C (Darii) - - ALL A are B + SOME C are B → SOME A are C (Disamis) - - 4. Particular Negative (SOME_NOT): - - If one premise is SOME_NOT and other is ALL, conclusion can be SOME_NOT - - SOME A are not B + ALL B are C → SOME A are not C (Ferio) - - ALL A are B + SOME C are not B → SOME A are not C (Festino) - - 5. Invalid combinations: - - Two negative premises never yield a valid conclusion - - Two particular premises never yield a valid conclusion - - If both premises are particular, no valid conclusion - - If conclusion is universal but either premise is particular, invalid + Checks whether a given syllogism is valid under classical (Aristotelian) rules, + including the distribution rule: + - If a term is distributed in the conclusion, it must be distributed + in the premise where it appears as subject/predicate. """ - q1, t1_1, t1_2 = premise1 - q2, t2_1, t2_2 = premise2 - qc, tc_1, tc_2 = conclusion - # Rule 5: Two negative premises -> invalid - if q1 in (Quantifier.NO, Quantifier.SOME_NOT) and q2 in (Quantifier.NO, Quantifier.SOME_NOT): + # --- 1) Extract data --- + q1, p1_subj, p1_pred = premise1 + q2, p2_subj, p2_pred = premise2 + q3, c_subj, c_pred = conclusion + + negative_set = {Quantifier.NO, Quantifier.SOME_NOT} + particular_set = {Quantifier.SOME, Quantifier.SOME_NOT} + universal_set = {Quantifier.ALL, Quantifier.NO} + + # --- 2) Identify a unique middle term --- + premise1_terms = {p1_subj, p1_pred} + premise2_terms = {p2_subj, p2_pred} + common_terms = premise1_terms.intersection(premise2_terms) + + if len(common_terms) != 1: + return False + middle_term = next(iter(common_terms)) + + # Gather all terms => must be exactly 3 distinct terms + all_terms = premise1_terms.union(premise2_terms) + if len(all_terms) != 3: return False - # Rule 5: Two particular premises -> invalid - if q1 in (Quantifier.SOME, Quantifier.SOME_NOT) and q2 in (Quantifier.SOME, Quantifier.SOME_NOT): + # The conclusion must use the other two terms (not the middle) + other_two = all_terms - {middle_term} + conclusion_terms = {c_subj, c_pred} + if conclusion_terms != other_two: return False - # Rule 5: Universal conclusion with particular premise -> invalid - if qc in (Quantifier.ALL, Quantifier.NO) and ( - q1 in (Quantifier.SOME, Quantifier.SOME_NOT) or q2 in (Quantifier.SOME, Quantifier.SOME_NOT) - ): + # --- 3) Identify which premise is major vs. minor --- + def premise_contains(premise, term): + return (premise[1] == term) or (premise[2] == term) + + if premise_contains(premise1, c_pred): + major = premise1 + minor = premise2 + elif premise_contains(premise2, c_pred): + major = premise2 + minor = premise1 + else: return False - # Rule 1: Barbara syllogism - if q1 == Quantifier.ALL and q2 == Quantifier.ALL: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.ALL + # The minor premise must contain the conclusion's subject + if not premise_contains(minor, c_subj): + return False - # Rule 2: Celarent syllogism - if q1 == Quantifier.NO and q2 == Quantifier.ALL: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.NO + # --- 4) Quick checks (traditional “no two negative,” etc.) --- + if (q1 in negative_set) and (q2 in negative_set): + return False + if (q1 in particular_set) and (q2 in particular_set): + return False + if q3 in universal_set: + if (q1 in particular_set) or (q2 in particular_set): + return False + if q3 in negative_set: + if not ((q1 in negative_set) or (q2 in negative_set)): + return False - # Rule 2: Cesare syllogism - if q1 == Quantifier.ALL and q2 == Quantifier.NO: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.NO + # --- 5) Distribution checks --- + def distribution(q: Quantifier): + if q == Quantifier.ALL: # A + return (True, False) + elif q == Quantifier.NO: # E + return (True, True) + elif q == Quantifier.SOME: # I + return (False, False) + elif q == Quantifier.SOME_NOT: # O + return (False, True) + else: + raise ValueError(f"Unknown quantifier: {q}") - # Rule 3: Darii syllogism - if q1 == Quantifier.SOME and q2 == Quantifier.ALL: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.SOME + # Conclusion distribution + dist_c_subj, dist_c_pred = distribution(q3) - # Rule 3: Disamis syllogism - if q1 == Quantifier.ALL and q2 == Quantifier.SOME: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.SOME + # Major premise distribution + q_major, major_subj, major_pred = major + dist_major_subj, dist_major_pred = distribution(q_major) - # Rule 4: Ferio syllogism - if q1 == Quantifier.SOME_NOT and q2 == Quantifier.ALL: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.SOME_NOT + # Minor premise distribution + q_minor, minor_subj, minor_pred = minor + dist_minor_subj, dist_minor_pred = distribution(q_minor) - # Rule 4: Festino syllogism - if q1 == Quantifier.ALL and q2 == Quantifier.SOME_NOT: - if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2: - return qc == Quantifier.SOME_NOT + # If the conclusion's subject is distributed, check it in the minor premise + if dist_c_subj: + if c_subj == minor_subj: + if not dist_minor_subj: + return False + elif c_subj == minor_pred: + if not dist_minor_pred: + return False - return False + # If the conclusion's predicate is distributed, check it in the major premise + if dist_c_pred: + if c_pred == major_subj: + if not dist_major_subj: + return False + elif c_pred == major_pred: + if not dist_major_pred: + return False + + # If either premise is negative, the conclusion must be negative. + if (q1 in negative_set) or (q2 in negative_set): + if q3 not in negative_set: + return False + + # If all checks pass, it's valid + return True def _format_quantifier_statement(self, quantifier: Quantifier, subject: Term, predicate: Term) -> str: """Format a quantified statement in natural language""" @@ -219,18 +248,29 @@ class SyllogismDataset(ProceduralDataset): terms = rng.sample(self.terms, 3) quantifiers = self._get_allowed_quantifiers() - # Generate premises and conclusion - premise1 = (rng.choice(quantifiers), terms[0], terms[1]) - premise2 = (rng.choice(quantifiers), terms[1], terms[2]) - conclusion = (rng.choice(quantifiers), terms[0], terms[2]) + target_valid = rng.random() > self.config.invalid_ratio # Invert ratio to match meaning + max_attempts = 100 + attempts = 0 - # Decide if this should be a valid or invalid syllogism - is_valid = True - if self.config.include_invalid and rng.random() < self.config.invalid_ratio: - is_valid = False - # If should be invalid, regenerate conclusion until invalid - while self._is_valid_syllogism(premise1, premise2, conclusion): - conclusion = (rng.choice(quantifiers), terms[0], terms[2]) + while attempts < max_attempts: + # Generate premises and conclusion + premise1 = (rng.choice(quantifiers), terms[0], terms[1]) + premise2 = (rng.choice(quantifiers), terms[1], terms[2]) + conclusion = (rng.choice(quantifiers), terms[0], terms[2]) + + # Check if validity matches target + is_valid = self._is_valid_syllogism(premise1, premise2, conclusion) + if is_valid == target_valid: + break + + attempts += 1 + + if attempts >= max_attempts: + # If we couldn't find a matching syllogism, return a basic valid one + premise1 = (Quantifier.ALL, terms[0], terms[1]) + premise2 = (Quantifier.ALL, terms[1], terms[2]) + conclusion = (Quantifier.ALL, terms[0], terms[2]) + is_valid = True # Format the syllogism as text premise1_text = self._format_quantifier_statement(premise1[0], premise1[1], premise1[2]) diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py new file mode 100644 index 00000000..0d369fc1 --- /dev/null +++ b/tests/test_complex_arithmetic.py @@ -0,0 +1,90 @@ +import pytest + +from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset + + +def test_complex_arithmetic_basic(): + """Test basic functionality of complex arithmetic dataset.""" + config = ComplexArithmeticConfig( + min_real=-5, max_real=5, min_imag=-5, max_imag=5, operations=("+", "-", "*", "/"), seed=42, size=10 + ) + dataset = ComplexArithmeticDataset(config) + + print(dataset) + + # Test dataset size + assert len(dataset) == 10 + + # Test a specific item + item = dataset[0] + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Add more detailed assertions + assert isinstance(item["question"], str) + assert isinstance(item["answer"], str) + assert isinstance(item["metadata"], dict) + + # Check metadata structure + assert "num1" in item["metadata"] + assert "num2" in item["metadata"] + assert "operation" in item["metadata"] + assert "result" in item["metadata"] + + # Check data types in metadata + assert isinstance(item["metadata"]["num1"], tuple) + assert isinstance(item["metadata"]["num2"], tuple) + assert len(item["metadata"]["num1"]) == 2 # Real and imaginary parts + assert len(item["metadata"]["num2"]) == 2 + assert isinstance(item["metadata"]["operation"], str) + assert isinstance(item["metadata"]["result"], tuple) + + # Make sure answer matches the result in metadata + # results is a tuple of two floats (real, imag) and answer is a string + # answer is formatted as "real + imagi" + assert ComplexArithmeticDataset.parse_string_to_complex(item["answer"]) == complex(*item["metadata"]["result"]) + + +def test_complex_arithmetic_scoring(): + """Test scoring function with various answer formats and accuracies.""" + config = ComplexArithmeticConfig(seed=42) + dataset = ComplexArithmeticDataset(config) + + # Test case with answer 3 + 2i + metadata = {"result": (3.0, 2.0)} + + # Test exact matches (should get score of 1.0) + assert dataset.score_answer("3 + 2i", metadata) == 1.0 + assert dataset.score_answer("3+2i", metadata) == 1.0 + assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0 + + # Test answers with small errors (should get high but < 1.0 scores) + print(dataset.score_answer("3.1 + 2i", metadata)) + assert 0.9 < dataset.score_answer("3.1 + 2i", metadata) < 1.0 + assert 0.9 < dataset.score_answer("3 + 2.1i", metadata) < 1.0 + assert 0.7 < dataset.score_answer("3.1 + 2.1i", metadata) < 0.95 + + # Test answers with moderate errors (should get medium scores) + assert 0.3 < dataset.score_answer("4 + 2i", metadata) < 0.4 + assert 0.3 < dataset.score_answer("3 + 3i", metadata) < 0.4 + + # Test answers with large errors (should get very low scores) + assert dataset.score_answer("10 + 10i", metadata) < 0.01 + + # Test invalid answers (should get 0.0) + assert dataset.score_answer("invalid", metadata) == 0.0 + assert dataset.score_answer(None, metadata) == 0.0 + assert dataset.score_answer("inf + 2i", metadata) == 0.0 + + +def test_complex_arithmetic_division_by_zero(): + """Test that division by zero is handled properly.""" + config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division + dataset = ComplexArithmeticDataset(config) + + # Check multiple items to ensure no division by zero + for i in range(10): + item = dataset[i] + num2 = complex(*item["metadata"]["num2"]) + assert num2 != 0 diff --git a/tests/test_isomorphic_strings.py b/tests/test_isomorphic_strings.py new file mode 100644 index 00000000..6e515cf7 --- /dev/null +++ b/tests/test_isomorphic_strings.py @@ -0,0 +1,108 @@ +"""Tests for Isomorphic Strings questions generation""" + +import json + +import pytest + +from reasoning_gym.algorithmic.isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset + + +def test_isomorphic_strings_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = IsomorphicStringsConfig(max_string_length=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = IsomorphicStringsConfig(max_string_length=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = IsomorphicStringsConfig(max_string_length=1) # One not allowed + config.validate() + + with pytest.raises(AssertionError): + config = IsomorphicStringsConfig(p_solvable=-0.01) # < 0 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = IsomorphicStringsConfig(p_solvable=1.01) # > 1 not allowed + config.validate() + + +def test_isomorphic_strings_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = IsomorphicStringsConfig(seed=42, size=10) + dataset1 = IsomorphicStringsDataset(config) + dataset2 = IsomorphicStringsDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_isomorphic_strings_dataset_items(): + """Test basic properties of generated items""" + config = IsomorphicStringsConfig(max_string_length=10, size=10, seed=42) + dataset = IsomorphicStringsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "words" in item["metadata"] + assert "solution" in item["metadata"] + assert "solvable" in item["metadata"] + + words = item["metadata"]["words"] + solution = item["metadata"]["solution"] + solvable = item["metadata"]["solvable"] + + # Verify list dimensions + assert len(words) == 2 + assert solution in {True, False} + assert solvable in {True, False} + assert solution == solvable + + +def test_isomorphic_strings_dataset_iteration(): + """Test that iteration respects dataset size""" + config = IsomorphicStringsConfig(size=5, seed=42) + dataset = IsomorphicStringsDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_isomorphic_strings_answer(): + """Test the _check_isomorphic method""" + config = IsomorphicStringsConfig(seed=42) + dataset = IsomorphicStringsDataset(config) + + # General use case + s, t = "foo", "bar" + assert dataset._check_isomorphic(s, t) == False + + s, t = "foo", "baa" + assert dataset._check_isomorphic(s, t) == True + + # Unequal lengths + s, t = "foo", "bo" + assert dataset._check_isomorphic(s, t) == False + + # Empty strings + ( + s, + t, + ) = ( + "", + "", + ) + assert dataset._check_isomorphic(s, t) == True diff --git a/tests/test_polynomial_multiplication.py b/tests/test_polynomial_multiplication.py new file mode 100644 index 00000000..a27bd6bf --- /dev/null +++ b/tests/test_polynomial_multiplication.py @@ -0,0 +1,166 @@ +import pytest +import sympy as sp + +from reasoning_gym import create_dataset +from reasoning_gym.algebra.polynomial_multiplication import ( + PolynomialMultiplicationConfig, + PolynomialMultiplicationDataset, +) + + +def test_polynomial_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + PolynomialMultiplicationConfig(min_terms=0).validate() + + with pytest.raises(AssertionError): + PolynomialMultiplicationConfig(min_value=0).validate() + + with pytest.raises(AssertionError): + PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate() + + with pytest.raises(AssertionError): + PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate() + + with pytest.raises(AssertionError): + PolynomialMultiplicationConfig(operators=("^",)).validate() + + with pytest.raises(AssertionError): + PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate() + + +def test_polynomial_multiplication_dataset_basic(): + """Test dataset creation and length""" + dataset_size = 50 + config = PolynomialMultiplicationConfig( + min_terms=2, + max_terms=3, + min_value=1, + max_value=5, + min_degree=1, + max_degree=2, + min_polynomials=2, + max_polynomials=3, + single_variable=True, + seed=42, + size=dataset_size, + ) + + dataset = PolynomialMultiplicationDataset(config) + + assert len(dataset) == dataset_size + + +def test_polynomial_equations_dataset_items(): + """Test that generated items have correct structure""" + ds = create_dataset( + "polynomial_multiplication", + min_terms=2, + max_terms=3, + min_value=1, + max_value=5, + min_degree=1, + max_degree=2, + min_polynomials=2, + max_polynomials=5, + single_variable=False, + size=3, + seed=100, + ) + + for item in ds: + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert isinstance(item["metadata"]["polynomial_expr"], str) + assert isinstance(item["metadata"]["single_variable"], bool) + + # Check polynomial_expr existence + poly_str = item["metadata"]["polynomial_expr"] + # Ensure it can parse with sympy + sp.sympify(poly_str) + + +def test_polynomial_equations_dataset_deterministic(): + """Test dataset reproducibility with fixed seed.""" + cfg = PolynomialMultiplicationConfig(seed=999, size=3) + ds1 = PolynomialMultiplicationDataset(cfg) + ds2 = PolynomialMultiplicationDataset(cfg) + + for i in range(len(ds1)): + assert ds1[i] == ds2[i], "Polynomial datasets with same seed should match exactly." + + +def test_polynomial_solutions_evaluation(): + """Test that solution satisfy the polynomial multiplication.""" + ds = create_dataset( + "polynomial_multiplication", + min_terms=2, + max_terms=4, + min_value=1, + max_value=10, + min_degree=1, + max_degree=3, + min_polynomials=2, + max_polynomials=5, + single_variable=False, + size=5, + seed=42, + ) + + for item in ds: + # Extract the polynomial expression + poly_str = item["metadata"]["polynomial_expr"] + # Get the polynomial product + poly_expr = sp.expand(poly_str) + + # Verify that each solution satisfies the polynomial + assert poly_expr == item["answer"] + + +def test_score_function(): + """Test that solution satisfy the polynomial multiplication.""" + ds = create_dataset( + "polynomial_multiplication", + min_terms=2, + max_terms=4, + min_value=1, + max_value=10, + min_degree=1, + max_degree=3, + min_polynomials=2, + max_polynomials=5, + single_variable=True, + size=1, + seed=42, + ) + + assert ds.score_answer(None, ds[0]["metadata"]) == 0.00 + assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1 + assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01 + assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05 + + +def test_multivariate_score_function(): + """Test that solution satisfy the polynomial multiplication.""" + ds = create_dataset( + "polynomial_multiplication", + min_terms=2, + max_terms=4, + min_value=1, + max_value=10, + min_degree=1, + max_degree=3, + min_polynomials=2, + max_polynomials=5, + single_variable=False, + size=1, + seed=42, + ) + + assert ds.score_answer(None, ds[0]["metadata"]) == 0.00 + assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1 + assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01 + assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05 diff --git a/tests/test_self_reference.py b/tests/test_self_reference.py new file mode 100644 index 00000000..66f15081 --- /dev/null +++ b/tests/test_self_reference.py @@ -0,0 +1,55 @@ +import pytest + +from reasoning_gym.logic.self_reference import SelfReferenceConfig, SelfReferenceDataset + + +def test_self_reference(): + """Test basic properties and solution of generated items""" + + # Easy + config = SelfReferenceConfig(seed=42, size=20, difficulty=1) + dataset = SelfReferenceDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=99, entry=item) == 0.1 + assert dataset.score_answer(answer="99", entry=item) == 0.1 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + # # Medium + config = SelfReferenceConfig(seed=42, size=1, difficulty=5) + dataset = SelfReferenceDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=99, entry=item) == 0.1 + assert dataset.score_answer(answer="99", entry=item) == 0.1 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + # # Hard + config = SelfReferenceConfig(seed=42, size=1, difficulty=10) + dataset = SelfReferenceDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Test the scoring + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=99, entry=item) == 0.1 + assert dataset.score_answer(answer="99", entry=item) == 0.1 + assert dataset.score_answer(answer=None, entry=item) == 0.0 diff --git a/tests/test_syllogisms.py b/tests/test_syllogisms.py index 498be586..9f2c5607 100644 --- a/tests/test_syllogisms.py +++ b/tests/test_syllogisms.py @@ -64,6 +64,204 @@ def test_syllogism_dataset_items(): assert "Does it logically follow that:" in item["question"] +def test_valid_syllogism_forms(): + """Test specific valid syllogistic forms""" + config = SyllogismConfig(size=1, seed=42) + dataset = SyllogismDataset(config) + + # Create some test terms + A = Term("mortal", "mortals") + B = Term("human", "humans") + C = Term("animal", "animals") + + # Test Barbara (AAA-1) + # Major premise: All M are P + # Minor premise: All S are M + # Conclusion: All S are P + assert dataset._is_valid_syllogism( + (Quantifier.ALL, B, C), # All B (M) are C (P) + (Quantifier.ALL, A, B), # All A (S) are B (M) + (Quantifier.ALL, A, C), # All A (S) are C (P) + ) + + # Test Celarent (EAE-1) + # Major premise: No M are P + # Minor premise: All S are M + # Conclusion: No S are P + assert dataset._is_valid_syllogism( + (Quantifier.NO, B, C), # No B (M) are C (P) + (Quantifier.ALL, A, B), # All A (S) are B (M) + (Quantifier.NO, A, C), # No A (S) are C (P) + ) + + # Test Cesare (EAE-2) — corrected order + # Major premise: No P are M + # Minor premise: All S are M + # Conclusion: No S are P + assert dataset._is_valid_syllogism( + (Quantifier.NO, C, B), # No C (P) are B (M) [Major premise] + (Quantifier.ALL, A, B), # All A (S) are B (M) [Minor premise] + (Quantifier.NO, A, C), # No A (S) are C (P) + ) + + # Test Darii (AII-1) + # Major premise: All M are P + # Minor premise: Some S are M + # Conclusion: Some S are P + assert dataset._is_valid_syllogism( + (Quantifier.ALL, B, C), # All B (M) are C (P) + (Quantifier.SOME, A, B), # Some A (S) are B (M) + (Quantifier.SOME, A, C), # Some A (S) are C (P) + ) + + # Test Disamis (IAI-3) + # Major premise: Some M are P + # Minor premise: All M are S + # Conclusion: Some S are P + assert dataset._is_valid_syllogism( + (Quantifier.SOME, B, C), # Some B (M) are C (P) + (Quantifier.ALL, B, A), # All B (M) are A (S) + (Quantifier.SOME, A, C), # Some A (S) are C (P) + ) + + # Test Ferio (EIO-1) + # Major premise: No M are P + # Minor premise: Some S are M + # Conclusion: Some S are not P + assert dataset._is_valid_syllogism( + (Quantifier.NO, B, C), # No B (M) are C (P) + (Quantifier.SOME, A, B), # Some A (S) are B (M) + (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P) + ) + + # Test Festino (EIO-2) + # Major premise: No P are M + # Minor premise: Some S are M + # Conclusion: Some S are not P + assert dataset._is_valid_syllogism( + (Quantifier.NO, C, B), # No C (P) are B (M) + (Quantifier.SOME, A, B), # Some A (S) are B (M) + (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P) + ) + + # Test Datisi (AII-3) + # Major premise: All M are P + # Minor premise: Some M are S + # Conclusion: Some S are P + assert dataset._is_valid_syllogism( + (Quantifier.ALL, B, C), # All B (M) are C (P) + (Quantifier.SOME, B, A), # Some B (M) are A (S) + (Quantifier.SOME, A, C), # Some A (S) are C (P) + ) + + # Test Bocardo (OAO-3) + # Major premise: Some M are not P + # Minor premise: All M are S + # Conclusion: Some S are not P + assert dataset._is_valid_syllogism( + (Quantifier.SOME_NOT, B, C), # Some B (M) are not C (P) + (Quantifier.ALL, B, A), # All B (M) are A (S) + (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P) + ) + + # Test Baroco (AOO-2) + # Major premise: All P are M + # Minor premise: Some S are not M + # Conclusion: Some S are not P + assert dataset._is_valid_syllogism( + (Quantifier.ALL, C, B), # All C (P) are B (M) + (Quantifier.SOME_NOT, A, B), # Some A (S) are not B (M) + (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P) + ) + + # Test Camestres (AEE-2) + # Major premise: All P are M + # Minor premise: No S are M + # Conclusion: No S are P + assert dataset._is_valid_syllogism( + (Quantifier.ALL, C, B), # All C (P) are B (M) + (Quantifier.NO, A, B), # No A (S) are B (M) + (Quantifier.NO, A, C), # No A (S) are C (P) + ) + + # Test Dimaris (IAI-4) + # Major premise: Some P are M + # Minor premise: All M are S + # Conclusion: Some S are P + assert dataset._is_valid_syllogism( + (Quantifier.SOME, C, B), # Some C (P) are B (M) + (Quantifier.ALL, B, A), # All B (M) are A (S) + (Quantifier.SOME, A, C), # Some A (S) are C (P) + ) + + # Test Ferison (EIO-3) + # Major premise: No M are P + # Minor premise: Some M are S + # Conclusion: Some S are not P + assert dataset._is_valid_syllogism( + (Quantifier.NO, B, C), # No B (M) are C (P) + (Quantifier.SOME, B, A), # Some B (M) are A (S) + (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P) + ) + + # Test Fresison (EIO-4) + # Major premise: No P are M + # Minor premise: Some M are S + # Conclusion: Some S are not P + assert dataset._is_valid_syllogism( + (Quantifier.NO, C, B), # No C (P) are B (M) + (Quantifier.SOME, B, A), # Some B (M) are A (S) + (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P) + ) + + # Test Camenes (AEE-4) + # Major premise: All P are M + # Minor premise: No M are S + # Conclusion: No S are P + assert dataset._is_valid_syllogism( + (Quantifier.ALL, C, B), # All C (P) are B (M) + (Quantifier.NO, B, A), # No B (M) are A (S) + (Quantifier.NO, A, C), # No A (S) are C (P) + ) + + # Test invalid forms + assert not dataset._is_valid_syllogism( + (Quantifier.SOME, B, C), # Some B are C + (Quantifier.SOME, A, B), # Some A are B + (Quantifier.SOME, A, C), # Some A are C (invalid: two particular premises) + ) + + assert not dataset._is_valid_syllogism( + (Quantifier.NO, B, C), # No B are C + (Quantifier.NO, A, B), # No A are B + (Quantifier.NO, A, C), # No A are C (invalid: two negative premises) + ) + + # Test specific invalid case with two negative premises + S = Term("student", "students") + M = Term("human", "humans") + P = Term("chef", "chefs") + assert not dataset._is_valid_syllogism( + (Quantifier.NO, S, M), # No students are humans + (Quantifier.NO, M, P), # No humans are chefs + (Quantifier.NO, S, P), # No students are chefs (invalid!) + ) + + child = Term("child", "children") + animal = Term("animal", "animals") + doctor = Term("doctor", "doctors") + + # Premise 1: Some children are not animals + # Premise 2: All animals are doctors + # Conclusion: Some children are not doctors + # We expect this NOT to be a valid syllogism + assert not dataset._is_valid_syllogism( + (Quantifier.SOME_NOT, child, animal), # Some children are not animals + (Quantifier.ALL, animal, doctor), # All animals are doctors + (Quantifier.SOME_NOT, child, doctor), # Some children are not doctors + ) + + def test_syllogism_dataset_iteration(): """Test that iteration respects dataset size""" config = SyllogismConfig(size=5, seed=42) @@ -74,41 +272,3 @@ def test_syllogism_dataset_iteration(): # Test multiple iterations yield same items assert items == list(dataset) - - -def test_syllogism_custom_terms(): - """Test syllogism generation with custom terms""" - custom_terms = [ - Term("programmer", "programmers"), - Term("coder", "coders"), - Term("developer", "developers"), - ] - config = SyllogismConfig(terms=custom_terms, size=10, seed=42) - dataset = SyllogismDataset(config) - - for item in dataset: - # Verify only custom terms are used - text = item["question"] + str(item["metadata"]) - assert any(term.name in text or term.plural in text for term in custom_terms) - # Verify default terms are not used - assert "mortal" not in text - assert "human" not in text - - -def test_syllogism_validity(): - """Test logical validity rules""" - config = SyllogismConfig( - allow_all=True, - allow_no=False, - allow_some=False, - allow_some_not=False, - include_invalid=False, # Only generate valid syllogisms - size=10, - seed=42, - ) - dataset = SyllogismDataset(config) - - for item in dataset: - # All valid ALL syllogisms should have "Yes" as answer - assert item["answer"] == "Yes" - assert item["metadata"]["is_valid"] is True diff --git a/tests/test_tsumego.py b/tests/test_tsumego.py new file mode 100644 index 00000000..e979bcac --- /dev/null +++ b/tests/test_tsumego.py @@ -0,0 +1,281 @@ +"""Tests for Ttsumego problem generation""" + +import re + +import pytest + +from reasoning_gym.games.tsumego import TsumegoConfig, TsumegoDataset + + +def test_config_validation(): + # Valid configuration + TsumegoConfig(min_board_size=9, max_board_size=13, max_stones=10, size=100, seed=42) + + # Invalid configurations + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=4, max_board_size=13, max_stones=10) # min_board_size too low + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=9, max_board_size=20, max_stones=10) # max_board_size too high + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=13, max_board_size=9, max_stones=10) # min_board_size > max_board_size + with pytest.raises(ValueError): + TsumegoConfig(min_board_size=9, max_board_size=13, max_stones=2) # max_stones too low + + +def test_dataset_item_properties(): + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=15, size=100, seed=42) + dataset = TsumegoDataset(config) + item = dataset[0] + # Check that item has the required keys + for key in ["question", "answer", "metadata"]: + assert key in item + + metadata = item["metadata"] + for key in ["difficulty", "board", "solution"]: + assert key in metadata + + board = metadata["board"] + # Board size should be equal to the fixed min_board_size for this test + assert len(board) == config.min_board_size + assert all(len(row) == config.min_board_size for row in board) + # Check stone count does not exceed max_stones + 7 (to account for extra fill in capture formation) + stone_count = sum(cell in "XO" for row in board for cell in row) + assert stone_count <= config.max_stones + 7 + + +def test_deterministic_generation(): + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10, seed=42) + dataset1 = TsumegoDataset(config) + dataset2 = TsumegoDataset(config) + for i in range(3): + item1 = dataset1[i] + item2 = dataset2[i] + assert item1["metadata"]["board"] == item2["metadata"]["board"] + assert item1["answer"] == item2["answer"] + + +def test_liberties_and_move(): + # Use a small board for simplicity + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=10) + dataset = TsumegoDataset(config) + + # Part 1: Liberty counting test + board_liberties = [ + [".", "O", ".", ".", "."], + ["O", "X", "O", ".", "."], + [".", "O", ".", ".", "."], + [".", ".", ".", ".", "."], + [".", ".", ".", ".", "."], + ] + liberties = dataset._get_liberties(board_liberties, 1, 1) + assert len(liberties) == 0 + liberties_edge = dataset._get_liberties(board_liberties, 0, 1) + assert len(liberties_edge) == 2 + + # Part 2: Test capturing move + # Construct a board where an enemy stone at (2,2) is surrounded on three sides, + # so that placing an "X" at (2,3) will remove its last liberty and capture it. + board_capture = [["." for _ in range(5)] for _ in range(5)] + board_capture[1][2] = "X" + board_capture[2][1] = "X" + board_capture[3][2] = "X" + board_capture[2][2] = "O" + # Now, (2,2) (enemy) has only one liberty at (2,3). + # Placing "X" at (2,3) should capture the enemy stone. + assert dataset._is_valid_move(board_capture, 2, 3, "X") + dataset._make_move(board_capture, 2, 3, "X") + # After move, captured_stones should be [(2,2)] and ko point set to (2,2). + assert not dataset._is_valid_move(board_capture, 2, 2, "O"), "Ko move should be invalid" + + # Part 3: Test suicide move (without capture) + board_move = [ + [".", "O", ".", ".", "."], + ["O", ".", "O", ".", "."], + [".", "O", ".", ".", "."], + [".", ".", ".", ".", "."], + [".", ".", ".", ".", "."], + ] + # Placing "X" at (1,1) would be suicide as all adjacent positions are occupied by "O". + assert not dataset._is_valid_move(board_move, 1, 1, "X") + + +def convert_solution(sol, board_size): + # sol is expected to be a string like 'E5' + letter = sol[0].upper() + number = int(sol[1:]) + return (board_size - number, ord(letter) - ord("A")) + + +def test_score_answer(): + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10, size=5) + dataset = TsumegoDataset(config) + + # prepare dummy with letter+number format solution + entry = dataset[0].copy() + entry["metadata"]["solution"] = "E5" + + # Patch score_answer to convert metadata solution if needed + original_score_answer = dataset.score_answer + + def patched_score_answer(answer, entry): + board_size = len(entry["metadata"]["board"]) + sol = entry["metadata"]["solution"] + if isinstance(sol, str): + entry["metadata"]["solution"] = convert_solution(sol, board_size) + return original_score_answer(answer, entry) + + dataset.score_answer = patched_score_answer + + # Correct letter-number answer (E corresponds to board coordinate (4,4) for a 9x9 board) + assert dataset.score_answer("E5", entry) == 1.0 + + # Valid but incorrect letter-number move (D corresponds to (4,3) for a 9x9 board) + assert dataset.score_answer("D4", entry) == 0.05 + + # Invalid format + assert dataset.score_answer("invalid", entry) == 0.01 + + # Empty answer + assert dataset.score_answer("", entry) == 0.01 + + # None answer + assert dataset.score_answer(None, entry) == 0.0 + + # Out-of-bound letter-number move: 'J' corresponds to 10 which is greater than board size = 9 + assert dataset.score_answer("J9", entry) == 0.01 + + # test optimal score for answers, patching each entry + for x in dataset: + board_size = len(x["metadata"]["board"]) + sol = x["metadata"]["solution"] + if isinstance(sol, str): + x["metadata"]["solution"] = convert_solution(sol, board_size) + assert len(x["metadata"]["board"]) == x["metadata"]["difficulty"]["board_size"] + assert dataset.score_answer(x["answer"], entry=x) == 1.0 + + +# Additional tests for game logic edge cases + + +def test_get_group(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [ + ["X", "X", "."], + [".", "X", "O"], + [".", ".", "O"], + ] + group_X = dataset._get_group(board, 0, 0) + expected_group_X = {(0, 0), (0, 1), (1, 1)} + assert group_X == expected_group_X + + group_O = dataset._get_group(board, 1, 2) + expected_group_O = {(1, 2), (2, 2)} + assert group_O == expected_group_O + + +def test_count_liberties(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [ + ["X", "X", "."], + [".", "X", "O"], + [".", ".", "O"], + ] + group_X = {(0, 0), (0, 1), (1, 1)} + liberties_X = dataset._count_liberties(board, group_X) + # For (0,0): neighbor (1,0); (0,1): neighbor (0,2); (1,1): neighbors (1,0) and (2,1) + # Combined unique liberties: {(1,0), (0,2), (2,1)} so count should be 3 + assert liberties_X == 3 + + +def test_out_of_bounds_move(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [["." for _ in range(5)] for _ in range(5)] + # Test moves that are out of bounds + assert not dataset._is_valid_move(board, -1, 0, "X") + assert not dataset._is_valid_move(board, 0, -1, "X") + assert not dataset._is_valid_move(board, 5, 0, "X") + assert not dataset._is_valid_move(board, 0, 5, "X") + + +def test_move_on_occupied_intersection(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [["." for _ in range(5)] for _ in range(5)] + board[1][1] = "X" + # Attempting to play on an occupied spot should be invalid + assert not dataset._is_valid_move(board, 1, 1, "O") + assert not dataset._is_valid_move(board, 1, 1, "X") + + +def test_valid_non_capturing_move(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [["." for _ in range(5)] for _ in range(5)] + # A move on an empty board that doesn't result in capture or suicide should be valid + assert dataset._is_valid_move(board, 0, 0, "X") + move_result = dataset._make_move(board, 0, 0, "X") + assert move_result + assert board[0][0] == "X" + + +def test_multiple_capture(): + # Set up a board where a move will capture multiple opponent stones, + # which should not trigger the ko rule (ko point remains None) + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + board = [ + [".", ".", ".", ".", "."], + [".", "X", "X", "X", "."], + ["X", "O", "O", ".", "."], + [".", "X", "X", "X", "."], + [".", ".", ".", ".", "."], + ] + # Move at (2,3) with 'X' should capture the opponent stones at (2,1) and (2,2) + assert dataset._is_valid_move(board, 2, 3, "X") + move_result = dataset._make_move(board, 2, 3, "X") + assert move_result, "Move should be successfully made" + assert board[2][1] == ".", "Stone at (2,1) should be captured" + assert board[2][2] == ".", "Stone at (2,2) should be captured" + assert dataset._ko_point is None, "Ko point should not be set for multiple captures" + + +def test_would_capture(): + config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42) + dataset = TsumegoDataset(config) + # Create a scenario similar to the one in test_liberties_and_move for capturing + board_capture = [["." for _ in range(5)] for _ in range(5)] + board_capture[1][2] = "X" + board_capture[2][1] = "X" + board_capture[3][2] = "X" + board_capture[2][2] = "O" + # Placing 'X' at (2,3) should capture the stone at (2,2) + assert dataset._would_capture(board_capture, 2, 3, "X") + # In a scenario with no capture, the move should not be considered capturing + board_no_capture = [["." for _ in range(5)] for _ in range(5)] + board_no_capture[2][2] = "O" + assert not dataset._would_capture(board_no_capture, 0, 0, "X") + + +def test_capture_verification(): + """Verifies that the solution move in a generated puzzle captures at least one opponent stone.""" + config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=15, size=1, seed=10) + dataset = TsumegoDataset(config) + entry = dataset[0] + board = entry["metadata"]["board"] + solution = entry["metadata"]["solution"] + # If solution is a letter+number string, convert it + if isinstance(solution, str): + board_size = len(board) + solution = convert_solution(solution, board_size) + initial_white = sum(row.count("O") for row in board) + + # Make a deep copy of the board to simulate the move + board_after = [row[:] for row in board] + move_success = dataset._make_move(board_after, solution[0], solution[1], "X") + assert move_success, "The solution move should be legal." + + final_white = sum(row.count("O") for row in board_after) + assert final_white < initial_white, "The solution move should capture at least one opponent stone."