diff --git a/.gitignore b/.gitignore
index ce057fd8..d1e0d496 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,6 +21,7 @@ wheels/
*.egg-info/
.installed.cfg
*.egg
+.python-version
# Virtual Environment
venv/
diff --git a/GALLERY.md b/GALLERY.md
index ff56e124..239d07a9 100644
--- a/GALLERY.md
+++ b/GALLERY.md
@@ -12,6 +12,7 @@ This gallery shows examples from all available datasets using their default conf
- [calendar_arithmetic](#calendar_arithmetic)
- [chain_sum](#chain_sum)
- [color_cube_rotation](#color_cube_rotation)
+- [complex_arithmetic](#complex_arithmetic)
- [countdown](#countdown)
- [course_schedule](#course_schedule)
- [family_relationships](#family_relationships)
@@ -23,6 +24,7 @@ This gallery shows examples from all available datasets using their default conf
- [ransom_note](#ransom_note)
- [gsm_symbolic](#gsm_symbolic)
- [intermediate_integration](#intermediate_integration)
+- [isomorphic_strings](#isomorphic_strings)
- [largest_island](#largest_island)
- [lcm](#lcm)
- [leg_counting](#leg_counting)
@@ -36,10 +38,12 @@ This gallery shows examples from all available datasets using their default conf
- [number_sorting](#number_sorting)
- [palindrome](#palindrome)
- [polynomial_equations](#polynomial_equations)
+- [polynomial_multiplication](#polynomial_multiplication)
- [prime_factorization](#prime_factorization)
- [propositional_logic](#propositional_logic)
- [quantum_lock](#quantum_lock)
- [rubiks_cube](#rubiks_cube)
+- [self_reference](#self_reference)
- [sentence_reordering](#sentence_reordering)
- [simple_equations](#simple_equations)
- [simple_geometry](#simple_geometry)
@@ -50,6 +54,7 @@ This gallery shows examples from all available datasets using their default conf
- [syllogism](#syllogism)
- [time_intervals](#time_intervals)
- [tower_of_hanoi](#tower_of_hanoi)
+- [tsumego](#tsumego)
- [word_ladder](#word_ladder)
- [word_sequence_reversal](#word_sequence_reversal)
- [word_sorting](#word_sorting)
@@ -490,6 +495,39 @@ Metadata: {'initial_state': {'top': 'orange', 'right': 'cyan', 'front': 'violet'
````
+### complex_arithmetic
+Generates complex number arithmetic problems.
+
+Default configuration:
+```python
+min_real = -10
+max_real = 10
+min_imag = -10
+max_imag = 10
+operations = ('+', '-', '*', '/')
+seed = 42
+size = 500
+```
+
+Example tasks:
+````
+Example 1:
+Question: Add the complex numbers: (-10.0 - 2.0i) + (-3.0 - 3.0i)
+Answer: -13.0 - 5.0i
+Metadata: {'num1': (-10.0, -2.0), 'num2': (-3.0, -3.0), 'operation': '+', 'result': (-13, -5)}
+
+Example 2:
+Question: Add the complex numbers: (-1.0 - 6.0i) + (4.0 + 1.0i)
+Answer: 3.0 - 5.0i
+Metadata: {'num1': (-1.0, -6.0), 'num2': (4.0, 1.0), 'operation': '+', 'result': (3, -5)}
+
+Example 3:
+Question: Divide the complex numbers: (-7.0 - 79.0i) ÷ (-7.0 - 5.0i)
+Answer: 6.0 + 7.0i
+Metadata: {'num1': (-7.0, -79.0), 'num2': (-7.0, -5.0), 'operation': '/', 'result': (6, 7)}
+
+````
+
### countdown
Generates Countdown Number Game tasks
@@ -1122,6 +1160,99 @@ Metadata: {'integrand': '2*asin(x)', 'problem_type': 'by_parts', 'variable': 'x'
````
+### isomorphic_strings
+Generates Isomorphic Strings exercises with configurable difficulty
+
+Default configuration:
+```python
+max_string_length = 10
+p_solvable = 0.5
+size = 500
+seed = 42
+```
+
+Example tasks:
+````
+Example 1:
+Question: Two strings are isomorphic if the characters in one string can be replaced to get the second string.
+
+All occurrences of a character must be replaced with another character while preserving the order of characters.
+
+No two characters may map to the same character, but a character may map to itself.
+
+Example 1:
+Input: egg add
+Output: True
+Explanation: The strings s and t can be made identical by:
+ - Mapping 'e' to 'a'.
+ - Mapping 'g' to 'd'.
+
+Example 2:
+Input: foo bar
+Output: False
+Explanation:
+ - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'.
+
+Return True if the following two strings are isomorphic, or False otherwise:
+cc bw
+
+Answer: False
+Metadata: {'words': ['cc', 'bw'], 'solution': False, 'solvable': False}
+
+Example 2:
+Question: Two strings are isomorphic if the characters in one string can be replaced to get the second string.
+
+All occurrences of a character must be replaced with another character while preserving the order of characters.
+
+No two characters may map to the same character, but a character may map to itself.
+
+Example 1:
+Input: egg add
+Output: True
+Explanation: The strings s and t can be made identical by:
+ - Mapping 'e' to 'a'.
+ - Mapping 'g' to 'd'.
+
+Example 2:
+Input: foo bar
+Output: False
+Explanation:
+ - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'.
+
+Return True if the following two strings are isomorphic, or False otherwise:
+nai oik
+
+Answer: True
+Metadata: {'words': ['nai', 'oik'], 'solution': True, 'solvable': True}
+
+Example 3:
+Question: Two strings are isomorphic if the characters in one string can be replaced to get the second string.
+
+All occurrences of a character must be replaced with another character while preserving the order of characters.
+
+No two characters may map to the same character, but a character may map to itself.
+
+Example 1:
+Input: egg add
+Output: True
+Explanation: The strings s and t can be made identical by:
+ - Mapping 'e' to 'a'.
+ - Mapping 'g' to 'd'.
+
+Example 2:
+Input: foo bar
+Output: False
+Explanation:
+ - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'.
+
+Return True if the following two strings are isomorphic, or False otherwise:
+hogtytyof kgqwfwfgh
+
+Answer: True
+Metadata: {'words': ['hogtytyof', 'kgqwfwfgh'], 'solution': True, 'solvable': True}
+
+````
+
### largest_island
Generates Largest Island exercises with configurable difficulty
@@ -1745,6 +1876,46 @@ Metadata: {'polynomial_expr': '71*n**3 - 2*n - 29', 'variable': 'n', 'degree': 3
````
+### polynomial_multiplication
+Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree].
+ - The polynomial is formed by summing random terms of the form: coeff * x^exponent.
+ - Then we find "F = P_0 * ... * P_1" using Sympy.
+
+Default configuration:
+```python
+min_terms = 2
+max_terms = 4
+min_value = 1
+max_value = 100
+min_degree = 1
+max_degree = 3
+min_polynomials = 2
+max_polynomials = 3
+single_variable = (True,)
+operators = ('+', '-')
+seed = 42
+size = 500
+```
+
+Example tasks:
+````
+Example 1:
+Question: Calculate the following: (65*x - 72)*(105*x - 125)
+Answer: 6825*x**2 - 15685*x + 9000
+Metadata: {'polynomial_expr': '(65*x - 72)*(105*x - 125)', 'single_variable': (True,), 'result': '6825*x**2 - 15685*x + 9000'}
+
+Example 2:
+Question: Calculate the following: (-9*x**2 - 28*x)*(86*x**2 - 2*x - 13)
+Answer: -774*x**4 - 2390*x**3 + 173*x**2 + 364*x
+Metadata: {'polynomial_expr': '(-9*x**2 - 28*x)*(86*x**2 - 2*x - 13)', 'single_variable': (True,), 'result': '-774*x**4 - 2390*x**3 + 173*x**2 + 364*x'}
+
+Example 3:
+Question: Calculate the following: (43 - 91*x)*(3*x**2 - 10*x)*(71*x**3 - 2*x - 29)
+Answer: -19383*x**6 + 73769*x**5 - 29984*x**4 + 5839*x**3 - 29271*x**2 + 12470*x
+Metadata: {'polynomial_expr': '(43 - 91*x)*(3*x**2 - 10*x)*(71*x**3 - 2*x - 29)', 'single_variable': (True,), 'result': '-19383*x**6 + 73769*x**5 - 29984*x**4 + 5839*x**3 - 29271*x**2 + 12470*x'}
+
+````
+
### prime_factorization
Generates prime factorization tasks
@@ -1943,6 +2114,56 @@ Metadata: {'cube_size': 3, 'scramble_steps': 3, 'scramble_moves': "U R' R'", 'ex
````
+### self_reference
+Generates self-referential puzzles
+
+Default configuration:
+```python
+difficulty = 5
+seed = 42
+size = 500
+```
+
+Example tasks:
+````
+Example 1:
+Question: Given the truthfulness of these statements, please tell me the number of possible solutions:
+ - Statement 1: 'At least 1 of these 7 statements are true.'
+ - Statement 2: 'At most 3 of these 7 statements are false.'
+ - Statement 3: 'Exactly 4 of these 7 statements are true.'
+ - Statement 4: 'Exactly 3 of these 7 statements are false.'
+ - Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.'
+ - Statement 6: 'The number of true statements is a prime number.'
+ - Statement 7: 'The number of false statements is a composite number.'
+
+Answer: 4
+
+Example 2:
+Question: Given the truthfulness of these statements, please tell me the number of possible solutions:
+ - Statement 1: 'At least 4 of these 7 statements are true.'
+ - Statement 2: 'At most 5 of these 7 statements are false.'
+ - Statement 3: 'Exactly 7 of these 7 statements are true.'
+ - Statement 4: 'Exactly 1 of these 7 statements are false.'
+ - Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.'
+ - Statement 6: 'The number of true statements is a prime number.'
+ - Statement 7: 'The number of false statements is a composite number.'
+
+Answer: 4
+
+Example 3:
+Question: Given the truthfulness of these statements, please tell me the number of possible solutions:
+ - Statement 1: 'At least 2 of these 7 statements are true.'
+ - Statement 2: 'At most 5 of these 7 statements are false.'
+ - Statement 3: 'Exactly 0 of these 7 statements are true.'
+ - Statement 4: 'Exactly 3 of these 7 statements are false.'
+ - Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.'
+ - Statement 6: 'The number of true statements is a prime number.'
+ - Statement 7: 'The number of false statements is a composite number.'
+
+Answer: 2
+
+````
+
### sentence_reordering
Generates sentence reordering tasks from text spans
@@ -2295,12 +2516,10 @@ Generates syllogism reasoning tasks
Default configuration:
```python
-terms = None
allow_all = True
allow_no = True
allow_some = True
allow_some_not = True
-include_invalid = True
invalid_ratio = 0.3
seed = 42
size = 500
@@ -2311,24 +2530,24 @@ Example tasks:
Example 1:
Question: Consider these statements:
1. No students are humans
-2. No humans are chefs
+2. All humans are chefs
Does it logically follow that:
-No students are chefs?
+All students are chefs?
(Answer Yes or No)
-Answer: Yes
-Metadata: {'premise1': 'No students are humans', 'premise2': 'No humans are chefs', 'conclusion': 'No students are chefs', 'is_valid': True}
+Answer: No
+Metadata: {'premise1': 'No students are humans', 'premise2': 'All humans are chefs', 'conclusion': 'All students are chefs', 'is_valid': False}
Example 2:
Question: Consider these statements:
-1. Some children are not animals
-2. Some animals are doctors
+1. All children are animals
+2. No animals are doctors
Does it logically follow that:
-All children are doctors?
+Some children are not doctors?
(Answer Yes or No)
Answer: Yes
-Metadata: {'premise1': 'Some children are not animals', 'premise2': 'Some animals are doctors', 'conclusion': 'All children are doctors', 'is_valid': True}
+Metadata: {'premise1': 'All children are animals', 'premise2': 'No animals are doctors', 'conclusion': 'Some children are not doctors', 'is_valid': True}
Example 3:
Question: Consider these statements:
@@ -2338,8 +2557,8 @@ Question: Consider these statements:
Does it logically follow that:
Some butterflies are not whales?
(Answer Yes or No)
-Answer: No
-Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are not whales', 'is_valid': False}
+Answer: Yes
+Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are not whales', 'is_valid': True}
````
@@ -2442,6 +2661,96 @@ Metadata: {'num_disks': 6, 'num_pegs': 3, 'start_peg': 1, 'target_peg': 2, 'auxi
````
+### tsumego
+Generates (one-move) Tsumego problems with configurable parameters
+
+Default configuration:
+```python
+min_board_size = 9
+max_board_size = 13
+max_stones = 15
+size = 10
+seed = 42
+```
+
+Example tasks:
+````
+Example 1:
+Question: I have a Go problem for you. Black moves next - can you capture some of the white stones?
+
+ A B C D E F G H I
+ 9 X . . . X . . . .
+ 8 . . . . . . . . .
+ 7 . O . O . . X . .
+ 6 . . . X . . . . O
+ 5 O . X O X . . . .
+ 4 . X O O . O . . .
+ 3 . . X O X . . . .
+ 2 . . . X . . . . .
+ 1 . O . O . . X . .
+
+X - Black
+O - White
+
+Specify your move in coordinates (e.g. 'C4' for column C, row 4)
+Answer: E4
+
+Metadata: {'difficulty': {'board_size': 9}, 'board': [['X', '.', '.', '.', 'X', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', '.', 'O', '.', '.', 'X', '.', '.'], ['.', '.', '.', 'X', '.', '.', '.', '.', 'O'], ['O', '.', 'X', 'O', 'X', '.', '.', '.', '.'], ['.', 'X', 'O', 'O', '.', 'O', '.', '.', '.'], ['.', '.', 'X', 'O', 'X', '.', '.', '.', '.'], ['.', '.', '.', 'X', '.', '.', '.', '.', '.'], ['.', 'O', '.', 'O', '.', '.', 'X', '.', '.']], 'solution': 'E4'}
+
+--------------------------------------------------
+
+Example 2:
+Question: Here's a Go challenge. Playing as Black, how can you capture as many white stones as possible?
+
+ A B C D E F G H I
+ 9 . . O . . . . . .
+ 8 . X O . . . . . .
+ 7 X . X . . . . . .
+ 6 O O O X . . . . .
+ 5 X O O . . . . . .
+ 4 . X . . . . . . O
+ 3 . X . . . . X . .
+ 2 O . O . . . . . .
+ 1 . . . . O . . . .
+
+X - Black
+O - White
+
+Specify your move in coordinates (e.g. 'C4' for column C, row 4)
+Answer: B7
+
+Metadata: {'difficulty': {'board_size': 9}, 'board': [['.', '.', 'O', '.', '.', '.', '.', '.', '.'], ['.', 'X', 'O', '.', '.', '.', '.', '.', '.'], ['X', '.', 'X', '.', '.', '.', '.', '.', '.'], ['O', 'O', 'O', 'X', '.', '.', '.', '.', '.'], ['X', 'O', 'O', '.', '.', '.', '.', '.', '.'], ['.', 'X', '.', '.', '.', '.', '.', '.', 'O'], ['.', 'X', '.', '.', '.', '.', 'X', '.', '.'], ['O', '.', 'O', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', 'O', '.', '.', '.', '.']], 'solution': 'B7'}
+
+--------------------------------------------------
+
+Example 3:
+Question: Tsumego time. Black to play and capture some stones.
+Find the key move.
+
+ A B C D E F G H I J K L
+12 . . . . . . . . . . . .
+11 . . X . . . . . . . . .
+10 . . . . . . . . . . . .
+ 9 . . . . . . . . . . . .
+ 8 X . . . . X . . . X . .
+ 7 . X . . . . . . . . . .
+ 6 . O X X . . . . . . . O
+ 5 . X O O X . . . . . . .
+ 4 . O O . . . . . O . . O
+ 3 X . X . . . . . . . . .
+ 2 . . . . . . . . . . . .
+ 1 . . . . . . . . . . X .
+
+X - Black
+O - White
+
+Specify your move in coordinates (e.g. 'C4' for column C, row 4)
+Answer: D4
+
+Metadata: {'difficulty': {'board_size': 12}, 'board': [['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['X', '.', '.', '.', '.', 'X', '.', '.', '.', 'X', '.', '.'], ['.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', 'X', 'X', '.', '.', '.', '.', '.', '.', '.', 'O'], ['.', 'X', 'O', 'O', 'X', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', 'O', '.', '.', '.', '.', '.', 'O', '.', '.', 'O'], ['X', '.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', 'X', '.']], 'solution': 'D4'}
+
+````
+
### word_ladder
Generates word ladder transformation tasks
diff --git a/README.md b/README.md
index 0dd159b9..a899df26 100644
--- a/README.md
+++ b/README.md
@@ -64,7 +64,7 @@ metadata: {'animals': {'sheep': 2, 'dog': 2}, 'total_legs': 16}
...
```
-See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets with examples.
+See the [Dataset Gallery](https://github.com/open-thought/reasoning-gym/blob/main/GALLERY.md) for a complete list of available datasets with examples.
## Task Overview
@@ -72,6 +72,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
- `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3\*x + 2 = 14")
- `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h\*\*4 + 4*h\**2 - 5*h = 0")
+- `PolynomialMultiplicationDataset`: Generate polynomial multiplicatons (e.g. "(8x^3 + x + 2)\*(y - 3)")
### Arithmetic Tasks
@@ -100,6 +101,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
- `WordLadderDataset`: Generate word ladder puzzles where one word is transformed into another by changing one letter at a time
- `GroupAnagramsDataset`: Group anagrams together in a list of words
- `RansomNoteDataset`: Check if a ransom note can be created from a given set of letters in a magazine
+- `IsomorphicStrings`: Check if two strings are isomorphic (have the same character mapping)
### Code Tasks
@@ -118,6 +120,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
- `SyllogismDataset`: Generates a [syllogism](https://en.wikipedia.org/wiki/Syllogism) reasoning dataset
- `AliceInWonderlandDataset`: Generates [AIW](https://openreview.net/forum?id=Mkl7dzjYiW) (Alice In Wonderland) problems with a few variations
- `ZebraDataset`: Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) of varying difficulty.
+- `SelfReferenceDataset`: Generates self-referencing logic puzzles.
### Graph Tasks
@@ -134,6 +137,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
- `MazeDataset`: Generate a maze with a start and a goal
- `CountdownDataset`: Generate number game tasks where numbers and operators must be combined to reach a target value
- `NQueensDataset`: Generate N-Queens puzzles with configurable board size and number of starting queens
+- `TsumegoDataset`: Generate Tsumego capture puzzles with variable board sizes and stone placements
## Future Generator Ideas
diff --git a/pyproject.toml b/pyproject.toml
index c3cc31b7..794077d3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "reasoning_gym"
-version = "0.1.3"
+version = "0.1.5"
authors = [
{ name = "Open-Thought community", email = "andreas.koepf@xamla.com" },
]
@@ -31,20 +31,20 @@ license = "Apache-2.0"
license-files = ["LICENSE*"]
[project.optional-dependencies]
-test = [
- "pytest>=7.0.0",
- "pytest-cov>=4.0.0",
-]
+test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"]
[project.urls]
"Homepage" = "https://github.com/open-thought/reasoning-gym"
"Bug Tracker" = "https://github.com/open-thought/reasoning-gym/issues"
-[tool.hatch.build.targets.wheel]
-packages = ["reasoning_gym"]
[tool.hatch.build]
-include = ["reasoning_gym/**/*.txt"]
+packages = ["reasoning_gym"]
+include = [
+ "reasoning_gym/**/*.py",
+ "reasoning_gym/**/*.txt",
+ "reasoning_gym/**/levels/*",
+]
[tool.black]
line-length = 120
@@ -58,6 +58,4 @@ line_length = 120
[tool.pytest.ini_options]
addopts = "-ra -q"
-testpaths = [
- "tests",
-]
+testpaths = ["tests"]
diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py
index 019873ff..ecca7f3f 100644
--- a/reasoning_gym/__init__.py
+++ b/reasoning_gym/__init__.py
@@ -5,7 +5,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin
from . import algebra, algorithmic, arithmetic, code, cognition, data, games, geometry, graphs, logic
from .factory import create_dataset, register_dataset
-__version__ = "0.1.3"
+__version__ = "0.1.5"
__all__ = [
"algebra",
"algorithmic",
diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py
index fc7a867a..fc77b977 100644
--- a/reasoning_gym/algebra/__init__.py
+++ b/reasoning_gym/algebra/__init__.py
@@ -1,9 +1,13 @@
+from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset
from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
+from .polynomial_multiplication import PolynomialMultiplicationConfig, PolynomialMultiplicationDataset
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset
from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset
__all__ = [
+ "ComplexArithmeticConfig",
+ "ComplexArithmeticDataset",
"IntermediateIntegrationConfig",
"IntermediateIntegrationDataset",
"PolynomialEquationsConfig",
@@ -12,4 +16,6 @@ __all__ = [
"SimpleEquationsConfig",
"SimpleIntegrationConfig",
"SimpleIntegrationDataset",
+ "PolynomialMultiplicationConfig",
+ "PolynomialMultiplicationDataset",
]
diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py
new file mode 100644
index 00000000..7c749eaa
--- /dev/null
+++ b/reasoning_gym/algebra/complex_arithmetic.py
@@ -0,0 +1,147 @@
+import cmath
+import math
+import random
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+from ..factory import ProceduralDataset, register_dataset
+
+
+@dataclass
+class ComplexArithmeticConfig:
+ min_real: int = -10
+ max_real: int = 10
+ min_imag: int = -10
+ max_imag: int = 10
+ operations: Tuple[str, ...] = ("+", "-", "*", "/")
+ seed: Optional[int] = None
+ size: int = 500
+
+ def validate(self) -> None:
+ """Validate configuration parameters."""
+ assert self.max_real >= self.min_real, "max_real must be >= min_real"
+ assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag"
+ assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator"
+
+
+class ComplexArithmeticDataset(ProceduralDataset):
+ """Generates complex number arithmetic problems."""
+
+ def __init__(self, config: ComplexArithmeticConfig):
+ self._prompt_templates = {
+ "+": "Add the complex numbers: ({a}) + ({b})",
+ "-": "Subtract the complex numbers: ({a}) - ({b})",
+ "*": "Multiply the complex numbers: ({a}) × ({b})",
+ "/": "Divide the complex numbers: ({a}) ÷ ({b})",
+ }
+ super().__init__(config=config, seed=config.seed, size=config.size)
+
+ def _generate_complex(self, rng: random.Random) -> complex:
+ """Generate a random complex number."""
+ real = rng.randint(self.config.min_real, self.config.max_real)
+ imag = rng.randint(self.config.min_imag, self.config.max_imag)
+ return complex(real, imag)
+
+ def _format_complex(self, z: complex) -> str:
+ """Format complex number with 2 decimal places."""
+ real, imag = z.real, z.imag
+ if abs(imag) < 1e-10:
+ return f"{real:.2f}"
+ elif abs(real) < 1e-10:
+ return f"{imag:.2f}i"
+ else:
+ sign = "+" if imag >= 0 else "-"
+ return f"{real} {sign} {abs(imag)}i"
+
+ def __getitem__(self, idx: int) -> dict:
+ rng = random.Random(self.seed + idx)
+
+ # Choose random operation
+ op = rng.choice(self.config.operations)
+
+ if op == "/":
+ # For division, first generate the quotient (a) and divisor (b)
+ # Then calculate the dividend (result) as a * b
+ a = self._generate_complex(rng) # This will be the final result
+ b = self._generate_complex(rng)
+ while b == 0: # Ensure non-zero divisor
+ b = self._generate_complex(rng)
+ result = a # Store the intended result
+ a = result * b # Calculate dividend to ensure whole number division
+ else:
+ # For other operations, generate numbers normally
+ a = self._generate_complex(rng)
+ b = self._generate_complex(rng)
+
+ # Calculate result
+ if op == "+":
+ result = a + b
+ elif op == "-":
+ result = a - b
+ else: # op == "*"
+ result = a * b
+
+ question = self._prompt_templates[op].format(a=self._format_complex(a), b=self._format_complex(b))
+
+ return {
+ "question": question,
+ "answer": self._format_complex(result),
+ "metadata": {
+ "num1": (a.real, a.imag),
+ "num2": (b.real, b.imag),
+ "operation": op,
+ "result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers
+ },
+ }
+
+ @staticmethod
+ def parse_string_to_complex(answer: str) -> complex:
+ try:
+ # Normalize the answer string by removing spaces and converting to lowercase
+ answer = answer.replace(" ", "").lower()
+ # Convert mathematical notation 'i' to Python's 'j' for complex numbers
+ answer = answer.replace("i", "j")
+
+ # Handle real numbers (no imaginary part)
+ if "j" not in answer:
+ student_result = complex(float(answer))
+ else:
+ # Handle cases like "j" or "2j" (implicit coefficient)
+ if answer[0] == "j":
+ # Convert "j" to "1j", "2j" remains unchanged
+ answer = "1" + answer
+ # Handle cases like "3j" where there's no explicit + or - before j
+ elif answer[-1] == "j" and not any(c in answer[:-1] for c in "+-"):
+ # Convert "3j" to "3+1j"
+ answer = answer.replace("j", "+1j")
+
+ # Ensure the string has an imaginary part, even if zero
+ if "j" not in answer:
+ answer += "+0j"
+
+ # Parse the normalized string into a complex number
+ student_result = complex(answer)
+
+ except ValueError:
+ return None
+
+ return student_result
+
+ def score_answer(self, answer: str, metadata: dict) -> float:
+ """Score the answer using exponential distance-based scoring."""
+ if answer is None:
+ return 0.0
+
+ try:
+ student_result = self.parse_string_to_complex(answer)
+ expected_result = complex(*metadata["result"])
+ # Calculate distance-based score using exponential decay
+ distance = abs(student_result - expected_result)
+ score = min(1.0, math.exp(-distance)) # Add 'import math' at the top
+ return score
+
+ except (ValueError, TypeError):
+ return 0.0
+
+
+register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig)
diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py
new file mode 100644
index 00000000..9bcadc66
--- /dev/null
+++ b/reasoning_gym/algebra/polynomial_multiplication.py
@@ -0,0 +1,161 @@
+import random
+import string
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple
+
+import sympy as sp
+from sympy import Eq, Symbol, expand, solve
+
+from ..factory import ProceduralDataset, register_dataset
+
+
+@dataclass
+class PolynomialMultiplicationConfig:
+ """
+ Configuration for polynomial multiplication task generation.
+ """
+
+ min_terms: int = 2 # Minimum number of polynomial terms
+ max_terms: int = 4 # Maximum number of polynomial terms
+ min_value: int = 1 # Minimum value for coefficients
+ max_value: int = 100 # Maximum value for coefficients
+ min_degree: int = 1 # Minimum polynomial degree
+ max_degree: int = 3 # Maximum polynomial degree
+ min_polynomials: int = 2 # Minimum number of polynomials being multiplied
+ max_polynomials: int = 3 # Maximum number of polynomials being multiplied
+ single_variable: bool = (True,)
+ operators: Tuple[str, ...] = (
+ "+",
+ "-",
+ ) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
+ seed: Optional[int] = None
+ size: int = 500
+
+ def validate(self) -> None:
+ """Validate configuration parameters."""
+ assert self.min_terms > 0, "min_terms must be positive."
+ assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
+
+ assert self.min_value > 0, "min_value must be positive."
+ assert self.max_value >= self.min_value, "max_value must be >= min_value."
+
+ assert self.min_degree >= 1, "min_degree must be >= 1."
+ assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
+
+ assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
+ assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
+
+ allowed_ops = {"+", "-"}
+ assert len(self.operators) > 0, "operators tuple cannot be empty."
+ assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
+
+
+class PolynomialMultiplicationDataset(ProceduralDataset):
+ """
+ Generates [min_polynomials, max_polynomials] random polynomials of degree in [min_degree, max_degree].
+ - The polynomial is formed by summing random terms of the form: coeff * x^exponent.
+ - Then we find "F = P_0 * ... * P_1" using Sympy.
+ """
+
+ def __init__(self, config: PolynomialMultiplicationConfig):
+ self._prompt_templates = [
+ "Simplify this expression: {polynomial_expr}",
+ "Calculate the following: {polynomial_expr}",
+ ]
+ super().__init__(config=config, seed=config.seed, size=config.size)
+
+ def __getitem__(self, idx: int) -> dict:
+ """
+ Generate a single polynomial multiplication item.
+
+ Returns:
+ A dict with:
+ - question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
+ - answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
+ - metadata: dict with details (polynomial_expr, single_variable)
+ """
+ rng = random.Random(self.seed + idx)
+ number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
+ polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)]
+
+ polynomial_expr = sp.prod(polynomials)
+ product = sp.expand(polynomial_expr)
+
+ return {
+ "question": rng.choice(self._prompt_templates).format(
+ polynomial_expr=polynomial_expr,
+ ),
+ "answer": product,
+ "metadata": {
+ "polynomial_expr": str(polynomial_expr),
+ "single_variable": self.config.single_variable,
+ "result": str(product),
+ },
+ }
+
+ def _get_variable(self, rng: random.Random) -> str:
+ """Get a random lowercase variable name"""
+ if self.config.single_variable:
+ return "x"
+ return rng.choice(string.ascii_lowercase)
+
+ def _generate_polynomial_expr(self, rng: random.Random):
+ """
+ Randomly generate a polynomial expression of 'degree'.
+ We'll use the config parameters:
+ - min_terms, max_terms: how many total terms to combine
+ - min_value, max_value: range for coefficients
+ - operators: to decide sign flips or direct addition
+
+ Args:
+ rng: Random number generator
+
+ Returns:
+ Polynomial string
+ """
+ variable = self._get_variable(rng)
+ degree = rng.randint(self.config.min_degree, self.config.max_degree)
+
+ x = Symbol(variable)
+
+ # Choose the number of terms and their respective degrees
+ num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
+ # Keep track of exponents, exponents can repeat or skip but we force the highest exponent
+ chosen_exponents = [degree]
+ # Fill the rest randomly in [0, degree]
+ for _ in range(num_terms - 1):
+ exp = rng.randint(0, degree)
+ chosen_exponents.append(exp)
+
+ # Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
+ polynomial_expr = 0
+ for exp in chosen_exponents:
+ coeff = rng.randint(self.config.min_value, self.config.max_value)
+ # If '-' in operators, we can randomly flip the sign
+ if "-" in self.config.operators and rng.random() < 0.5:
+ coeff = -coeff
+ term_expr = coeff * (x**exp)
+ polynomial_expr += term_expr
+
+ return polynomial_expr
+
+ def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
+ reward = 0.0
+ if answer is not None:
+ try:
+ predicted_poly = sp.parse_expr(answer)
+ target_poly = sp.parse_expr(metadata["result"])
+
+ # Check if the difference simplifies to zero (i.e. they are equivalent).
+ if sp.simplify(predicted_poly - target_poly) == 0:
+ reward = 1.0
+ elif answer.strip():
+ reward = 0.05
+ else:
+ reward = 0.01
+ except Exception:
+ reward = 0.01
+ return reward
+
+
+register_dataset("polynomial_multiplication", PolynomialMultiplicationDataset, PolynomialMultiplicationConfig)
diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py
index e5326af5..4c33d08f 100644
--- a/reasoning_gym/algorithmic/__init__.py
+++ b/reasoning_gym/algorithmic/__init__.py
@@ -9,6 +9,7 @@ Algorithmic tasks for training reasoning capabilities:
from .base_conversion import BaseConversionConfig, BaseConversionDataset
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
+from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset
from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
@@ -51,4 +52,6 @@ __all__ = [
"GroupAnagramsDataset",
"RansomNoteConfig",
"RansomNoteDataset",
+ "IsomorphicStringsConfig",
+ "IsomorphicStringsDataset",
]
diff --git a/reasoning_gym/algorithmic/isomorphic_strings.py b/reasoning_gym/algorithmic/isomorphic_strings.py
new file mode 100644
index 00000000..3b4a59e5
--- /dev/null
+++ b/reasoning_gym/algorithmic/isomorphic_strings.py
@@ -0,0 +1,121 @@
+"""Check if two strings are isomorphic.
+
+Two strings are isomorphic if the characters in one string can be replaced to get the second string.
+
+A popular Leetcode problem:
+https://leetcode.com/problems/isomorphic-strings/description/
+"""
+
+from dataclasses import dataclass
+from random import Random
+from typing import Optional
+
+from ..factory import ProceduralDataset, register_dataset
+
+QUESTION_TEMPLATE = """Two strings are isomorphic if the characters in one string can be replaced to get the second string.
+
+All occurrences of a character must be replaced with another character while preserving the order of characters.
+
+No two characters may map to the same character, but a character may map to itself.
+
+Example 1:
+Input: egg add
+Output: True
+Explanation: The strings s and t can be made identical by:
+ - Mapping 'e' to 'a'.
+ - Mapping 'g' to 'd'.
+
+Example 2:
+Input: foo bar
+Output: False
+Explanation:
+ - The strings cannot be made identical as 'o' needs to be mapped to both 'a' and 'r'.
+
+Return True if the following two strings are isomorphic, or False otherwise:
+{s} {t}
+"""
+
+
+@dataclass
+class IsomorphicStringsConfig:
+ """Configuration for Isomorphic Strings dataset generation"""
+
+ max_string_length: int = 10 # Maximum length of the strings
+ p_solvable: float = 0.5 # Probability that the generated question is solvable
+
+ size: int = 500 # Virtual dataset size
+ seed: Optional[int] = None
+
+ def validate(self):
+ """Validate configuration parameters"""
+ assert 2 <= self.max_string_length, "max_string_length must be at least 2"
+ assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1"
+
+
+class IsomorphicStringsDataset(ProceduralDataset):
+ """Generates Isomorphic Strings exercises with configurable difficulty"""
+
+ def __init__(self, config: IsomorphicStringsConfig):
+ super().__init__(config=config, seed=config.seed, size=config.size)
+ self.letters = {chr(i) for i in range(ord("a"), ord("z") + 1)}
+
+ def _check_isomorphic(self, s: str, t: str) -> bool:
+ """Check if two strings are isomorphic"""
+ if len(s) != len(t):
+ return False
+
+ mapping, inverse_mapping = {}, {} # s -> t, t -> s
+ for i in range(len(s)):
+ if (s[i] in mapping and mapping[s[i]] != t[i]) or (
+ t[i] in inverse_mapping and s[i] != inverse_mapping[t[i]]
+ ):
+ return False
+ mapping[s[i]] = t[i]
+ inverse_mapping[t[i]] = s[i]
+
+ return True
+
+ def _generate_inputs(self, rng: Random, solvable: bool) -> tuple[str, str]:
+ """Generate the two input strings"""
+ s, t = [], []
+ mapping = {}
+
+ # Generate a valid isomorphic pair first (leave one character for potential conflict)
+ for _ in range(rng.randint(1, self.config.max_string_length - 1)):
+ char_s = rng.choice(list(self.letters))
+ if char_s not in mapping:
+ # Choose a random character that is not already mapped
+ char_t = rng.choice(list(self.letters - set(mapping.values())))
+ mapping[char_s] = char_t
+ else:
+ # Use the existing mapping
+ char_t = mapping[char_s]
+ s.append(char_s)
+ t.append(char_t)
+
+ if not solvable:
+ # Solution should be unsolvable, create conflict
+ letter = rng.choice(list(mapping.keys()))
+ conflict = rng.choice(list(self.letters - {mapping[letter]}))
+ insert_idx = rng.randint(0, len(s))
+ s.insert(insert_idx, letter)
+ t.insert(insert_idx, conflict)
+
+ return "".join(s), "".join(t)
+
+ def __getitem__(self, idx: int) -> dict:
+ """Generate a single Isomorphic Strings question"""
+ rng = Random(self.seed + idx)
+
+ solvable = rng.random() < self.config.p_solvable
+ s, t = self._generate_inputs(rng, solvable)
+ answer = self._check_isomorphic(s, t)
+
+ return {
+ "question": QUESTION_TEMPLATE.format(s=s, t=t),
+ "answer": str(answer),
+ "metadata": {"words": [s, t], "solution": answer, "solvable": solvable},
+ }
+
+
+register_dataset("isomorphic_strings", IsomorphicStringsDataset, IsomorphicStringsConfig)
diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py
index 958dcd01..295f6cdf 100644
--- a/reasoning_gym/games/__init__.py
+++ b/reasoning_gym/games/__init__.py
@@ -14,6 +14,7 @@ from .n_queens import NQueensDataset
from .sokoban import SokobanConfig, SokobanDataset
from .sudoku import SudokuConfig, SudokuDataset
from .tower_of_hanoi import HanoiConfig, HanoiDataset
+from .tsumego import TsumegoConfig, TsumegoDataset
__all__ = [
"CountdownConfig",
@@ -31,4 +32,6 @@ __all__ = [
"HanoiConfig",
"HanoiDataset",
"NQueensDataset",
+ "TsumegoConfig",
+ "TsumegoDataset",
]
diff --git a/reasoning_gym/games/tsumego.py b/reasoning_gym/games/tsumego.py
new file mode 100644
index 00000000..be1e4fd6
--- /dev/null
+++ b/reasoning_gym/games/tsumego.py
@@ -0,0 +1,305 @@
+"""Go problem (tsumego) generator"""
+
+"""
+This module generates one-move Tsumego puzzles, which are Go problems focused on tactical capture scenarios.
+
+The puzzles generated here have the following characteristics:
+- They are created on a board of configurable size (with a minimum and maximum board size).
+- A number of stones are randomly placed on the board, subject to a maximum stone limit.
+- A specific capture problem is then constructed by arranging white stones in a plus-shaped formation.
+- Extra liberties surrounding this white group are filled with black stones, except for one key liberty.
+ This forces a situation where a single move by Black (at the remaining liberty) results in a capture.
+- Puzzle generation is deterministic given a seed, which ensures reproducibility.
+
+These puzzles are intended to provide focused practice on reading and executing capturing moves in Go.
+
+TODO: Generate multi-step Tsumego problems.
+"""
+
+import re
+from dataclasses import dataclass
+from random import Random
+from typing import Any, Dict, List, Optional, Set, Tuple
+
+from ..factory import ProceduralDataset, register_dataset
+
+# Added constant to avoid repetition of adjacent directions
+DIRECTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)]
+
+
+@dataclass
+class TsumegoConfig:
+ """Configuration for Tsumego problem generation"""
+
+ min_board_size: int = 9
+ max_board_size: int = 13
+ max_stones: int = 15
+ size: int = 100
+ seed: Optional[int] = None
+
+ def __post_init__(self):
+ """Validate configuration parameters"""
+ if self.min_board_size < 5:
+ raise ValueError("min_board_size must be at least 5")
+ if self.max_board_size > 19:
+ raise ValueError("max_board_size must be at most 19")
+ if self.min_board_size > self.max_board_size:
+ raise ValueError("min_board_size must be less than or equal to max_board_size")
+ if self.max_stones < 5:
+ raise ValueError("max_stones must be at least 5")
+
+
+class TsumegoDataset(ProceduralDataset):
+ """Generates Tsumego problems with configurable parameters"""
+
+ def __init__(self, config: TsumegoConfig):
+ self._prompt_templates = [
+ "Tsumego time. Black to play and capture some stones.\nFind the key move.",
+ "I have a Go problem for you. Black moves next - can you capture some of the white stones?",
+ "Here's a Go challenge. Playing as Black, how can you capture as many white stones as possible?",
+ ]
+ self._ko_point = None
+ super().__init__(config=config, seed=config.seed, size=config.size)
+
+ # New helper method for board copying
+ def _copy_board(self, board: List[List[str]]) -> List[List[str]]:
+ """Return a deep copy of the board."""
+ return [row[:] for row in board]
+
+ def _get_liberties(self, board: List[List[str]], row: int, col: int) -> Set[Tuple[int, int]]:
+ """Get empty adjacent points (liberties) for a stone"""
+ size = len(board)
+ liberties = set()
+ for dr, dc in DIRECTIONS:
+ r, c = row + dr, col + dc
+ if 0 <= r < size and 0 <= c < size and board[r][c] == ".":
+ liberties.add((r, c))
+ return liberties
+
+ def _get_group(self, board: List[List[str]], row: int, col: int) -> Set[Tuple[int, int]]:
+ """Get all stones in the same group (connected stones of same color)"""
+ size = len(board)
+ color = board[row][col]
+ if color == ".":
+ return set()
+
+ group = {(row, col)}
+ queue = [(row, col)]
+ while queue:
+ r, c = queue.pop(0)
+ for dr, dc in DIRECTIONS:
+ nr, nc = r + dr, c + dc
+ if 0 <= nr < size and 0 <= nc < size and board[nr][nc] == color and (nr, nc) not in group:
+ group.add((nr, nc))
+ queue.append((nr, nc))
+ return group
+
+ def _count_liberties(self, board: List[List[str]], group: Set[Tuple[int, int]]) -> int:
+ """Count total liberties for a group of stones"""
+ liberties = set()
+ for row, col in group:
+ liberties.update(self._get_liberties(board, row, col))
+ return len(liberties)
+
+ def _would_capture(self, board: List[List[str]], row: int, col: int, color: str) -> bool:
+ """Check if a move would capture any opponent stones"""
+ size = len(board)
+ opponent = "O" if color == "X" else "X"
+
+ # Make a copy of the board and place the stone
+ board_copy = self._copy_board(board)
+ board_copy[row][col] = color
+
+ checked = set()
+ for dr, dc in DIRECTIONS:
+ r, c = row + dr, col + dc
+ if 0 <= r < size and 0 <= c < size and board_copy[r][c] == opponent and (r, c) not in checked:
+ group = self._get_group(board_copy, r, c)
+ checked.update(group)
+ if self._count_liberties(board_copy, group) == 0:
+ return True
+ return False
+
+ def _is_valid_move(self, board: List[List[str]], row: int, col: int, color: str) -> bool:
+ """Check if a move is legal (not suicide, unless it captures)"""
+ size = len(board)
+ if not (0 <= row < size and 0 <= col < size):
+ return False
+ if board[row][col] != ".":
+ return False
+ if (row, col) == self._ko_point:
+ return False
+
+ # If the move captures opponent stones, it's valid
+ if self._would_capture(board, row, col, color):
+ return True
+
+ board_copy = self._copy_board(board)
+ board_copy[row][col] = color
+ group = self._get_group(board_copy, row, col)
+ return self._count_liberties(board_copy, group) > 0
+
+ def _make_move(self, board: List[List[str]], row: int, col: int, color: str) -> bool:
+ """Make a move and update ko point. Returns True if move was valid."""
+ if not self._is_valid_move(board, row, col, color):
+ return False
+
+ self._ko_point = None
+ board[row][col] = color
+ opponent = "O" if color == "X" else "X"
+ captured_stones = []
+
+ for dr, dc in DIRECTIONS:
+ r, c = row + dr, col + dc
+ if 0 <= r < len(board) and 0 <= c < len(board) and board[r][c] == opponent:
+ group = self._get_group(board, r, c)
+ if self._count_liberties(board, group) == 0:
+ captured_stones.extend(group)
+
+ if len(captured_stones) == 1 and len(self._get_group(board, row, col)) == 1:
+ self._ko_point = captured_stones[0]
+
+ for r, c in captured_stones:
+ board[r][c] = "."
+
+ return True
+
+ def _generate_capture_problem(self, size: int, rng: Random) -> Tuple[List[List[str]], Tuple[int, int]]:
+ """Generate a capture problem"""
+ board = [["." for _ in range(size)] for _ in range(size)]
+ stones_placed = 0
+ max_stones = self.config.max_stones - 4 # Reserve space for capture setup
+
+ while stones_placed < max_stones:
+ row = rng.randint(0, size - 1)
+ col = rng.randint(0, size - 1)
+ color = "X" if rng.random() < 0.5 else "O"
+ if board[row][col] == "." and self._is_valid_move(board, row, col, color):
+ self._make_move(board, row, col, color)
+ stones_placed += 1
+
+ tries = 0
+ formation_options = {
+ "plus": {
+ "white_offsets": [(0, 0), (-1, 0), (1, 0), (0, -1)],
+ "forced_move_offset": (0, 1),
+ "neighbor_offsets": [(0, 0), (-1, 0), (1, 0), (0, -1), (0, 1)],
+ },
+ "L": {
+ "white_offsets": [(0, 0), (0, 1), (1, 0)],
+ "forced_move_offset": (1, 1),
+ "neighbor_offsets": [(0, 0), (0, 1), (1, 0), (1, 1)],
+ },
+ "T": {
+ "white_offsets": [(0, -1), (0, 0), (0, 1), (1, 0)],
+ "forced_move_offset": (-1, 0),
+ "neighbor_offsets": [(0, -1), (0, 0), (0, 1), (1, 0), (-1, 0)],
+ },
+ }
+
+ while tries < 50:
+ row = rng.randint(1, size - 2)
+ col = rng.randint(1, size - 2)
+ formation_type = rng.choice(list(formation_options.keys()))
+ formation = formation_options[formation_type]
+ if all(board[row + dr][col + dc] == "." for dr, dc in formation["neighbor_offsets"]):
+ # Place white stones according to chosen formation
+ for dr, dc in formation["white_offsets"]:
+ board[row + dr][col + dc] = "O"
+ forced_move = (row + formation["forced_move_offset"][0], col + formation["forced_move_offset"][1])
+ white_group = {(row + dr, col + dc) for dr, dc in formation["white_offsets"]}
+ extra_liberties = set()
+ for r, c in white_group:
+ extra_liberties |= self._get_liberties(board, r, c)
+ extra_liberties.discard(forced_move)
+ for r, c in extra_liberties:
+ board[r][c] = "X"
+
+ # Add decoy stone to enhance puzzle difficulty
+ current_stone_count = sum(cell in "XO" for row in board for cell in row)
+ if current_stone_count < self.config.max_stones + 7:
+ center = (row, col) # using the base white stone as center
+ decoy_candidates = []
+ for i in range(center[0] - 2, center[0] + 3):
+ for j in range(center[1] - 2, center[1] + 3):
+ if abs(i - center[0]) + abs(j - center[1]) == 2:
+ if 0 <= i < size and 0 <= j < size and board[i][j] == "." and (i, j) != forced_move:
+ decoy_candidates.append((i, j))
+ if decoy_candidates:
+ decoy_pos = rng.choice(decoy_candidates)
+ decoy_color = "X" if rng.random() < 0.5 else "O"
+ board[decoy_pos[0]][decoy_pos[1]] = decoy_color
+
+ if self._is_valid_move(board, forced_move[0], forced_move[1], "X"):
+ return board, forced_move
+ tries += 1
+ raise RuntimeError("Failed to generate a capture problem")
+
+ def _board_to_string(self, board: List[List[str]]) -> str:
+ """Convert board to string representation"""
+ size = len(board)
+ # Column labels
+ cols = " " + " ".join(chr(ord("A") + i) for i in range(size)) + "\n"
+ # Board with row numbers
+ rows = [f"{size-i:2d} {' '.join(row)}" for i, row in enumerate(board)]
+ return cols + "\n".join(rows)
+
+ def __getitem__(self, idx: int) -> dict:
+ """Generate a single Tsumego problem
+
+ Returns:
+ dict with:
+ - "question": Problem description and board state
+ - "answer": Solution move(s)
+ - "metadata": Problem details and configuration
+ """
+ rng = Random(self.seed + idx if self.seed is not None else None)
+ size = rng.randint(self.config.min_board_size, self.config.max_board_size)
+
+ board, solution = self._generate_capture_problem(size, rng)
+ board_str = self._board_to_string(board)
+ solution_str = f"{chr(ord('A')+solution[1])}{size - solution[0]}"
+ self._ko_point = None
+
+ return {
+ "question": (
+ rng.choice(self._prompt_templates) + "\n\n" + board_str + "\n\n"
+ "X - Black\n"
+ "O - White\n\n"
+ "Specify your move in coordinates (e.g. 'C4' for column C, row 4)"
+ ),
+ "answer": solution_str,
+ "metadata": {"difficulty": {"board_size": size}, "board": board, "solution": solution_str},
+ }
+
+ def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
+ """Score the answer against the solution"""
+ if answer is None:
+ return 0.0
+ answer = answer.strip()
+ if not answer:
+ return 0.01
+ metadata = entry["metadata"]
+ board_size = len(metadata["board"])
+ expected_row, expected_col = metadata["solution"] # get solution from (row, col) tuple
+
+ try:
+ # Assume letter-number format, e.g. "C4"
+ m = re.match(r"^([A-Za-z])(\d+)$", answer)
+ if not m:
+ return 0.01
+ col_letter, row_str = m.group(1), m.group(2)
+ row = board_size - int(row_str)
+ col = ord(col_letter.upper()) - ord("A")
+ if (row, col) == (expected_row, expected_col):
+ return 1.0
+
+ if 0 <= row < board_size and 0 <= col < board_size:
+ return 0.05
+ except Exception:
+ return 0.01
+ return 0.01
+
+
+# Register the dataset
+register_dataset("tsumego", TsumegoDataset, TsumegoConfig)
diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py
index dfa1c7ad..c05c4dba 100644
--- a/reasoning_gym/logic/__init__.py
+++ b/reasoning_gym/logic/__init__.py
@@ -4,6 +4,7 @@ Logic tasks for training reasoning capabilities.
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
+from .self_reference import SelfReferenceConfig, SelfReferenceDataset
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
from .zebra_puzzles import ZebraConfig, ZebraDataset
@@ -18,4 +19,6 @@ __all__ = [
"Term",
"ZebraConfig",
"ZebraDataset",
+ "SelfReference",
+ "SelfReferenceDataset",
]
diff --git a/reasoning_gym/logic/self_reference.py b/reasoning_gym/logic/self_reference.py
new file mode 100644
index 00000000..d8155b4c
--- /dev/null
+++ b/reasoning_gym/logic/self_reference.py
@@ -0,0 +1,373 @@
+from dataclasses import dataclass
+from random import Random
+from typing import Dict, Optional
+
+from ..factory import ProceduralDataset, register_dataset
+
+
+def is_prime(n):
+ """Return True if n is a prime number, False otherwise."""
+ if n < 2:
+ return False
+ for i in range(2, int(n**0.5) + 1):
+ if n % i == 0:
+ return False
+ return True
+
+
+def is_composite(n):
+ """
+ Return True if n is composite.
+ (Composite means an integer greater than 1 that is not prime.)
+ """
+ return n > 1 and not is_prime(n)
+
+
+def generate_dynamic_puzzle(difficulty, rng):
+ """
+ Dynamically generates a 7-statement self-referential puzzle.
+
+ The seven statements (with parameters determined by this function) are:
+
+ 1. "At least a of these 7 statements are true."
+ 2. "At most b of these 7 statements are false."
+ 3. "Exactly c of these 7 statements are true."
+ 4. "Exactly d of these 7 statements are false."
+ 5. "Either Statement 3 or Statement 4 is true, but not both."
+ 6. "The number of true statements is a prime number."
+ 7. "The number of false statements is a composite number."
+
+ The idea is to choose an intended number T (1 ≤ T ≤ 6) of true statements
+ and then “plant” an intended solution. In our construction the truth values
+ for Statements 6 and 7 are forced by T (e.g. Statement 6 should be true exactly
+ when T is prime). For the first four statements the numeric parameters (a, b, c, d)
+ are chosen so that the statement evaluates correctly when compared to T.
+
+ The difficulty parameter (an integer, e.g. 1 for easy up to 10 for hard)
+ influences how “borderline” the numeric choices are. At lower difficulty the numbers
+ are chosen with a clear gap; at higher difficulty they are chosen closer to T.
+
+ Returns:
+ dict: A puzzle dictionary containing:
+ - 'n': number of statements (always 7 here),
+ - 'statements_text': a list of 7 strings (one per statement),
+ - 'parameters': a dict with the numeric parameters (for statements 1-4),
+ - 'intended_assignment': the intended truth values (list of 7 booleans),
+ - 'intended_T': the intended number of true statements.
+ """
+ n = 7
+
+ # Choose an intended number of true statements, T, from 1 to 6 (nontrivial).
+ T = rng.choice(range(1, n))
+
+ # For the global statements (6 and 7), the intended truth is forced:
+ intended6 = is_prime(T) # Statement 6 must be true if T is prime.
+ intended7 = is_composite(n - T) # Statement 7 must be true if (# false) is composite.
+
+ # Among statements 1-5, we need exactly k trues such that overall the total becomes T.
+ # Let k = T - (truth from statements 6 and 7).
+ forced_true_count = (1 if intended6 else 0) + (1 if intended7 else 0)
+ k = T - forced_true_count
+ # k must be between 0 and 5.
+ if not (0 <= k <= 5):
+ # If for some reason it is not in range, fall back to a known configuration (T=4).
+ T = 4
+ intended6 = False
+ intended7 = False
+ k = 4 # so that overall T=4.
+ intended_assignment_15 = [True, True, True, True, False]
+ else:
+ # For statements 1-5, randomly choose which ones are intended true.
+ # We'll index these as 0..4 corresponding to statements 1..5.
+ intended_assignment_15 = [False] * 5
+ if k > 0:
+ true_indices = set(rng.sample(range(5), k))
+ for i in true_indices:
+ intended_assignment_15[i] = True
+
+ # Now, for statements 1-4, choose numeric parameters based on whether the statement is
+ # intended to be true or false. We use the difficulty parameter to control the "margin."
+ #
+ # For statement 1: "At least a of these 7 statements are true."
+ # The condition is: T >= a.
+ def choose_at_least_param(T, intended, diff, rng):
+ # diff will be used as a margin factor: lower diff => wider gap.
+ if intended: # must have a <= T.
+ # At easy difficulty, choose a clearly below T (if possible).
+ low = 1
+ high = T
+ # At lower difficulty, bias toward the lower end.
+ return rng.randint(low, high)
+ else: # must have a > T.
+ low = T + 1
+ high = n # a can be at most n.
+ if low > high:
+ return n
+ return rng.randint(low, high)
+
+ a_param = choose_at_least_param(T, intended_assignment_15[0], difficulty, rng)
+
+ # For statement 2: "At most b of these 7 statements are false."
+ # F = n - T, so condition is: (n - T) <= b <=> T >= n - b.
+ def choose_at_most_param(T, intended, diff, rng):
+ if intended: # b must be >= n - T.
+ low = n - T
+ high = n
+ return rng.randint(low, high)
+ else:
+ # b must be < n - T.
+ low = 0
+ high = max(n - T - 1, 0)
+ return rng.randint(low, high)
+
+ b_param = choose_at_most_param(T, intended_assignment_15[1], difficulty, rng)
+
+ # For statement 3: "Exactly c of these 7 statements are true."
+ def choose_exactly_true_param(T, intended, diff, rng):
+ if intended:
+ return T
+ else:
+ choices = [x for x in range(0, n + 1) if x != T]
+ return rng.choice(choices)
+
+ c_param = choose_exactly_true_param(T, intended_assignment_15[2], difficulty, rng)
+
+ # For statement 4: "Exactly d of these 7 statements are false."
+ # Condition: (n - T) == d.
+ def choose_exactly_false_param(T, intended, diff, rng):
+ false_count = n - T
+ if intended:
+ return false_count
+ else:
+ choices = [x for x in range(0, n + 1) if x != false_count]
+ return rng.choice(choices)
+
+ d_param = choose_exactly_false_param(T, intended_assignment_15[3], difficulty, rng)
+
+ # For statement 5: "Either Statement 3 or Statement 4 is true, but not both."
+ # We do not need a parameter here; the intended condition is that the truth values for
+ # statements 3 and 4 (which are positions 2 and 3 in our 0-indexed list) differ.
+ # The intended truth for statement 5 is taken from our assignment.
+ # (Later the verification function will check: solution[2] != solution[3].)
+
+ # Build the intended assignment for all 7 statements.
+ # For statements 1-5, we use our generated intended_assignment_15.
+ intended_assignment = [
+ intended_assignment_15[0],
+ intended_assignment_15[1],
+ intended_assignment_15[2],
+ intended_assignment_15[3],
+ intended_assignment_15[4],
+ intended6,
+ intended7,
+ ]
+
+ # (If the total intended true count doesn't equal T, adjust statement 5.)
+ current_T = sum(intended_assignment)
+ if current_T != T:
+ # Since only statement 5 is free (its parameter wasn't numeric),
+ # force its intended truth to be what is needed.
+ intended_assignment[4] = T - (current_T - (1 if intended_assignment[4] else 0)) == 1
+
+ # Now build the text for each statement.
+ statements_text = [
+ f"Statement 1: 'At least {a_param} of these 7 statements are true.'",
+ f"Statement 2: 'At most {b_param} of these 7 statements are false.'",
+ f"Statement 3: 'Exactly {c_param} of these 7 statements are true.'",
+ f"Statement 4: 'Exactly {d_param} of these 7 statements are false.'",
+ "Statement 5: 'Either Statement 3 or Statement 4 is true, but not both.'",
+ "Statement 6: 'The number of true statements is a prime number.'",
+ "Statement 7: 'The number of false statements is a composite number.'",
+ ]
+
+ return {
+ "n": n,
+ "statements_text": statements_text,
+ "parameters": {
+ "a": a_param,
+ "b": b_param,
+ "c": c_param,
+ "d": d_param,
+ },
+ "intended_assignment": intended_assignment,
+ "intended_T": T,
+ "difficulty": difficulty,
+ }
+
+
+def verify_solution_dynamic(puzzle, solution):
+ """
+ Verifies a candidate solution for a dynamically generated puzzle.
+
+ The rules are:
+ - If a statement is marked True, then its claim must hold.
+ - If a statement is marked False, then its claim must fail.
+
+ The conditions are as follows:
+ 1. "At least a of these 7 statements are true." => (T >= a)
+ 2. "At most b of these 7 statements are false." => (F <= b)
+ 3. "Exactly c of these 7 statements are true." => (T == c)
+ 4. "Exactly d of these 7 statements are false." => (F == d)
+ 5. "Either Statement 3 or Statement 4 is true, but not both." => (solution[2] != solution[3])
+ 6. "The number of true statements is a prime number." => is_prime(T)
+ 7. "The number of false statements is a composite number." => is_composite(F)
+
+ Parameters:
+ puzzle (dict): The puzzle dictionary returned by generate_dynamic_puzzle.
+ solution (list of bool): A candidate assignment (length 7).
+
+ Returns:
+ bool: True if candidate is self-consistent; False otherwise.
+ """
+ n = puzzle["n"]
+ if len(solution) != n:
+ return False
+ T = sum(solution)
+ F = n - T
+ params = puzzle["parameters"]
+
+ # Statement 1: "At least a of these 7 statements are true."
+ cond1 = T >= params["a"]
+ if solution[0] and not cond1:
+ return False
+ if not solution[0] and cond1:
+ return False
+
+ # Statement 2: "At most b of these 7 statements are false."
+ cond2 = F <= params["b"]
+ if solution[1] and not cond2:
+ return False
+ if not solution[1] and cond2:
+ return False
+
+ # Statement 3: "Exactly c of these 7 statements are true."
+ cond3 = T == params["c"]
+ if solution[2] and not cond3:
+ return False
+ if not solution[2] and cond3:
+ return False
+
+ # Statement 4: "Exactly d of these 7 statements are false."
+ cond4 = F == params["d"]
+ if solution[3] and not cond4:
+ return False
+ if not solution[3] and cond4:
+ return False
+
+ # Statement 5: "Either Statement 3 or Statement 4 is true, but not both."
+ cond5 = solution[2] != solution[3]
+ if solution[4] and not cond5:
+ return False
+ if not solution[4] and cond5:
+ return False
+
+ # Statement 6: "The number of true statements is a prime number."
+ cond6 = is_prime(T)
+ if solution[5] and not cond6:
+ return False
+ if not solution[5] and cond6:
+ return False
+
+ # Statement 7: "The number of false statements is a composite number."
+ cond7 = is_composite(F)
+ if solution[6] and not cond7:
+ return False
+ if not solution[6] and cond7:
+ return False
+
+ return True
+
+
+def print_puzzle_dynamic(puzzle):
+ """Prints the dynamically generated puzzle."""
+ x = ""
+ for stmt in puzzle["statements_text"]:
+ x = x + " - " + stmt + "\n"
+ return x
+
+
+def solve_puzzle_dynamic(puzzle):
+ """
+ Searches all 2^7 possible truth assignments and returns those that
+ are self-consistent with the generated puzzle.
+ """
+ n = puzzle["n"]
+ valid_solutions = []
+ for i in range(2**n):
+ candidate = [(i >> j) & 1 == 1 for j in range(n)]
+ if verify_solution_dynamic(puzzle, candidate):
+ valid_solutions.append(candidate)
+ return valid_solutions
+
+
+@dataclass
+class SelfReferenceConfig:
+ """Configuration for SelfReference puzzle generation"""
+
+ difficulty: int = 5
+ seed: Optional[int] = None
+ size: int = 500
+
+ def validate(self):
+ """Validate configuration parameters"""
+ assert 1 <= self.difficulty <= 10, "difficulty must be between 1 and 10"
+
+
+class SelfReferenceDataset(ProceduralDataset):
+ """Generates self-referential puzzles"""
+
+ def __init__(self, config: SelfReferenceConfig):
+ super().__init__(config=config, seed=config.seed, size=config.size)
+
+ def __getitem__(self, idx: int) -> dict:
+ """Generate a single SelfReference task
+
+ Returns:
+ dict with keys:
+ - question: str, the task description
+ - answer: str, a solution string
+ - metadata: dict with generation parameters
+ """
+ rng = Random(self.seed + idx)
+
+ # Generate puzzle
+ puzzle = generate_dynamic_puzzle(self.config.difficulty, rng)
+ puzz_s = (
+ "Given the truthfulness of these statements, please tell me the number of possible solutions: \n"
+ + print_puzzle_dynamic(puzzle)
+ )
+
+ # Solve puzzle
+ solutions = solve_puzzle_dynamic(puzzle)
+ for idx, sol in enumerate(solutions, start=1):
+ sol_str = ["True" if s else "False" for s in sol]
+ answer = len(solutions)
+
+ return {
+ "question": puzz_s,
+ "answer": answer,
+ "metadata": {},
+ }
+
+ def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
+ """Determine if the solution provided solves the SelfReference task.
+
+ The function awards 1.0 for a correct answer.
+
+ Args:
+ answer (Optional[str]): The user's answer.
+ entry (Dict[str, any]): The original dataset entry containing the correct answer.
+
+ Returns:
+ float: The computed score between 0.0 and 1.0.
+ """
+
+ if answer == None:
+ return 0.0
+ if str(answer) != str(entry["answer"]):
+ return 0.1
+ else:
+ return 1.0 # Yay
+
+
+register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig)
diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py
index a5bbb219..37b87a6f 100644
--- a/reasoning_gym/logic/syllogisms.py
+++ b/reasoning_gym/logic/syllogisms.py
@@ -22,23 +22,21 @@ class Term:
self.name = name
self.plural = plural
+ def __repr__(self) -> str:
+ """Return string representation of the term"""
+ return f"Term({self.name}, {self.plural})"
+
@dataclass
class SyllogismConfig:
"""Configuration for syllogism task generation"""
- # Lists of terms to use in syllogisms
- terms: List[Term] = None # Will be populated with defaults if None
-
# Control which quantifiers to use
allow_all: bool = True
allow_no: bool = True
allow_some: bool = True
allow_some_not: bool = True
- # Whether to include invalid syllogisms as negative examples
- include_invalid: bool = True
-
# Percentage of invalid examples if included (0.0 to 1.0)
invalid_ratio: float = 0.3
@@ -101,7 +99,7 @@ class SyllogismDataset(ProceduralDataset):
def __init__(self, config: SyllogismConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
- self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms
+ self.terms = self.DEFAULT_TERMS
def _get_allowed_quantifiers(self) -> List[Quantifier]:
"""Get list of allowed quantifiers based on config"""
@@ -116,95 +114,126 @@ class SyllogismDataset(ProceduralDataset):
quantifiers.append(Quantifier.SOME_NOT)
return quantifiers
+ @staticmethod
def _is_valid_syllogism(
- self,
- premise1: Tuple[Quantifier, Term, Term],
- premise2: Tuple[Quantifier, Term, Term],
- conclusion: Tuple[Quantifier, Term, Term],
+ premise1: Tuple[Quantifier, "Term", "Term"],
+ premise2: Tuple[Quantifier, "Term", "Term"],
+ conclusion: Tuple[Quantifier, "Term", "Term"],
) -> bool:
"""
- Check if a syllogism is logically valid using classical logic rules.
-
- Rules implemented:
- 1. Universal Affirmative (ALL):
- - If both premises are ALL, conclusion must be ALL
- - ALL A are B + ALL B are C → ALL A are C (Barbara)
-
- 2. Universal Negative (NO):
- - If one premise is NO and other is ALL, conclusion must be NO
- - NO A are B + ALL C are B → NO A are C (Celarent)
- - ALL A are B + NO C are B → NO A are C (Cesare)
-
- 3. Particular Affirmative (SOME):
- - If one premise is SOME and other is ALL, conclusion must be SOME
- - SOME A are B + ALL B are C → SOME A are C (Darii)
- - ALL A are B + SOME C are B → SOME A are C (Disamis)
-
- 4. Particular Negative (SOME_NOT):
- - If one premise is SOME_NOT and other is ALL, conclusion can be SOME_NOT
- - SOME A are not B + ALL B are C → SOME A are not C (Ferio)
- - ALL A are B + SOME C are not B → SOME A are not C (Festino)
-
- 5. Invalid combinations:
- - Two negative premises never yield a valid conclusion
- - Two particular premises never yield a valid conclusion
- - If both premises are particular, no valid conclusion
- - If conclusion is universal but either premise is particular, invalid
+ Checks whether a given syllogism is valid under classical (Aristotelian) rules,
+ including the distribution rule:
+ - If a term is distributed in the conclusion, it must be distributed
+ in the premise where it appears as subject/predicate.
"""
- q1, t1_1, t1_2 = premise1
- q2, t2_1, t2_2 = premise2
- qc, tc_1, tc_2 = conclusion
- # Rule 5: Two negative premises -> invalid
- if q1 in (Quantifier.NO, Quantifier.SOME_NOT) and q2 in (Quantifier.NO, Quantifier.SOME_NOT):
+ # --- 1) Extract data ---
+ q1, p1_subj, p1_pred = premise1
+ q2, p2_subj, p2_pred = premise2
+ q3, c_subj, c_pred = conclusion
+
+ negative_set = {Quantifier.NO, Quantifier.SOME_NOT}
+ particular_set = {Quantifier.SOME, Quantifier.SOME_NOT}
+ universal_set = {Quantifier.ALL, Quantifier.NO}
+
+ # --- 2) Identify a unique middle term ---
+ premise1_terms = {p1_subj, p1_pred}
+ premise2_terms = {p2_subj, p2_pred}
+ common_terms = premise1_terms.intersection(premise2_terms)
+
+ if len(common_terms) != 1:
+ return False
+ middle_term = next(iter(common_terms))
+
+ # Gather all terms => must be exactly 3 distinct terms
+ all_terms = premise1_terms.union(premise2_terms)
+ if len(all_terms) != 3:
return False
- # Rule 5: Two particular premises -> invalid
- if q1 in (Quantifier.SOME, Quantifier.SOME_NOT) and q2 in (Quantifier.SOME, Quantifier.SOME_NOT):
+ # The conclusion must use the other two terms (not the middle)
+ other_two = all_terms - {middle_term}
+ conclusion_terms = {c_subj, c_pred}
+ if conclusion_terms != other_two:
return False
- # Rule 5: Universal conclusion with particular premise -> invalid
- if qc in (Quantifier.ALL, Quantifier.NO) and (
- q1 in (Quantifier.SOME, Quantifier.SOME_NOT) or q2 in (Quantifier.SOME, Quantifier.SOME_NOT)
- ):
+ # --- 3) Identify which premise is major vs. minor ---
+ def premise_contains(premise, term):
+ return (premise[1] == term) or (premise[2] == term)
+
+ if premise_contains(premise1, c_pred):
+ major = premise1
+ minor = premise2
+ elif premise_contains(premise2, c_pred):
+ major = premise2
+ minor = premise1
+ else:
return False
- # Rule 1: Barbara syllogism
- if q1 == Quantifier.ALL and q2 == Quantifier.ALL:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.ALL
+ # The minor premise must contain the conclusion's subject
+ if not premise_contains(minor, c_subj):
+ return False
- # Rule 2: Celarent syllogism
- if q1 == Quantifier.NO and q2 == Quantifier.ALL:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.NO
+ # --- 4) Quick checks (traditional “no two negative,” etc.) ---
+ if (q1 in negative_set) and (q2 in negative_set):
+ return False
+ if (q1 in particular_set) and (q2 in particular_set):
+ return False
+ if q3 in universal_set:
+ if (q1 in particular_set) or (q2 in particular_set):
+ return False
+ if q3 in negative_set:
+ if not ((q1 in negative_set) or (q2 in negative_set)):
+ return False
- # Rule 2: Cesare syllogism
- if q1 == Quantifier.ALL and q2 == Quantifier.NO:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.NO
+ # --- 5) Distribution checks ---
+ def distribution(q: Quantifier):
+ if q == Quantifier.ALL: # A
+ return (True, False)
+ elif q == Quantifier.NO: # E
+ return (True, True)
+ elif q == Quantifier.SOME: # I
+ return (False, False)
+ elif q == Quantifier.SOME_NOT: # O
+ return (False, True)
+ else:
+ raise ValueError(f"Unknown quantifier: {q}")
- # Rule 3: Darii syllogism
- if q1 == Quantifier.SOME and q2 == Quantifier.ALL:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.SOME
+ # Conclusion distribution
+ dist_c_subj, dist_c_pred = distribution(q3)
- # Rule 3: Disamis syllogism
- if q1 == Quantifier.ALL and q2 == Quantifier.SOME:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.SOME
+ # Major premise distribution
+ q_major, major_subj, major_pred = major
+ dist_major_subj, dist_major_pred = distribution(q_major)
- # Rule 4: Ferio syllogism
- if q1 == Quantifier.SOME_NOT and q2 == Quantifier.ALL:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.SOME_NOT
+ # Minor premise distribution
+ q_minor, minor_subj, minor_pred = minor
+ dist_minor_subj, dist_minor_pred = distribution(q_minor)
- # Rule 4: Festino syllogism
- if q1 == Quantifier.ALL and q2 == Quantifier.SOME_NOT:
- if t1_2 == t2_1 and tc_1 == t1_1 and tc_2 == t2_2:
- return qc == Quantifier.SOME_NOT
+ # If the conclusion's subject is distributed, check it in the minor premise
+ if dist_c_subj:
+ if c_subj == minor_subj:
+ if not dist_minor_subj:
+ return False
+ elif c_subj == minor_pred:
+ if not dist_minor_pred:
+ return False
- return False
+ # If the conclusion's predicate is distributed, check it in the major premise
+ if dist_c_pred:
+ if c_pred == major_subj:
+ if not dist_major_subj:
+ return False
+ elif c_pred == major_pred:
+ if not dist_major_pred:
+ return False
+
+ # If either premise is negative, the conclusion must be negative.
+ if (q1 in negative_set) or (q2 in negative_set):
+ if q3 not in negative_set:
+ return False
+
+ # If all checks pass, it's valid
+ return True
def _format_quantifier_statement(self, quantifier: Quantifier, subject: Term, predicate: Term) -> str:
"""Format a quantified statement in natural language"""
@@ -219,18 +248,29 @@ class SyllogismDataset(ProceduralDataset):
terms = rng.sample(self.terms, 3)
quantifiers = self._get_allowed_quantifiers()
- # Generate premises and conclusion
- premise1 = (rng.choice(quantifiers), terms[0], terms[1])
- premise2 = (rng.choice(quantifiers), terms[1], terms[2])
- conclusion = (rng.choice(quantifiers), terms[0], terms[2])
+ target_valid = rng.random() > self.config.invalid_ratio # Invert ratio to match meaning
+ max_attempts = 100
+ attempts = 0
- # Decide if this should be a valid or invalid syllogism
- is_valid = True
- if self.config.include_invalid and rng.random() < self.config.invalid_ratio:
- is_valid = False
- # If should be invalid, regenerate conclusion until invalid
- while self._is_valid_syllogism(premise1, premise2, conclusion):
- conclusion = (rng.choice(quantifiers), terms[0], terms[2])
+ while attempts < max_attempts:
+ # Generate premises and conclusion
+ premise1 = (rng.choice(quantifiers), terms[0], terms[1])
+ premise2 = (rng.choice(quantifiers), terms[1], terms[2])
+ conclusion = (rng.choice(quantifiers), terms[0], terms[2])
+
+ # Check if validity matches target
+ is_valid = self._is_valid_syllogism(premise1, premise2, conclusion)
+ if is_valid == target_valid:
+ break
+
+ attempts += 1
+
+ if attempts >= max_attempts:
+ # If we couldn't find a matching syllogism, return a basic valid one
+ premise1 = (Quantifier.ALL, terms[0], terms[1])
+ premise2 = (Quantifier.ALL, terms[1], terms[2])
+ conclusion = (Quantifier.ALL, terms[0], terms[2])
+ is_valid = True
# Format the syllogism as text
premise1_text = self._format_quantifier_statement(premise1[0], premise1[1], premise1[2])
diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py
new file mode 100644
index 00000000..0d369fc1
--- /dev/null
+++ b/tests/test_complex_arithmetic.py
@@ -0,0 +1,90 @@
+import pytest
+
+from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset
+
+
+def test_complex_arithmetic_basic():
+ """Test basic functionality of complex arithmetic dataset."""
+ config = ComplexArithmeticConfig(
+ min_real=-5, max_real=5, min_imag=-5, max_imag=5, operations=("+", "-", "*", "/"), seed=42, size=10
+ )
+ dataset = ComplexArithmeticDataset(config)
+
+ print(dataset)
+
+ # Test dataset size
+ assert len(dataset) == 10
+
+ # Test a specific item
+ item = dataset[0]
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Add more detailed assertions
+ assert isinstance(item["question"], str)
+ assert isinstance(item["answer"], str)
+ assert isinstance(item["metadata"], dict)
+
+ # Check metadata structure
+ assert "num1" in item["metadata"]
+ assert "num2" in item["metadata"]
+ assert "operation" in item["metadata"]
+ assert "result" in item["metadata"]
+
+ # Check data types in metadata
+ assert isinstance(item["metadata"]["num1"], tuple)
+ assert isinstance(item["metadata"]["num2"], tuple)
+ assert len(item["metadata"]["num1"]) == 2 # Real and imaginary parts
+ assert len(item["metadata"]["num2"]) == 2
+ assert isinstance(item["metadata"]["operation"], str)
+ assert isinstance(item["metadata"]["result"], tuple)
+
+ # Make sure answer matches the result in metadata
+ # results is a tuple of two floats (real, imag) and answer is a string
+ # answer is formatted as "real + imagi"
+ assert ComplexArithmeticDataset.parse_string_to_complex(item["answer"]) == complex(*item["metadata"]["result"])
+
+
+def test_complex_arithmetic_scoring():
+ """Test scoring function with various answer formats and accuracies."""
+ config = ComplexArithmeticConfig(seed=42)
+ dataset = ComplexArithmeticDataset(config)
+
+ # Test case with answer 3 + 2i
+ metadata = {"result": (3.0, 2.0)}
+
+ # Test exact matches (should get score of 1.0)
+ assert dataset.score_answer("3 + 2i", metadata) == 1.0
+ assert dataset.score_answer("3+2i", metadata) == 1.0
+ assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0
+
+ # Test answers with small errors (should get high but < 1.0 scores)
+ print(dataset.score_answer("3.1 + 2i", metadata))
+ assert 0.9 < dataset.score_answer("3.1 + 2i", metadata) < 1.0
+ assert 0.9 < dataset.score_answer("3 + 2.1i", metadata) < 1.0
+ assert 0.7 < dataset.score_answer("3.1 + 2.1i", metadata) < 0.95
+
+ # Test answers with moderate errors (should get medium scores)
+ assert 0.3 < dataset.score_answer("4 + 2i", metadata) < 0.4
+ assert 0.3 < dataset.score_answer("3 + 3i", metadata) < 0.4
+
+ # Test answers with large errors (should get very low scores)
+ assert dataset.score_answer("10 + 10i", metadata) < 0.01
+
+ # Test invalid answers (should get 0.0)
+ assert dataset.score_answer("invalid", metadata) == 0.0
+ assert dataset.score_answer(None, metadata) == 0.0
+ assert dataset.score_answer("inf + 2i", metadata) == 0.0
+
+
+def test_complex_arithmetic_division_by_zero():
+ """Test that division by zero is handled properly."""
+ config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division
+ dataset = ComplexArithmeticDataset(config)
+
+ # Check multiple items to ensure no division by zero
+ for i in range(10):
+ item = dataset[i]
+ num2 = complex(*item["metadata"]["num2"])
+ assert num2 != 0
diff --git a/tests/test_isomorphic_strings.py b/tests/test_isomorphic_strings.py
new file mode 100644
index 00000000..6e515cf7
--- /dev/null
+++ b/tests/test_isomorphic_strings.py
@@ -0,0 +1,108 @@
+"""Tests for Isomorphic Strings questions generation"""
+
+import json
+
+import pytest
+
+from reasoning_gym.algorithmic.isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
+
+
+def test_isomorphic_strings_config_validation():
+ """Test that invalid configs raise appropriate errors"""
+ with pytest.raises(AssertionError):
+ config = IsomorphicStringsConfig(max_string_length=-1) # Negative not allowed
+ config.validate()
+
+ with pytest.raises(AssertionError):
+ config = IsomorphicStringsConfig(max_string_length=0) # Zero not allowed
+ config.validate()
+
+ with pytest.raises(AssertionError):
+ config = IsomorphicStringsConfig(max_string_length=1) # One not allowed
+ config.validate()
+
+ with pytest.raises(AssertionError):
+ config = IsomorphicStringsConfig(p_solvable=-0.01) # < 0 not allowed
+ config.validate()
+
+ with pytest.raises(AssertionError):
+ config = IsomorphicStringsConfig(p_solvable=1.01) # > 1 not allowed
+ config.validate()
+
+
+def test_isomorphic_strings_dataset_deterministic():
+ """Test that dataset generates same items with same seed"""
+ config = IsomorphicStringsConfig(seed=42, size=10)
+ dataset1 = IsomorphicStringsDataset(config)
+ dataset2 = IsomorphicStringsDataset(config)
+
+ for i in range(len(dataset1)):
+ assert dataset1[i] == dataset2[i]
+
+
+def test_isomorphic_strings_dataset_items():
+ """Test basic properties of generated items"""
+ config = IsomorphicStringsConfig(max_string_length=10, size=10, seed=42)
+ dataset = IsomorphicStringsDataset(config)
+
+ for i in range(len(dataset)):
+ item = dataset[i]
+ # Check item structure
+ assert isinstance(item, dict)
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Check metadata
+ assert "words" in item["metadata"]
+ assert "solution" in item["metadata"]
+ assert "solvable" in item["metadata"]
+
+ words = item["metadata"]["words"]
+ solution = item["metadata"]["solution"]
+ solvable = item["metadata"]["solvable"]
+
+ # Verify list dimensions
+ assert len(words) == 2
+ assert solution in {True, False}
+ assert solvable in {True, False}
+ assert solution == solvable
+
+
+def test_isomorphic_strings_dataset_iteration():
+ """Test that iteration respects dataset size"""
+ config = IsomorphicStringsConfig(size=5, seed=42)
+ dataset = IsomorphicStringsDataset(config)
+
+ items = list(dataset)
+ assert len(items) == config.size
+
+ # Test multiple iterations yield same items
+ assert items == list(dataset)
+
+
+def test_isomorphic_strings_answer():
+ """Test the _check_isomorphic method"""
+ config = IsomorphicStringsConfig(seed=42)
+ dataset = IsomorphicStringsDataset(config)
+
+ # General use case
+ s, t = "foo", "bar"
+ assert dataset._check_isomorphic(s, t) == False
+
+ s, t = "foo", "baa"
+ assert dataset._check_isomorphic(s, t) == True
+
+ # Unequal lengths
+ s, t = "foo", "bo"
+ assert dataset._check_isomorphic(s, t) == False
+
+ # Empty strings
+ (
+ s,
+ t,
+ ) = (
+ "",
+ "",
+ )
+ assert dataset._check_isomorphic(s, t) == True
diff --git a/tests/test_polynomial_multiplication.py b/tests/test_polynomial_multiplication.py
new file mode 100644
index 00000000..a27bd6bf
--- /dev/null
+++ b/tests/test_polynomial_multiplication.py
@@ -0,0 +1,166 @@
+import pytest
+import sympy as sp
+
+from reasoning_gym import create_dataset
+from reasoning_gym.algebra.polynomial_multiplication import (
+ PolynomialMultiplicationConfig,
+ PolynomialMultiplicationDataset,
+)
+
+
+def test_polynomial_config_validation():
+ """Test that invalid configs raise appropriate errors"""
+ with pytest.raises(AssertionError):
+ PolynomialMultiplicationConfig(min_terms=0).validate()
+
+ with pytest.raises(AssertionError):
+ PolynomialMultiplicationConfig(min_value=0).validate()
+
+ with pytest.raises(AssertionError):
+ PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
+
+ with pytest.raises(AssertionError):
+ PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
+
+ with pytest.raises(AssertionError):
+ PolynomialMultiplicationConfig(operators=("^",)).validate()
+
+ with pytest.raises(AssertionError):
+ PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
+
+
+def test_polynomial_multiplication_dataset_basic():
+ """Test dataset creation and length"""
+ dataset_size = 50
+ config = PolynomialMultiplicationConfig(
+ min_terms=2,
+ max_terms=3,
+ min_value=1,
+ max_value=5,
+ min_degree=1,
+ max_degree=2,
+ min_polynomials=2,
+ max_polynomials=3,
+ single_variable=True,
+ seed=42,
+ size=dataset_size,
+ )
+
+ dataset = PolynomialMultiplicationDataset(config)
+
+ assert len(dataset) == dataset_size
+
+
+def test_polynomial_equations_dataset_items():
+ """Test that generated items have correct structure"""
+ ds = create_dataset(
+ "polynomial_multiplication",
+ min_terms=2,
+ max_terms=3,
+ min_value=1,
+ max_value=5,
+ min_degree=1,
+ max_degree=2,
+ min_polynomials=2,
+ max_polynomials=5,
+ single_variable=False,
+ size=3,
+ seed=100,
+ )
+
+ for item in ds:
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Check metadata
+ assert isinstance(item["metadata"]["polynomial_expr"], str)
+ assert isinstance(item["metadata"]["single_variable"], bool)
+
+ # Check polynomial_expr existence
+ poly_str = item["metadata"]["polynomial_expr"]
+ # Ensure it can parse with sympy
+ sp.sympify(poly_str)
+
+
+def test_polynomial_equations_dataset_deterministic():
+ """Test dataset reproducibility with fixed seed."""
+ cfg = PolynomialMultiplicationConfig(seed=999, size=3)
+ ds1 = PolynomialMultiplicationDataset(cfg)
+ ds2 = PolynomialMultiplicationDataset(cfg)
+
+ for i in range(len(ds1)):
+ assert ds1[i] == ds2[i], "Polynomial datasets with same seed should match exactly."
+
+
+def test_polynomial_solutions_evaluation():
+ """Test that solution satisfy the polynomial multiplication."""
+ ds = create_dataset(
+ "polynomial_multiplication",
+ min_terms=2,
+ max_terms=4,
+ min_value=1,
+ max_value=10,
+ min_degree=1,
+ max_degree=3,
+ min_polynomials=2,
+ max_polynomials=5,
+ single_variable=False,
+ size=5,
+ seed=42,
+ )
+
+ for item in ds:
+ # Extract the polynomial expression
+ poly_str = item["metadata"]["polynomial_expr"]
+ # Get the polynomial product
+ poly_expr = sp.expand(poly_str)
+
+ # Verify that each solution satisfies the polynomial
+ assert poly_expr == item["answer"]
+
+
+def test_score_function():
+ """Test that solution satisfy the polynomial multiplication."""
+ ds = create_dataset(
+ "polynomial_multiplication",
+ min_terms=2,
+ max_terms=4,
+ min_value=1,
+ max_value=10,
+ min_degree=1,
+ max_degree=3,
+ min_polynomials=2,
+ max_polynomials=5,
+ single_variable=True,
+ size=1,
+ seed=42,
+ )
+
+ assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
+ assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1
+ assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
+ assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
+
+
+def test_multivariate_score_function():
+ """Test that solution satisfy the polynomial multiplication."""
+ ds = create_dataset(
+ "polynomial_multiplication",
+ min_terms=2,
+ max_terms=4,
+ min_value=1,
+ max_value=10,
+ min_degree=1,
+ max_degree=3,
+ min_polynomials=2,
+ max_polynomials=5,
+ single_variable=False,
+ size=1,
+ seed=42,
+ )
+
+ assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
+ assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1
+ assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
+ assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
diff --git a/tests/test_self_reference.py b/tests/test_self_reference.py
new file mode 100644
index 00000000..66f15081
--- /dev/null
+++ b/tests/test_self_reference.py
@@ -0,0 +1,55 @@
+import pytest
+
+from reasoning_gym.logic.self_reference import SelfReferenceConfig, SelfReferenceDataset
+
+
+def test_self_reference():
+ """Test basic properties and solution of generated items"""
+
+ # Easy
+ config = SelfReferenceConfig(seed=42, size=20, difficulty=1)
+ dataset = SelfReferenceDataset(config)
+
+ for item in dataset:
+ assert isinstance(item, dict)
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Test the scoring
+ assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
+ assert dataset.score_answer(answer=99, entry=item) == 0.1
+ assert dataset.score_answer(answer="99", entry=item) == 0.1
+ assert dataset.score_answer(answer=None, entry=item) == 0.0
+
+ # # Medium
+ config = SelfReferenceConfig(seed=42, size=1, difficulty=5)
+ dataset = SelfReferenceDataset(config)
+
+ for item in dataset:
+ assert isinstance(item, dict)
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Test the scoring
+ assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
+ assert dataset.score_answer(answer=99, entry=item) == 0.1
+ assert dataset.score_answer(answer="99", entry=item) == 0.1
+ assert dataset.score_answer(answer=None, entry=item) == 0.0
+
+ # # Hard
+ config = SelfReferenceConfig(seed=42, size=1, difficulty=10)
+ dataset = SelfReferenceDataset(config)
+
+ for item in dataset:
+ assert isinstance(item, dict)
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Test the scoring
+ assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
+ assert dataset.score_answer(answer=99, entry=item) == 0.1
+ assert dataset.score_answer(answer="99", entry=item) == 0.1
+ assert dataset.score_answer(answer=None, entry=item) == 0.0
diff --git a/tests/test_syllogisms.py b/tests/test_syllogisms.py
index 498be586..9f2c5607 100644
--- a/tests/test_syllogisms.py
+++ b/tests/test_syllogisms.py
@@ -64,6 +64,204 @@ def test_syllogism_dataset_items():
assert "Does it logically follow that:" in item["question"]
+def test_valid_syllogism_forms():
+ """Test specific valid syllogistic forms"""
+ config = SyllogismConfig(size=1, seed=42)
+ dataset = SyllogismDataset(config)
+
+ # Create some test terms
+ A = Term("mortal", "mortals")
+ B = Term("human", "humans")
+ C = Term("animal", "animals")
+
+ # Test Barbara (AAA-1)
+ # Major premise: All M are P
+ # Minor premise: All S are M
+ # Conclusion: All S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.ALL, B, C), # All B (M) are C (P)
+ (Quantifier.ALL, A, B), # All A (S) are B (M)
+ (Quantifier.ALL, A, C), # All A (S) are C (P)
+ )
+
+ # Test Celarent (EAE-1)
+ # Major premise: No M are P
+ # Minor premise: All S are M
+ # Conclusion: No S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.NO, B, C), # No B (M) are C (P)
+ (Quantifier.ALL, A, B), # All A (S) are B (M)
+ (Quantifier.NO, A, C), # No A (S) are C (P)
+ )
+
+ # Test Cesare (EAE-2) — corrected order
+ # Major premise: No P are M
+ # Minor premise: All S are M
+ # Conclusion: No S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.NO, C, B), # No C (P) are B (M) [Major premise]
+ (Quantifier.ALL, A, B), # All A (S) are B (M) [Minor premise]
+ (Quantifier.NO, A, C), # No A (S) are C (P)
+ )
+
+ # Test Darii (AII-1)
+ # Major premise: All M are P
+ # Minor premise: Some S are M
+ # Conclusion: Some S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.ALL, B, C), # All B (M) are C (P)
+ (Quantifier.SOME, A, B), # Some A (S) are B (M)
+ (Quantifier.SOME, A, C), # Some A (S) are C (P)
+ )
+
+ # Test Disamis (IAI-3)
+ # Major premise: Some M are P
+ # Minor premise: All M are S
+ # Conclusion: Some S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.SOME, B, C), # Some B (M) are C (P)
+ (Quantifier.ALL, B, A), # All B (M) are A (S)
+ (Quantifier.SOME, A, C), # Some A (S) are C (P)
+ )
+
+ # Test Ferio (EIO-1)
+ # Major premise: No M are P
+ # Minor premise: Some S are M
+ # Conclusion: Some S are not P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.NO, B, C), # No B (M) are C (P)
+ (Quantifier.SOME, A, B), # Some A (S) are B (M)
+ (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P)
+ )
+
+ # Test Festino (EIO-2)
+ # Major premise: No P are M
+ # Minor premise: Some S are M
+ # Conclusion: Some S are not P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.NO, C, B), # No C (P) are B (M)
+ (Quantifier.SOME, A, B), # Some A (S) are B (M)
+ (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P)
+ )
+
+ # Test Datisi (AII-3)
+ # Major premise: All M are P
+ # Minor premise: Some M are S
+ # Conclusion: Some S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.ALL, B, C), # All B (M) are C (P)
+ (Quantifier.SOME, B, A), # Some B (M) are A (S)
+ (Quantifier.SOME, A, C), # Some A (S) are C (P)
+ )
+
+ # Test Bocardo (OAO-3)
+ # Major premise: Some M are not P
+ # Minor premise: All M are S
+ # Conclusion: Some S are not P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.SOME_NOT, B, C), # Some B (M) are not C (P)
+ (Quantifier.ALL, B, A), # All B (M) are A (S)
+ (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P)
+ )
+
+ # Test Baroco (AOO-2)
+ # Major premise: All P are M
+ # Minor premise: Some S are not M
+ # Conclusion: Some S are not P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.ALL, C, B), # All C (P) are B (M)
+ (Quantifier.SOME_NOT, A, B), # Some A (S) are not B (M)
+ (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P)
+ )
+
+ # Test Camestres (AEE-2)
+ # Major premise: All P are M
+ # Minor premise: No S are M
+ # Conclusion: No S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.ALL, C, B), # All C (P) are B (M)
+ (Quantifier.NO, A, B), # No A (S) are B (M)
+ (Quantifier.NO, A, C), # No A (S) are C (P)
+ )
+
+ # Test Dimaris (IAI-4)
+ # Major premise: Some P are M
+ # Minor premise: All M are S
+ # Conclusion: Some S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.SOME, C, B), # Some C (P) are B (M)
+ (Quantifier.ALL, B, A), # All B (M) are A (S)
+ (Quantifier.SOME, A, C), # Some A (S) are C (P)
+ )
+
+ # Test Ferison (EIO-3)
+ # Major premise: No M are P
+ # Minor premise: Some M are S
+ # Conclusion: Some S are not P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.NO, B, C), # No B (M) are C (P)
+ (Quantifier.SOME, B, A), # Some B (M) are A (S)
+ (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P)
+ )
+
+ # Test Fresison (EIO-4)
+ # Major premise: No P are M
+ # Minor premise: Some M are S
+ # Conclusion: Some S are not P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.NO, C, B), # No C (P) are B (M)
+ (Quantifier.SOME, B, A), # Some B (M) are A (S)
+ (Quantifier.SOME_NOT, A, C), # Some A (S) are not C (P)
+ )
+
+ # Test Camenes (AEE-4)
+ # Major premise: All P are M
+ # Minor premise: No M are S
+ # Conclusion: No S are P
+ assert dataset._is_valid_syllogism(
+ (Quantifier.ALL, C, B), # All C (P) are B (M)
+ (Quantifier.NO, B, A), # No B (M) are A (S)
+ (Quantifier.NO, A, C), # No A (S) are C (P)
+ )
+
+ # Test invalid forms
+ assert not dataset._is_valid_syllogism(
+ (Quantifier.SOME, B, C), # Some B are C
+ (Quantifier.SOME, A, B), # Some A are B
+ (Quantifier.SOME, A, C), # Some A are C (invalid: two particular premises)
+ )
+
+ assert not dataset._is_valid_syllogism(
+ (Quantifier.NO, B, C), # No B are C
+ (Quantifier.NO, A, B), # No A are B
+ (Quantifier.NO, A, C), # No A are C (invalid: two negative premises)
+ )
+
+ # Test specific invalid case with two negative premises
+ S = Term("student", "students")
+ M = Term("human", "humans")
+ P = Term("chef", "chefs")
+ assert not dataset._is_valid_syllogism(
+ (Quantifier.NO, S, M), # No students are humans
+ (Quantifier.NO, M, P), # No humans are chefs
+ (Quantifier.NO, S, P), # No students are chefs (invalid!)
+ )
+
+ child = Term("child", "children")
+ animal = Term("animal", "animals")
+ doctor = Term("doctor", "doctors")
+
+ # Premise 1: Some children are not animals
+ # Premise 2: All animals are doctors
+ # Conclusion: Some children are not doctors
+ # We expect this NOT to be a valid syllogism
+ assert not dataset._is_valid_syllogism(
+ (Quantifier.SOME_NOT, child, animal), # Some children are not animals
+ (Quantifier.ALL, animal, doctor), # All animals are doctors
+ (Quantifier.SOME_NOT, child, doctor), # Some children are not doctors
+ )
+
+
def test_syllogism_dataset_iteration():
"""Test that iteration respects dataset size"""
config = SyllogismConfig(size=5, seed=42)
@@ -74,41 +272,3 @@ def test_syllogism_dataset_iteration():
# Test multiple iterations yield same items
assert items == list(dataset)
-
-
-def test_syllogism_custom_terms():
- """Test syllogism generation with custom terms"""
- custom_terms = [
- Term("programmer", "programmers"),
- Term("coder", "coders"),
- Term("developer", "developers"),
- ]
- config = SyllogismConfig(terms=custom_terms, size=10, seed=42)
- dataset = SyllogismDataset(config)
-
- for item in dataset:
- # Verify only custom terms are used
- text = item["question"] + str(item["metadata"])
- assert any(term.name in text or term.plural in text for term in custom_terms)
- # Verify default terms are not used
- assert "mortal" not in text
- assert "human" not in text
-
-
-def test_syllogism_validity():
- """Test logical validity rules"""
- config = SyllogismConfig(
- allow_all=True,
- allow_no=False,
- allow_some=False,
- allow_some_not=False,
- include_invalid=False, # Only generate valid syllogisms
- size=10,
- seed=42,
- )
- dataset = SyllogismDataset(config)
-
- for item in dataset:
- # All valid ALL syllogisms should have "Yes" as answer
- assert item["answer"] == "Yes"
- assert item["metadata"]["is_valid"] is True
diff --git a/tests/test_tsumego.py b/tests/test_tsumego.py
new file mode 100644
index 00000000..e979bcac
--- /dev/null
+++ b/tests/test_tsumego.py
@@ -0,0 +1,281 @@
+"""Tests for Ttsumego problem generation"""
+
+import re
+
+import pytest
+
+from reasoning_gym.games.tsumego import TsumegoConfig, TsumegoDataset
+
+
+def test_config_validation():
+ # Valid configuration
+ TsumegoConfig(min_board_size=9, max_board_size=13, max_stones=10, size=100, seed=42)
+
+ # Invalid configurations
+ with pytest.raises(ValueError):
+ TsumegoConfig(min_board_size=4, max_board_size=13, max_stones=10) # min_board_size too low
+ with pytest.raises(ValueError):
+ TsumegoConfig(min_board_size=9, max_board_size=20, max_stones=10) # max_board_size too high
+ with pytest.raises(ValueError):
+ TsumegoConfig(min_board_size=13, max_board_size=9, max_stones=10) # min_board_size > max_board_size
+ with pytest.raises(ValueError):
+ TsumegoConfig(min_board_size=9, max_board_size=13, max_stones=2) # max_stones too low
+
+
+def test_dataset_item_properties():
+ config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=15, size=100, seed=42)
+ dataset = TsumegoDataset(config)
+ item = dataset[0]
+ # Check that item has the required keys
+ for key in ["question", "answer", "metadata"]:
+ assert key in item
+
+ metadata = item["metadata"]
+ for key in ["difficulty", "board", "solution"]:
+ assert key in metadata
+
+ board = metadata["board"]
+ # Board size should be equal to the fixed min_board_size for this test
+ assert len(board) == config.min_board_size
+ assert all(len(row) == config.min_board_size for row in board)
+ # Check stone count does not exceed max_stones + 7 (to account for extra fill in capture formation)
+ stone_count = sum(cell in "XO" for row in board for cell in row)
+ assert stone_count <= config.max_stones + 7
+
+
+def test_deterministic_generation():
+ config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10, seed=42)
+ dataset1 = TsumegoDataset(config)
+ dataset2 = TsumegoDataset(config)
+ for i in range(3):
+ item1 = dataset1[i]
+ item2 = dataset2[i]
+ assert item1["metadata"]["board"] == item2["metadata"]["board"]
+ assert item1["answer"] == item2["answer"]
+
+
+def test_liberties_and_move():
+ # Use a small board for simplicity
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=10)
+ dataset = TsumegoDataset(config)
+
+ # Part 1: Liberty counting test
+ board_liberties = [
+ [".", "O", ".", ".", "."],
+ ["O", "X", "O", ".", "."],
+ [".", "O", ".", ".", "."],
+ [".", ".", ".", ".", "."],
+ [".", ".", ".", ".", "."],
+ ]
+ liberties = dataset._get_liberties(board_liberties, 1, 1)
+ assert len(liberties) == 0
+ liberties_edge = dataset._get_liberties(board_liberties, 0, 1)
+ assert len(liberties_edge) == 2
+
+ # Part 2: Test capturing move
+ # Construct a board where an enemy stone at (2,2) is surrounded on three sides,
+ # so that placing an "X" at (2,3) will remove its last liberty and capture it.
+ board_capture = [["." for _ in range(5)] for _ in range(5)]
+ board_capture[1][2] = "X"
+ board_capture[2][1] = "X"
+ board_capture[3][2] = "X"
+ board_capture[2][2] = "O"
+ # Now, (2,2) (enemy) has only one liberty at (2,3).
+ # Placing "X" at (2,3) should capture the enemy stone.
+ assert dataset._is_valid_move(board_capture, 2, 3, "X")
+ dataset._make_move(board_capture, 2, 3, "X")
+ # After move, captured_stones should be [(2,2)] and ko point set to (2,2).
+ assert not dataset._is_valid_move(board_capture, 2, 2, "O"), "Ko move should be invalid"
+
+ # Part 3: Test suicide move (without capture)
+ board_move = [
+ [".", "O", ".", ".", "."],
+ ["O", ".", "O", ".", "."],
+ [".", "O", ".", ".", "."],
+ [".", ".", ".", ".", "."],
+ [".", ".", ".", ".", "."],
+ ]
+ # Placing "X" at (1,1) would be suicide as all adjacent positions are occupied by "O".
+ assert not dataset._is_valid_move(board_move, 1, 1, "X")
+
+
+def convert_solution(sol, board_size):
+ # sol is expected to be a string like 'E5'
+ letter = sol[0].upper()
+ number = int(sol[1:])
+ return (board_size - number, ord(letter) - ord("A"))
+
+
+def test_score_answer():
+ config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=10, size=5)
+ dataset = TsumegoDataset(config)
+
+ # prepare dummy with letter+number format solution
+ entry = dataset[0].copy()
+ entry["metadata"]["solution"] = "E5"
+
+ # Patch score_answer to convert metadata solution if needed
+ original_score_answer = dataset.score_answer
+
+ def patched_score_answer(answer, entry):
+ board_size = len(entry["metadata"]["board"])
+ sol = entry["metadata"]["solution"]
+ if isinstance(sol, str):
+ entry["metadata"]["solution"] = convert_solution(sol, board_size)
+ return original_score_answer(answer, entry)
+
+ dataset.score_answer = patched_score_answer
+
+ # Correct letter-number answer (E corresponds to board coordinate (4,4) for a 9x9 board)
+ assert dataset.score_answer("E5", entry) == 1.0
+
+ # Valid but incorrect letter-number move (D corresponds to (4,3) for a 9x9 board)
+ assert dataset.score_answer("D4", entry) == 0.05
+
+ # Invalid format
+ assert dataset.score_answer("invalid", entry) == 0.01
+
+ # Empty answer
+ assert dataset.score_answer("", entry) == 0.01
+
+ # None answer
+ assert dataset.score_answer(None, entry) == 0.0
+
+ # Out-of-bound letter-number move: 'J' corresponds to 10 which is greater than board size = 9
+ assert dataset.score_answer("J9", entry) == 0.01
+
+ # test optimal score for answers, patching each entry
+ for x in dataset:
+ board_size = len(x["metadata"]["board"])
+ sol = x["metadata"]["solution"]
+ if isinstance(sol, str):
+ x["metadata"]["solution"] = convert_solution(sol, board_size)
+ assert len(x["metadata"]["board"]) == x["metadata"]["difficulty"]["board_size"]
+ assert dataset.score_answer(x["answer"], entry=x) == 1.0
+
+
+# Additional tests for game logic edge cases
+
+
+def test_get_group():
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ board = [
+ ["X", "X", "."],
+ [".", "X", "O"],
+ [".", ".", "O"],
+ ]
+ group_X = dataset._get_group(board, 0, 0)
+ expected_group_X = {(0, 0), (0, 1), (1, 1)}
+ assert group_X == expected_group_X
+
+ group_O = dataset._get_group(board, 1, 2)
+ expected_group_O = {(1, 2), (2, 2)}
+ assert group_O == expected_group_O
+
+
+def test_count_liberties():
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ board = [
+ ["X", "X", "."],
+ [".", "X", "O"],
+ [".", ".", "O"],
+ ]
+ group_X = {(0, 0), (0, 1), (1, 1)}
+ liberties_X = dataset._count_liberties(board, group_X)
+ # For (0,0): neighbor (1,0); (0,1): neighbor (0,2); (1,1): neighbors (1,0) and (2,1)
+ # Combined unique liberties: {(1,0), (0,2), (2,1)} so count should be 3
+ assert liberties_X == 3
+
+
+def test_out_of_bounds_move():
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ board = [["." for _ in range(5)] for _ in range(5)]
+ # Test moves that are out of bounds
+ assert not dataset._is_valid_move(board, -1, 0, "X")
+ assert not dataset._is_valid_move(board, 0, -1, "X")
+ assert not dataset._is_valid_move(board, 5, 0, "X")
+ assert not dataset._is_valid_move(board, 0, 5, "X")
+
+
+def test_move_on_occupied_intersection():
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ board = [["." for _ in range(5)] for _ in range(5)]
+ board[1][1] = "X"
+ # Attempting to play on an occupied spot should be invalid
+ assert not dataset._is_valid_move(board, 1, 1, "O")
+ assert not dataset._is_valid_move(board, 1, 1, "X")
+
+
+def test_valid_non_capturing_move():
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ board = [["." for _ in range(5)] for _ in range(5)]
+ # A move on an empty board that doesn't result in capture or suicide should be valid
+ assert dataset._is_valid_move(board, 0, 0, "X")
+ move_result = dataset._make_move(board, 0, 0, "X")
+ assert move_result
+ assert board[0][0] == "X"
+
+
+def test_multiple_capture():
+ # Set up a board where a move will capture multiple opponent stones,
+ # which should not trigger the ko rule (ko point remains None)
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ board = [
+ [".", ".", ".", ".", "."],
+ [".", "X", "X", "X", "."],
+ ["X", "O", "O", ".", "."],
+ [".", "X", "X", "X", "."],
+ [".", ".", ".", ".", "."],
+ ]
+ # Move at (2,3) with 'X' should capture the opponent stones at (2,1) and (2,2)
+ assert dataset._is_valid_move(board, 2, 3, "X")
+ move_result = dataset._make_move(board, 2, 3, "X")
+ assert move_result, "Move should be successfully made"
+ assert board[2][1] == ".", "Stone at (2,1) should be captured"
+ assert board[2][2] == ".", "Stone at (2,2) should be captured"
+ assert dataset._ko_point is None, "Ko point should not be set for multiple captures"
+
+
+def test_would_capture():
+ config = TsumegoConfig(min_board_size=5, max_board_size=5, max_stones=10, size=1, seed=42)
+ dataset = TsumegoDataset(config)
+ # Create a scenario similar to the one in test_liberties_and_move for capturing
+ board_capture = [["." for _ in range(5)] for _ in range(5)]
+ board_capture[1][2] = "X"
+ board_capture[2][1] = "X"
+ board_capture[3][2] = "X"
+ board_capture[2][2] = "O"
+ # Placing 'X' at (2,3) should capture the stone at (2,2)
+ assert dataset._would_capture(board_capture, 2, 3, "X")
+ # In a scenario with no capture, the move should not be considered capturing
+ board_no_capture = [["." for _ in range(5)] for _ in range(5)]
+ board_no_capture[2][2] = "O"
+ assert not dataset._would_capture(board_no_capture, 0, 0, "X")
+
+
+def test_capture_verification():
+ """Verifies that the solution move in a generated puzzle captures at least one opponent stone."""
+ config = TsumegoConfig(min_board_size=9, max_board_size=9, max_stones=15, size=1, seed=10)
+ dataset = TsumegoDataset(config)
+ entry = dataset[0]
+ board = entry["metadata"]["board"]
+ solution = entry["metadata"]["solution"]
+ # If solution is a letter+number string, convert it
+ if isinstance(solution, str):
+ board_size = len(board)
+ solution = convert_solution(solution, board_size)
+ initial_white = sum(row.count("O") for row in board)
+
+ # Make a deep copy of the board to simulate the move
+ board_after = [row[:] for row in board]
+ move_success = dataset._make_move(board_after, solution[0], solution[1], "X")
+ assert move_success, "The solution move should be legal."
+
+ final_white = sum(row.count("O") for row in board_after)
+ assert final_white < initial_white, "The solution move should capture at least one opponent stone."