mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Merge branch 'main' into env/rotate-matrix
This commit is contained in:
commit
e9d87a6933
11 changed files with 516 additions and 124 deletions
77
CONTRIBUTING.md
Normal file
77
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
# Contributing to Reasoning Gym
|
||||
|
||||
Thank you for your interest in contributing to Reasoning Gym! This document provides guidelines and instructions for contributing to the project.
|
||||
|
||||
## Development Setup
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/open-thought/reasoning-gym.git
|
||||
```
|
||||
|
||||
2. Create a virtual environment (using conda):
|
||||
```bash
|
||||
conda create --name reasoning_gym python=3.11 -y
|
||||
conda activate reasoning_gym
|
||||
```
|
||||
|
||||
3. Install the package in editable mode:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
4. Install development dependencies:
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
```
|
||||
|
||||
## Creating Procedural Datasets
|
||||
|
||||
When creating new datasets, please follow these guidelines:
|
||||
|
||||
1. **Focus on Complex Problems**:
|
||||
- Prioritize problems where guessing has a low probability of success (e.g., number multiplication)
|
||||
- Avoid tasks with small answer sets (true/false, multiple-choice) as they create noisy rewards for RL
|
||||
|
||||
2. **Implementation Requirements**:
|
||||
- Create a configuration class
|
||||
- Derive your dataset class from `ProceduralDataset` (see [dataset.py](https://github.com/open-thought/reasoning-gym/blob/main/reasoning_gym/dataset.py))
|
||||
- Include comprehensive unit tests
|
||||
- Return dictionary items with keys: `"question"`, `"answer"`, and `"metadata"`
|
||||
- For datasets with multiple correct answers, override the `score_answer()` method (return value range: [0, 1])
|
||||
|
||||
3. **Getting Started**:
|
||||
- Review an example implementation:
|
||||
- Configuration & dataset class: [chain_sum.py](reasoning_gym/arithmetic/chain_sum.py)
|
||||
- Unit tests: [test_chain_sum.py](https://github.com/open-thought/reasoning-gym/blob/main/tests/test_chain_sum.py)
|
||||
- Write clear question prompts that an average human can understand and answer correctly
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
1. **Fork and Clone**:
|
||||
- [Fork the repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
|
||||
- Clone your fork locally
|
||||
- Read more about [forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks)
|
||||
|
||||
2. **Create a Feature Branch**:
|
||||
- Work on a [new branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-and-deleting-branches-within-your-repository)
|
||||
- Keep changes focused and minimal
|
||||
|
||||
3. **Code Quality**:
|
||||
- Install pre-commit hooks: `pre-commit install`
|
||||
- Run `pre-commit run -a` before committing
|
||||
- When using AI coding assistants (cursor, aider, etc.), ensure proper formatting
|
||||
|
||||
4. **Submit Your PR**:
|
||||
- [Create a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
|
||||
- [Request review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review)
|
||||
- Do not include changes to `GALLERY.md` (it's updated automatically)
|
||||
|
||||
5. **Review Process**:
|
||||
- Address reviewer feedback promptly
|
||||
- Keep discussions constructive
|
||||
- Once approved, your changes will be merged into `main`
|
||||
|
||||
## Need Help?
|
||||
|
||||
Join our community discussion in the `#reasoning-gym` channel on the [GPU-Mode Discord server](https://discord.gg/gpumode).
|
||||
14
GALLERY.md
14
GALLERY.md
|
|
@ -1005,6 +1005,7 @@ Metadata: {'words': ['eagerest', 'granitite', 'helium', 'nizam', 'nazim', 'strip
|
|||
|
||||
````
|
||||
|
||||
|
||||
### gsm_symbolic
|
||||
Default configuration:
|
||||
```python
|
||||
|
|
@ -2502,7 +2503,7 @@ Metadata: {'task_type': 'datetime_tz', 'start_time': datetime.datetime(2964, 6,
|
|||
Example 2:
|
||||
Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM.
|
||||
Answer: 02:38
|
||||
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 7, 9, 44), 'end_time': datetime.datetime(2025, 2, 7, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
|
||||
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 8, 9, 44), 'end_time': datetime.datetime(2025, 2, 8, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
|
||||
|
||||
Example 3:
|
||||
Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days.
|
||||
|
|
@ -2576,14 +2577,14 @@ Metadata: {'num_disks': 6, 'num_pegs': 3, 'start_peg': 1, 'target_peg': 2, 'auxi
|
|||
````
|
||||
|
||||
### tsumego
|
||||
Generates (one-move) Tsumego problems with configurable parameters
|
||||
Generates Tsumego problems with configurable parameters
|
||||
|
||||
Default configuration:
|
||||
```python
|
||||
min_board_size = 9
|
||||
max_board_size = 13
|
||||
max_stones = 15
|
||||
size = 10
|
||||
size = 100
|
||||
seed = 42
|
||||
```
|
||||
|
||||
|
|
@ -2608,11 +2609,8 @@ O - White
|
|||
|
||||
Specify your move in coordinates (e.g. 'C4' for column C, row 4)
|
||||
Answer: E4
|
||||
|
||||
Metadata: {'difficulty': {'board_size': 9}, 'board': [['X', '.', '.', '.', 'X', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', '.', 'O', '.', '.', 'X', '.', '.'], ['.', '.', '.', 'X', '.', '.', '.', '.', 'O'], ['O', '.', 'X', 'O', 'X', '.', '.', '.', '.'], ['.', 'X', 'O', 'O', '.', 'O', '.', '.', '.'], ['.', '.', 'X', 'O', 'X', '.', '.', '.', '.'], ['.', '.', '.', 'X', '.', '.', '.', '.', '.'], ['.', 'O', '.', 'O', '.', '.', 'X', '.', '.']], 'solution': 'E4'}
|
||||
|
||||
--------------------------------------------------
|
||||
|
||||
Example 2:
|
||||
Question: Here's a Go challenge. Playing as Black, how can you capture as many white stones as possible?
|
||||
|
||||
|
|
@ -2632,11 +2630,8 @@ O - White
|
|||
|
||||
Specify your move in coordinates (e.g. 'C4' for column C, row 4)
|
||||
Answer: B7
|
||||
|
||||
Metadata: {'difficulty': {'board_size': 9}, 'board': [['.', '.', 'O', '.', '.', '.', '.', '.', '.'], ['.', 'X', 'O', '.', '.', '.', '.', '.', '.'], ['X', '.', 'X', '.', '.', '.', '.', '.', '.'], ['O', 'O', 'O', 'X', '.', '.', '.', '.', '.'], ['X', 'O', 'O', '.', '.', '.', '.', '.', '.'], ['.', 'X', '.', '.', '.', '.', '.', '.', 'O'], ['.', 'X', '.', '.', '.', '.', 'X', '.', '.'], ['O', '.', 'O', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', 'O', '.', '.', '.', '.']], 'solution': 'B7'}
|
||||
|
||||
--------------------------------------------------
|
||||
|
||||
Example 3:
|
||||
Question: Tsumego time. Black to play and capture some stones.
|
||||
Find the key move.
|
||||
|
|
@ -2660,7 +2655,6 @@ O - White
|
|||
|
||||
Specify your move in coordinates (e.g. 'C4' for column C, row 4)
|
||||
Answer: D4
|
||||
|
||||
Metadata: {'difficulty': {'board_size': 12}, 'board': [['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['X', '.', '.', '.', '.', 'X', '.', '.', '.', 'X', '.', '.'], ['.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', 'X', 'X', '.', '.', '.', '.', '.', '.', '.', 'O'], ['.', 'X', 'O', 'O', 'X', '.', '.', '.', '.', '.', '.', '.'], ['.', 'O', 'O', '.', '.', '.', '.', '.', 'O', '.', '.', 'O'], ['X', '.', 'X', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'], ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', 'X', '.']], 'solution': 'D4'}
|
||||
|
||||
````
|
||||
|
|
|
|||
121
README.md
121
README.md
|
|
@ -1,44 +1,28 @@
|
|||
# Reasoning Gym
|
||||
|
||||
We are building a python library of procedural dataset generators and algorithmically verifiable reasoning environments for training Reasoning Models with reinforcement learning (RL).
|
||||
We are building a python library of procedural dataset generators and algorithmically verifiable reasoning environments for training reasoning models with reinforcement learning (RL).
|
||||
|
||||
The goal is to generate virtually infinite data with adjustable complexity.
|
||||
|
||||
Algorithmic verification allows to train on tasks like Rubik‘s cube or [Countdown](<https://en.wikipedia.org/wiki/Countdown_(game_show)#Numbers_Round>) which have many correct solutions.
|
||||
|
||||
## Set up for development
|
||||
## Dataset Gallery
|
||||
|
||||
1. Clone the project
|
||||
In [GALLERY.md](https://github.com/open-thought/reasoning-gym/blob/main/GALLERY.md) you find example outputs of all datasets available in reasoning-gym.
|
||||
|
||||
```
|
||||
git clone https://github.com/open-thought/reasoning-gym.git
|
||||
```
|
||||
## Installation
|
||||
|
||||
2. Create a virtual environment (here we use conda)
|
||||
The `reasoning-gym` package requires Python >= 3.11.
|
||||
|
||||
```
|
||||
conda create --name reasoning_gym python=3.11 -y
|
||||
conda activate reasoning_gym
|
||||
```
|
||||
|
||||
3. Link project and install dependencies
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
4. Install development dependencies
|
||||
|
||||
```
|
||||
pip install -r requirements-dev.txt
|
||||
```
|
||||
|
||||
> NOTE: To consume the APIs in reasoning_gym, just install from pip using the following
|
||||
Install via pip:
|
||||
|
||||
```
|
||||
pip install reasoning-gym
|
||||
```
|
||||
|
||||
For development setup see [CONTRIBUTING.md](CONTRIBUTING.md#delevloper-setup).
|
||||
|
||||
|
||||
## How to instantiate a task dataset?
|
||||
|
||||
Example:
|
||||
|
|
@ -64,89 +48,10 @@ metadata: {'animals': {'sheep': 2, 'dog': 2}, 'total_legs': 16}
|
|||
...
|
||||
```
|
||||
|
||||
See the [Dataset Gallery](https://github.com/open-thought/reasoning-gym/blob/main/GALLERY.md) for a complete list of available datasets with examples.
|
||||
## Contributing
|
||||
|
||||
## Task Overview
|
||||
Please see [CONTRIBUTING.md](CONTRIBUTING.md).
|
||||
|
||||
### <small>Algebra Tasks</small>
|
||||
If you have ideas for dataset generators please create an issue here or contact us in the `#reasoning-gym` channel of the [GPU-Mode discord server](https://discord.gg/gpumode).
|
||||
|
||||
- `SimpleEquationsDataset`: Generate linear equations with one variable to solve (e.g. "3\*x + 2 = 14")
|
||||
- `PolynomialEquationsDataset`: Generate polynomial equations with one variable to solve (e.g. "-6*h\*\*4 + 4*h\**2 - 5*h = 0")
|
||||
- `PolynomialMultiplicationDataset`: Generate polynomial multiplicatons (e.g. "(8x^3 + x + 2)\*(y - 3)")
|
||||
|
||||
### <small>Arithmetic Tasks</small>
|
||||
|
||||
- `BasicArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, \*, /)
|
||||
- `CalendarArithmeticDatset`: Generate arithmetic problems around calendar navigation logic
|
||||
- `ChainSum`: Generate addition/subtraction chains with configurable length and digit counts
|
||||
- `FractionSimplificationDataset`: Generate fraction simplification tasks with configurable complexity
|
||||
- `GCDDataset`: Generate Greatest Common Divisor problems with configurable number of integers
|
||||
- `LCMDataset`: Generate Least Common Multiple problems with configurable number of integers
|
||||
- `LegCountingDataset`: Generate animal leg counting word problems with various animals
|
||||
- `PrimeFactorizationDataset`: Generate prime factorization tasks with configurable number ranges
|
||||
- `TimeIntervalsDataset`: Generate time interval calculation tasks with various formats (time, date, datetime) and complexities
|
||||
|
||||
### <small>Algorithmic Tasks</small>
|
||||
|
||||
- `BaseConversionDataset`: Convert numbers between different bases (binary, hex, etc.)
|
||||
- `CaesarCipherDataset`: Encrypt/decrypt text using Caesar cipher with configurable rotation
|
||||
- `LetterCountingDataset`: Count letter occurrences in text spans
|
||||
- `NumberFilteringDataset`: Filter numbers based on comparison with threshold
|
||||
- `NumberSortingDataset`: Sort lists of numbers in ascending or descending order
|
||||
- `WordSortingDataset`: Sort words in ascending or descending order using ASCII/Unicode ordering
|
||||
- `LetterJumbleDataset`: Unscramble words that have had their letters randomly jumbled
|
||||
- `SentenceReorderingDataset`: Reorder sentence after words in it have been randomly shuffled
|
||||
- `SpellBackwardDataset`: Spell individual words backward (e.g. "sun" -> "nus")
|
||||
- `WordSequenceReversalDataset`: Reverse word order in text spans
|
||||
- `WordLadderDataset`: Generate word ladder puzzles where one word is transformed into another by changing one letter at a time
|
||||
- `GroupAnagramsDataset`: Group anagrams together in a list of words
|
||||
- `IsomorphicStrings`: Check if two strings are isomorphic (have the same character mapping)
|
||||
- `RotateMatrix`: Rotate a matrix K times by 90 degrees clockwise
|
||||
|
||||
### <small>Code Tasks</small>
|
||||
|
||||
- `BFDataset`: Generates BF programs of various difficult, from simple string printing to loops and conditional logic
|
||||
|
||||
### <small>Cognition Tasks</small>
|
||||
|
||||
- `NumberSequenceDataset`: Generate number sequences with discoverable patterns
|
||||
- `ColorCubeRotationDataset`: Generate 3D spatial reasoning tasks with colored cube rotations and orientation tracking
|
||||
- `RubiksCubeDataset`: Generate Rubik's Cube configurations and check correct solutions
|
||||
- `FigletFontDataset`: Generate random words in different "Figlet" fonts for reasoning about the structure of letters
|
||||
|
||||
### <small>Logic Tasks</small>
|
||||
|
||||
- `PropositionalLogicDataset`: Generate propositional logic reasoning problems
|
||||
- `SyllogismDataset`: Generates a [syllogism](https://en.wikipedia.org/wiki/Syllogism) reasoning dataset
|
||||
- `AliceInWonderlandDataset`: Generates [AIW](https://openreview.net/forum?id=Mkl7dzjYiW) (Alice In Wonderland) problems with a few variations
|
||||
- `ZebraDataset`: Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) of varying difficulty.
|
||||
- `SelfReferenceDataset`: Generates self-referencing logic puzzles.
|
||||
|
||||
### <small>Graph Tasks</small>
|
||||
|
||||
- `FamilyRelationshipsDataset`: Generate family relationship reasoning tasks with family trees
|
||||
- `QuantumLockDataset`: Generates puzzles which involve stateful arithmetic and a correct sequence of operations
|
||||
- `LargestIslandDataset`: Generate a grid with islands and find the largest one
|
||||
- `CourseScheduleDataset`: Generate a course schedule with prerequisites and find whether you can complete all courses
|
||||
|
||||
### <small>Game Tasks</small>
|
||||
|
||||
- `SudokuDataset`: Generate 9x9 Sudoku puzzles with configurable number of empty cells
|
||||
- `SokobanDataset`: Generate [Sokoban](https://en.wikipedia.org/wiki/Sokoban) puzzles with configurable size and detail.
|
||||
- `MiniSudokuDataset`: Generate 4x4 Mini Sudoku puzzles with configurable difficulty
|
||||
- `MazeDataset`: Generate a maze with a start and a goal
|
||||
- `CountdownDataset`: Generate number game tasks where numbers and operators must be combined to reach a target value
|
||||
- `NQueensDataset`: Generate N-Queens puzzles with configurable board size and number of starting queens
|
||||
- `TsumegoDataset`: Generate Tsumego capture puzzles with variable board sizes and stone placements
|
||||
|
||||
## Future Generator Ideas
|
||||
|
||||
- More complex math tasks (algebra, geometry)
|
||||
- Algorithmic tasks (counting, sorting, re-ordering)
|
||||
- Logic riddles
|
||||
- Logic inductive programming tasks
|
||||
- ARC-AGI synthetic riddles
|
||||
|
||||
## Call for Contributions
|
||||
|
||||
If you have ideas for additional procedural dataset generators please create an issue here or contact us in the `#reasoning-gym` channel of the [GPU-Mode discord server](https://discord.gg/gpumode).
|
||||
[](https://discord.gg/gpumode)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
|
|||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
||||
from .ransom_note import RansomNoteConfig, RansomNoteDataset
|
||||
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||
|
|
@ -50,6 +51,8 @@ __all__ = [
|
|||
"PalindromeDataset",
|
||||
"GroupAnagramsConfig",
|
||||
"GroupAnagramsDataset",
|
||||
"RansomNoteConfig",
|
||||
"RansomNoteDataset",
|
||||
"IsomorphicStringsConfig",
|
||||
"IsomorphicStringsDataset",
|
||||
"RotateMatrixConfig",
|
||||
|
|
|
|||
99
reasoning_gym/algorithmic/ransom_note.py
Normal file
99
reasoning_gym/algorithmic/ransom_note.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Check if you can construct a ransom note from letters in a magazine.
|
||||
|
||||
A popular Leetcode problem:
|
||||
https://leetcode.com/problems/ransom-note/description/
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
MAX_NOTE_LENGTH = 100_000
|
||||
MAX_MAGAZINE_LENGTH = 100_001
|
||||
|
||||
QUESTION_TEMPLATE = """Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise.
|
||||
|
||||
Each letter in the magazine string can only be used once in your ransom note.
|
||||
|
||||
Ransom note: {ransom_note}
|
||||
Magazine: {magazine}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RansomNoteConfig:
|
||||
"""Configuration for Ransom Note dataset generation"""
|
||||
|
||||
max_note_length: int = 10 # Maximum length of the ransom note
|
||||
max_magazine_length: int = 30 # Maximum length of the magazine
|
||||
p_solvable: float = 0.5 # Probability that the ransom note can be constructed
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_note_length <= MAX_NOTE_LENGTH, "max_note_length must be between 1 and MAX_NOTE_LENGTH"
|
||||
assert (
|
||||
2 <= self.max_magazine_length <= MAX_MAGAZINE_LENGTH
|
||||
), "max_magazine_length must be between 2 and MAX_MAGAZINE_LENGTH"
|
||||
assert self.max_note_length < self.max_magazine_length, "max_note_length must be less than max_magazine_length"
|
||||
assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1"
|
||||
|
||||
|
||||
class RansomNoteDataset(ProceduralDataset):
|
||||
"""Generates Ransom Note exercises with configurable difficulty"""
|
||||
|
||||
def __init__(self, config: RansomNoteConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.letters = {chr(i) for i in range(ord("a"), ord("z") + 1)}
|
||||
|
||||
def _get_inputs(self, rng: Random, solvable: bool) -> tuple[str, str]:
|
||||
"""Generate random ransom note and magazine"""
|
||||
ransom_note_len = rng.randint(1, self.config.max_note_length)
|
||||
ransom_note = [rng.choice(list(self.letters)) for _ in range(ransom_note_len)]
|
||||
|
||||
magazine_len = rng.randint(ransom_note_len, self.config.max_magazine_length)
|
||||
magazine = ransom_note.copy()
|
||||
if solvable:
|
||||
magazine.extend([rng.choice(list(self.letters)) for _ in range(magazine_len - ransom_note_len)])
|
||||
else:
|
||||
remove_letter = rng.choice(magazine)
|
||||
magazine.remove(remove_letter)
|
||||
magazine.extend(
|
||||
[rng.choice(list(self.letters - {remove_letter})) for _ in range(magazine_len - ransom_note_len + 1)]
|
||||
)
|
||||
|
||||
rng.shuffle(ransom_note)
|
||||
rng.shuffle(magazine)
|
||||
return "".join(ransom_note), "".join(magazine)
|
||||
|
||||
def _can_construct(self, ransom_note: str, magazine: str) -> bool:
|
||||
"""Check if ransom note can be constructed from magazine"""
|
||||
count = defaultdict(int)
|
||||
for c in magazine:
|
||||
count[c] += 1
|
||||
for c in ransom_note:
|
||||
if count[c] <= 0:
|
||||
return False
|
||||
count[c] -= 1
|
||||
return True
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Group Anagrams question"""
|
||||
rng = Random(self.seed + idx)
|
||||
solvable = rng.random() < self.config.p_solvable
|
||||
ransom_note, magazine = self._get_inputs(rng, solvable)
|
||||
answer = self._can_construct(ransom_note, magazine)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine),
|
||||
"answer": str(answer),
|
||||
"metadata": {"ransom_note": ransom_note, "magazine": magazine, "solution": answer, "solvable": solvable},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)
|
||||
|
|
@ -11,14 +11,14 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
class SokobanConfig:
|
||||
"""Configuration for sokoban puzzle generation"""
|
||||
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
min_w: int = 6 # Minimum width of the puzzle.
|
||||
min_h: int = 6 # Minimum height of the puzzle.
|
||||
max_w: int = 10 # Maximum width of the puzzle.
|
||||
max_h: int = 10 # Maximum height of the puzzle.
|
||||
min_boxes: int = 6 # Minimum number of boxes.
|
||||
max_boxes: int = 10 # Maximum number of boxes.
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class HanoiConfig:
|
|||
max_disks: int = 7
|
||||
min_pegs: int = 3
|
||||
max_pegs: int = 4
|
||||
size: int = 50
|
||||
size: int = 500
|
||||
seed: Optional[int] = None
|
||||
visualize: bool = False # New parameter
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TsumegoConfig:
|
|||
min_board_size: int = 9
|
||||
max_board_size: int = 13
|
||||
max_stones: int = 15
|
||||
size: int = 100
|
||||
size: int = 500
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ class SyllogismConfig:
|
|||
# Percentage of invalid examples if included (0.0 to 1.0)
|
||||
invalid_ratio: float = 0.3
|
||||
|
||||
# Probability of generating inversion problems instead of syllogisms (0.0 to 1.0)
|
||||
inversion_probability: float = 0.3
|
||||
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
|
|
@ -49,6 +52,7 @@ class SyllogismConfig:
|
|||
[self.allow_all, self.allow_no, self.allow_some, self.allow_some_not]
|
||||
), "At least one quantifier type must be allowed"
|
||||
assert 0.0 <= self.invalid_ratio <= 1.0, "invalid_ratio must be between 0.0 and 1.0"
|
||||
assert 0.0 <= self.inversion_probability <= 1.0, "inversion_probability must be between 0.0 and 1.0"
|
||||
|
||||
|
||||
class SyllogismDataset(ProceduralDataset):
|
||||
|
|
@ -242,12 +246,144 @@ class SyllogismDataset(ProceduralDataset):
|
|||
else:
|
||||
return f"{quantifier.value} {subject.plural} are {predicate.plural}"
|
||||
|
||||
def _check_logical_equivalence(
|
||||
self, premise: Tuple[Quantifier, Term, Term], conclusion: Tuple[Quantifier, Term, Term]
|
||||
) -> bool:
|
||||
"""Check if a conclusion is logically equivalent to a premise"""
|
||||
p_quant, p_subj, p_pred = premise
|
||||
c_quant, c_subj, c_pred = conclusion
|
||||
|
||||
# Direct inversion for universal negative
|
||||
if p_quant == Quantifier.NO:
|
||||
if c_quant == Quantifier.NO:
|
||||
return p_subj == c_pred and p_pred == c_subj
|
||||
return False
|
||||
|
||||
# Particular inversion for universal affirmative
|
||||
if p_quant == Quantifier.ALL:
|
||||
if c_quant == Quantifier.SOME:
|
||||
return p_subj == c_pred and p_pred == c_subj
|
||||
return False
|
||||
|
||||
# Rules for particular statements
|
||||
if p_quant == Quantifier.SOME:
|
||||
if c_quant == Quantifier.SOME:
|
||||
return p_subj == c_pred and p_pred == c_subj
|
||||
return False
|
||||
|
||||
if p_quant == Quantifier.SOME_NOT:
|
||||
# Some A are not B does not imply Some B are not A
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _generate_syllogism(self, rng: Random) -> dict:
|
||||
"""Generate a single syllogism problem"""
|
||||
# Select three different terms
|
||||
terms = rng.sample(self.terms, 3)
|
||||
quantifiers = self._get_allowed_quantifiers()
|
||||
|
||||
# Decide whether to generate a traditional syllogism or an inversion problem
|
||||
if rng.random() < self.config.inversion_probability:
|
||||
# Generate two premises, one will be used for inversion, the other as distractor
|
||||
quantifier1 = rng.choice(quantifiers)
|
||||
quantifier2 = rng.choice(quantifiers)
|
||||
term1, term2, term3 = terms # Use all three terms
|
||||
|
||||
# Create two different premises
|
||||
premise1 = (quantifier1, term1, term2)
|
||||
premise2 = (quantifier2, term2, term3)
|
||||
|
||||
# Format both premises
|
||||
premise1_text = self._format_quantifier_statement(premise1[0], premise1[1], premise1[2])
|
||||
premise2_text = self._format_quantifier_statement(premise2[0], premise2[1], premise2[2])
|
||||
|
||||
# Randomly select which premise to use for inversion
|
||||
if rng.random() < 0.5:
|
||||
premise = premise1
|
||||
selected_premise_num = 1
|
||||
else:
|
||||
premise = premise2
|
||||
selected_premise_num = 2
|
||||
|
||||
# Decide whether to generate a valid or invalid inversion
|
||||
target_valid = rng.random() > self.config.invalid_ratio
|
||||
|
||||
# Get the quantifier and terms from the selected premise
|
||||
premise_quantifier, premise_term1, premise_term2 = premise
|
||||
|
||||
if target_valid:
|
||||
# Generate valid inversions
|
||||
if premise_quantifier == Quantifier.NO:
|
||||
conclusion = (premise_quantifier, premise_term2, premise_term1) # No B are A
|
||||
elif premise_quantifier == Quantifier.ALL:
|
||||
conclusion = (Quantifier.SOME, premise_term2, premise_term1) # Some B are A
|
||||
elif premise_quantifier == Quantifier.SOME:
|
||||
conclusion = (premise_quantifier, premise_term2, premise_term1) # Some B are A
|
||||
else: # SOME_NOT - try a different quantifier
|
||||
new_quantifier = rng.choice([q for q in quantifiers if q != Quantifier.SOME_NOT])
|
||||
# Update the premise with the new quantifier
|
||||
premise = (new_quantifier, premise_term1, premise_term2)
|
||||
premise_quantifier = new_quantifier # Update the quantifier for conclusion generation
|
||||
if selected_premise_num == 1:
|
||||
premise1 = premise
|
||||
premise1_text = self._format_quantifier_statement(premise[0], premise[1], premise[2])
|
||||
else:
|
||||
premise2 = premise
|
||||
premise2_text = self._format_quantifier_statement(premise[0], premise[1], premise[2])
|
||||
|
||||
# Handle the new quantifier
|
||||
if new_quantifier == Quantifier.NO:
|
||||
conclusion = (new_quantifier, premise_term2, premise_term1)
|
||||
elif new_quantifier == Quantifier.ALL:
|
||||
conclusion = (Quantifier.SOME, premise_term2, premise_term1)
|
||||
else: # SOME
|
||||
conclusion = (new_quantifier, premise_term2, premise_term1)
|
||||
else:
|
||||
# Generate invalid inversions by sampling from inappropriate quantifiers
|
||||
if premise_quantifier == Quantifier.NO:
|
||||
# For NO statements, use ALL or SOME
|
||||
conclusion = (rng.choice([Quantifier.ALL, Quantifier.SOME]), premise_term2, premise_term1)
|
||||
elif premise_quantifier == Quantifier.ALL:
|
||||
# For ALL statements, use ALL or NO
|
||||
conclusion = (rng.choice([Quantifier.ALL, Quantifier.NO]), premise_term2, premise_term1)
|
||||
elif premise_quantifier == Quantifier.SOME:
|
||||
# For SOME statements, use ALL or NO
|
||||
conclusion = (rng.choice([Quantifier.ALL, Quantifier.NO]), premise_term2, premise_term1)
|
||||
else: # SOME_NOT
|
||||
# For SOME_NOT statements, use any other quantifier
|
||||
conclusion = (
|
||||
rng.choice([q for q in quantifiers if q != Quantifier.SOME_NOT]),
|
||||
premise_term2,
|
||||
premise_term1,
|
||||
)
|
||||
|
||||
conclusion_text = self._format_quantifier_statement(conclusion[0], conclusion[1], conclusion[2])
|
||||
is_valid = self._check_logical_equivalence(premise, conclusion)
|
||||
|
||||
question = (
|
||||
f"Consider these statements:\n"
|
||||
f"1. {premise1_text}\n"
|
||||
f"2. {premise2_text}\n\n"
|
||||
f"Does it logically follow that:\n"
|
||||
f"{conclusion_text}?\n"
|
||||
f"(Answer Yes or No)"
|
||||
)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": "Yes" if is_valid else "No",
|
||||
"metadata": {
|
||||
"premise1": premise1_text,
|
||||
"premise2": premise2_text,
|
||||
"selected_premise": selected_premise_num,
|
||||
"conclusion": conclusion_text,
|
||||
"is_valid": is_valid,
|
||||
"type": "inversion",
|
||||
},
|
||||
}
|
||||
|
||||
# Traditional syllogism generation
|
||||
target_valid = rng.random() > self.config.invalid_ratio # Invert ratio to match meaning
|
||||
max_attempts = 100
|
||||
attempts = 0
|
||||
|
|
@ -294,6 +430,7 @@ class SyllogismDataset(ProceduralDataset):
|
|||
"premise2": premise2_text,
|
||||
"conclusion": conclusion_text,
|
||||
"is_valid": is_valid,
|
||||
"type": "syllogism",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
111
tests/test_ransom_note.py
Normal file
111
tests/test_ransom_note.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for Ransom Note questions generation"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.ransom_note import RansomNoteConfig, RansomNoteDataset
|
||||
|
||||
|
||||
def test_ransom_note_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_note_length=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_note_length=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_magazine_length=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_magazine_length=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_magazine_length=1) # One not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(
|
||||
max_note_length=3, max_magazine_length=2
|
||||
) # max_note_length must be less than max_magazine_length
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(p_solvable=-0.01) # p_solvable must be between 0 and 1
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(p_solvable=1.01) # p_solvable must be between 0 and 1
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_ransom_note_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = RansomNoteConfig(seed=42, size=10)
|
||||
dataset1 = RansomNoteDataset(config)
|
||||
dataset2 = RansomNoteDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_group_anagrams_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = RansomNoteConfig(max_note_length=10, max_magazine_length=30, size=10, seed=42)
|
||||
dataset = RansomNoteDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "ransom_note" in item["metadata"]
|
||||
assert "magazine" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
assert "solvable" in item["metadata"]
|
||||
|
||||
ransom_note = item["metadata"]["ransom_note"]
|
||||
magazine = item["metadata"]["magazine"]
|
||||
solution = item["metadata"]["solution"]
|
||||
solvable = item["metadata"]["solvable"]
|
||||
|
||||
# Verify dimensions
|
||||
assert len(ransom_note) <= config.max_note_length
|
||||
assert len(ransom_note) <= len(magazine)
|
||||
assert len(magazine) <= config.max_magazine_length
|
||||
assert solution == solvable
|
||||
|
||||
|
||||
def test_ransom_note_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = RansomNoteConfig(size=5, seed=42)
|
||||
dataset = RansomNoteDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_ransom_note_answer():
|
||||
"""Test the _can_construct method"""
|
||||
config = RansomNoteConfig(seed=42)
|
||||
dataset = RansomNoteDataset(config)
|
||||
|
||||
# Correct solution
|
||||
ransom_note, magazine = "ab", "badhergh"
|
||||
assert dataset._can_construct(ransom_note, magazine) == True
|
||||
|
||||
# Inorrect solution
|
||||
ransom_note, magazine = "az", "badhergh"
|
||||
assert dataset._can_construct(ransom_note, magazine) == False
|
||||
|
|
@ -50,9 +50,13 @@ def test_syllogism_dataset_items():
|
|||
|
||||
# Check metadata
|
||||
assert "premise1" in item["metadata"]
|
||||
assert "premise2" in item["metadata"]
|
||||
assert "conclusion" in item["metadata"]
|
||||
assert "is_valid" in item["metadata"]
|
||||
assert "type" in item["metadata"]
|
||||
|
||||
# For traditional syllogisms, check for premise2
|
||||
if item["metadata"]["type"] == "syllogism":
|
||||
assert "premise2" in item["metadata"]
|
||||
|
||||
# Verify answer format
|
||||
assert item["answer"] in ("Yes", "No")
|
||||
|
|
@ -60,7 +64,8 @@ def test_syllogism_dataset_items():
|
|||
# Verify question format
|
||||
assert "Consider these statements:" in item["question"]
|
||||
assert "1." in item["question"]
|
||||
assert "2." in item["question"]
|
||||
if item["metadata"]["type"] == "syllogism":
|
||||
assert "2." in item["question"]
|
||||
assert "Does it logically follow that:" in item["question"]
|
||||
|
||||
|
||||
|
|
@ -262,6 +267,67 @@ def test_valid_syllogism_forms():
|
|||
)
|
||||
|
||||
|
||||
def test_logical_equivalence():
|
||||
"""Test logical equivalence rules for inversions"""
|
||||
config = SyllogismConfig(size=1, seed=42)
|
||||
dataset = SyllogismDataset(config)
|
||||
|
||||
# Create test terms
|
||||
A = Term("student", "students")
|
||||
B = Term("human", "humans")
|
||||
|
||||
# Test direct inversion of NO statements
|
||||
assert dataset._check_logical_equivalence(
|
||||
(Quantifier.NO, A, B), # No students are humans
|
||||
(Quantifier.NO, B, A), # No humans are students
|
||||
)
|
||||
|
||||
# Test particular inversion of ALL statements
|
||||
assert dataset._check_logical_equivalence(
|
||||
(Quantifier.ALL, A, B), # All students are humans
|
||||
(Quantifier.SOME, B, A), # Some humans are students
|
||||
)
|
||||
|
||||
# Test direct inversion of SOME statements
|
||||
assert dataset._check_logical_equivalence(
|
||||
(Quantifier.SOME, A, B), # Some students are humans
|
||||
(Quantifier.SOME, B, A), # Some humans are students
|
||||
)
|
||||
|
||||
# Test invalid inversions
|
||||
assert not dataset._check_logical_equivalence(
|
||||
(Quantifier.SOME_NOT, A, B), # Some students are not humans
|
||||
(Quantifier.SOME_NOT, B, A), # Some humans are not students (invalid)
|
||||
)
|
||||
|
||||
assert not dataset._check_logical_equivalence(
|
||||
(Quantifier.ALL, A, B), # All students are humans
|
||||
(Quantifier.ALL, B, A), # All humans are students (invalid)
|
||||
)
|
||||
|
||||
|
||||
def test_inversion_generation():
|
||||
"""Test generation of inversion problems"""
|
||||
# Force inversion problems by setting probability to 1.0
|
||||
config = SyllogismConfig(size=10, seed=42, inversion_probability=1.0)
|
||||
dataset = SyllogismDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
# Check type is marked as inversion
|
||||
assert item["metadata"]["type"] == "inversion"
|
||||
# Check both premises and selection
|
||||
assert "premise1" in item["metadata"]
|
||||
assert "premise2" in item["metadata"]
|
||||
assert "selected_premise" in item["metadata"]
|
||||
assert item["metadata"]["selected_premise"] in (1, 2)
|
||||
# Check format
|
||||
assert item["answer"] in ("Yes", "No")
|
||||
assert "Consider these statements:" in item["question"]
|
||||
assert "1." in item["question"]
|
||||
assert "2." in item["question"] # Inversion questions now show both premises
|
||||
assert "Does it logically follow that:" in item["question"]
|
||||
|
||||
|
||||
def test_syllogism_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = SyllogismConfig(size=5, seed=42)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue