diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b31c4dc8..8bf7ae71 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,4 +1,8 @@ -name: Pre-commit +name: Pre-commit Checks + +permissions: + contents: read + pull-requests: write on: pull_request: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e97239d..157c993d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,6 +9,10 @@ on: jobs: test: runs-on: ubuntu-latest + permissions: + contents: read + issues: write + pull-requests: write strategy: matrix: python-version: ["3.11", "3.12"] @@ -19,12 +23,12 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - + - name: Install dependencies run: | python -m pip install --upgrade pip pip install ".[test]" - + - name: Run tests run: | pytest diff --git a/GALLERY.md b/GALLERY.md index f52bf50b..56e5836f 100644 --- a/GALLERY.md +++ b/GALLERY.md @@ -2,10 +2,14 @@ This gallery shows examples from all available datasets using their default configurations. ## Available Datasets +- [advanced_geometry](#advanced_geometry) +- [aiw](#aiw) +- [arc_1d](#arc_1d) - [base_conversion](#base_conversion) - [basic_arithmetic](#basic_arithmetic) - [bf](#bf) - [caesar_cipher](#caesar_cipher) +- [calendar_arithmetic](#calendar_arithmetic) - [chain_sum](#chain_sum) - [color_cube_rotation](#color_cube_rotation) - [countdown](#countdown) @@ -14,12 +18,14 @@ This gallery shows examples from all available datasets using their default conf - [fraction_simplification](#fraction_simplification) - [game_of_life](#game_of_life) - [gcd](#gcd) +- [intermediate_integration](#intermediate_integration) - [lcm](#lcm) - [leg_counting](#leg_counting) - [letter_counting](#letter_counting) - [letter_jumble](#letter_jumble) - [maze](#maze) - [mini_sudoku](#mini_sudoku) +- [n_queens](#n_queens) - [number_filtering](#number_filtering) - [number_sequence](#number_sequence) - [number_sorting](#number_sorting) @@ -30,43 +36,204 @@ This gallery shows examples from all available datasets using their default conf - [rubiks_cube](#rubiks_cube) - [sentence_reordering](#sentence_reordering) - [simple_equations](#simple_equations) +- [simple_geometry](#simple_geometry) +- [simple_integration](#simple_integration) - [spell_backward](#spell_backward) - [sudoku](#sudoku) - [syllogism](#syllogism) +- [time_intervals](#time_intervals) +- [tower_of_hanoi](#tower_of_hanoi) - [word_ladder](#word_ladder) - [word_sequence_reversal](#word_sequence_reversal) - [word_sorting](#word_sorting) ## Dataset Examples +### advanced_geometry +A dataset for advanced geometry tasks using coordinate geometry. + +Default configuration: +```python +min_coord = -10 +max_coord = 10 +size = 50 +seed = 42 +task_types = ['orthocenter', 'incircle_radius', 'angle_measure'] +``` + +Example tasks: +```` +Example 1: +Question: In triangle ABC with coordinates A=(-7, -10), B=(-2, -3), and C=(-3, -6), find the measure (in degrees) of angle ABC. +Answer: 17.10° +Metadata: {'A': (-7, -10), 'B': (-2, -3), 'C': (-3, -6), 'angle_ABC_degrees': 17.10272896905237} + +Example 2: +Question: For triangle with vertices A=(-1, -6), B=(4, 1), and C=(-7, 4), determine the orthocenter (intersection of altitudes). +Answer: (0.304, -1.217) +Metadata: {'A': (-1, -6), 'B': (4, 1), 'C': (-7, 4), 'orthocenter_exact': ('7/23', '-28/23'), 'orthocenter_approx': (0.30434782608695654, -1.2173913043478262)} + +Example 3: +Question: Find the incircle radius of triangle ABC whose vertices are A=(6, 7), B=(-7, -5), and C=(2, -3). +Answer: 2.176 +Metadata: {'A': (6, 7), 'B': (-7, -5), 'C': (2, -3), 'incircle_radius_exact': 'sqrt(-sqrt(29) + sqrt(85)/2 + sqrt(313)/2)*sqrt(-sqrt(313)/2 + sqrt(85)/2 + sqrt(29))*sqrt(-sqrt(85)/2 + sqrt(29) + sqrt(313)/2)/sqrt(sqrt(85)/2 + sqrt(29) + sqrt(313)/2)', 'incircle_radius_approx': 2.176123777286009} + +```` + +### aiw +A procedural dataset inspired by the "Alice in Wonderland" paper. + + The dataset is inspired by the following paper: + @inproceedings{nezhurina2024alice, + title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and + Basic Reasoning Deficits in State-Of-the-Art Large Language Models}, + author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and + Jenia Jitsev}, + booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding + Deep Learning}, + year={2024}, + url={https://openreview.net/forum?id=Mkl7dzjYiW} + } + +Default configuration: +```python +male_names = ['James', 'John', 'Robert', 'Michael', 'William', 'David', 'Richard', 'Joseph', 'Thomas', 'Charles', 'Bob'] +female_names = ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah', 'Margaret', 'Alice'] +task_types = [, , ] +seed = 42 +size = 10 +max_entities = 6 +``` + +Example tasks: +```` +Example 1: +Question: Patricia has 6 male colleagues and she also has 3 female colleagues. These are all colleagues that Patricia has. All these mentioned persons around Patricia are colleagues of each other. James has 2 male colleagues and 2 female colleagues in total. All these mentioned persons around James are colleagues of each other. The people in the circle around James do not have other colleagues aside - with the only exception of Matilda. She is colleague of James and she is also colleague of Patricia, being part of Patricia's circle. How many female colleagues does Matilda have? +Answer: 4 +Metadata: {'task_type': 'colleagues'} + +Example 2: +Question: Elizabeth has 4 brothers and she also has 3 sisters. How many sisters does Elizabeth's brother have? +Answer: 4 +Metadata: {'task_type': 'siblings'} + +Example 3: +Question: Sarah has 6 male friends and she also has 1 female friends. They all are friends with each other and have no other friends aside. How many female friends does Thomas, a male friend of Sarah, have? +Answer: 2 +Metadata: {'task_type': 'friends'} + +```` + +### arc_1d +Generates ARC 1D tasks by randomly selecting from available task generators + +Default configuration: +```python +min_size = 10 +max_size = 30 +num_train = 3 +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Find the common rule that maps an input grid to an output grid, given the examples below. + +Example 1: +Input: 7 1 0 0 5 5 0 5 5 0 0 0 0 +Output: 7 1 0 0 7 7 0 1 1 0 0 0 0 + +Example 2: +Input: 5 1 0 5 5 0 5 5 0 0 0 0 0 +Output: 5 1 0 5 5 0 1 1 0 0 0 0 0 + +Example 3: +Input: 2 6 0 0 5 5 0 5 5 0 0 0 0 +Output: 2 6 0 0 2 2 0 6 6 0 0 0 0 + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. Describe how you derived the rule and your overall reasoning process in detail before you submit your answer. Your final answer must be placed in tags and should be just be the text output grid itself. + +Input: +6 0 0 0 0 0 0 5 5 5 0 0 0 +Answer: 6 0 0 0 0 0 0 6 6 6 0 0 0 +Metadata: {'task_name': 'recolor_blocks_from_palette', 'size': 13, 'train_examples': [{'input': [7, 1, 0, 0, 5, 5, 0, 5, 5, 0, 0, 0, 0], 'output': [7, 1, 0, 0, 7, 7, 0, 1, 1, 0, 0, 0, 0]}, {'input': [5, 1, 0, 5, 5, 0, 5, 5, 0, 0, 0, 0, 0], 'output': [5, 1, 0, 5, 5, 0, 1, 1, 0, 0, 0, 0, 0]}, {'input': [2, 6, 0, 0, 5, 5, 0, 5, 5, 0, 0, 0, 0], 'output': [2, 6, 0, 0, 2, 2, 0, 6, 6, 0, 0, 0, 0]}], 'test_example': {'input': [6, 0, 0, 0, 0, 0, 0, 5, 5, 5, 0, 0, 0], 'output': [6, 0, 0, 0, 0, 0, 0, 6, 6, 6, 0, 0, 0]}} + +Example 2: +Question: Find the common rule that maps an input grid to an output grid, given the examples below. + +Example 1: +Input: 0 8 8 8 8 8 8 8 8 8 8 8 8 0 0 0 0 0 0 +Output: 0 0 0 0 8 8 8 8 8 8 8 8 8 8 8 8 0 0 0 + +Example 2: +Input: 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 0 0 0 +Output: 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 + +Example 3: +Input: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 +Output: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. Describe how you derived the rule and your overall reasoning process in detail before you submit your answer. Your final answer must be placed in tags and should be just be the text output grid itself. + +Input: +0 0 0 0 0 0 6 6 6 6 6 6 6 6 6 0 0 0 0 +Answer: 0 0 0 0 0 0 0 0 0 6 6 6 6 6 6 6 6 6 0 +Metadata: {'task_name': 'move_3pix_solid', 'size': 19, 'train_examples': [{'input': [0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 0, 0, 0, 0, 0], 'output': [0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 0, 0]}, {'input': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0], 'output': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2]}, {'input': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], 'output': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0]}], 'test_example': {'input': [0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0], 'output': [0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 0]}} + +Example 3: +Question: Find the common rule that maps an input grid to an output grid, given the examples below. + +Example 1: +Input: 0 0 0 0 0 0 0 2 0 0 4 4 4 4 4 4 4 4 4 4 4 4 4 0 0 0 +Output: 0 0 0 0 0 0 0 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 0 0 0 + +Example 2: +Input: 0 0 0 2 0 0 0 0 0 0 0 0 0 3 3 3 3 3 3 3 3 0 0 0 0 0 +Output: 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0 0 0 0 0 + +Example 3: +Input: 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 3 3 3 3 0 0 0 0 +Output: 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0 0 0 0 + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. Describe how you derived the rule and your overall reasoning process in detail before you submit your answer. Your final answer must be placed in tags and should be just be the text output grid itself. + +Input: +0 0 0 0 0 0 0 0 0 0 0 7 7 7 7 7 7 7 7 7 7 7 7 7 2 0 +Answer: 0 0 0 0 0 0 0 0 0 0 0 7 7 7 7 7 7 7 7 7 7 7 7 7 7 0 +Metadata: {'task_name': 'block_scale_to_dot', 'size': 26, 'train_examples': [{'input': [0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0], 'output': [0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0]}, {'input': [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0], 'output': [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0]}, {'input': [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 0, 0, 0, 0], 'output': [0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0]}], 'test_example': {'input': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 2, 0], 'output': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0]}} + +```` + ### base_conversion Generates base conversion tasks Default configuration: -````python +```python min_base = 2 max_base = 16 min_value = 0 max_value = 1000 seed = 42 size = 500 -```` +``` Example tasks: ```` Example 1: -Question: Convert the base-3 number 28e to binary +Question: Convert the base-3 number 220020 to binary Answer: 1010001110 -Metadata: {'decimal_value': 654, 'source_base': 3, 'target_base': 2, 'source_repr': '28e', 'target_repr': '1010001110'} +Metadata: {'decimal_value': 654, 'source_base': 3, 'target_base': 2, 'source_repr': '220020', 'target_repr': '1010001110'} Example 2: -Question: Convert the base-6 number 27 to base-13 (use lowercase letters a-z for digits above 9) -Answer: 27 -Metadata: {'decimal_value': 39, 'source_base': 6, 'target_base': 13, 'source_repr': '27', 'target_repr': '27'} +Question: Convert the base-6 number 103 to base-13 (use lowercase letters a-z for digits above 9) +Answer: 30 +Metadata: {'decimal_value': 39, 'source_base': 6, 'target_base': 13, 'source_repr': '103', 'target_repr': '30'} Example 3: -Question: Convert the base-10 number 1a2 to base-13 (use lowercase letters a-z for digits above 9) -Answer: 1a2 -Metadata: {'decimal_value': 418, 'source_base': 10, 'target_base': 13, 'source_repr': '1a2', 'target_repr': '1a2'} +Question: Convert the base-10 number 418 to base-13 (use lowercase letters a-z for digits above 9) +Answer: 262 +Metadata: {'decimal_value': 418, 'source_base': 10, 'target_base': 13, 'source_repr': '418', 'target_repr': '262'} ```` @@ -74,7 +241,7 @@ Metadata: {'decimal_value': 418, 'source_base': 10, 'target_base': 13, 'source_r Dataset that generates basic arithmetic tasks with configurable complexity Default configuration: -````python +```python min_terms = 2 max_terms = 6 min_digits = 1 @@ -86,7 +253,7 @@ seed = 42 size = 500 format_style = simple whitespace = single -```` +``` Example tasks: ```` @@ -111,11 +278,11 @@ Metadata: {'num_terms': 5, 'num_digits': 1, 'expression': '0 + -2 + -4 * 0 * 3'} Generates BF tasks Default configuration: -````python +```python seed = 42 size = 500 difficulty = 1 -```` +``` Example tasks: ```` @@ -146,7 +313,7 @@ Metadata: {'bfit_code': '\nint main() {\n print("under");\n}\n', 'bf_program' Generates Caesar cipher encryption/decryption tasks Default configuration: -````python +```python delimiter = . min_words = 3 max_words = 20 @@ -154,7 +321,7 @@ min_rotation = 1 max_rotation = 25 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -175,11 +342,41 @@ Metadata: {'rotation': 17, 'cipher_text': 'ZW PFLI JKFDRTY ZJ FLK FW ZK DLJK SV ```` +### calendar_arithmetic +Default configuration: +```python +year = 2022 +tasks = ['weekday_offset', 'weekday_of_date', 'weekday_of_date_from_first_day', 'recurring_event_day', 'count_days', 'count_business_days', 'is_leap_year'] +offset_upper_bound = 100 +leap_year_range = 200 +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Between Sunday, February 27, 2022 and Wednesday, March 2, 2022 (counting both dates), what's the total count of business days (Monday through Friday)? Give the count numerically. +Answer: 3 +Metadata: {'task': 'count_business_days', 'start_date': '2022-02-27', 'end_date': '2022-03-02'} + +Example 2: +Question: Starting from Monday, May 23, 2022, which weekday was it 98 days before? Write out the full weekday name. +Answer: Monday +Metadata: {'task': 'weekday_offset', 'start_date': '2022-05-23', 'offset_days': -98, 'target_date': '2022-02-14'} + +Example 3: +Question: If a meeting is scheduled on the last Saturday of September 2022, on which day of the month does it occur? Respond with just the number. Answer with -1 if the ordinal does not exist in the month. +Answer: 24 +Metadata: {'task': 'recurring_event_day', 'year': 2022, 'month': 9, 'ordinal': 'last', 'weekday': 'Saturday'} + +```` + ### chain_sum Generates simple arithmetic tasks using only + and - operators Default configuration: -````python +```python min_terms = 2 max_terms = 6 min_digits = 1 @@ -187,7 +384,7 @@ max_digits = 4 allow_negation = False seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -212,12 +409,12 @@ Metadata: {'num_terms': 5, 'num_digits': 1, 'expression': '2 + 6 + 3 + 4 + 0'} Generates color cube rotation reasoning tasks Default configuration: -````python +```python min_rotations = 1 max_rotations = 3 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -280,7 +477,7 @@ Metadata: {'initial_state': {'top': 'orange', 'right': 'cyan', 'front': 'violet' Generates Countdown Number Game tasks Default configuration: -````python +```python min_numbers = 4 max_numbers = 6 min_value = 1 @@ -291,7 +488,7 @@ operators = ('+', '-', '*', '/') shuffle = True seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -319,14 +516,14 @@ Metadata: {'numbers': [5, 41, 38, 81, 14], 'target': 450, 'expression': '41*14 - Generates family relationship reasoning tasks Default configuration: -````python +```python min_family_size = 4 max_family_size = 8 male_names = ['James', 'John', 'Robert', 'Michael', 'William', 'David', 'Richard', 'Joseph', 'Thomas', 'Charles', 'Peter', 'Daniel', 'Matthew', 'Christopher', 'Andrew', 'George', 'Edward', 'Benjamin', 'Henry', 'Samuel', 'Alexander', 'Oliver', 'Jack', 'Harry', 'Jacob', 'Noah', 'Ethan', 'Lucas', 'Mason', 'Logan', 'Sebastian', 'Theodore', 'Owen', 'Liam', 'Aiden', 'Kai', 'Jayden', 'Zion', 'Phoenix', 'Atlas', 'Axel', 'Ryder', 'Finn'] female_names = ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah', 'Karen', 'Emma', 'Lisa', 'Anna', 'Margaret', 'Victoria', 'Charlotte', 'Sophia', 'Isabella', 'Olivia', 'Ava', 'Mia', 'Emily', 'Abigail', 'Amelia', 'Eleanor', 'Grace', 'Alice', 'Lucy', 'Chloe', 'Sophie', 'Lily', 'Hannah', 'Zoe', 'Luna', 'Nova', 'Aria', 'Willow', 'Aurora', 'Sage', 'River', 'Winter', 'Sky', 'Rain'] seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -357,13 +554,13 @@ Metadata: {'person1': 'Liam', 'person2': 'Noah', 'relationship': 'father', 'fami Generates FigletFont tasks Default configuration: -````python +```python static_word = None static_font = None space_letters = True seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -421,7 +618,7 @@ Metadata: {'font': 'xcourb', 'space_letters': True} Generates fraction simplification tasks Default configuration: -````python +```python min_value = 1 max_value = 1000 min_factor = 1 @@ -429,7 +626,7 @@ max_factor = 100 styles = ('plain', 'latex_inline', 'latex_frac', 'latex_dfrac') seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -454,14 +651,14 @@ Metadata: {'numerator': 29330, 'denominator': 37310, 'simplified_numerator': 419 Generates Game of Life games with configurable parameters Default configuration: -````python +```python grid_size_x = 20 grid_size_y = 20 filled_cells = 100 simulation_steps = 1 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -606,14 +803,14 @@ Metadata: {'grid_size_x': 20, 'grid_size_y': 20, 'filled_cells': 100, 'simulatio Generates Greatest Common Divisor (GCD) tasks Default configuration: -````python +```python min_numbers = 2 max_numbers = 2 min_value = 1 max_value = 1000 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -634,18 +831,60 @@ Metadata: {'numbers': [297, 30], 'result': 3} ```` +### intermediate_integration +Generates intermediate integration problem - either + by substitution or by parts + +Default configuration: +```python +problem_types = ('substitution', 'by_parts') +substitution_types = ('linear', 'trigonometric', 'exponential', 'radical') +by_parts_types = ('polynomial_exp_trig', 'log_inverse_trig', 'cyclic', 'repeated_parts') +seed = 42 +size = 500 +linear_lower_bound = 1 +linear_upper_bound = 10 +min_linear_degree = 2 +max_linear_degree = 4 +outer_constant_min = 1 +outer_constant_max = 3 +min_poly_degree = 1 +max_poly_degree = 3 +symbols = ('x', 'X') +operators = ('+', '-') +``` + +Example tasks: +```` +Example 1: +Question: Find the indefinite integral: ∫ -3*exp(3*x + 9) dx +Answer: -exp(3*x + 9) + C +Metadata: {'integrand': '-3*exp(3*x + 9)', 'problem_type': 'substitution', 'variable': 'x', 'type': 'exponential', 'expected_answer_expression': -exp(3*x + 9)} + +Example 2: +Question: Evaluate the indefinite integral: ∫ -6*sin(2*X + 10)*cos(2*X + 10)**4 dx +Answer: 3*cos(2*X + 10)**5/5 + C +Metadata: {'integrand': '-6*sin(2*X + 10)*cos(2*X + 10)**4', 'problem_type': 'substitution', 'variable': 'X', 'type': 'trigonometric', 'expected_answer_expression': 3*cos(2*X + 10)**5/5} + +Example 3: +Question: Find the indefinite integral: ∫ 2*asin(x) dx +Answer: 2*Integral(asin(x), x) + C +Metadata: {'integrand': '2*asin(x)', 'problem_type': 'by_parts', 'variable': 'x', 'type': 'log_inverse_trig', 'expected_answer_expression': 2*Integral(asin(x), x)} + +```` + ### lcm Generates Least Common Multiple (LCM) tasks Default configuration: -````python +```python min_numbers = 2 max_numbers = 2 min_value = 1 max_value = 100 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -670,13 +909,13 @@ Metadata: {'numbers': [38, 4], 'result': 76} Generates leg counting arithmetic tasks Default configuration: -````python +```python min_animals = 2 max_animals = 5 max_instances = 3 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -701,12 +940,12 @@ Metadata: {'animals': {'crab': 1, 'lobster': 2, 'human': 1, 'cow': 1, 'bee': 1}, Generates letter counting tasks from text spans Default configuration: -````python +```python min_words = 5 max_words = 15 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -731,7 +970,7 @@ Metadata: {'span_length': 11, 'target_letter': 't', 'span': ['readable', 'form', Generates word letter jumbling tasks Default configuration: -````python +```python min_word_len = 1 max_word_len = 64 min_words = 3 @@ -741,7 +980,7 @@ max_corruption_level = 0.9 consecutive_words = True seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -767,14 +1006,14 @@ Generates mazes with guaranteed shortest path distance from start to goal within [min_dist, max_dist]. Default configuration: -````python +```python min_dist = 5 max_dist = 10 min_grid_size = 5 max_grid_size = 10 seed = 42 size = 50 -```` +``` Example tasks: ```` @@ -840,12 +1079,12 @@ Metadata: {'grid_size': 7, 'grid': ['QQQQQQQ', 'QQ%%%%Q', 'QQ`%Q%Q', 'Q%%Q%%Q', Generates 4x4 sudoku puzzles with configurable difficulty Default configuration: -````python +```python min_empty = 8 max_empty = 12 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -887,11 +1126,108 @@ Metadata: {'puzzle': [[0, 0, 0, 0], [1, 3, 4, 0], [3, 1, 2, 4], [4, 0, 0, 0]], ' ```` +### n_queens +Generates N Queens puzzles with configurable difficulty + +Default configuration: +```python +n = 8 +min_remove = 1 +max_remove = 7 +size = 500 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: Solve this N Queens puzzle: +_ _ _ _ _ _ Q _ +_ Q _ _ _ _ _ _ +_ _ _ Q _ _ _ _ +_ _ _ _ _ _ _ _ +_ _ _ _ _ _ _ Q +_ _ _ _ Q _ _ _ +_ _ Q _ _ _ _ _ +_ _ _ _ _ Q _ _ + +The board size is 8x8 and your job is to place 1 queen(s) on the board such that no two queens attack each other. + +No two queens attack each other if they are not in the same row, column, or diagonal. + +Place a queen by replacing an underscore (_) with a Q. + +Answer: _ _ _ _ _ _ Q _ +_ Q _ _ _ _ _ _ +_ _ _ Q _ _ _ _ +Q _ _ _ _ _ _ _ +_ _ _ _ _ _ _ Q +_ _ _ _ Q _ _ _ +_ _ Q _ _ _ _ _ +_ _ _ _ _ Q _ _ +Metadata: {'puzzle': [['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', 'Q', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', 'Q'], ['_', '_', '_', '_', 'Q', '_', '_', '_'], ['_', '_', 'Q', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_']], 'solutions': [[['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', 'Q', '_', '_', '_', '_'], ['Q', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', 'Q'], ['_', '_', '_', '_', 'Q', '_', '_', '_'], ['_', '_', 'Q', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_']]], 'num_removed': 1, 'valid_answers': ['_ _ _ _ _ _ Q _\n_ Q _ _ _ _ _ _\n_ _ _ Q _ _ _ _\nQ _ _ _ _ _ _ _\n_ _ _ _ _ _ _ Q\n_ _ _ _ Q _ _ _\n_ _ Q _ _ _ _ _\n_ _ _ _ _ Q _ _']} + +Example 2: +Question: Solve this N Queens puzzle: +_ Q _ _ _ _ _ _ +_ _ _ _ _ _ _ _ +_ _ _ _ _ Q _ _ +_ _ _ _ _ _ _ Q +_ _ _ _ _ _ _ _ +_ _ _ _ _ _ _ _ +_ _ _ _ _ _ Q _ +_ _ _ _ Q _ _ _ + +The board size is 8x8 and your job is to place 3 queen(s) on the board such that no two queens attack each other. + +No two queens attack each other if they are not in the same row, column, or diagonal. + +Place a queen by replacing an underscore (_) with a Q. + +Answer: _ Q _ _ _ _ _ _ +_ _ _ Q _ _ _ _ +_ _ _ _ _ Q _ _ +_ _ _ _ _ _ _ Q +_ _ Q _ _ _ _ _ +Q _ _ _ _ _ _ _ +_ _ _ _ _ _ Q _ +_ _ _ _ Q _ _ _ +Metadata: {'puzzle': [['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', 'Q'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', '_', '_', '_', 'Q', '_', '_', '_']], 'solutions': [[['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', 'Q', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', 'Q'], ['_', '_', 'Q', '_', '_', '_', '_', '_'], ['Q', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', '_', '_', '_', 'Q', '_', '_', '_']]], 'num_removed': 3, 'valid_answers': ['_ Q _ _ _ _ _ _\n_ _ _ Q _ _ _ _\n_ _ _ _ _ Q _ _\n_ _ _ _ _ _ _ Q\n_ _ Q _ _ _ _ _\nQ _ _ _ _ _ _ _\n_ _ _ _ _ _ Q _\n_ _ _ _ Q _ _ _']} + +Example 3: +Question: Solve this N Queens puzzle: +_ _ _ _ _ _ _ _ +_ Q _ _ _ _ _ _ +_ _ _ _ _ _ _ _ +Q _ _ _ _ _ _ _ +_ _ _ _ _ _ _ _ +_ _ _ _ _ _ _ _ +_ _ _ _ _ _ _ _ +_ _ _ _ _ Q _ _ + +The board size is 8x8 and your job is to place 5 queen(s) on the board such that no two queens attack each other. + +No two queens attack each other if they are not in the same row, column, or diagonal. + +Place a queen by replacing an underscore (_) with a Q. + +Answer: _ _ _ _ Q _ _ _ +_ Q _ _ _ _ _ _ +_ _ _ _ _ _ _ Q +Q _ _ _ _ _ _ _ +_ _ _ Q _ _ _ _ +_ _ _ _ _ _ Q _ +_ _ Q _ _ _ _ _ +_ _ _ _ _ Q _ _ +Metadata: {'puzzle': [['_', '_', '_', '_', '_', '_', '_', '_'], ['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['Q', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_']], 'solutions': [[['_', '_', '_', '_', 'Q', '_', '_', '_'], ['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', 'Q'], ['Q', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', 'Q', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', '_', 'Q', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_']], [['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', 'Q', '_', '_', '_', '_'], ['Q', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', 'Q'], ['_', '_', '_', '_', 'Q', '_', '_', '_'], ['_', '_', 'Q', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_']], [['_', '_', '_', '_', '_', '_', '_', 'Q'], ['_', 'Q', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', 'Q', '_', '_', '_', '_'], ['Q', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', 'Q', '_'], ['_', '_', '_', '_', 'Q', '_', '_', '_'], ['_', '_', 'Q', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', 'Q', '_', '_']]], 'num_removed': 5, 'valid_answers': ['_ _ _ _ Q _ _ _\n_ Q _ _ _ _ _ _\n_ _ _ _ _ _ _ Q\nQ _ _ _ _ _ _ _\n_ _ _ Q _ _ _ _\n_ _ _ _ _ _ Q _\n_ _ Q _ _ _ _ _\n_ _ _ _ _ Q _ _', '_ _ _ _ _ _ Q _\n_ Q _ _ _ _ _ _\n_ _ _ Q _ _ _ _\nQ _ _ _ _ _ _ _\n_ _ _ _ _ _ _ Q\n_ _ _ _ Q _ _ _\n_ _ Q _ _ _ _ _\n_ _ _ _ _ Q _ _', '_ _ _ _ _ _ _ Q\n_ Q _ _ _ _ _ _\n_ _ _ Q _ _ _ _\nQ _ _ _ _ _ _ _\n_ _ _ _ _ _ Q _\n_ _ _ _ Q _ _ _\n_ _ Q _ _ _ _ _\n_ _ _ _ _ Q _ _']} + +```` + ### number_filtering Generates number filtering tasks Default configuration: -````python +```python min_numbers = 3 max_numbers = 10 min_decimals = 0 @@ -900,7 +1236,7 @@ min_value = -100.0 max_value = 100.0 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -925,7 +1261,7 @@ Metadata: {'original_numbers': ['4', '-64.7', '-42.1', '-77', '-79.9640', '37.76 Generates number sequence completion tasks with dynamic pattern generation Default configuration: -````python +```python min_terms = 4 max_terms = 8 min_value = -100 @@ -933,7 +1269,7 @@ max_value = 100 max_complexity = 3 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -958,7 +1294,7 @@ Metadata: {'rule': 'halve', 'complexity': 2, 'sequence': [8, 4, 2, 1, 0, 0, 0, 0 Generates number sorting tasks Default configuration: -````python +```python min_numbers = 3 max_numbers = 10 min_decimals = 0 @@ -967,7 +1303,7 @@ min_value = -100.0 max_value = 100.0 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -995,7 +1331,7 @@ Generates random polynomial equations of degree in [min_degree, max_degree]. - The solution may be real or complex; we filter real solutions by default for simplicity. Default configuration: -````python +```python min_terms = 2 max_terms = 4 min_value = 1 @@ -1005,7 +1341,7 @@ max_degree = 3 operators = ('+', '-') seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1030,12 +1366,12 @@ Metadata: {'polynomial_expr': '71*n**3 - 2*n - 29', 'variable': 'n', 'degree': 3 Generates prime factorization tasks Default configuration: -````python +```python min_value = 2 max_value = 1000 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1060,7 +1396,7 @@ Metadata: {'number': 420, 'factors': [2, 2, 3, 5, 7]} Generates propositional logic reasoning tasks Default configuration: -````python +```python min_vars = 2 max_vars = 4 min_statements = 2 @@ -1068,7 +1404,7 @@ max_statements = 4 max_complexity = 3 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1105,17 +1441,17 @@ Metadata: {'premises': ['((Q ∨ P) ∧ ¬P)', 'P', '((P ∧ R) ∧ ¬R)', '((Q Generates QuantumLock tasks Default configuration: -````python +```python difficulty = 10 seed = 42 size = 500 -```` +``` Example tasks: ```` Example 1: Question: In front of you are some buttons, a light, and a number. The light will toggle between red and green whenever you press a button. Each button performs a mathematical operation to the number, but the operation may depend on the state of the light. -You must press the shortest correct sequence of buttons to reach the target value. +You must press the shortest correct sequence of buttons to reach the target value. Your answer should be a sequence of buttons separated by '→', for example: A → B → C Start: 0 (red) Target: 46 @@ -1128,7 +1464,7 @@ Metadata: {'difficulty': 10, 'solution_path': ['A', 'B', 'C', 'C', 'A', 'C'], 't Example 2: Question: In front of you are some buttons, a light, and a number. The light will toggle between red and green whenever you press a button. Each button performs a mathematical operation to the number, but the operation may depend on the state of the light. -You must press the shortest correct sequence of buttons to reach the target value. +You must press the shortest correct sequence of buttons to reach the target value. Your answer should be a sequence of buttons separated by '→', for example: A → B → C Start: 0 (red) Target: 30 @@ -1141,7 +1477,7 @@ Metadata: {'difficulty': 10, 'solution_path': ['C', 'A', 'C', 'A', 'C', 'A', 'C' Example 3: Question: In front of you are some buttons, a light, and a number. The light will toggle between red and green whenever you press a button. Each button performs a mathematical operation to the number, but the operation may depend on the state of the light. -You must press the shortest correct sequence of buttons to reach the target value. +You must press the shortest correct sequence of buttons to reach the target value. Your answer should be a sequence of buttons separated by '→', for example: A → B → C Start: 0 (red) Target: 45 @@ -1158,13 +1494,13 @@ Metadata: {'difficulty': 10, 'solution_path': ['B', 'B', 'B', 'B', 'B', 'B', 'B' Generates RubiksCube tasks Default configuration: -````python +```python scramble_steps = 3 cube_size = 3 remove_ansi = True seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1228,12 +1564,12 @@ Metadata: {'cube_size': 3, 'scramble_steps': 3, 'scramble_moves': "U R' R'", 'ex Generates sentence reordering tasks from text spans Default configuration: -````python +```python min_words_in_sentence = 3 max_words_in_sentence = 20 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1258,7 +1594,7 @@ Metadata: {'word_count': 10} Generates simple equations with one variable to solve Default configuration: -````python +```python min_terms = 2 max_terms = 4 min_value = 1 @@ -1266,7 +1602,7 @@ max_value = 100 operators = ('+', '-', '*') seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1287,15 +1623,88 @@ Metadata: {'equation': '29*n - 5 = 430', 'variable': 'n'} ```` +### simple_geometry +A dataset for simple polygon angle-finding tasks. + We randomly choose the number of sides N within [min_sides, max_sides]. + We then generate (N-1) random angles (in degrees), ensuring their sum is + strictly less than the total sum for an (N)-sided convex polygon (which is 180*(N-2)). + The question asks for the missing angle; the answer is computed by subtracting the + sum of known angles from 180*(N-2). + +Default configuration: +```python +min_sides = 3 +max_sides = 6 +min_angle = 10 +max_angle = 170 +seed = 42 +size = 100 +``` + +Example tasks: +```` +Example 1: +Question: Given a convex polygon with 3 sides, its first 2 interior angles are: 16.0°, 80.0°. What is the measure of the remaining interior angle (in degrees)? +Answer: 84 +Metadata: {'n_sides': 3, 'known_angles': [16.0, 80.0], 'sum_of_known_angles': 96.0, 'missing_angle_raw': 84.0, 'missing_angle_rounded': 84, 'total_interior_sum': 180} + +Example 2: +Question: A convex polygon has 3 sides. The measures of the first 2 interior angles are: 83.0°, 46.0°. Find the measure of the last interior angle. +Answer: 51 +Metadata: {'n_sides': 3, 'known_angles': [83.0, 46.0], 'sum_of_known_angles': 129.0, 'missing_angle_raw': 51.0, 'missing_angle_rounded': 51, 'total_interior_sum': 180} + +Example 3: +Question: Given a convex polygon with 6 sides, its first 5 interior angles are: 143.0°, 148.0°, 39.0°, 55.0°, 107.0°. What is the measure of the remaining interior angle (in degrees)? +Answer: 228 +Metadata: {'n_sides': 6, 'known_angles': [143.0, 148.0, 39.0, 55.0, 107.0], 'sum_of_known_angles': 492.0, 'missing_angle_raw': 228.0, 'missing_angle_rounded': 228, 'total_interior_sum': 720} + +```` + +### simple_integration +Generates simple integration problems with one variable + +Default configuration: +```python +min_terms = 2 +max_terms = 5 +min_degree = 1 +max_degree = 10 +min_bounds = 1 +max_bounds = 10 +operators = ('+', '-') +symbols = ('x', 'X') +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Find the indefinite integral: ∫ 70*x**6 + 12*x**2/5 dx +Answer: 10*x**7 + 4*x**3/5 + C +Metadata: {'integrand': '70*x**6 + 12*x**2/5', 'variable': 'x', 'expected_answer_expression': 10*x**7 + 4*x**3/5} + +Example 2: +Question: Find the indefinite integral: ∫ 49*x**6/10 + 48*x**5 - 4*x - 10/9 dx +Answer: 7*x**7/10 + 8*x**6 - 2*x**2 - 10*x/9 + C +Metadata: {'integrand': '49*x**6/10 + 48*x**5 - 4*x - 10/9', 'variable': 'x', 'expected_answer_expression': 7*x**7/10 + 8*x**6 - 2*x**2 - 10*x/9} + +Example 3: +Question: Find the indefinite integral: ∫ -28*X**3 + 8*X dx +Answer: -7*X**4 + 4*X**2 + C +Metadata: {'integrand': '-28*X**3 + 8*X', 'variable': 'X', 'expected_answer_expression': -7*X**4 + 4*X**2} + +```` + ### spell_backward Generates tasks to spell words backward Default configuration: -````python +```python min_word_len = 3 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1320,12 +1729,12 @@ Metadata: {'word': 'One', 'word_len': 3} Generates sudoku puzzles with configurable difficulty Default configuration: -````python +```python min_empty = 30 max_empty = 50 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1401,7 +1810,7 @@ Metadata: {'puzzle': [[0, 0, 1, 2, 3, 0, 0, 0, 9], [3, 0, 0, 1, 8, 5, 6, 7, 2], Generates syllogism reasoning tasks Default configuration: -````python +```python terms = None allow_all = True allow_no = True @@ -1411,7 +1820,7 @@ include_invalid = True invalid_ratio = 0.3 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1428,14 +1837,14 @@ Metadata: {'premise1': 'No students are humans', 'premise2': 'No humans are chef Example 2: Question: Consider these statements: -1. Some ... are not children are animals +1. Some children are not animals 2. Some animals are doctors Does it logically follow that: All children are doctors? (Answer Yes or No) Answer: Yes -Metadata: {'premise1': 'Some ... are not children are animals', 'premise2': 'Some animals are doctors', 'conclusion': 'All children are doctors', 'is_valid': True} +Metadata: {'premise1': 'Some children are not animals', 'premise2': 'Some animals are doctors', 'conclusion': 'All children are doctors', 'is_valid': True} Example 3: Question: Consider these statements: @@ -1443,10 +1852,109 @@ Question: Consider these statements: 2. No tigers are whales Does it logically follow that: -Some ... are not butterflies are whales? +Some butterflies are not whales? (Answer Yes or No) Answer: No -Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some ... are not butterflies are whales', 'is_valid': False} +Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are not whales', 'is_valid': False} + +```` + +### time_intervals +Generates time interval calculation tasks with various formats and complexities + +Default configuration: +```python +min_time = 00:00:00 +max_time = 23:59:59.999999 +max_time_difference_seconds = 86400 +min_date = 1900-01-01 +max_date = 3000-01-01 +max_date_difference_days = 100 +task_types = ['time', 'time_seconds', 'time_ms', 'date', 'datetime', 'datetime_tz'] +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: A system backup started at 2964-06-17 08:15:14 and completed at 2964-07-04 11:59:09. What was the total backup duration? Answer in D days, HH:MM. +Answer: 17 days, 03:43 +Metadata: {'task_type': 'datetime_tz', 'start_time': datetime.datetime(2964, 6, 17, 8, 15, 14), 'end_time': datetime.datetime(2964, 7, 4, 11, 59, 9), 'format': '%Y-%m-%d %H:%M:%S', 'expected_format': 'D days, HH:MM'} + +Example 2: +Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM. +Answer: 02:38 +Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 2, 9, 44), 'end_time': datetime.datetime(2025, 2, 2, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'} + +Example 3: +Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days. +Answer: 89 days +Metadata: {'task_type': 'date', 'start_time': datetime.datetime(2677, 12, 22, 0, 0), 'end_time': datetime.datetime(2678, 3, 21, 0, 0), 'format': '%a %b %d %Y', 'expected_format': 'D days'} + +```` + +### tower_of_hanoi +Generates Tower of Hanoi problems with solutions. + Supports variable number of pegs using the optimized Frame-Stewart algorithm with Peg State Tracking. + +Default configuration: +```python +min_disks = 3 +max_disks = 7 +min_pegs = 3 +max_pegs = 4 +size = 50 +seed = 42 +visualize = False +``` + +Example tasks: +```` +Example 1: +Question: Solve the Tower of Hanoi problem with 3 disks and 3 pegs. +Move all disks from Peg 3 to Peg 2 following the rules: +- Only one disk can be moved at a time. +- A larger disk cannot be placed on top of a smaller disk. +- All disks must be on a peg at all times. +Example: +Move disk 1 from Peg 1 to Peg 3 +Move disk 2 from Peg 1 to Peg 2 +Move disk 1 from Peg 3 to Peg 2 + +Provide the sequence of moves. +Answer: ['Move disk 1 from Peg 3 to Peg 2', 'Move disk 2 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 3 from Peg 3 to Peg 2', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2'] +Metadata: {'num_disks': 3, 'num_pegs': 3, 'start_peg': 3, 'target_peg': 2, 'auxiliary_pegs': [1], 'solution_length': 7} + +Example 2: +Question: Solve the Tower of Hanoi problem with 3 disks and 4 pegs. +Move all disks from Peg 2 to Peg 4 following the rules: +- Only one disk can be moved at a time. +- A larger disk cannot be placed on top of a smaller disk. +- All disks must be on a peg at all times. +Example: +Move disk 1 from Peg 1 to Peg 3 +Move disk 2 from Peg 1 to Peg 2 +Move disk 1 from Peg 3 to Peg 2 + +Provide the sequence of moves. +Answer: ['Move disk 1 from Peg 2 to Peg 1', 'Move disk 2 from Peg 2 to Peg 3', 'Move disk 3 from Peg 2 to Peg 4', 'Move disk 2 from Peg 3 to Peg 4', 'Move disk 1 from Peg 1 to Peg 4'] +Metadata: {'num_disks': 3, 'num_pegs': 4, 'start_peg': 2, 'target_peg': 4, 'auxiliary_pegs': [1, 3], 'solution_length': 5} + +Example 3: +Question: Solve the Tower of Hanoi problem with 6 disks and 3 pegs. +Move all disks from Peg 1 to Peg 2 following the rules: +- Only one disk can be moved at a time. +- A larger disk cannot be placed on top of a smaller disk. +- All disks must be on a peg at all times. +Example: +Move disk 1 from Peg 1 to Peg 3 +Move disk 2 from Peg 1 to Peg 2 +Move disk 1 from Peg 3 to Peg 2 + +Provide the sequence of moves. +Answer: ['Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 3 from Peg 1 to Peg 3', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 2 from Peg 2 to Peg 3', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 4 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 2 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 3 from Peg 3 to Peg 2', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 5 from Peg 1 to Peg 3', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 2 from Peg 2 to Peg 3', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 3 from Peg 2 to Peg 1', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 2 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 4 from Peg 2 to Peg 3', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 3 from Peg 1 to Peg 3', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 2 from Peg 2 to Peg 3', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 6 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 2 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 3 from Peg 3 to Peg 2', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 4 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 2 from Peg 2 to Peg 3', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 3 from Peg 2 to Peg 1', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 2 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 5 from Peg 3 to Peg 2', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 3 from Peg 1 to Peg 3', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 2 from Peg 2 to Peg 3', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 4 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2', 'Move disk 2 from Peg 3 to Peg 1', 'Move disk 1 from Peg 2 to Peg 1', 'Move disk 3 from Peg 3 to Peg 2', 'Move disk 1 from Peg 1 to Peg 3', 'Move disk 2 from Peg 1 to Peg 2', 'Move disk 1 from Peg 3 to Peg 2'] +Metadata: {'num_disks': 6, 'num_pegs': 3, 'start_peg': 1, 'target_peg': 2, 'auxiliary_pegs': [3], 'solution_length': 63} ```` @@ -1454,14 +1962,14 @@ Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are Generates word ladder transformation tasks Default configuration: -````python +```python min_word_length = 3 max_word_length = 5 min_chain_length = -1 max_chain_length = -1 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1486,12 +1994,12 @@ Metadata: {'start_word': 'SAUT', 'end_word': 'SKER', 'word_length': 4, 'chain_le Generates word sequence reversal tasks from text spans Default configuration: -````python +```python min_words = 3 max_words = 8 seed = 42 size = 500 -```` +``` Example tasks: ```` @@ -1516,7 +2024,7 @@ Metadata: {'num_words': 6, 'words': ['readable', 'to', 'he', 'that', 'to', 'poss Generates word sorting tasks Default configuration: -````python +```python min_words = 3 max_words = 10 min_word_length = 3 @@ -1524,7 +2032,7 @@ max_word_length = 12 transformation = original seed = 42 size = 500 -```` +``` Example tasks: ```` diff --git a/README.md b/README.md index ff4fe17b..a1aa3791 100644 --- a/README.md +++ b/README.md @@ -76,12 +76,14 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets ### Arithmetic Tasks - `BasicArithmeticDataset`: Generate arithmetic expressions with configurable complexity and operators (+, -, \*, /) +- `CalendarArithmeticDatset`: Generate arithmetic problems around calendar navigation logic - `ChainSum`: Generate addition/subtraction chains with configurable length and digit counts - `FractionSimplificationDataset`: Generate fraction simplification tasks with configurable complexity - `GCDDataset`: Generate Greatest Common Divisor problems with configurable number of integers - `LCMDataset`: Generate Least Common Multiple problems with configurable number of integers - `LegCountingDataset`: Generate animal leg counting word problems with various animals - `PrimeFactorizationDataset`: Generate prime factorization tasks with configurable number ranges +- `TimeIntervalsDataset`: Generate time interval calculation tasks with various formats (time, date, datetime) and complexities ### Algorithmic Tasks @@ -111,7 +113,8 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets ### Logic Tasks - `PropositionalLogicDataset`: Generate propositional logic reasoning problems - +- `SyllogismDataset`: Generates a [syllogism](https://en.wikipedia.org/wiki/Syllogism) reasoning dataset +- `AliceInWonderlandDataset`: Generates [AIW](https://openreview.net/forum?id=Mkl7dzjYiW) (Alice In Wonderland) problems with a few variations ### Graph Tasks - `FamilyRelationshipsDataset`: Generate family relationship reasoning tasks with family trees @@ -123,6 +126,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `MiniSudokuDataset`: Generate 4x4 Mini Sudoku puzzles with configurable difficulty - `MazeDataset`: Generate a maze with a start and a goal - `CountdownDataset`: Generate number game tasks where numbers and operators must be combined to reach a target value +- `NQueensDataset`: Generate N-Queens puzzles with configurable board size and number of starting queens ## Future Generator Ideas @@ -134,4 +138,4 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets ## Call for Contributions -If you have ideas for additional procedural dataset generators please create an issue here or contact us in the `#arc-agi-2` channel of the [GPU-Mode discord server](https://discord.gg/gpumode). +If you have ideas for additional procedural dataset generators please create an issue here or contact us in the `#reasoning-gym` channel of the [GPU-Mode discord server](https://discord.gg/gpumode). diff --git a/examples/veRL/.gitignore b/examples/veRL/.gitignore new file mode 100644 index 00000000..c54a47c0 --- /dev/null +++ b/examples/veRL/.gitignore @@ -0,0 +1,3 @@ +outputs/ +wandb/ +verl_output.log diff --git a/examples/veRL/README.md b/examples/veRL/README.md new file mode 100644 index 00000000..e2ec6e50 --- /dev/null +++ b/examples/veRL/README.md @@ -0,0 +1,19 @@ +### env setup + +``` +conda create --name verl python=3.12 -y +conda activate verl + +pip install flash-attn --no-build-isolation +pip install vllm==0.7.0 ray wandb +``` + +### clone and install veRL + +tested with verl HEAD a65c9157bc0b85b64cd753de19f94e80a11bd871 + +``` +git clone https://github.com/volcengine/verl.git +cd verl +pip install -e . +``` diff --git a/examples/veRL/config/ppo_trainer.yaml b/examples/veRL/config/ppo_trainer.yaml new file mode 100644 index 00000000..b294a7cb --- /dev/null +++ b/examples/veRL/config/ppo_trainer.yaml @@ -0,0 +1,167 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: 1312 + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + grad_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + grad_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number + max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: [ 'console', 'wandb' ] + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/examples/veRL/launch_on_4gpu.sh b/examples/veRL/launch_on_4gpu.sh new file mode 100755 index 00000000..0a51f68c --- /dev/null +++ b/examples/veRL/launch_on_4gpu.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +export N_GPUS=4 +export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct +export ROLLOUT_TP_SIZE=2 +export EXPERIMENT_NAME=chain_sum_llama +export VLLM_ATTENTION_BACKEND=XFORMERS + +bash ./train.sh diff --git a/examples/veRL/main_ppo_custom_reward.py b/examples/veRL/main_ppo_custom_reward.py new file mode 100644 index 00000000..2addb8e9 --- /dev/null +++ b/examples/veRL/main_ppo_custom_reward.py @@ -0,0 +1,285 @@ +# This example is an adapted version of Bytedance's code: +# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py +from typing import Optional + +import hydra +import ray +import torch +import verl.utils.torch_functional as verl_F +from omegaconf import OmegaConf, open_dict +from torch.utils.data import DataLoader, Dataset +from transformers import PreTrainedTokenizer +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.dataset.rl_dataset import collate_fn +from verl.utils.model import compute_position_id_with_mask + +import reasoning_gym +import reasoning_gym.utils +from reasoning_gym.utils import extract_answer + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + dataset_name: str, + seed: int, + size: int, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + max_prompt_length: int = 2048, + truncation: str = "error", ## ['left', 'right', 'error'] + return_raw_chat: bool = False, + ): + self.tokenizer = tokenizer + self.dataset_name = dataset_name + self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size) + self.developer_prompt = developer_prompt + self.developer_role = developer_role + self.max_prompt_length = max_prompt_length + self.truncation = truncation + self.return_raw_chat = return_raw_chat + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index): + row_dict = self.data[index].copy() + q = row_dict["question"] + + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": q}) + + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt, + tokenizer=self.tokenizer, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["data_source"] = "reasoning_gym/" + self.dataset_name + row_dict["input_ids"] = input_ids[0] + row_dict["attention_mask"] = attention_mask[0] + row_dict["position_ids"] = position_ids[0] + + # encode prompts without chat template + if self.return_raw_chat: + row_dict["raw_prompt"] = chat.tolist() + + # add index for each prompt + # index = row_dict.get("extra_info", {}).get("index", 0) + row_dict["index"] = index + + return row_dict + + +class RayPPOTrainerCustom(RayPPOTrainer): + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict, + resource_pool_manager, + ray_worker_group_cls, + dataset_name: str = "chain_sum", + dataset_size: int = 10000, + ): + self.dataset_name = dataset_name + self.dataset_size = dataset_size + + developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] + self.train_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + dataset_name=self.dataset_name, + seed=1, + size=self.dataset_size, + developer_prompt=developer_prompt, + ) + + self.val_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + dataset_name=self.dataset_name, + seed=2, + size=self.dataset_size, + developer_prompt=developer_prompt, + ) + + train_reward_fn = lambda data: self._score_output(data, num_examine=0) + val_reward_fn = lambda data: self._score_output(data, num_examine=1) + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + train_reward_fn, + val_reward_fn, + ) + + def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + num_printed = 0 + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch["prompts"] # tokenized prompts + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + + index = data_item.non_tensor_batch["index"] + + score = self._compute_score( + solution_str=sequences_str, + index=index, + ) + reward_tensor[i, valid_response_length - 1] = score + + if num_printed < num_examine: + print(f"reward={score}, seq={sequences_str}") + num_printed += 1 + + return reward_tensor + + def _compute_score(self, solution_str: str, index: int) -> float: + found_answer = extract_answer(solution_str, tag_name="answer") + entry = self.train_dataset.data[index] + reward = self.train_dataset.data.score_answer(found_answer, entry=entry) + # print(f"found answer={found_answer}; reward: {reward};") + return reward + + def _create_dataloader(self): + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + +@ray.remote +def main_task(config): + # print initial config + from pprint import pprint + + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_local_path_from_hdfs + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPPOTrainerCustom( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + ) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) + + ray.get(main_task.remote(config)) + + +if __name__ == "__main__": + main() diff --git a/examples/veRL/train.sh b/examples/veRL/train.sh new file mode 100755 index 00000000..92ed0b84 --- /dev/null +++ b/examples/veRL/train.sh @@ -0,0 +1,30 @@ +#!/bin/bash +python3 -u main_ppo_custom_reward.py \ +data.train_files=$DATA_DIR/train.parquet \ +data.val_files=$DATA_DIR/test.parquet \ +data.train_batch_size=256 \ +data.val_batch_size=1312 \ +data.max_prompt_length=256 \ +data.max_response_length=1024 \ +actor_rollout_ref.model.path=$BASE_MODEL \ +actor_rollout_ref.actor.optim.lr=1e-6 \ +actor_rollout_ref.actor.ppo_mini_batch_size=128 \ +actor_rollout_ref.actor.ppo_micro_batch_size=8 \ +actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \ +actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ +actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ +actor_rollout_ref.ref.log_prob_micro_batch_size=4 \ +critic.optim.lr=1e-5 \ +critic.model.path=$BASE_MODEL \ +critic.ppo_micro_batch_size=8 \ +algorithm.kl_ctrl.kl_coef=0.001 \ +trainer.logger=['wandb'] \ ++trainer.val_before_train=False \ +trainer.default_hdfs_dir=null \ +trainer.n_gpus_per_node=$N_GPUS \ +trainer.nnodes=1 \ +trainer.save_freq=100 \ +trainer.test_freq=100 \ +trainer.project_name=verl_chain_sum \ +trainer.experiment_name=$EXPERIMENT_NAME \ +trainer.total_epochs=15 2>&1 | tee verl_output.log diff --git a/pyproject.toml b/pyproject.toml index 80ad4865..964c0b4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "reasoning_gym" -version = "0.1.2" +version = "0.1.3" authors = [ { name = "Open-Thought community", email = "andreas.koepf@xamla.com" }, ] @@ -16,19 +16,8 @@ dependencies = [ "cellpylib==2.4.0", "sympy>=1.13.1", "magiccube==0.3.0", - "pyfiglet==1.0.2" -] - -[project.optional-dependencies] -test = [ - "pytest>=7.0.0", - "pytest-cov>=4.0.0", -] - -[tool.pytest.ini_options] -addopts = "-ra -q --cov=reasoning_gym" -testpaths = [ - "tests", + "pyfiglet==1.0.2", + "pytz>=2024.1" ] classifiers = [ "Programming Language :: Python :: 3", @@ -38,6 +27,12 @@ classifiers = [ license = "Apache-2.0" license-files = ["LICENSE*"] +[project.optional-dependencies] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", +] + [project.urls] "Homepage" = "https://github.com/open-thought/reasoning-gym" "Bug Tracker" = "https://github.com/open-thought/reasoning-gym/issues" @@ -57,3 +52,9 @@ include = '\.pyi?$' profile = "black" multi_line_output = 3 line_length = 120 + +[tool.pytest.ini_options] +addopts = "-ra -q" +testpaths = [ + "tests", +] diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index b25eb134..054cbd95 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -2,17 +2,18 @@ Reasoning Gym - A library of procedural dataset generators for training reasoning models """ -from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic +from . import algebra, algorithmic, arithmetic, cognition, data, games, geometry, graphs, logic from .factory import create_dataset, register_dataset -__version__ = "0.1.1" +__version__ = "0.1.3" __all__ = [ - "arithmetic", - "algorithmic", "algebra", + "algorithmic", + "arithmetic", "cognition", "data", "games", + "geometry", "graphs", "logic", "create_dataset", diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index 69d4b91e..fc7a867a 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,9 +1,15 @@ +from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset +from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset __all__ = [ - "SimpleEquationsDataset", - "SimpleEquationsConfig", + "IntermediateIntegrationConfig", + "IntermediateIntegrationDataset", "PolynomialEquationsConfig", "PolynomialEquationsDataset", + "SimpleEquationsDataset", + "SimpleEquationsConfig", + "SimpleIntegrationConfig", + "SimpleIntegrationDataset", ] diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py new file mode 100644 index 00000000..5d0b139c --- /dev/null +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -0,0 +1,263 @@ +import random +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import sympy + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class IntermediateIntegrationConfig: + problem_types: tuple = ("substitution", "by_parts") + substitution_types: tuple = ( + "linear", # (ax + b)^n + "trigonometric", # sin**2(x)cos(x) + "exponential", # 2xe^x**2 + "radical", # x (3x + 2)^1/2 + ) + + # Integration by parts problem categories + by_parts_types: tuple = ( + "polynomial_exp_trig", # e.g. x^2*e^x + "log_inverse_trig", # e.g. ln(x)/arctan(x) + "cyclic", # e.g. e^x*sinx requiring cyclic integration + "repeated_parts", # Requires multiple integration by parts + ) + seed: Optional[int] = None + size: int = 500 + + linear_lower_bound: int = 1 # coefficient of linear expression + linear_upper_bound: int = 10 + min_linear_degree: int = 2 # degree of linear expression in substitution problem + max_linear_degree: int = 4 + outer_constant_min: int = 1 # multiplicative constant to multiply integrand by + outer_constant_max: int = 3 + min_poly_degree: int = 1 # degree of polynomial in by parts problem + max_poly_degree: int = 3 + symbols: tuple = ("x", "X") + operators: tuple = ( + "+", + "-", + ) + + def validate(self) -> None: + """Validate the configuration parameters of the integral problem""" + assert self.size > 0, "size must be positive" + assert self.linear_lower_bound > 0, "linear_lower_bound must be positive" + assert self.linear_upper_bound >= self.linear_lower_bound, "linear_upper_bound must be >= linear_lower_bound" + assert self.min_linear_degree > 0, "min_linear_degree must be positive" + assert self.max_linear_degree >= self.min_linear_degree, "max_linear_degree must be >= min_linear_degree" + assert self.outer_constant_min > 0, "outer_constant_min must be positive" + assert self.outer_constant_max >= self.outer_constant_min, "outer_constant_max must be >= outer_constant_min" + assert self.min_poly_degree > 0, "min_poly_degree must be positive" + assert self.max_poly_degree >= self.min_poly_degree, "max_poly_degree must be >= min_poly_degree" + assert all(op in ("+", "-") for op in self.operators), "invalid operator specified" + assert all(symbols in ("x", "X") for symbols in self.symbols), "invalid symbol specified" + assert all(t in ("substitution", "by_parts") for t in self.problem_types), "invalid problem type" + assert all( + t in ("linear", "trigonometric", "exponential", "radical") for t in self.substitution_types + ), "invalid substitution type" + assert all( + t in ("polynomial_exp_trig", "log_inverse_trig", "cyclic", "repeated_parts") for t in self.by_parts_types + ), "invalid by_parts type" + + +class IntermediateIntegrationDataset(ProceduralDataset): + """Generates intermediate integration problem - either + by substitution or by parts""" + + """Add multiplicative constant""" + + def __init__(self, config: IntermediateIntegrationConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.prompt_template = [ + "Find the indefinite integral: ∫ {integrand} dx", + "Calculate the antiderivative: ∫ {integrand} dx", + "Evaluate the indefinite integral: ∫ {integrand} dx", + ] + + def _get_outer_constant(self, rng: random.Random) -> int: + """Helper to generate signed outer constant from config""" + value = rng.randint(self.config.outer_constant_min, self.config.outer_constant_max) + return -value if rng.choice(self.config.operators) == "-" else value + + def _generate_linear_substitution_problem(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate a linear substitution problem with outer constant""" + a = rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound) + b = rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound) + + linear_function = a * x + (b if rng.choice(self.config.operators) == "+" else -b) + degree = rng.randint(self.config.min_linear_degree, self.config.max_linear_degree) + + return self._get_outer_constant(rng) * linear_function**degree + + def _generate_exponential_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate exponential substitution problem with outer constant""" + exponent_type = rng.choice(["linear", "quadratic"]) + + # Generate terms with signs + num_terms = 2 if exponent_type == "linear" else 3 + terms = [ + (-1 if rng.choice(self.config.operators) == "-" else 1) + * rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound) + for _ in range(num_terms) + ] + + if exponent_type == "linear": + u = terms[0] * x + terms[1] + du_dx = terms[0] + else: # Quadratic + u = terms[0] * x**2 + terms[1] * x + terms[2] + du_dx = 2 * terms[0] * x + terms[1] + + return self._get_outer_constant(rng) * du_dx * sympy.exp(u) + + def _generate_radical_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate radical substitution problem with outer constant""" + + # Generate linear expression under radical: ax + b with possible negative coefficients + a = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + b = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + + u = a * x + b + derivative = a # du/dx + + integrand = derivative * sympy.sqrt(u) + return self._get_outer_constant(rng) * integrand + + def _generate_trigonometric_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate trigonometric substitution with outer constant""" + trig_func = rng.choice(["sin", "cos"]) + + # Generate signed coefficients + a = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + b = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + + inner = a * x + b + power = rng.randint(1, 4) + if trig_func == "sin": + integrand = a * sympy.cos(inner) * sympy.sin(inner) ** power + else: + integrand = -a * sympy.sin(inner) * sympy.cos(inner) ** power + return self._get_outer_constant(rng) * integrand + + def _generate_polynomial_exp_trig(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate polynomial × exponential/trigonometric integrand""" + poly_degree = rng.randint(self.config.min_poly_degree, self.config.max_poly_degree) + + func_type = rng.choice(["exp", "sin", "cos"]) + if func_type == "exp": + transcendental = sympy.exp(x) + else: + coefficient = rng.randint(1, 3) + transcendental = sympy.Function(func_type)(coefficient * x) + + polynomial = x**poly_degree + integrand = polynomial * transcendental + return self._get_outer_constant(rng) * integrand + + def _generate_log_inverse_trig(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate logarithmic or inverse trigonometric integrand""" + func_type = rng.choice(["log", "asin", "atan"]) + + if func_type == "log": + log_arg = x if rng.random() < 0.8 else x ** rng.randint(2, 3) + func = sympy.ln(log_arg) + else: + coefficient = rng.randint(1, 3) + func = sympy.Function(func_type)(coefficient * x) + + return self._get_outer_constant(rng) * func + + def _generate_cyclic_integral(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate cyclic integral (e.g., e^x * sinx)""" + func_pair = rng.choice( + [(sympy.exp(x), sympy.sin(x)), (sympy.exp(x), sympy.cos(x)), (sympy.sin(x), sympy.cos(x))] + ) + integrand = func_pair[0] * func_pair[1] + return self._get_outer_constant(rng) * integrand + + def _generate_repeated_parts(self, rng: random.Random, x: sympy.Symbol): + """Generate problem requiring multiple integration by parts""" + poly_degree = rng.randint(3, self.config.max_poly_degree) + transcendental = rng.choice([sympy.sin(x), sympy.cos(x), sympy.exp(x)]) + integrand = x**poly_degree * transcendental + return self._get_outer_constant(rng) * integrand + + def __getitem__(self, index: int): + """Generate either substitution or by-parts problem""" + rng = random.Random(self.seed + index) + problem_type = rng.choice(self.config.problem_types) + x = sympy.Symbol(rng.choice(self.config.symbols)) + + if problem_type == "substitution": + substitution_type = rng.choice(self.config.substitution_types) + if substitution_type == "linear": + integrand = self._generate_linear_substitution_problem(rng, x) + elif substitution_type == "trigonometric": + integrand = self._generate_trigonometric_substitution(rng, x) + elif substitution_type == "exponential": + integrand = self._generate_exponential_substitution(rng, x) + elif substitution_type == "radical": + integrand = self._generate_radical_substitution(rng, x) + else: + parts_type = rng.choice(self.config.by_parts_types) + if parts_type == "polynomial_exp_trig": + integrand = self._generate_polynomial_exp_trig(rng, x) + elif parts_type == "log_inverse_trig": + integrand = self._generate_log_inverse_trig(rng, x) + elif parts_type == "cyclic": + integrand = self._generate_cyclic_integral(rng, x) + elif parts_type == "repeated_parts": + integrand = self._generate_repeated_parts(rng, x) + + answer = sympy.integrate(integrand, x) + answer_str = str(answer) + " + C" + + return { + "question": rng.choice(self.prompt_template).format(integrand=integrand), + "answer": answer_str, + "metadata": { + "integrand": str(integrand), + "problem_type": problem_type, + "variable": str(x), + "type": substitution_type if problem_type == "substitution" else parts_type, + "expected_answer_expression": answer, + }, + } + + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + var = metadata["variable"] + x = sympy.Symbol(var) + # Parse answer while allowing integration constant 'C' + user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")}) + # Compute derivative of student's answer + derivative = sympy.diff(user_expr, x) + integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x}) + + # Check mathematical equivalence through simplification + if sympy.simplify(derivative - integrand) == 0: + reward = 1.0 + elif answer.strip(): + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + + +register_dataset("intermediate_integration", IntermediateIntegrationDataset, IntermediateIntegrationConfig) diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py new file mode 100644 index 00000000..1da32004 --- /dev/null +++ b/reasoning_gym/algebra/simple_integration.py @@ -0,0 +1,108 @@ +import random +from dataclasses import dataclass +from fractions import Fraction +from typing import Any, Dict, Optional + +import sympy + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class SimpleIntegrationConfig: + min_terms: int = 2 + max_terms: int = 5 + min_degree: int = 1 + max_degree: int = 10 + min_bounds: int = 1 + max_bounds: int = 10 + operators: tuple = ("+", "-") + symbols: tuple = ("x", "X") + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate the configuration parameters of the integral proble""" + assert self.min_bounds > 0, "min_bounds must be positive" + assert self.max_bounds >= self.min_bounds, "max_bounds must be >= min_bounds" + assert self.min_terms >= 0, "min_terms must be positive" + assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" + assert self.min_degree >= -10, "min_degree must be >= -10" + assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree" + assert all(op in ("+", "-") for op in self.operators), "invalid operator specified" + + +class SimpleIntegrationDataset(ProceduralDataset): + """Generates simple integration problems with one variable""" + + def __init__(self, config: SimpleIntegrationConfig): + self._prompt_templates = [ + "Find the indefinite integral: ∫ {integrand} dx", + "Calculate the antiderivative: ∫ {integrand} dx", + "Evaluate the indefinite integral: ∫ {integrand} dx", + ] + super().__init__(config=config, seed=config.seed, size=config.size) + + def _generate_coefficient(self, rng: random.Random) -> Fraction: + """Generate a random coefficient for the polynomial""" + if rng.choice([True, False]): # 50% chance for integer + return Fraction(rng.randint(self.config.min_bounds, self.config.max_bounds), 1) + denominator = rng.randint(2, 10) + return Fraction(rng.randint(self.config.min_bounds, self.config.max_bounds), denominator) + + def _generate_polynomial(self, rng: random.Random) -> tuple[sympy.Symbol, sympy.Expr]: + """Generate a random polynomial with one variable""" + terms = [] + x = sympy.Symbol(rng.choice(self.config.symbols)) + + for _ in range(rng.randint(self.config.min_terms, self.config.max_terms)): + coefficient = self._generate_coefficient(rng) + degree = rng.randint(self.config.min_degree, self.config.max_degree) + operator = rng.choice(self.config.operators) + term = coefficient * x**degree + if operator == "-": + term = -term + terms.append(term) + return x, sum(terms) + + def __getitem__(self, idx: int) -> dict: + rng = random.Random(self.seed + idx) + symbol, polynomial = self._generate_polynomial(rng) + derivative = sympy.diff(polynomial, symbol) + + return { + "question": rng.choice(self._prompt_templates).format(integrand=derivative), + "answer": str(polynomial) + " + C", + "metadata": { + "integrand": str(derivative), + "variable": str(symbol), + "expected_answer_expression": polynomial, + }, + } + + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + var = metadata["variable"] + x = sympy.Symbol(var) + # Parse answer while allowing integration constant 'C' + user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")}) + # Compute derivative of student's answer + derivative = sympy.diff(user_expr, x) + integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x}) + + # Check mathematical equivalence through simplification + if sympy.simplify(derivative - integrand) == 0: + reward = 1.0 + elif answer.strip(): + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + + +register_dataset("simple_integration", SimpleIntegrationDataset, SimpleIntegrationConfig) diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index eb0978bd..afa6200a 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -60,14 +60,32 @@ class BaseConversionDataset(ProceduralDataset): value, source_base, target_base = self._generate_conversion(rng) # Convert decimal to source base representation - source_repr = format(value, f"x" if source_base == 16 else f"b" if source_base == 2 else "").strip() - if source_base not in (2, 16): - source_repr = format(value, f"{source_base}x").lower().strip() + if source_base == 16: + source_repr = format(value, "x") + elif source_base == 2: + source_repr = format(value, "b") + else: + # Manual conversion for other bases + n = value + digits = [] + while n: + digits.append(int(n % source_base)) + n //= source_base + source_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0]) # Convert decimal to target base for answer - target_repr = format(value, f"x" if target_base == 16 else f"b" if target_base == 2 else "").strip() - if target_base not in (2, 16): - target_repr = format(value, f"{target_base}x").lower().strip() + if target_base == 16: + target_repr = format(value, "x") + elif target_base == 2: + target_repr = format(value, "b") + else: + # Manual conversion for other bases + n = value + digits = [] + while n: + digits.append(int(n % target_base)) + n //= target_base + target_repr = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0]) source_name = self._format_base_name(source_base) target_name = self._format_base_name(target_base) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 64a9a805..9047d55f 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -4,9 +4,11 @@ Arithmetic tasks for training reasoning capabilities: - Chain sums - Word problems - Leg counting +- Time intervals """ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig +from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSum, ChainSumConfig from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset @@ -14,6 +16,7 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .gsm_symbolic.gsm_symbolic_datasets import GSMSymbolicDataset, GSMSymbolicDatasetConfig +from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", @@ -21,6 +24,10 @@ __all__ = [ "basic_arithmetic_dataset", "ChainSum", "ChainSumConfig", + "CalendarArithmeticConfig", + "CalendarArithmeticDataset", + "Weekday", + "CalendarTask", "FractionSimplificationConfig", "FractionSimplificationDataset", "GCDConfig", @@ -33,4 +40,6 @@ __all__ = [ "PrimeFactorizationDataset", "GSMSymbolicDatasetConfig", "GSMSymbolicDataset", + "TimeIntervalsConfig", + "TimeIntervalsDataset", ] diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py new file mode 100644 index 00000000..78c42df8 --- /dev/null +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -0,0 +1,490 @@ +import calendar +import math +import random +from dataclasses import dataclass +from datetime import date, timedelta +from enum import Enum, auto +from typing import Any, Dict, List, Optional, Tuple + +from ..factory import ProceduralDataset, register_dataset + + +class Weekday(Enum): + MONDAY = auto() + TUESDAY = auto() + WEDNESDAY = auto() + THURSDAY = auto() + FRIDAY = auto() + SATURDAY = auto() + SUNDAY = auto() + + @classmethod + def from_date(cls, d: date) -> "Weekday": + return list(cls)[d.weekday()] + + @classmethod + def random(cls, rng: random.Random) -> "Weekday": + return list(cls)[rng.randint(0, 6)] + + @classmethod + def __getitem__(cls, idx) -> "Weekday": + return list(cls)[idx] + + @property + def index(self) -> int: + return self.value - 1 + + def __str__(self) -> str: + return self.name.capitalize() + + +class CalendarTask(Enum): + WEEKDAY_OFFSET = "weekday_offset" + WEEKDAY_OF_DATE = "weekday_of_date" + WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day" + RECURRING_EVENT_CALCULATIONS = "recurring_event_day" + COUNT_DAYS = "count_days" + COUNT_BUSINESS_DAYS = "count_business_days" + IS_LEAP_YEAR = "is_leap_year" + + +@dataclass +class CalendarArithmeticConfig: + year: int = 2022 + tasks: Optional[List[str]] = None + offset_upper_bound: int = 100 + leap_year_range: int = 200 + seed: Optional[int] = 42 + size: int = 500 + + def __post_init__(self): + if self.tasks is None: + self.tasks = [task.value for task in CalendarTask] + else: + self.tasks = [task.lower() for task in self.tasks] + valid_tasks = {task.value for task in CalendarTask} + invalid_tasks = set(self.tasks) - valid_tasks + if invalid_tasks: + valid_task_list = ", ".join(sorted(valid_tasks)) + raise ValueError( + f"Invalid tasks: {', '.join(sorted(invalid_tasks))}. " f"Valid tasks are: {valid_task_list}" + ) + + def validate(self) -> None: + """Validate the configuration parameters.""" + if not isinstance(self.year, int) or self.year <= 0: + raise ValueError(f"year must be a positive integer, got {self.year}") + + if self.seed is not None and not isinstance(self.seed, int): + raise ValueError(f"seed must be an integer or None, got {type(self.seed)}") + + if not isinstance(self.size, int) or self.size <= 0: + raise ValueError(f"size must be a positive integer, got {self.size}") + + +class CalendarArithmeticDataset(ProceduralDataset): + DAY_QUESTION_TEMPLATES = [ + "Answer with the weekday's name (e.g., Monday, Tuesday, etc.).", + "Provide the full name of the weekday.", + "State the weekday (Monday through Sunday).", + "Give the weekday name in full.", + "Reply with just the weekday name.", + "Write out the full weekday name.", + "Respond with the weekday (Monday-Sunday).", + "Answer using the complete weekday name.", + "Name the day of the week in full.", + ] + + COUNT_QUESTION_TEMPLATES = [ + "Answer with a number.", + "Provide the count as a number.", + "Respond with just the number.", + "Write the total number.", + "Give the count numerically.", + "State the amount as a number.", + "Reply with the numerical value.", + "Express your answer as a number.", + ] + + def __init__(self, config: CalendarArithmeticConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + self.task_handlers = { + CalendarTask.WEEKDAY_OFFSET.value: self._weekday_offset, + CalendarTask.WEEKDAY_OF_DATE.value: self._weekday_of_date, + CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value: self._weekday_of_date_from_first_day, + CalendarTask.RECURRING_EVENT_CALCULATIONS.value: self._recurring_event_day, + CalendarTask.COUNT_DAYS.value: self._count_days, + CalendarTask.COUNT_BUSINESS_DAYS.value: self._count_business_days, + CalendarTask.IS_LEAP_YEAR.value: self._is_leap_year, + } + + self.tasks = [self.task_handlers[task] for task in self.config.tasks] + + def __getitem__(self, idx: int) -> dict: + item_rng = random.Random(self.seed + idx) + task = item_rng.choice(self.tasks) + question, answer, metadata = task(item_rng) + return { + "question": question, + "answer": str(answer), + "metadata": metadata, + } + + def _weekday_offset(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + Task: Given a starting date and a day offset (which may be positive or negative), + ask what day of the week it will be. + Examples: + - "If today is Wednesday, March 13, 2024, what day of the week will it be in 10 days? Answer with the weekday's name." + - "If today is Wednesday, March 13, 2024, what day of the week was it 10 days ago? Answer with the weekday's name." + """ + year = self.config.year + start_date = self._random_date_for_year(rng, year) + offset = rng.randint(1, self.config.offset_upper_bound) + sign = rng.choice([-1, 1]) + offset_days = sign * offset + target_date = start_date + timedelta(days=offset_days) + target_weekday = target_date.strftime("%A") + + date_str = f"{start_date.strftime('%A')}, {start_date.strftime('%B')} {start_date.day}, {start_date.year}" + if offset_days >= 0: + templates = [ + f"If today is {date_str}, what day of the week will it be in {offset_days} days? ", + f"Starting from {date_str}, which weekday falls after a {offset_days}-day jump? ", + f"Count forward {offset_days} days from {date_str} - what's the weekday? ", + ] + else: + templates = [ + f"If today is {date_str}, what day of the week was it {abs(offset_days)} days ago? ", + f"Starting from {date_str}, which weekday was it {abs(offset_days)} days before? ", + f"Count backward {abs(offset_days)} days from {date_str} - what's the weekday? ", + ] + + question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES) + metadata = { + "task": CalendarTask.WEEKDAY_OFFSET.value, + "start_date": start_date.isoformat(), + "offset_days": offset_days, + "target_date": target_date.isoformat(), + } + return question, target_weekday, metadata + + def _weekday_of_date(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + task: Ask what day of the week a given date was. + example: + "What day of the week was January 15, 2024? + Answer with the weekday's name." + """ + year = self.config.year + target_date = self._random_date_for_year(rng, year) + answer_weekday = target_date.strftime("%A") + templates = [ + f"What day of the week was {target_date.strftime('%B')} {target_date.day}, {year}?", + f"On which weekday did {target_date.strftime('%B')} {target_date.day}, {year} fall?", + f"Name the day of the week for {target_date.strftime('%m/%d/%Y')}.", + ] + + question = f"{rng.choice(templates)} {rng.choice(self.DAY_QUESTION_TEMPLATES)}" + metadata = { + "task": CalendarTask.WEEKDAY_OF_DATE.value, + "target_date": target_date.isoformat(), + } + return question, answer_weekday, metadata + + def _weekday_of_date_from_first_day(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + task: Given an hypothetical weekday for January 1, ask what weekday a later date in the year falls on. + example: + "If the first day of the year was a Monday, what day of the week will December 31 be? + Answer with the weekday's name." + """ + year = self.config.year + first_day = Weekday.random(rng) + first_day_index = first_day.index + # Ensure target date is not January 1. + year_start = date(year, 1, 1) + year_end = date(year, 12, 31) + max_delta = timedelta(days=self.config.offset_upper_bound) + max_date = min(year_start + max_delta, year_end) + while True: + target_date = self._random_date_between(rng, year_start, max_date) + if target_date != date(year, 1, 1): + break + delta_days = (target_date - date(year, 1, 1)).days + answer_index = (first_day_index + delta_days) % 7 + answer_weekday = Weekday(answer_index + 1) + + templates = [ + f"If the first day of the year was a {first_day}, what day of the week will " + f"{target_date.strftime('%B')} {target_date.day} be? ", + f"Given that January 1 fell on a {first_day}, which weekday occurs on " + f"{target_date.strftime('%B')} {target_date.day}? ", + f"In a year where {first_day} is January 1st, name the weekday of " + f"{target_date.strftime('%B')} {target_date.day}. ", + ] + + question = rng.choice(templates) + rng.choice(self.DAY_QUESTION_TEMPLATES) + metadata = { + "task": CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value, + "year": year, + "first_day": str(first_day), + "target_date": target_date.isoformat(), + "delta_days": delta_days, + } + return question, answer_weekday, metadata + + def _recurring_event_day(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + task: For a recurring event defined by an ordinal weekday pattern in a month, + ask on which day of the month the event occurs. + example: + "If a meeting is scheduled on the second Tuesday of May 2024, on which day does it fall? + Answer with a number." + """ + year = self.config.year + month = rng.randint(1, 12) + ordinals = ["first", "second", "third", "fourth", "last"] + ordinal = rng.choice(ordinals) + weekday = Weekday.random(rng) + month_name = calendar.month_name[month] + _, last_day = calendar.monthrange(year, month) + + if ordinal != "last": + ordinal_number = {"first": 1, "second": 2, "third": 3, "fourth": 4}[ordinal] + count = 0 + event_day = None + for day in range(1, last_day + 1): + d = date(year, month, day) + if d.strftime("%A") == str(weekday): + count += 1 + if count == ordinal_number: + event_day = day + break + if event_day is None: + # This should rarely happen but in some months the ordinal may not exist. + event_day = -1 + else: + event_day = None + for day in range(last_day, 0, -1): + d = date(year, month, day) + if d.strftime("%A") == str(weekday): + event_day = day + break + if event_day is None: + event_day = -1 + + templates = [ + f"If a meeting is scheduled on the {ordinal} {weekday} of {month_name} {year}, on which day of the month does it occur? ", + f"In {month_name} {year}, if an event recurs on the {ordinal} {weekday}, what is the date (day of the month) of the event? ", + f"Determine the day of the month for the {ordinal} {weekday} in {month_name} {year}. ", + ] + question = ( + rng.choice(templates) + + rng.choice(self.COUNT_QUESTION_TEMPLATES) + + " Answer with -1 if the ordinal does not exist in the month." + ) + metadata = { + "task": CalendarTask.RECURRING_EVENT_CALCULATIONS.value, + "year": year, + "month": month, + "ordinal": ordinal, + "weekday": str(weekday), + } + return question, str(event_day), metadata + + def _count_days(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + task: Ask how many times a given weekday occurs in a specified range. + example: + "How many days are there between March 1, 2024 and March 15, 2024? + Answer with a number." + """ + year = self.config.year + year_start = date(year, 1, 1) + year_end = date(year, 12, 31) + start_date = self._random_date_between(rng, year_start, year_end) + max_delta = timedelta(days=self.config.offset_upper_bound) + end_date = self._random_date_between(rng, start_date, min(year_end, start_date + max_delta)) + weekday = Weekday.random(rng) + + def count_weekday_between(d1: date, d2: date, weekday: str) -> int: + days = (d2 - d1).days + 1 + return sum(1 for i in range(days) if (d1 + timedelta(days=i)).strftime("%A") == weekday) + + count = count_weekday_between(start_date, end_date, str(weekday)) + + templates = [ + f"How many {weekday}s are there from {start_date.strftime('%A, %B')} {start_date.day}, {year} to " + f"{end_date.strftime('%A, %B')} {end_date.day}, {year} (inclusive of both dates)? ", + f"Count the occurrences of {weekday} from {start_date.strftime('%A, %B')} {start_date.day} " + f"to {end_date.strftime('%A, %B')} {end_date.day}, {year} (including both start and end dates). ", + f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and " + f"{end_date.strftime('%A, %B')} {end_date.day}, {year} " + f"(counting both dates), how many times does {weekday} occur? ", + ] + + question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES) + metadata = { + "task": CalendarTask.COUNT_DAYS.value, + "year": year, + "start_date": start_date.isoformat(), + "end_date": end_date.isoformat(), + } + return question, str(count), metadata + + def _count_business_days(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + task: Count the number of business days (Monday-Friday) between two dates. + example: + "How many business days (Monday-Friday) are there between March 1, 2024 and March 15, 2024? + Answer with a number." + """ + year = self.config.year + year_start = date(year, 1, 1) + year_end = date(year, 12, 31) + start_date = self._random_date_between(rng, year_start, year_end) + max_delta = timedelta(days=self.config.offset_upper_bound) + end_date = self._random_date_between(rng, start_date, start_date + max_delta) + + count = 0 + + def business_days_between(d1: date, d2: date) -> int: + days = (d2 - d1).days + 1 + weeks, remainder = divmod(days, 7) + count = weeks * 5 + start_weekday = d1.weekday() + for i in range(remainder): + if (start_weekday + i) % 7 < 5: + count += 1 + return count + + count = business_days_between(start_date, end_date) + + templates = [ + f"How many business days (Monday-Friday) are there from " + f"{start_date.strftime('%A, %B')} {start_date.day}, {year} to " + f"{end_date.strftime('%A, %B')} {end_date.day}, {year} " + f"(inclusive of both dates)? ", + f"Count the weekdays (excluding weekends) from " + f"{start_date.strftime('%A, %B')} {start_date.day} to " + f"{end_date.strftime('%A, %B')} {end_date.day}, {year} " + f"(including both start and end dates). ", + f"Between {start_date.strftime('%A, %B')} {start_date.day}, {year} and " + f"{end_date.strftime('%A, %B')} {end_date.day}, {year} " + f"(counting both dates), what's the total count of business days " + f"(Monday through Friday)? ", + ] + + question = rng.choice(templates) + rng.choice(self.COUNT_QUESTION_TEMPLATES) + metadata = { + "task": CalendarTask.COUNT_BUSINESS_DAYS.value, + "start_date": start_date.isoformat(), + "end_date": end_date.isoformat(), + } + return question, str(count), metadata + + def _is_leap_year(self, rng: random.Random) -> Tuple[str, str, dict]: + """ + task: Given a year, determine whether it is a leap year. + example: + "Is 2024 a leap year? Answer with Yes or No." + """ + semirange = self.config.leap_year_range // 2 + year = rng.randint(self.config.year - semirange, self.config.year + semirange) + is_leap = calendar.isleap(year) + answer = "Yes" if is_leap else "No" + templates = [ + f"Determine if the year {year} is a leap year. ", + f"Is {year} a leap year? ", + f"Tell me whether {year} is a leap year. ", + ] + question = rng.choice(templates) + "Answer with Yes or No." + metadata = { + "task": CalendarTask.IS_LEAP_YEAR.value, + "year": year, + "is_leap": is_leap, + } + return question, answer, metadata + + def _random_date_for_year(self, rng: random.Random, year: int) -> date: + """Return a random date within the given year.""" + month = rng.randint(1, 12) + _, last_day = calendar.monthrange(year, month) + day = rng.randint(1, last_day) + return date(year, month, day) + + def _random_date_between(self, rng: random.Random, start_date: date, end_date: date) -> date: + """ + Return a random date between start_date and end_date (inclusive). + Assumes start_date <= end_date. + """ + if start_date > end_date: + raise ValueError("start_date must be <= end_date") + delta = (end_date - start_date).days + random_days = rng.randint(0, delta) + return start_date + timedelta(days=random_days) + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + # we suppose the answer is the last occurence of the expected answer type + if answer is None: + return 0.0 + + oracle_answer = entry["answer"] + task = entry["metadata"]["task"] + + if task in { + CalendarTask.WEEKDAY_OFFSET.value, + CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value, + CalendarTask.WEEKDAY_OF_DATE.value, + }: + if not answer: + return 0.0 + + answer = answer.strip() + oracle_answer = oracle_answer + weekdays = {d.name.title() for d in Weekday} + + if answer == oracle_answer: + return 1.0 + + if answer in weekdays: + return 0.1 + + if answer.title() in weekdays: + return 0.05 + + if answer.title() not in weekdays: + return 0.0 + + return 0.0 + + # denser reward for numerical tasks + elif task in { + CalendarTask.COUNT_BUSINESS_DAYS.value, + CalendarTask.COUNT_DAYS.value, + CalendarTask.RECURRING_EVENT_CALCULATIONS.value, + }: + try: + ans_num = int(answer.strip()) + oracle_num = int(oracle_answer.strip()) + + if oracle_num == 0: + return 1.0 if ans_num == 0 else 0.0 + + relative_error = abs(ans_num - oracle_num) / oracle_num + return max(0.0, math.exp(-5 * relative_error)) + + except (ValueError, AttributeError): + return 0.0 + + elif task == CalendarTask.IS_LEAP_YEAR.value: + if answer.strip().lower() == oracle_answer.lower(): + return 1.0 + return 0.0 + + return 0.0 + + +register_dataset("calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig) diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py new file mode 100644 index 00000000..1b296d02 --- /dev/null +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -0,0 +1,323 @@ +import random +from dataclasses import dataclass, field +from datetime import date, datetime, time, timedelta +from typing import List, Optional + +import pytz +from dateutil import parser + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class TimeIntervalsConfig: + """Configuration for time interval calculation tasks""" + + min_time: time = time.min + max_time: time = time.max + max_time_difference_seconds: int = 24 * 60 * 60 + min_date: date = date(1900, 1, 1) + max_date: date = date(3000, 1, 1) + max_date_difference_days: int = 100 + task_types: List[str] = field( + default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"] + ) + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.max_time_difference_seconds > 0, "max_time_difference_seconds must be positive" + assert self.max_date_difference_days > 0, "max_date_difference_days must be positive" + assert self.min_date < self.max_date, "min_date must be before max_date" + + +class TimeIntervalsDataset(ProceduralDataset): + """Generates time interval calculation tasks with various formats and complexities""" + + TEMPLATES = [ + "What is the duration between {start} and {end}? Please answer in {format}.", + "Calculate the time difference between {start} and {end}. Express the result in {format}.", + "How much time elapsed from {start} to {end}? Give your answer in {format}.", + "A meeting started at {start} and ended at {end}. How long was the meeting? Answer in {format}.", + "A system operation started at {start} and completed at {end}. What was the operation duration? Answer in {format}.", + "A database query started at {start} and ended at {end}. How long did the query take? Answer in {format}.", + "A flight departed at {start} and arrived at {end}. How long was the flight? Answer in {format}.", + "A video call started at {start} and ended at {end}. How long was the call? Answer in {format}.", + "A system backup started at {start} and completed at {end}. What was the total backup duration? Answer in {format}.", + "A conference call began at {start} and ended at {end}. How long was the conference? Answer in {format}.", + ] + + TIME_FORMATS = [ + "%H:%M", + "%H:%M:%S", + "%H:%M:%S.%f", + ] + + DATE_FORMATS = [ + "%Y-%m-%d", + "%B %d, %Y", + "%m/%d/%Y", + "%A, %B %d, %Y", # e.g. Monday, January 15, 2024 + "%a %b %d %Y", # e.g. Mon Jan 15 2024 + "%d %B %Y", # e.g. 15 January 2024 + "%Y-%m-%d (%A)", # e.g. 2024-01-15 (Monday) + ] + + DATETIME_FORMATS = [ + "%Y-%m-%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M %z", # For UTC offset format + "%Y-%m-%d %H:%M:%S %z", # For UTC offset with seconds + "%A, %B %d, %Y at %H:%M", # e.g. Monday, January 15, 2024 at 14:30 + "%a %b %d %Y %H:%M:%S", # e.g. Mon Jan 15 2024 14:30:45 + "%d %B %Y, %H:%M", # e.g. 15 January 2024, 14:30 + "%d %B %Y, %H:%M %z", # e.g. 15 January 2024, 14:30 +0000 + "%Y-%m-%d (%A) %H:%M:%S %z", # e.g. 2024-01-15 (Monday) 14:30:45 +0000 + ] + + def __init__(self, config: TimeIntervalsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single time interval calculation task""" + item_rng = random.Random(self.seed + idx) + + # Randomly choose task type from config + task_type = item_rng.choice(self.config.task_types) + + start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type) + + template = item_rng.choice(self.TEMPLATES) + question = template.format(start=start_time, end=end_time, format=expected_format) + + # Calculate the actual difference + if isinstance(start_time, str): + # Handle datetime strings with weekday names in parentheses + start_time = start_time.split(" (")[0] # Remove (Weekday) if present + end_time = end_time.split(" (")[0] + # Parse with UTC offset handling + start_dt = parser.parse(start_time) + end_dt = parser.parse(end_time) + else: + start_dt = start_time + end_dt = end_time + + difference = end_dt - start_dt + + # Format the answer according to expected_format + if expected_format == "HH:MM": + total_seconds = difference.total_seconds() + answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}" + elif expected_format == "HH:MM:SS": + total_seconds = difference.total_seconds() + answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}" + elif expected_format == "HH:MM:SS.mmm": + total_seconds = difference.total_seconds() + ms = int((total_seconds % 1) * 1000) + answer = f"{int(total_seconds // 3600):02d}:{int((total_seconds % 3600) // 60):02d}:{int(total_seconds % 60):02d}.{ms:03d}" + elif expected_format == "D days": + answer = f"{difference.days} days" + else: # "D days, HH:MM" or "D days, HH:MM:SS" + days = difference.days + hours = difference.seconds // 3600 + minutes = (difference.seconds % 3600) // 60 + seconds = difference.seconds % 60 + if expected_format == "D days, HH:MM:SS": + answer = f"{days} days, {hours:02d}:{minutes:02d}:{seconds:02d}" + else: # "D days, HH:MM" + answer = f"{days} days, {hours:02d}:{minutes:02d}" + + return { + "question": question, + "answer": answer, + "metadata": { + "task_type": task_type, + "start_time": start_dt, + "end_time": end_dt, + "format": format_str, + "expected_format": expected_format, + }, + } + + def _generate_times(self, rng: random.Random, task_type: str): + """Generate start and end times based on task type""" + if task_type.startswith("time"): + if task_type == "time_ms": + format_str = self.TIME_FORMATS[2] # Get milliseconds format + expected_format = "HH:MM:SS.mmm" + else: + format_str = next(f for f in self.TIME_FORMATS if f.count(":") == (2 if "seconds" in task_type else 1)) + expected_format = "HH:MM:SS" if "seconds" in task_type else "HH:MM" + + # Generate random start time + start_hour = rng.randint(0, 23) + start_minute = rng.randint(0, 59) + start_second = rng.randint(0, 59) + base = datetime.combine(date.today(), time(start_hour, start_minute, start_second)) + + # Calculate seconds remaining until midnight + seconds_until_midnight = ((24 - start_hour) * 3600) - (start_minute * 60) - start_second + # Use the minimum of config max and seconds until midnight + max_seconds = min(self.config.max_time_difference_seconds, seconds_until_midnight) + diff_seconds = rng.randint(1, max_seconds) if max_seconds > 0 else 0 + + if task_type == "time_ms": + # Add microseconds for millisecond precision + base = base.replace(microsecond=rng.randint(0, 999) * 1000) + end_time = base + timedelta(seconds=diff_seconds, microseconds=rng.randint(0, 999) * 1000) + # Format with exactly 3 decimal places for milliseconds + start_time = base.strftime(format_str)[:-3] # Remove extra microsecond digits + end_time = end_time.strftime(format_str)[:-3] # Remove extra microsecond digits + else: + start_time = base.strftime(format_str) + end_time = (base + timedelta(seconds=diff_seconds)).strftime(format_str) + + elif task_type == "date": + format_str = rng.choice(self.DATE_FORMATS) + expected_format = "D days" # Always return number of days for date tasks + + # Generate random start date within configured range, leaving room for end date + max_date_difference_days = min( + self.config.max_date_difference_days, (self.config.max_date - self.config.min_date).days + ) + max_start_days = (self.config.max_date - self.config.min_date).days - max_date_difference_days + start_days = rng.randint(0, max_start_days - 1) + start_date = self.config.min_date + timedelta(days=start_days) + + # Ensure positive difference between dates + diff_days = rng.randint(0, max_date_difference_days) + end_date = start_date + timedelta(days=diff_days) + + start_time = start_date.strftime(format_str) + end_time = end_date.strftime(format_str) + + else: # datetime or datetime_tz + format_str = rng.choice(self.DATETIME_FORMATS) + # Choose between HH:MM and HH:MM:SS format for datetime answers + expected_format = rng.choice(["D days, HH:MM", "D days, HH:MM:SS"]) + + # Generate random start datetime + days_range = (self.config.max_date - self.config.min_date).days + start_days = rng.randint(0, days_range) + start_hour = rng.randint(0, 23) + start_minute = rng.randint(0, 59) + start_second = rng.randint(0, 59) + + # Generate random time differences first + diff_days = rng.randint(0, self.config.max_date_difference_days) + diff_seconds = rng.randint(1, self.config.max_time_difference_seconds) + + if "%z" in format_str: + # Use simpler timezone format with offset + base = datetime.combine( + self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second) + ) + # Generate timezone offsets + start_offset = rng.randint(-12, 12) + end_offset = rng.randint(-12, 12) + + # Apply start timezone + base = base.replace(tzinfo=pytz.FixedOffset(start_offset * 60)) + start_format = format_str.replace("%z", "%+05d" % (start_offset * 100)) + + # Calculate end time and convert to end timezone + end_dt = base + timedelta(days=diff_days, seconds=diff_seconds) + end_dt = end_dt.replace(tzinfo=pytz.FixedOffset(end_offset * 60)) + end_format = format_str.replace("%z", "%+05d" % (end_offset * 100)) + + # Format times with their respective timezone offsets + start_time = base.strftime(start_format).rstrip() + end_time = end_dt.strftime(end_format).rstrip() + else: + base = datetime.combine( + self.config.min_date + timedelta(days=start_days), time(start_hour, start_minute, start_second) + ) + # For non-timezone aware times, both use same format + start_time = base.strftime(format_str).rstrip() + end_time = (base + timedelta(days=diff_days, seconds=diff_seconds)).strftime(format_str).rstrip() + + return start_time, end_time, format_str, expected_format + + def score_answer(self, answer: Optional[str], entry: dict) -> float: + """Score an answer based on how close it is to the expected duration + + Returns a score between 0 and 1, with partial credit for answers that are + close to correct in the appropriate units/format + """ + if not answer: + return 0.0 + + expected = entry["answer"] + task_type = entry["metadata"]["task_type"] + + try: + if task_type == "date": + # Parse "X days" format + try: + actual = int(answer.strip().split()[0]) # Get number before "days" + expected = int(expected.strip().split()[0]) + if actual == expected: + return 1.0 + # Partial credit based on how close the day count is + max_diff = self.config.max_date_difference_days + diff = abs(actual - expected) + return max(0.0, 1.0 - (diff / max_diff)) + except (ValueError, IndexError): + return 0.0 + + elif task_type.startswith("time"): + # Parse times into total seconds for comparison + def parse_time(t): + parts = t.strip().split(":") + seconds = int(parts[0]) * 3600 + int(parts[1]) * 60 + if len(parts) > 2: + if "." in parts[2]: # Has milliseconds + s, ms = parts[2].split(".") + seconds += int(s) + int(ms) / 1000 + else: + seconds += int(parts[2]) + return seconds + + actual_seconds = parse_time(answer) + expected_seconds = parse_time(expected) + + if actual_seconds == expected_seconds: + return 1.0 + + # Partial credit based on how close the times are + max_diff = self.config.max_time_difference_seconds + diff = abs(actual_seconds - expected_seconds) + return max(0.0, 1.0 - (diff / max_diff)) + + else: # datetime or datetime_tz + # Parse the complex format "X days, HH:MM" or "X days, HH:MM:SS" + def parse_datetime(t): + days = int(t.split(" days,")[0]) + time_part = t.split(",")[1].strip() + parts = time_part.split(":") + seconds = int(parts[0]) * 3600 + int(parts[1]) * 60 + if len(parts) > 2: + seconds += int(parts[2]) + return days * 86400 + seconds + + actual_seconds = parse_datetime(answer) + expected_seconds = parse_datetime(expected) + + if actual_seconds == expected_seconds: + return 1.0 + + # Partial credit based on total time difference + max_diff = self.config.max_date_difference_days * 86400 + diff = abs(actual_seconds - expected_seconds) + return max(0.0, 1.0 - (diff / max_diff)) + + except (ValueError, IndexError): + return 0.0 # Invalid format + + return 0.0 + + +# Register the dataset +register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig) diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index fddd97b1..38baf31b 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -6,18 +6,21 @@ Cognition tasks for training reasoning capabilities: - Working memory """ +from .arc_1d import Arc1DConfig, Arc1DDataset from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset from .figlet_fonts import FigletFontConfig, FigletFontDataset from .number_sequences import NumberSequenceConfig, NumberSequenceDataset from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset __all__ = [ - "NumberSequenceConfig", - "NumberSequenceDataset", + "Arc1DConfig", + "Arc1DDataset", "ColorCubeRotationConfig", "ColorCubeRotationDataset", - "RubiksCubeConfig", - "RubiksCubeDataset", "FigletFontConfig", "FigletFontDataset", + "NumberSequenceConfig", + "NumberSequenceDataset", + "RubiksCubeConfig", + "RubiksCubeDataset", ] diff --git a/reasoning_gym/cognition/arc_1d.py b/reasoning_gym/cognition/arc_1d.py new file mode 100644 index 00000000..7e399f20 --- /dev/null +++ b/reasoning_gym/cognition/arc_1d.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..dataset import ProceduralDataset +from ..factory import register_dataset + + +@dataclass +class Arc1DConfig: + """Configuration for ARC 1D task generation""" + + min_size: int = 10 # Minimum grid size + max_size: int = 30 # Maximum grid size + num_train: int = 3 # Number of training examples + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.min_size > 0, "min_size must be positive" + assert self.max_size >= self.min_size, "max_size must be >= min_size" + assert self.num_train > 0, "num_train must be positive" + assert self.size > 0, "size must be positive" + + +class Arc1DDataset(ProceduralDataset): + """ + Generates ARC 1D tasks by randomly selecting from available task generators + + This dataset is a procedural variant of the 1D-ARC dataset which is described in the paper: + `LLMs and the Abstraction and Reasoning Corpus: Successes, Failures, and the Importance + of Object-based Representations` (https://arxiv.org/abs/2305.18354) + + Ilya Sheprut (optozorax) created rust generators for most of the ARC 1d tasks. For + reasoning-gym rust tasks were machine-converted to python via Sonnet. + + Ilya's original rust code can be found here: https://github.com/optozorax/arc_1d/ + """ + + def __init__(self, config: Arc1DConfig): + from .arc_1d_tasks import ARC_1D_TASKS + + super().__init__(config=config, seed=config.seed, size=config.size) + self.ARC_1D_TASKS = ARC_1D_TASKS + self.task_names = list(ARC_1D_TASKS.keys()) + + def __getitem__(self, idx: int) -> dict: + """Generate a single ARC 1D task with training examples + + Args: + idx: Index of the item to generate + + Returns: + dict with keys: + - question: str, the task description and examples + - answer: str, the expected output format + - metadata: dict with generation parameters + """ + # Create deterministic RNG from base seed and idx + item_rng = Random(self.seed + idx) + + # Select random task + task_name = item_rng.choice(self.task_names) + task_func, task_kwargs = self.ARC_1D_TASKS[task_name] + + # Generate training examples + train_examples = [] + size = item_rng.randint(self.config.min_size, self.config.max_size) + + for _ in range(self.config.num_train): + example = None + while example is None: + example = task_func(item_rng, size, **task_kwargs) + + train_examples.append(example) + + # Generate test example + test_example = None + while test_example is None: + test_example = task_func(item_rng, size, **task_kwargs) + + # Format question + question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n" + + # Add training examples + for i, example in enumerate(train_examples, 1): + question += f"Example {i}:\n" + question += "Input: " + " ".join(str(x) for x in example["input"]) + "\n" + question += "Output: " + " ".join(str(x) for x in example["output"]) + "\n\n" + + # Add test input + question += "Below is a test input grid. Predict the corresponding output grid by applying the rule you found. " + question += "Describe how you derived the rule and your overall reasoning process in detail before you submit your answer. " + question += "Your final answer must be placed in tags and should be just be the text output grid itself.\n\n" + question += "Input:\n" + question += " ".join(str(x) for x in test_example["input"]) + + return { + "question": question, + "answer": " ".join(str(x) for x in test_example["output"]), + "metadata": { + "task_name": task_name, + "size": size, + "train_examples": train_examples, + "test_example": test_example, + }, + } + + +# Register the dataset +register_dataset("arc_1d", Arc1DDataset, Arc1DConfig) diff --git a/reasoning_gym/cognition/arc_1d_tasks.py b/reasoning_gym/cognition/arc_1d_tasks.py new file mode 100644 index 00000000..61151b34 --- /dev/null +++ b/reasoning_gym/cognition/arc_1d_tasks.py @@ -0,0 +1,1227 @@ +from random import Random +from typing import Dict, List, Optional + + +def gen_field(size: int, color: int = 0) -> List[int]: + """Generate a field of given size filled with specified color (default 0).""" + return [color] * size + + +def write_block(pos: int, block: List[int], field: List[int]) -> List[int]: + """Write a block into a field at given position.""" + result = field.copy() + for i, color in enumerate(block): + result[pos + i] = color + return result + + +def task_move_n_pix(rng: Random, size: int, move_pix: int, solid: bool) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block is moved to the right by move_pix pixels.""" + if size <= move_pix + 1: + return None + + block_size = rng.randint(1, size - move_pix - 1) + block_pos = rng.randint(0, size - block_size - move_pix) + + if solid: + color = rng.randint(1, 9) + block = [color] * block_size + else: + block = [rng.randint(1, 9) for _ in range(block_size)] + + question = write_block(block_pos, block, gen_field(size)) + answer = write_block(block_pos + move_pix, block, gen_field(size)) + + return {"input": question, "output": answer} + + +def task_move_n_pix_wrapped(rng: Random, size: int, move_pix: int, solid: bool) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block is moved to the right by move_pix pixels with wrapping.""" + block_size = rng.randint(1, size) + block_pos = rng.randint(0, size) + + if solid: + color = rng.randint(1, 9) + block = [color] * block_size + else: + block = [rng.randint(1, 9) for _ in range(block_size)] + + question = gen_field(size) + answer = gen_field(size) + + for i, color in enumerate(block): + question[(block_pos + i) % size] = color + answer[(block_pos + move_pix + i) % size] = color + + return {"input": question, "output": answer} + + +def task_gravity(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where all non-zero elements are attracted to the left.""" + density = 0.5 + question = [rng.randint(1, 9) if rng.random() < density else 0 for _ in range(size)] + + non_zero = [x for x in question if x != 0] + answer = non_zero + [0] * (size - len(non_zero)) + + return {"input": question, "output": answer} + + +def task_gravity_counting(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where non-zero elements are counted and represented as a sequence of 1s.""" + density = 0.5 + question = [rng.randint(1, 9) if rng.random() < density else 0 for _ in range(size)] + + count = sum(1 for x in question if x != 0) + answer = [1] * count + [0] * (size - count) + + return {"input": question, "output": answer} + + +def task_gravity_antigravity(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where color 1 moves right and color 2 moves left.""" + density = 0.5 + question = [rng.randint(1, 2) if rng.random() < density else 0 for _ in range(size)] + + color1 = [x for x in question if x == 1] + color2 = [x for x in question if x == 2] + answer = [2] * len(color2) + [0] * (size - len(color1) - len(color2)) + [1] * len(color1) + + return {"input": question, "output": answer} + + +def task_block_touch_dot(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block moves to touch (but not cover) a dot.""" + dot_color = 1 + block_color = rng.randint(2, 9) + + block_size = rng.randint(1, size) + dot_pos = rng.randint(0, size) + + can_place_left = dot_pos >= block_size + can_place_right = dot_pos + block_size < size + + if not (can_place_left or can_place_right): + return None + + if can_place_left and can_place_right: + side = rng.choice(["left", "right"]) + elif can_place_left: + side = "left" + else: + side = "right" + + if side == "left": + q_block_pos = rng.randint(0, dot_pos - block_size) + a_block_pos = dot_pos - block_size + else: + q_block_pos = rng.randint(dot_pos + 1, size - block_size) + a_block_pos = dot_pos + 1 + + question = gen_field(size) + question[dot_pos] = dot_color + question = write_block(q_block_pos, [block_color] * block_size, question) + + answer = gen_field(size) + answer[dot_pos] = dot_color + answer = write_block(a_block_pos, [block_color] * block_size, answer) + + return {"input": question, "output": answer} + + +def task_block_touch_dot_n_pix(rng: Random, size: int, move_pix: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block moves move_pix pixels toward a dot.""" + dot_color = 2 + block_color = rng.randint(3, 9) + + block_size = rng.randint(1, size) + dot_pos = rng.randint(0, size) + + can_place_left = dot_pos >= block_size + can_place_right = dot_pos + block_size < size + + if not (can_place_left or can_place_right): + return None + + if can_place_left and can_place_right: + side = rng.choice(["left", "right"]) + elif can_place_left: + side = "left" + else: + side = "right" + + if side == "left": + q_block_pos = rng.randint(0, dot_pos - block_size) + distance = (dot_pos - block_size) - q_block_pos + move = min(distance, move_pix) + a_block_pos = q_block_pos + move + else: + q_block_pos = rng.randint(dot_pos + 1, size - block_size) + distance = q_block_pos - (dot_pos + 1) + move = min(distance, move_pix) + a_block_pos = q_block_pos - move + + question = gen_field(size) + question[dot_pos] = dot_color + question = write_block(q_block_pos, [block_color] * block_size, question) + + answer = gen_field(size) + answer[dot_pos] = dot_color + answer = write_block(a_block_pos, [block_color] * block_size, answer) + + return {"input": question, "output": answer} + + +def task_block_scale_to_dot(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block scales to touch a dot (keeping one end fixed).""" + dot_color = 2 + block_color = rng.randint(3, 9) + + block_size = rng.randint(1, size) + dot_pos = rng.randint(0, size) + + can_place_left = dot_pos >= block_size + can_place_right = dot_pos + block_size < size + + if not (can_place_left or can_place_right): + return None + + if can_place_left and can_place_right: + side = rng.choice(["left", "right"]) + elif can_place_left: + side = "left" + else: + side = "right" + + if side == "left": + q_block_pos = rng.randint(0, dot_pos - block_size) + new_size = dot_pos - q_block_pos + 1 + a_block_pos = q_block_pos + else: + q_block_pos = rng.randint(dot_pos + 1, size - block_size) + new_size = (q_block_pos + block_size) - dot_pos + a_block_pos = dot_pos + + question = gen_field(size) + question[dot_pos] = dot_color + question = write_block(q_block_pos, [block_color] * block_size, question) + + answer = gen_field(size) + answer[dot_pos] = dot_color + answer = write_block(a_block_pos, [block_color] * new_size, answer) + + return {"input": question, "output": answer} + + +def task_two_points_and_fill(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where space between two points of same color is filled with that color.""" + color = rng.randint(1, 9) + + pos1 = rng.randint(0, size - 1) + pos2 = rng.randint(0, size - 1) + if pos1 == pos2: + return None + + pos1, pos2 = min(pos1, pos2), max(pos1, pos2) + + question = gen_field(size) + question[pos1] = color + question[pos2] = color + + answer = question.copy() + for i in range(pos1, pos2 + 1): + answer[i] = color + + return {"input": question, "output": answer} + + +def task_reflect_block_with_border_pixel(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block with a border pixel is reflected.""" + block_size = rng.randint(2, size) + if block_size > size: + return None + + c1 = rng.randint(1, 9) + c2 = rng.randint(1, 9) + if c1 == c2: + return None + + side = "left" if rng.random() < 0.5 else "right" + pos = rng.randint(0, size - block_size) + + block = [c1] * block_size + if side == "left": + block[0] = c2 + else: + block[block_size - 1] = c2 + + question = write_block(pos, block, gen_field(size)) + reversed_block = block[::-1] # Reverse the block + answer = write_block(pos, reversed_block, gen_field(size)) + + return {"input": question, "output": answer} + + +def task_reflect_block_with_border_pixel_random(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a random-colored block with a border pixel is reflected.""" + block_size = rng.randint(2, size) + if block_size > size: + return None + + side = "left" if rng.random() < 0.5 else "right" + pos = rng.randint(0, size - block_size) + + block = [rng.randint(1, 9) for _ in range(block_size)] + border_color = rng.randint(1, 9) + + if side == "left": + if block[0] == border_color: + return None + block[0] = border_color + else: + if block[block_size - 1] == border_color: + return None + block[block_size - 1] = border_color + + question = write_block(pos, block, gen_field(size)) + reversed_block = block[::-1] # Reverse the block + answer = write_block(pos, reversed_block, gen_field(size)) + + return {"input": question, "output": answer} + + +def task_reflect_block_around_dot(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block is reflected around a dot.""" + dot_color = 2 + + dot_pos = rng.randint(0, size) + block_size = rng.randint(1, size) + block_pos = rng.randint(0, size - block_size) + block_end = block_pos + block_size - 1 + + # Check if block is strictly to left or right of dot + strictly_left = block_end < dot_pos + strictly_right = block_pos > dot_pos + + if not (strictly_left or strictly_right): + return None + + block_color = rng.randint(3, 9) # Different from dot color + block = [block_color] * block_size + + # Calculate reflection bounds + min_reflect = 2 * dot_pos - block_end + max_reflect = 2 * dot_pos - block_pos + if min_reflect < 0 or max_reflect >= size: + return None + + question = gen_field(size) + question = write_block(block_pos, block, question) + question[dot_pos] = dot_color + + answer = gen_field(size) + answer[dot_pos] = dot_color + for i in range(block_size): + reflect_idx = 2 * dot_pos - (block_pos + i) + answer[reflect_idx] = block[i] + + return {"input": question, "output": answer} + + +def task_block_and_noise_remove(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where noise around a block needs to be removed.""" + block_size = rng.randint(2, size) + if block_size > size: + return None + + block_pos = rng.randint(0, size - block_size) + color = rng.randint(1, 9) + + # Create field with block + field = gen_field(size) + for i in range(block_size): + field[block_pos + i] = color + + # Track forbidden positions for noise + forbidden = [False] * size + for i in range(block_pos, block_pos + block_size): + forbidden[i] = True + if block_pos > 0: + forbidden[block_pos - 1] = True + if block_pos + block_size < size: + forbidden[block_pos + block_size] = True + + # Add noise + noise_count = rng.randint(1, 3) + noise_positions = [] + + for _ in range(noise_count): + allowed = [i for i in range(size) if not forbidden[i]] + if not allowed: + break + noise_pos = rng.choice(allowed) + noise_positions.append(noise_pos) + field[noise_pos] = color + forbidden[noise_pos] = True + if noise_pos > 0: + forbidden[noise_pos - 1] = True + if noise_pos + 1 < size: + forbidden[noise_pos + 1] = True + + if len(noise_positions) < noise_count: + return None + + question = field + answer = field.copy() + for pos in noise_positions: + answer[pos] = 0 + + return {"input": question, "output": answer} + + +def task_block_and_noise_remove_inside(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where noise inside a block needs to be removed.""" + if size <= 6: + return None + + block_size = rng.randint(6, size) + if block_size > size: + return None + + block_pos = rng.randint(0, size - block_size) + color = rng.randint(1, 9) + + # Create field with block + field = gen_field(size) + for i in range(block_size): + field[block_pos + i] = color + + # Add noise inside block + max_noise = max(1, (block_size // 2) - 1) + noise_count = rng.randint(1, max_noise) + + positions = list(range(block_size)) + rng.shuffle(positions) + noise_positions = positions[:noise_count] + + for offset in noise_positions: + pos = block_pos + offset + noise_color = rng.randint(1, 9) + while noise_color == color: + noise_color = rng.randint(1, 9) + field[pos] = noise_color + + question = field + answer = field.copy() + for offset in noise_positions: + answer[block_pos + offset] = color + + return {"input": question, "output": answer} + + +def task_copy_block_to_dots(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block pattern is copied to dot positions.""" + block_size = 3 if rng.random() < 0.5 else 5 + if block_size >= size: + return None + + color = rng.randint(1, 9) + block = [color] * block_size + + # Generate dots with minimum distance to prevent overlap + min_gap = block_size + dot_positions = [] + pos = block_size + block_size // 2 + 1 + + while pos <= size - block_size: + if rng.random() < 0.5: # Control dot density + dot_positions.append(pos) + pos += min_gap + pos += 1 + + if not dot_positions: + return None + + question = gen_field(size) + question = write_block(0, block, question) + for pos in dot_positions: + question[pos] = color + + answer = gen_field(size) + answer = write_block(0, block, answer) + for pos in dot_positions: + block_start = pos - block_size // 2 + answer = write_block(block_start, block, answer) + + return {"input": question, "output": answer} + + +def task_copy_block_to_dots_colors(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block pattern is copied to dot positions with matching colors.""" + block_size = 3 if rng.random() < 0.5 else 5 + if block_size >= size: + return None + + block_color = rng.randint(1, 9) + block = [block_color] * block_size + + # Generate dots with minimum distance to prevent overlap + min_gap = block_size + dot_positions = [] + dot_colors = [] + pos = block_size + block_size // 2 + 1 + + while pos < size - block_size: + if rng.random() < 0.5: + dot_color = rng.randint(1, 9) + dot_positions.append(pos) + dot_colors.append(dot_color) + pos += min_gap + pos += 1 + + if not dot_positions: + return None + + question = gen_field(size) + question = write_block(0, block, question) + for i, pos in enumerate(dot_positions): + question[pos] = dot_colors[i] + + answer = gen_field(size) + answer = write_block(0, block, answer) + for i, pos in enumerate(dot_positions): + block_start = pos - block_size // 2 + colored_block = [dot_colors[i]] * block_size + answer = write_block(block_start, colored_block, answer) + + return {"input": question, "output": answer} + + +def task_paint_biggest_block(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where the largest block is painted a different color.""" + target_color = 1 + initial_color = rng.randint(2, 9) + + # Generate random blocks + question = gen_field(size) + blocks = [] + pos = 0 + + while pos < size: + if rng.random() < 0.4 and size - pos >= 2: + block_size = rng.randint(2, min(size - pos, 6)) + blocks.append((pos, block_size)) + for i in range(block_size): + question[pos + i] = initial_color + pos += block_size + 1 + else: + pos += 1 + + if len(blocks) < 2: + return None + + # Find biggest block + biggest_pos, biggest_size = max(blocks, key=lambda x: x[1]) + + # Check if there are multiple blocks of the same size + biggest_count = sum(1 for _, size in blocks if size == biggest_size) + if biggest_count > 1: + return None + + answer = question.copy() + for i in range(biggest_size): + answer[biggest_pos + i] = target_color + + return {"input": question, "output": answer} + + +def task_sort_blocks_by_size(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where blocks are sorted by size with 1 pixel gaps.""" + color = rng.randint(1, 9) + blocks = [] + pos = 0 + + # Generate random blocks with random sizes + while pos < size: + if rng.random() < 0.4 and size - pos >= 2: + block_size = rng.randint(1, min(size - pos, 6)) + blocks.append((pos, block_size)) + pos += block_size + rng.randint(1, 4) # Random gaps + else: + pos += 1 + + if len(blocks) < 2: + return None + + # Create input field + question = gen_field(size) + for pos, block_size in blocks: + for i in range(block_size): + question[pos + i] = color + + # Sort blocks by size + blocks.sort(key=lambda x: x[1]) + + # Check if sorted blocks fit with gaps + total_space = sum(size for _, size in blocks) + len(blocks) - 1 + if total_space > size: + return None + + # Create answer field with sorted blocks + answer = gen_field(size) + current_pos = 0 + + for _, block_size in blocks: + for i in range(block_size): + answer[current_pos + i] = color + current_pos += block_size + 1 # One pixel gap + + return {"input": question, "output": answer} + + +def task_sort_complete_sequence(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a complete sequence of block sizes is sorted.""" + # Calculate max possible block size given total array size + max_size = 1 + total_space = 0 + while total_space + max_size + 1 <= size: + total_space += max_size + 1 + max_size += 1 + max_size -= 1 + + if max_size < 2: + return None + + color = rng.randint(1, 9) + + # Create sequence of all sizes from 1 to max_size + blocks = list(range(1, max_size + 1)) + rng.shuffle(blocks) + + # Create input field with shuffled blocks + question = gen_field(size) + pos = 0 + for block_size in blocks: + for i in range(block_size): + question[pos + i] = color + pos += block_size + 1 + + # Create answer field with sorted blocks + answer = gen_field(size) + pos = 0 + for block_size in range(1, max_size + 1): + for i in range(block_size): + answer[pos + i] = color + pos += block_size + 1 + + return {"input": question, "output": answer} + + +def task_recolor_blocks_by_size(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where two blocks are recolored based on their size.""" + # Generate two different random sizes + size1 = rng.randint(2, 8) + size2 = rng.randint(2, 8) + while size2 == size1: + size2 = rng.randint(2, 8) + + # Ensure both blocks fit with at least 1 gap + if size1 + size2 + 1 > size: + return None + + # Place blocks with gap + pos1 = rng.randint(0, size - (size1 + size2 + 1)) + pos2 = rng.randint(pos1 + size1 + 1, size - size2) + + # Create input field with both blocks color 3 + question = gen_field(size) + for i in range(size1): + question[pos1 + i] = 3 + for i in range(size2): + question[pos2 + i] = 3 + + # Create answer field with recolored blocks + answer = question.copy() + if size1 > size2: + for i in range(size1): + answer[pos1 + i] = 1 + for i in range(size2): + answer[pos2 + i] = 2 + else: + for i in range(size1): + answer[pos1 + i] = 2 + for i in range(size2): + answer[pos2 + i] = 1 + + return {"input": question, "output": answer} + + +def task_gravity_one_step(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where non-zero elements move one step left if possible.""" + question = [rng.randint(1, 9) if rng.random() < 0.5 else 0 for _ in range(size)] + answer = question.copy() + + # Move each non-zero pixel one step left if possible + for i in range(1, size): + if answer[i] != 0 and answer[i - 1] == 0: + answer[i - 1] = answer[i] + answer[i] = 0 + + return {"input": question, "output": answer} + + +def task_move_block_by_own_size(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block moves right by its own size.""" + block_size = rng.randint(1, size // 2) # Ensure space for movement + pos = rng.randint(0, size - block_size * 2) # Space for block and movement + color = rng.randint(1, 9) + + question = gen_field(size) + block = [color] * block_size + question = write_block(pos, block, question) + + answer = write_block(pos + block_size, block, gen_field(size)) + + return {"input": question, "output": answer} + + +def task_change_to_five(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where all non-zero colors change to 5.""" + density = 0.5 + question = [rng.randint(1, 9) if rng.random() < density else 0 for _ in range(size)] + answer = [5 if x != 0 else 0 for x in question] + + return {"input": question, "output": answer} + + +def task_recolor_blocks_from_palette(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where blocks are recolored using a color palette.""" + # Generate blocks of same size + block_size = rng.randint(2, 4) + blocks = [] + pos = 0 + + while pos + block_size <= size: + if rng.random() < 0.4: + blocks.append(pos) + pos += block_size + 1 + else: + pos += 1 + + # Ensure we have space for palette + while blocks and blocks[-1] + block_size + len(blocks) + 1 >= size: + blocks.pop() + + if not blocks: + return None + + # Shift blocks right to make room for palette + palette_size = len(blocks) + blocks = [pos + palette_size + 1 for pos in blocks] + + # Generate color palette + colors = [] + for _ in range(len(blocks)): + while True: + color = rng.randint(1, 9) + if color not in colors: + colors.append(color) + break + + # Create question with color palette and blocks + question = gen_field(size) + + # Place color palette at start + for i, color in enumerate(colors): + question[i] = color + + # Place blocks of color 5 + for block_pos in blocks: + for i in range(block_size): + question[block_pos + i] = 5 + + # Create answer with recolored blocks + answer = question.copy() + for block_idx, block_pos in enumerate(blocks): + color = colors[block_idx] + for i in range(block_size): + answer[block_pos + i] = color + + return {"input": question, "output": answer} + + +def task_duplicate_block_from_seeds(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a block is duplicated from seed pixels.""" + block_size = rng.randint(2, 4) + if block_size + 1 >= size: + return None + if size <= 3 + block_size: + return None + + # Position block with space for seeds + block_pos = rng.randint(2, size - block_size - 1) + + # Decide seed placement + left_seed = rng.random() < 0.5 + right_seed = rng.random() < 0.5 + if not (left_seed or right_seed): + return None + + # Create input + question = gen_field(size) + + # Place main block + for i in range(block_size): + question[block_pos + i] = 1 + + # Place seeds with gaps + seeds = [] + if left_seed: + color = rng.randint(1, 9) + question[block_pos - 2] = color + seeds.append(("left", block_pos - 2, color)) + if right_seed: + color = rng.randint(1, 9) + question[block_pos + block_size + 1] = color + seeds.append(("right", block_pos + block_size + 1, color)) + + # Create answer with duplicated blocks + answer = question.copy() + + for side, seed_pos, color in seeds: + if side == "left": + # For left seed, blocks end at seed + end_pos = seed_pos + while end_pos >= 0: + start_pos = end_pos - block_size + 1 + for pos in range(max(0, start_pos), end_pos + 1): + answer[pos] = color + if start_pos < 1: + break + end_pos = start_pos - 2 # -1 for gap + else: # side == "right" + # For right seed, blocks start at seed + start_pos = seed_pos + while start_pos < size: + for offset in range(min(block_size, size - start_pos)): + answer[start_pos + offset] = color + if start_pos + block_size + 1 >= size: + break + start_pos = start_pos + block_size + 1 # +1 for gap + + return {"input": question, "output": answer} + + +def task_fill_from_pixel(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a pixel fills in one direction until hitting another pixel.""" + block_size = rng.randint(3, 6) + if block_size >= size - 2: + return None + + # Position block with space for seed + block_pos = rng.randint(1, size - block_size - 1) + + # Create input + question = gen_field(size) + + # Place main block + block_color = rng.randint(1, 9) + for i in range(block_size): + question[block_pos + i] = block_color + + # Place seed pixel and determine fill direction + seed_color = rng.randint(1, 9) + while seed_color == block_color: + seed_color = rng.randint(1, 9) + + is_left = rng.random() < 0.5 + + if is_left: + question[block_pos - 1] = seed_color + else: + question[block_pos + block_size] = seed_color + + # Create answer with fill + answer = question.copy() + + if is_left: + # Fill from seed to left border + for i in range(block_pos): + answer[i] = seed_color + else: + # Fill from seed to right border + for i in range(block_pos + block_size, size): + answer[i] = seed_color + + return {"input": question, "output": answer} + + +def task_mark_size_two_blocks(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where size-2 blocks are marked with surrounding pixels.""" + blocks = [] + pos = 0 + + # Generate blocks with minimum gap of 2 + while pos < size: + if rng.random() < 0.4: + block_size = rng.randint(1, 3) + # Check if we have space for block and potential markers + needed_space = block_size + (2 if block_size == 2 else 0) + if pos + needed_space < size: + blocks.append((pos, block_size)) + pos += block_size + 2 # Minimum gap of 2 + + pos += 1 + + if len(blocks) < 2: + return None + + # Verify gaps between blocks (including markers) + valid = True + for i in range(len(blocks) - 1): + pos1, size1 = blocks[i] + pos2, _ = blocks[i + 1] + needed_gap = 3 if size1 == 2 else 2 + if pos2 - (pos1 + size1) < needed_gap: + valid = False + break + if not valid: + return None + + # Create input with blocks + question = gen_field(size) + for pos, block_size in blocks: + # Place block + for i in range(block_size): + question[pos + i] = 1 + + # Create answer with markers + answer = question.copy() + for pos, block_size in blocks: + if block_size == 2: + # Add markers for size 2 blocks + if pos > 0: + answer[pos - 1] = 3 + if pos + block_size < size: + answer[pos + block_size] = 3 + + return {"input": question, "output": answer} + + +def task_fill_until_collision(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where pixels fill empty space until collision.""" + # At least 4 positions for meaningful puzzle + if size < 4: + return None + + is_left = rng.random() < 0.5 + question = gen_field(size) + + # Place the side marker + if is_left: + question[0] = 5 + else: + question[size - 1] = 5 + + # Place 2-4 random pixels + num_pixels = rng.randint(2, 4) + positions = [] + + if is_left: + # Skip first position + for _ in range(num_pixels): + while True: + pos = rng.randint(1, size - 1) + if pos not in positions: + positions.append(pos) + break + else: + # Skip last position + for _ in range(num_pixels): + while True: + pos = rng.randint(0, size - 2) + if pos not in positions: + positions.append(pos) + break + + # Color random pixels + for pos in positions: + question[pos] = rng.randint(1, 9) + + positions.sort() + + # Create answer + answer = question.copy() + + if is_left: + # Fill right from each pixel + prev_pos = 0 # Start from marker + for pos in positions: + color = question[pos] + # Fill from previous position to current + for i in range(prev_pos + 1, pos): + answer[i] = color + prev_pos = pos + else: + # Fill left from each pixel + prev_pos = size - 1 # Start from marker + for pos in reversed(positions): + color = question[pos] + # Fill from current position to previous + for i in range(pos + 1, prev_pos): + answer[i] = color + prev_pos = pos + + return {"input": question, "output": answer} + + +def task_repeat_pattern_full(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where a pattern is repeated to fill the space.""" + # Generate initial pattern + pattern_size = rng.randint(2, 5) + pattern = [rng.randint(1, 9) for _ in range(pattern_size)] + + # Calculate total size needed for 2 repetitions + double_size = pattern_size * 2 + if double_size >= size: + return None + + # Create input with 2 repetitions + question = gen_field(size) + for i in range(pattern_size): + question[i] = pattern[i] + question[i + pattern_size] = pattern[i] + + # Create answer with maximum repetitions + answer = gen_field(size) + pos = 0 + while pos + pattern_size <= size: + for i in range(pattern_size): + answer[pos + i] = pattern[i] + pos += pattern_size + + # Fill remaining space (if any) with pattern elements + for i in range(pos, size): + answer[i] = pattern[i - pos] + + return {"input": question, "output": answer} + + +def task_gravity_weighted_colors(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where color 2 is heavier than color 1 in gravity.""" + # Generate random field with only colors 1 and 2 + question = [rng.randint(1, 2) if rng.random() < 0.5 else 0 for _ in range(size)] + + # Count colors + count_1 = sum(1 for x in question if x == 1) + count_2 = sum(1 for x in question if x == 2) + + # Create answer with sorted colors + answer = gen_field(size) + + # Place heavier color 2 first + for i in range(count_2): + answer[i] = 2 + + # Then place color 1 + for i in range(count_1): + answer[count_2 + i] = 1 + + return {"input": question, "output": answer} + + +def task_color_left_half_blocks(rng: Random, size: int) -> Optional[Dict[str, List[int]]]: + """Generate a task where left half of blocks are colored differently.""" + pos = 0 + question = gen_field(size) + blocks = [] + + # Generate blocks with gap 1 + while pos < size: + if rng.random() < 0.4: + block_size = rng.randint(2, 8) + if pos + block_size >= size: + break + + blocks.append((pos, block_size)) + for i in range(block_size): + question[pos + i] = 2 + pos += block_size + 1 # block size + gap + else: + pos += 1 + + if len(blocks) < 2: + return None + + # Create answer with half-colored blocks + answer = question.copy() + for pos, block_size in blocks: + half_size = block_size // 2 + for i in range(half_size): + answer[pos + i] = 8 + + return {"input": question, "output": answer} + + +def task_mirror(task_result: Optional[Dict[str, List[int]]]) -> Optional[Dict[str, List[int]]]: + """Mirror the input and output arrays of a task result.""" + if task_result is None: + return None + return {"input": list(reversed(task_result["input"])), "output": list(reversed(task_result["output"]))} + + +def task_inverse(task_result: Optional[Dict[str, List[int]]]) -> Optional[Dict[str, List[int]]]: + """Swap the input and output arrays of a task result.""" + if task_result is None: + return None + return {"input": task_result["output"], "output": task_result["input"]} + + +def task_identity(task_result: Optional[Dict[str, List[int]]]) -> Optional[Dict[str, List[int]]]: + """Return the task result unchanged.""" + return task_result + + +# Table of all ARC 1D task functions with their parameters +ARC_1D_TASKS = { + # Move tasks - right direction + "move_1pix_solid_right": (task_move_n_pix, {"move_pix": 1, "solid": True}), + "move_2pix_solid_right": (task_move_n_pix, {"move_pix": 2, "solid": True}), + "move_3pix_solid_right": (task_move_n_pix, {"move_pix": 3, "solid": True}), + "move_4pix_solid_right": (task_move_n_pix, {"move_pix": 4, "solid": True}), + "move_1pix_colorful_right": (task_move_n_pix, {"move_pix": 1, "solid": False}), + "move_2pix_colorful_right": (task_move_n_pix, {"move_pix": 2, "solid": False}), + "move_3pix_colorful_right": (task_move_n_pix, {"move_pix": 3, "solid": False}), + "move_4pix_colorful_right": (task_move_n_pix, {"move_pix": 4, "solid": False}), + # Move tasks - left direction (mirrored) + "move_1pix_solid_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 1, "solid": True}, + ), + "move_2pix_solid_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 2, "solid": True}, + ), + "move_3pix_solid_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 3, "solid": True}, + ), + "move_4pix_solid_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 4, "solid": True}, + ), + "move_1pix_colorful_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 1, "solid": False}, + ), + "move_2pix_colorful_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 2, "solid": False}, + ), + "move_3pix_colorful_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 3, "solid": False}, + ), + "move_4pix_colorful_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix(rng, size, **kwargs)), + {"move_pix": 4, "solid": False}, + ), + # Move wrapped tasks - right direction + "move_1pix_solid_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 1, "solid": True}), + "move_2pix_solid_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 2, "solid": True}), + "move_3pix_solid_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 3, "solid": True}), + "move_4pix_solid_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 4, "solid": True}), + "move_1pix_colorful_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 1, "solid": False}), + "move_2pix_colorful_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 2, "solid": False}), + "move_3pix_colorful_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 3, "solid": False}), + "move_4pix_colorful_wrapped_right": (task_move_n_pix_wrapped, {"move_pix": 4, "solid": False}), + # Move wrapped tasks - left direction (mirrored) + "move_1pix_solid_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 1, "solid": True}, + ), + "move_2pix_solid_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 2, "solid": True}, + ), + "move_3pix_solid_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 3, "solid": True}, + ), + "move_4pix_solid_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 4, "solid": True}, + ), + "move_1pix_colorful_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 1, "solid": False}, + ), + "move_2pix_colorful_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 2, "solid": False}, + ), + "move_3pix_colorful_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 3, "solid": False}, + ), + "move_4pix_colorful_wrapped_left": ( + lambda rng, size, **kwargs: task_mirror(task_move_n_pix_wrapped(rng, size, **kwargs)), + {"move_pix": 4, "solid": False}, + ), + # Gravity tasks - right direction + "gravity_right": (task_gravity, {}), + "gravity_counting_right": (task_gravity_counting, {}), + "gravity_antigravity_right": (task_gravity_antigravity, {}), + "gravity_one_step_right": (task_gravity_one_step, {}), + "gravity_weighted_colors_right": (task_gravity_weighted_colors, {}), + # Gravity tasks - left direction (mirrored) + "gravity_left": (lambda rng, size, **kwargs: task_mirror(task_gravity(rng, size, **kwargs)), {}), + "gravity_counting_left": (lambda rng, size, **kwargs: task_mirror(task_gravity_counting(rng, size, **kwargs)), {}), + "gravity_antigravity_left": ( + lambda rng, size, **kwargs: task_mirror(task_gravity_antigravity(rng, size, **kwargs)), + {}, + ), + "gravity_one_step_left": (lambda rng, size, **kwargs: task_mirror(task_gravity_one_step(rng, size, **kwargs)), {}), + "gravity_weighted_colors_left": ( + lambda rng, size, **kwargs: task_mirror(task_gravity_weighted_colors(rng, size, **kwargs)), + {}, + ), + # Block tasks + "block_touch_dot": (task_block_touch_dot, {}), + "block_touch_dot_1pix": (task_block_touch_dot_n_pix, {"move_pix": 1}), + "block_touch_dot_2pix": (task_block_touch_dot_n_pix, {"move_pix": 2}), + "block_touch_dot_3pix": (task_block_touch_dot_n_pix, {"move_pix": 3}), + "block_touch_dot_4pix": (task_block_touch_dot_n_pix, {"move_pix": 4}), + "block_scale_to_dot": (task_block_scale_to_dot, {}), + "block_and_noise_remove": (task_block_and_noise_remove, {}), + "block_and_noise_remove_inside": (task_block_and_noise_remove_inside, {}), + "move_block_by_own_size": (task_move_block_by_own_size, {}), + # Pattern tasks + "two_points_and_fill": (task_two_points_and_fill, {}), + "two_points_and_fill_inv": ( + lambda rng, size, **kwargs: task_inverse(task_two_points_and_fill(rng, size, **kwargs)), + {}, + ), + "copy_block_to_dots": (task_copy_block_to_dots, {}), + "copy_block_to_dots_colors": (task_copy_block_to_dots_colors, {}), + "repeat_pattern_full": (task_repeat_pattern_full, {}), + # Reflection tasks + "reflect_block_with_border_pixel": (task_reflect_block_with_border_pixel, {}), + "reflect_block_random": (task_reflect_block_with_border_pixel_random, {}), + "reflect_block_around_dot": (task_reflect_block_around_dot, {}), + # Color tasks + "paint_biggest_block": (task_paint_biggest_block, {}), + "recolor_blocks_by_size": (task_recolor_blocks_by_size, {}), + "change_to_five": (task_change_to_five, {}), + "recolor_blocks_from_palette": (task_recolor_blocks_from_palette, {}), + "color_left_half_blocks": (task_color_left_half_blocks, {}), + # Sorting tasks + "sort_blocks_by_size": (task_sort_blocks_by_size, {}), + "sort_complete_sequence": (task_sort_complete_sequence, {}), + # Fill tasks + "duplicate_block_from_seeds": (task_duplicate_block_from_seeds, {}), + "fill_from_pixel": (task_fill_from_pixel, {}), + "fill_until_collision": (task_fill_until_collision, {}), + # Marking tasks + "mark_size_two_blocks": (task_mark_size_two_blocks, {}), +} diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index a801c6e4..8e4e32d6 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -10,7 +10,9 @@ from .countdown import CountdownConfig, CountdownDataset from .game_of_life import GameOfLifeConfig, GameOfLifeDataset from .maze import MazeConfig, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset +from .n_queens import NQueensDataset from .sudoku import SudokuConfig, SudokuDataset +from .tower_of_hanoi import HanoiConfig, HanoiDataset __all__ = [ "CountdownConfig", @@ -23,4 +25,7 @@ __all__ = [ "MazeDataset", "GameOfLifeConfig", "GameOfLifeDataset", + "HanoiConfig", + "HanoiDataset", + "NQueensDataset", ] diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 4721844d..38a60c4f 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import sympy from sympy import Symbol, symbols +from sympy.parsing.sympy_parser import parse_expr from ..factory import ProceduralDataset, register_dataset @@ -158,6 +159,23 @@ class CountdownDataset(ProceduralDataset): raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts") + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + user_answer = int(parse_expr(answer)) + solved = user_answer == metadata["target"] + if solved: + reward = 1.0 + elif len(answer.strip()) > 0: # encourage partial solutions + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + # Register the dataset register_dataset("countdown", CountdownDataset, CountdownConfig) diff --git a/reasoning_gym/games/n_queens.py b/reasoning_gym/games/n_queens.py new file mode 100644 index 00000000..1fef6c62 --- /dev/null +++ b/reasoning_gym/games/n_queens.py @@ -0,0 +1,163 @@ +"""N Queens puzzle generator + +A generalization of the 8-queens puzzle to any board size. +https://en.wikipedia.org/wiki/Eight_queens_puzzle +""" + +from copy import deepcopy +from dataclasses import dataclass +from random import Random +from typing import Dict, List, Optional + +from ..factory import ProceduralDataset, register_dataset + +MIN_BOARD_SIZE = 4 +MAX_BOARD_SIZE = 12 + +QUESTION_TEMPLATE = """Solve this N Queens puzzle: +{puzzle} + +The board size is {n}x{n} and your job is to place {num_removed} queen(s) on the board such that no two queens attack each other. + +No two queens attack each other if they are not in the same row, column, or diagonal. + +Place a queen by replacing an underscore (_) with a Q. +""" + + +@dataclass +class NQueensConfig: + """Configuration for N Queens puzzle generation""" + + n: int = 8 # Board size + min_remove: int = 1 # Minimum number of queens to remove from solved board + max_remove: int = 7 # Maximum number of queens to remove from solved board + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert MIN_BOARD_SIZE <= self.n <= MAX_BOARD_SIZE, f"n must be between {MIN_BOARD_SIZE} and {MAX_BOARD_SIZE}" + assert 1 <= self.min_remove <= self.max_remove, "min_remove must be between 1 and max_remove" + assert self.min_remove <= self.max_remove <= self.n, "max_remove must be between min_remove and n" + + +class NQueensDataset(ProceduralDataset): + """Generates N Queens puzzles with configurable difficulty""" + + def __init__(self, config: NQueensConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self._solutions = self._get_all_solutions(config.n) + + def __len__(self) -> int: + return self.config.size + + def __iter__(self): + self._current_idx = 0 + return self + + def __next__(self): + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + + def _get_all_solutions(self, n: int) -> List[List[List[str]]]: + """Get all solutions for the N Queens puzzle""" + + visited_cols = set() + visited_pos_diag = set() + visited_neg_diag = set() + + res = [] + board = [["_"] * n for _ in range(n)] + + def backtrack(row: int): + if row == n: + res.append(deepcopy(board)) + return + + for col in range(n): + if col in visited_cols or (row + col) in visited_pos_diag or (row - col) in visited_neg_diag: + continue + + visited_cols.add(col) + visited_pos_diag.add(row + col) + visited_neg_diag.add(row - col) + board[row][col] = "Q" + backtrack(row + 1) + visited_cols.remove(col) + visited_pos_diag.remove(row + col) + visited_neg_diag.remove(row - col) + board[row][col] = "_" + + backtrack(0) + return res + + def _create_puzzle(self, solved_board: List[List[str]], num_removed: int, rng: Random) -> List[List[str]]: + """Create puzzle by removing queens from solved board""" + puzzle = deepcopy(solved_board) + queens = [(i, j) for i in range(len(puzzle)) for j in range(len(puzzle)) if puzzle[i][j] == "Q"] + rng.shuffle(queens) + for i in range(num_removed): + x, y = queens[i] + puzzle[x][y] = "_" + return puzzle + + def _board_to_string(self, board: List[List[str]]) -> str: + """Convert board to string representation""" + return "\n".join(" ".join(x for x in row) for row in board) + + def _string_to_board(self, board_str: str) -> List[List[str]]: + """Convert string representation to board""" + return [list(row.split()) for row in board_str.strip().split("\n")] + + def _is_tractable_solution(self, puzzle: List[List[str]], solution: List[List[str]]) -> bool: + """Check if a solution is achievable from the starting state of the puzzle""" + for r in range(len(puzzle)): + for c in range(len(puzzle)): + if puzzle[r][c] == "Q" and solution[r][c] != "Q": + return False + return True + + def __getitem__(self, idx: int) -> dict: + """Generate a single N Queens puzzle""" + rng = Random(self.seed + idx) + + # Randomly select a valid solution + solved_board = rng.choice(self._solutions) + + # Create puzzle by removing queens + num_removed = rng.randint(self.config.min_remove, self.config.max_remove) + puzzle = self._create_puzzle(solved_board, num_removed, rng) + puzzle_str = self._board_to_string(puzzle) + + # Filter all solutions that are intractable from the puzzle's starting state + valid_solutions = [board for board in self._solutions if self._is_tractable_solution(puzzle, board)] + valid_solutions_str = sorted({self._board_to_string(board) for board in valid_solutions}) + + return { + "question": QUESTION_TEMPLATE.format(puzzle=puzzle_str, n=len(puzzle), num_removed=num_removed), + "answer": rng.choice(valid_solutions_str), # choose arbitary answer (e.g. for SFT) + "metadata": { + "puzzle": puzzle, + "solutions": valid_solutions, + "num_removed": num_removed, + "valid_answers": valid_solutions_str, + }, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + valid_solutions = entry["metadata"]["valid_answers"] + reward = 0.0 + if answer is not None: + if answer in valid_solutions: + reward = 1.0 + else: + reward = 0.01 + return reward + + +register_dataset("n_queens", NQueensDataset, NQueensConfig) diff --git a/reasoning_gym/games/tower_of_hanoi.py b/reasoning_gym/games/tower_of_hanoi.py new file mode 100644 index 00000000..df902300 --- /dev/null +++ b/reasoning_gym/games/tower_of_hanoi.py @@ -0,0 +1,373 @@ +# reasoning_gym/games/tower_of_hanoi.py + +import math +import random +import re +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class HanoiConfig: + """ + Configuration for the Tower of Hanoi task. + + - min_disks: Minimum number of disks in the puzzle. + - max_disks: Maximum number of disks in the puzzle. + - min_pegs: Minimum number of pegs (minimum 3). + - max_pegs: Maximum number of pegs. + - size: Number of problem instances in the dataset. + - seed: Optional seed for reproducibility. + - visualize: Whether to include a visualization of the initial state. + """ + + min_disks: int = 3 + max_disks: int = 7 + min_pegs: int = 3 + max_pegs: int = 4 + size: int = 50 + seed: Optional[int] = None + visualize: bool = False # New parameter + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.min_disks >= 1, "min_disks must be at least 1" + assert self.max_disks >= self.min_disks, "max_disks must be >= min_disks" + assert self.min_pegs >= 3, "min_pegs must be at least 3" + assert self.max_pegs >= self.min_pegs, "max_pegs must be >= min_pegs" + + +class MoveGenerator: + """ + Helper class to generate valid move sequences for Tower of Hanoi using the Frame-Stewart algorithm. + It maintains the current state of all pegs to ensure move validity. + """ + + def __init__(self, num_disks: int, pegs: List[int], start: int, target: int): + self.num_disks = num_disks + self.pegs = pegs + self.start = start + self.target = target + self.auxiliary_pegs = [peg for peg in pegs if peg not in (start, target)] + self.pegs_state: Dict[int, List[int]] = {peg: [] for peg in pegs} + for disk in range(num_disks, 0, -1): # Largest disk at the bottom + self.pegs_state[start].append(disk) + self.moves: List[str] = [] + self.memo: Dict[Tuple[int, int], int] = {} # Memoization for T(n, k) + + def generate_moves(self) -> List[str]: + self.move(n=self.num_disks, source=self.start, target=self.target, auxiliary_pegs=self.auxiliary_pegs) + return self.moves + + def move(self, n: int, source: int, target: int, auxiliary_pegs: List[int]): + if n == 0: + return + if n == 1: + self._move_disk(source, target) + return + + k = len(auxiliary_pegs) + 2 # Total number of pegs including source and target + + if k < 3: + raise ValueError("At least 3 pegs are required.") + + if k == 3: + # Classic Tower of Hanoi solution + aux = auxiliary_pegs[0] + self.move(n - 1, source, aux, [target]) + self._move_disk(source, target) + self.move(n - 1, aux, target, [source]) + return + + # For k > 3, apply Frame-Stewart algorithm + # Find m that minimizes 2*T(m, k) + T(n - m, k - 1) + min_moves = math.inf + best_m = 1 + for m in range(1, n): + moves_m = self._compute_T(m, k) + moves_n_minus_m = self._compute_T(n - m, k - 1) + total_moves = 2 * moves_m + moves_n_minus_m + if total_moves < min_moves: + min_moves = total_moves + best_m = m + + # Select a temporary peg to hold m disks + temp_peg = auxiliary_pegs[0] + new_auxiliary = [peg for peg in auxiliary_pegs if peg != temp_peg] + + # Step 1: Move top m disks to temp_peg using all pegs + self.move(n=best_m, source=source, target=temp_peg, auxiliary_pegs=auxiliary_pegs[1:] + [target]) + + # Step 2: Move remaining n - m disks to target using k - 1 pegs + self.move(n=n - best_m, source=source, target=target, auxiliary_pegs=new_auxiliary) + + # Step 3: Move m disks from temp_peg to target using all pegs + self.move(n=best_m, source=temp_peg, target=target, auxiliary_pegs=auxiliary_pegs[1:] + [source]) + + def _move_disk(self, from_peg: int, to_peg: int): + if not self.pegs_state[from_peg]: + raise ValueError(f"No disks to move from Peg {from_peg}.") + disk = self.pegs_state[from_peg][-1] + self.pegs_state[from_peg].pop() + self.pegs_state[to_peg].append(disk) + self.moves.append(f"Move disk {disk} from Peg {from_peg} to Peg {to_peg}") + + def _compute_T(self, n: int, k: int) -> int: + """ + Compute the minimal number of moves (T(n, k)) required to move n disks using k pegs. + Utilizes memoization to store previously computed results. + """ + if n == 0: + return 0 + if n == 1: + return 1 + if k == 3: + return 2**n - 1 + if (n, k) in self.memo: + return self.memo[(n, k)] + + min_moves = math.inf + for m in range(1, n): + moves = 2 * self._compute_T(m, k) + self._compute_T(n - m, k - 1) + if moves < min_moves: + min_moves = moves + self.memo[(n, k)] = min_moves + return min_moves + + +class HanoiDataset(ProceduralDataset): + """ + Generates Tower of Hanoi problems with solutions. + Supports variable number of pegs using the optimized Frame-Stewart algorithm with Peg State Tracking. + """ + + def __init__(self, config: HanoiConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.min_pegs = config.min_pegs + self.max_pegs = config.max_pegs + self.min_disks = config.min_disks + self.max_disks = config.max_disks + self.visualize = config.visualize # Initialize the visualize attribute + + def __getitem__(self, idx: int) -> dict: + """ + Generate a Tower of Hanoi problem instance. + + Returns: + dict with: + - "question": Text describing the problem setup. + - "answer": List of moves to solve the puzzle. + - "metadata": Configuration and solution details. + - "initial_state": (Optional) ASCII visualization of the initial pegs. + - "states": (Optional) List of ASCII visualizations after each move. + """ + rng = random.Random(self.seed + idx if self.seed is not None else None) + + # Randomly select number of disks and pegs within the specified ranges + num_disks = rng.randint(self.min_disks, self.max_disks) + num_pegs = rng.randint(self.min_pegs, self.max_pegs) + + # Assign unique peg identifiers (e.g., integers starting from 1) + pegs = list(range(1, num_pegs + 1)) + + """ #Debug: Print current instance configuration + print(f"\n--- Generating Instance {idx} ---") + print(f"Number of Disks: {num_disks}") + print(f"Number of Pegs: {num_pegs}") + print(f"Pegs: {pegs}") + """ + + # Randomly select start and target pegs + start_peg, target_peg = rng.sample(pegs, 2) + + # Auxiliary pegs are the remaining pegs + auxiliary_pegs = [peg for peg in pegs if peg not in (start_peg, target_peg)] + + """ # Debug: Print start, target, and auxiliary pegs + print(f"Start Peg: {start_peg}") + print(f"Target Peg: {target_peg}") + print(f"Auxiliary Pegs: {auxiliary_pegs}") + """ + + # Initialize the MoveGenerator and generate moves + move_gen = MoveGenerator(num_disks, pegs, start_peg, target_peg) + try: + solution = move_gen.generate_moves() + except ValueError as ve: + # print(f"Error during move generation: {ve}") + raise ve + + """ # Debug: Print the solution moves + print(f"Solution Length: {len(solution)}") + print("Solution Moves:") + for move_num, move in enumerate(solution, start=1): + print(f" Move {move_num}: {move}") + """ + + # Initialize pegs_state: all disks start on the start peg + pegs_state = {peg: [] for peg in pegs} + for disk in range(num_disks, 0, -1): # Largest disk at the bottom + pegs_state[start_peg].append(disk) + + # Generate initial state visualization if requested + initial_state_str = None + if self.visualize: + initial_state_str = self._visualize_state(pegs_state) + + # Apply moves to track state changes + states = [] + if self.visualize: + states.append(initial_state_str) # Initial state + for move in solution: + # Parse the move string using regex + try: + disk, from_peg, to_peg = self._parse_move(move) + except ValueError as ve: + # print(f"Error parsing move: {ve}") + raise ve + + # Validate the move + if not self._validate_move(pegs_state, move): + # print(f"Invalid move detected: {move}") + # print(f"Current Pegs State: {pegs_state}") + raise ValueError(f"Invalid move detected: {move}") + + # Move the disk + pegs_state[from_peg].pop() + pegs_state[to_peg].append(disk) + + # Visualize the new state + new_state_str = self._visualize_state(pegs_state) + states.append(new_state_str) + + # Peg labels + peg_labels = {peg: f"Peg {peg}" for peg in pegs} + + question_str = ( + f"Solve the Tower of Hanoi problem with {num_disks} disks and {num_pegs} pegs.\n" + f"Move all disks from {peg_labels[start_peg]} to {peg_labels[target_peg]} following the rules:\n" + "- Only one disk can be moved at a time.\n" + "- A larger disk cannot be placed on top of a smaller disk.\n" + "- All disks must be on a peg at all times.\n" + "Example:\n" + "Move disk 1 from Peg 1 to Peg 3\n" + "Move disk 2 from Peg 1 to Peg 2\n" + "Move disk 1 from Peg 3 to Peg 2\n" + "\n" + "Provide the sequence of moves." + ) + + result = { + "question": question_str, + "answer": solution, + "metadata": { + "num_disks": num_disks, + "num_pegs": num_pegs, + "start_peg": start_peg, + "target_peg": target_peg, + "auxiliary_pegs": auxiliary_pegs, + "solution_length": len(solution), + }, + } + + if self.visualize: + result["initial_state"] = initial_state_str + result["states"] = states # List of all states including initial and after each move + + return result + + def _visualize_state(self, pegs_state: Dict[int, List[int]]) -> str: + """ + Create an ASCII visualization of the current state of the pegs. + Adapts to variable number of pegs. + + Args: + pegs_state (dict): Dictionary mapping peg numbers to lists of disks. + + Returns: + str: ASCII art representing the pegs and disks. + """ + # Determine the number of levels based on the maximum number of disks on any peg + max_height = max(len(disks) for disks in pegs_state.values()) + pegs = sorted(pegs_state.keys()) + + visualization = "" + for level in range(max_height, 0, -1): + for peg in pegs: + if len(pegs_state[peg]) >= level: + disk_size = pegs_state[peg][level - 1] + disk_str = f"[{'*' * disk_size}]" + else: + disk_str = "[ ]" + visualization += disk_str.center(7) # Adjust spacing as needed + visualization += "\n" + + # Add the base and peg numbers + visualization += "-" * (7 * len(pegs)) + "\n" + for peg in pegs: + peg_label = f"P{peg}".center(7) + visualization += peg_label + visualization += "\n" + + return visualization + + def _validate_move(self, pegs_state: Dict[int, List[int]], move: str) -> bool: + """ + Validate that a move adheres to the Tower of Hanoi rules. + + Args: + pegs_state (dict): Current state of the pegs. + move (str): Move instruction, e.g., "Move disk 2 from Peg 1 to Peg 3". + + Returns: + bool: True if the move is valid, False otherwise. + """ + try: + parts = move.split() + if len(parts) != 9: + # print(f"Unexpected move format: '{move}'") + return False + disk = int(parts[2]) + from_peg = int(parts[5]) + to_peg = int(parts[8]) + + # Check if the disk to move is the top disk on the from_peg + if not pegs_state[from_peg] or pegs_state[from_peg][-1] != disk: + # print(f"Disk {disk} is not on top of Peg {from_peg}. Current state: {pegs_state[from_peg]}") + return False + + # Check if placing the disk on the to_peg violates size constraints + if pegs_state[to_peg] and pegs_state[to_peg][-1] < disk: + # print(f"Cannot place disk {disk} on top of smaller disk {pegs_state[to_peg][-1]} on Peg {to_peg}.") + return False + + return True + except Exception as e: + print(f"Error validating move '{move}': {e}") + return False + + def _parse_move(self, move: str) -> Tuple[int, int, int]: + """ + Parse a move string and extract disk number, from peg, and to peg. + + Args: + move (str): Move instruction, e.g., "Move disk 2 from Peg 1 to Peg 3". + + Returns: + tuple: (disk, from_peg, to_peg) + """ + pattern = r"Move disk (\d+) from Peg (\d+) to Peg (\d+)" + match = re.match(pattern, move) + if not match: + raise ValueError(f"Unexpected move format: '{move}'") + + disk = int(match.group(1)) + from_peg = int(match.group(2)) + to_peg = int(match.group(3)) + return disk, from_peg, to_peg + + +# Register the dataset +register_dataset("tower_of_hanoi", HanoiDataset, HanoiConfig) diff --git a/reasoning_gym/geometry/__init__.py b/reasoning_gym/geometry/__init__.py new file mode 100644 index 00000000..6e4e2d1a --- /dev/null +++ b/reasoning_gym/geometry/__init__.py @@ -0,0 +1,9 @@ +from .advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset +from .simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset + +__all__ = [ + "SimpleGeometryConfig", + "SimpleGeometryDataset", + "AdvancedGeometryConfig", + "AdvancedGeometryDataset", +] diff --git a/reasoning_gym/geometry/advanced_geometry.py b/reasoning_gym/geometry/advanced_geometry.py new file mode 100644 index 00000000..ac8797b9 --- /dev/null +++ b/reasoning_gym/geometry/advanced_geometry.py @@ -0,0 +1,216 @@ +import random +from dataclasses import dataclass, field +from typing import List, Optional + +import sympy +from sympy.geometry import Point, Segment, Triangle + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class AdvancedGeometryConfig: + """ + Configuration for generating advanced geometry tasks. + """ + + min_coord: int = -10 # Minimum x/y coordinate + max_coord: int = 10 # Maximum x/y coordinate + size: int = 50 # Number of problems to generate + seed: Optional[int] = None + + # Probability or list of tasks we want to generate + # For demonstration, we have three categories: + task_types: List[str] = field( + default_factory=lambda: [ + "orthocenter", + "incircle_radius", + "angle_measure", + ] + ) + + def validate(self): + assert self.min_coord < self.max_coord, "min_coord must be < max_coord." + assert self.size > 0, "Size of dataset must be positive." + assert len(self.task_types) > 0, "Must specify at least one task type." + + +class AdvancedGeometryDataset(ProceduralDataset): + """ + A dataset for advanced geometry tasks using coordinate geometry. + """ + + def __init__(self, config: AdvancedGeometryConfig): + self._prompt_templates = { + "orthocenter": [ + "Given triangle ABC with coordinates A={A}, B={B}, and C={C}, find the coordinates of its orthocenter.", + "For triangle with vertices A={A}, B={B}, and C={C}, determine the orthocenter (intersection of altitudes).", + ], + "incircle_radius": [ + "Consider triangle ABC with coordinates A={A}, B={B}, and C={C}. Compute the radius of its incircle.", + "Find the incircle radius of triangle ABC whose vertices are A={A}, B={B}, and C={C}.", + ], + "angle_measure": [ + "In triangle ABC with coordinates A={A}, B={B}, and C={C}, find the measure (in degrees) of angle ABC.", + "Given a triangle with vertices A={A}, B={B}, C={C}, determine the angle at B in degrees.", + ], + } + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """ + Generate a single advanced geometry item based on the config's task types. + """ + rng = random.Random(self.seed + idx) + task_type = rng.choice(self.config.task_types) + + # Randomly generate coordinates for a triangle + A, B, C = self._generate_non_degenerate_triangle(rng) + + # Build a question and compute the solution + if task_type == "orthocenter": + question, answer, metadata = self._build_orthocenter_task(rng, A, B, C) + elif task_type == "incircle_radius": + question, answer, metadata = self._build_incircle_radius_task(rng, A, B, C) + elif task_type == "angle_measure": + question, answer, metadata = self._build_angle_measure_task(rng, A, B, C) + else: + raise ValueError(f"Unknown task_type: {task_type}") + + return { + "question": question, + "answer": answer, + "metadata": metadata, + } + + def _generate_non_degenerate_triangle(self, rng: random.Random): + """ + Generate a random non-degenerate triangle with integer coordinates + in [min_coord, max_coord] x [min_coord, max_coord]. + """ + max_attempts = 100 + for _ in range(max_attempts): + # Generate points with integer coordinates + points = [] + for _ in range(3): + x = rng.randint(self.config.min_coord, self.config.max_coord) + y = rng.randint(self.config.min_coord, self.config.max_coord) + points.append(Point(x, y)) + + A, B, C = points + + # Calculate signed area to check for non-degeneracy + # Using the formula: 1/2 * |x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2)| + area = abs(A.x * (B.y - C.y) + B.x * (C.y - A.y) + C.x * (A.y - B.y)) / 2 + + if area > 0: + return A, B, C + + raise ValueError(f"Failed to generate a non-degenerate triangle after {max_attempts} attempts.") + + def _build_orthocenter_task(self, rng: random.Random, A: Point, B: Point, C: Point): + """ + Build a question about finding the orthocenter of triangle ABC. + """ + # Convert segments to lines + BC_line = sympy.Line(B, C) + CA_line = sympy.Line(C, A) + + # Calculate altitudes by creating lines perpendicular from each vertex + alt_A = BC_line.perpendicular_line(A) + alt_B = CA_line.perpendicular_line(B) + + # Find orthocenter (intersection of any two altitudes, e.g. alt_A and alt_B) + ortho = alt_A.intersection(alt_B)[0] + + x_ortho_approx = float(ortho.x.evalf()) + y_ortho_approx = float(ortho.y.evalf()) + + question_template = rng.choice(self._prompt_templates["orthocenter"]) + question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)) + answer_str = f"({x_ortho_approx:.3f}, {y_ortho_approx:.3f})" + + metadata = { + "A": (A.x, A.y), + "B": (B.x, B.y), + "C": (C.x, C.y), + "orthocenter_exact": (str(ortho.x), str(ortho.y)), + "orthocenter_approx": (x_ortho_approx, y_ortho_approx), + } + return question, answer_str, metadata + + def _build_incircle_radius_task(self, rng: random.Random, A: Point, B: Point, C: Point): + """ + Build a question about finding the incircle radius of triangle ABC. + """ + # Calculate side lengths + a = B.distance(C) + b = C.distance(A) + c = A.distance(B) + + # Semi-perimeter + s = (a + b + c) / 2 + + # Area using Heron's formula + area = sympy.sqrt(s * (s - a) * (s - b) * (s - c)) + + # Radius of incircle = Area / Semi-perimeter + radius = area / s + + # Convert to float for final answer + radius_approx = float(radius.evalf()) + + question_template = rng.choice(self._prompt_templates["incircle_radius"]) + question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)) + answer_str = f"{radius_approx:.3f}" + + metadata = { + "A": (A.x, A.y), + "B": (B.x, B.y), + "C": (C.x, C.y), + "incircle_radius_exact": str(radius), + "incircle_radius_approx": radius_approx, + } + return question, answer_str, metadata + + def _build_angle_measure_task(self, rng: random.Random, A: Point, B: Point, C: Point): + """ + Build a question about finding the measure of angle ABC in degrees. + """ + # Angle at B means the angle ∠ABC + # Vector BA = A - B, BC = C - B + BA = A - B + BC = C - B + + # Use vector dot product to find angle between BA and BC + # angle = arccos((BA · BC) / (|BA| * |BC|)) + dot_val = BA.dot(BC) + mag_ba = BA.distance(Point(0, 0)) + mag_bc = BC.distance(Point(0, 0)) + + # numerical check + if mag_ba == 0 or mag_bc == 0: + # degenerate, but theoretically we forced a non-degenerate triangle + angle_deg = 0 + else: + cos_theta = dot_val / (mag_ba * mag_bc) + # clamp cos_theta to [-1, 1] to avoid floating rounding errors + cos_theta = max(-1, min(1, cos_theta)) + angle_rad = sympy.acos(cos_theta) + angle_deg = float(angle_rad.evalf() * 180 / sympy.pi) + + question_template = rng.choice(self._prompt_templates["angle_measure"]) + question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)) + + answer_str = f"{angle_deg:.2f}°" + metadata = { + "A": (A.x, A.y), + "B": (B.x, B.y), + "C": (C.x, C.y), + "angle_ABC_degrees": angle_deg, + } + return question, answer_str, metadata + + +# Register the dataset +register_dataset("advanced_geometry", AdvancedGeometryDataset, AdvancedGeometryConfig) diff --git a/reasoning_gym/geometry/simple_geometry.py b/reasoning_gym/geometry/simple_geometry.py new file mode 100644 index 00000000..d04912d7 --- /dev/null +++ b/reasoning_gym/geometry/simple_geometry.py @@ -0,0 +1,140 @@ +import random +from dataclasses import dataclass +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class SimpleGeometryConfig: + """ + Configuration for generating basic geometry (angle-finding) tasks. + Produces a random convex polygon with N sides, random angles + for the first (N-1) sides, and asks the solver to find the last angle. + """ + + min_sides: int = 3 # Minimum number of sides (e.g. triangle) + max_sides: int = 6 # Maximum number of sides (e.g. hexagon) + min_angle: int = 10 # Minimum angle (in degrees) for each of the first (N-1) angles + max_angle: int = 170 # Maximum angle (in degrees) for each of the first (N-1) angles + seed: Optional[int] = None # Random seed + size: int = 100 # Number of geometry tasks to generate + + def validate(self) -> None: + """ + Validate configuration parameters. + """ + assert self.min_sides >= 3, "min_sides must be at least 3 (triangle)." + assert self.max_sides >= self.min_sides, "max_sides must be >= min_sides." + assert 0 < self.min_angle < 180, "min_angle must be in (0, 180)." + assert self.max_angle <= 179, "max_angle should be less than 180." + assert self.max_angle >= self.min_angle, "max_angle must be >= min_angle." + + +class SimpleGeometryDataset(ProceduralDataset): + """ + A dataset for simple polygon angle-finding tasks. + We randomly choose the number of sides N within [min_sides, max_sides]. + We then generate (N-1) random angles (in degrees), ensuring their sum is + strictly less than the total sum for an (N)-sided convex polygon (which is 180*(N-2)). + The question asks for the missing angle; the answer is computed by subtracting the + sum of known angles from 180*(N-2). + """ + + def __init__(self, config: SimpleGeometryConfig): + self._prompt_templates = [ + ( + "Given a convex polygon with {n_sides} sides, its first {n_minus_1} interior angles " + "are: {angle_list}. What is the measure of the remaining interior angle (in degrees)?" + ), + ( + "A convex polygon has {n_sides} sides. The measures of " + "the first {n_minus_1} interior angles are: {angle_list}. " + "Find the measure of the last interior angle." + ), + ( + "Consider a convex {n_sides}-gon whose first {n_minus_1} interior angles " + "are: {angle_list}. Determine the measure of the remaining angle." + ), + ] + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """ + Generate a single geometry angle-finding item. + + Returns: + A dict with: + - question: str + - answer: str (the missing angle, as an integer or float in degrees) + - metadata: dict (n_sides, angles, sum_of_known, missing_angle, etc.) + """ + rng = random.Random(self.seed + idx) + + # Randomly pick the number of sides + n_sides = rng.randint(self.config.min_sides, self.config.max_sides) + + # Total interior angle sum for a convex n_sides-gon + total_sum = 180 * (n_sides - 2) + + # Generate (n_sides - 1) random angles, ensuring their sum < total_sum + known_angles = self._generate_valid_angles(rng, n_sides, total_sum) + + # Missing angle + missing_angle = total_sum - sum(known_angles) + + # Build the question string + angle_list_str = ", ".join(f"{a:.1f}°" for a in known_angles) + prompt = rng.choice(self._prompt_templates).format( + n_sides=n_sides, n_minus_1=n_sides - 1, angle_list=angle_list_str + ) + + # Round the missing angle to one decimal place or integer if it is very close to an integer + # so that the answer remains consistent and clean + missing_angle_rounded = round(missing_angle, 1) + if abs(missing_angle_rounded - round(missing_angle_rounded)) < 1e-6: + # If it is effectively an integer, keep it as int + missing_angle_rounded = int(missing_angle_rounded) + + answer_str = str(missing_angle_rounded) + + return { + "question": prompt, + "answer": answer_str, + "metadata": { + "n_sides": n_sides, + "known_angles": known_angles, + "sum_of_known_angles": sum(known_angles), + "missing_angle_raw": missing_angle, + "missing_angle_rounded": missing_angle_rounded, + "total_interior_sum": total_sum, + }, + } + + def _generate_valid_angles(self, rng: random.Random, n_sides: int, total_sum: int): + """ + Generate (n_sides - 1) random angles in [min_angle, max_angle], + ensuring the sum is strictly less than total_sum to keep a valid missing angle. + We keep retrying until we find a valid set or reach a max attempt limit. + """ + max_attempts = 100 + for _ in range(max_attempts): + angles = [] + # We choose angles one by one + for _ in range(n_sides - 1): + angle = rng.randint(self.config.min_angle, self.config.max_angle) + angles.append(float(angle)) + + # Check if the sum is strictly less than total_sum + if sum(angles) < total_sum: + return angles + + # If we fail after max_attempts, raise an error + raise ValueError( + f"Could not generate valid angles for an {n_sides}-gon " + f"with total sum {total_sum} within {max_attempts} attempts." + ) + + +# Register the dataset so it can be accessed similarly to the others +register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig) diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 6ba042a8..ee278b33 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -1,5 +1,5 @@ import random -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import StrEnum from itertools import count from typing import List, Optional, Set, Tuple @@ -37,12 +37,8 @@ class Person: gender: Gender id: int spouse: Optional["Person"] = None - parents: List["Person"] = None - children: List["Person"] = None - - def __post_init__(self): - self.parents = self.parents or [] - self.children = self.children or [] + parents: List["Person"] = field(default_factory=list) + children: List["Person"] = field(default_factory=list) def __hash__(self): return self.id @@ -69,14 +65,8 @@ class FamilyRelationshipsConfig: min_family_size: int = 4 max_family_size: int = 8 - male_names: List[str] = None - female_names: List[str] = None - seed: Optional[int] = None - size: int = 500 - - def __post_init__(self): - # Default name lists if none provided - default_male_names = [ + male_names: List[str] = field( + default_factory=lambda: [ "James", "John", "Robert", @@ -121,7 +111,9 @@ class FamilyRelationshipsConfig: "Ryder", "Finn", ] - default_female_names = [ + ) + female_names: List[str] = field( + default_factory=lambda: [ "Mary", "Patricia", "Jennifer", @@ -166,11 +158,9 @@ class FamilyRelationshipsConfig: "Sky", "Rain", ] - - if self.male_names is None: - self.male_names = default_male_names - if self.female_names is None: - self.female_names = default_female_names + ) + seed: Optional[int] = None + size: int = 500 def validate(self) -> None: """Validate configuration parameters""" diff --git a/reasoning_gym/graphs/quantum_lock.py b/reasoning_gym/graphs/quantum_lock.py index 402b6f0c..5863a5bf 100644 --- a/reasoning_gym/graphs/quantum_lock.py +++ b/reasoning_gym/graphs/quantum_lock.py @@ -28,7 +28,7 @@ class QuantumLockDataset(ProceduralDataset): self._prompt_templates = [ """\ In front of you are some buttons, a light, and a number. The light will toggle between red and green whenever you press a button. Each button performs a mathematical operation to the number, but the operation may depend on the state of the light. -You must press the shortest correct sequence of buttons to reach the target value. +You must press the shortest correct sequence of buttons to reach the target value. Your answer should be a sequence of buttons separated by '→', for example: A → B → C Start: {initial_value} ({initial_state}) Target: {target_value} diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index c2c07625..38307647 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -6,10 +6,13 @@ Logic tasks for training reasoning capabilities: - Syllogisms """ +from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset from .syllogisms import SyllogismConfig, SyllogismDataset, Term __all__ = [ + "AliceInWonderlandConfig", + "AliceInWonderlandDataset", "PropositionalLogicConfig", "PropositionalLogicDataset", "SyllogismConfig", diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py new file mode 100644 index 00000000..0c864cc4 --- /dev/null +++ b/reasoning_gym/logic/aiw.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass, field +from enum import Enum +from random import Random +from string import Template +from typing import List, Optional + +from ..factory import ProceduralDataset, register_dataset + + +class TaskType(Enum): + """Defines the type of task for the Alice in Wonderland dataset.""" + + SIBLINGS = "siblings" + FRIENDS = "friends" + COLLEAGUES = "colleagues" # Added colleagues task + + +@dataclass +class AliceInWonderlandConfig: + """Configuration options for the Alice in Wonderland dataset. + + Attributes: + male_names (List[str]): List of male names to use in questions. + female_names (List[str]): List of female names to use in questions. Must include 'Alice'. + task_types (List[TaskType]): List of task types to include in dataset. + seed (Optional[int]): Seed for random number generation. + size (int): Number of samples in the dataset. + max_entities (int): Max number of siblings/friends/colleagues in questions. + """ + + male_names: List[str] = field( + default_factory=lambda: [ + "James", + "John", + "Robert", + "Michael", + "William", + "David", + "Richard", + "Joseph", + "Thomas", + "Charles", + "Bob", + ] + ) + female_names: List[str] = field( + default_factory=lambda: [ + "Mary", + "Patricia", + "Jennifer", + "Linda", + "Elizabeth", + "Barbara", + "Susan", + "Jessica", + "Sarah", + "Margaret", + "Alice", + ] + ) + task_types: List[TaskType] = field( + default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues + ) + seed: Optional[int] = None + size: int = 10 + max_entities: int = 6 # Added max_entities + + def validate(self) -> None: + """Validates the configuration parameters.""" + assert len(self.male_names) > 0, "must provide male names" + assert len(self.female_names) > 0, "must provide female names" + assert "Alice" in self.female_names, "'Alice' must be in female names" + assert len(self.task_types) > 0, "must provide at least one task type" + assert self.max_entities > 0, "max_entities must be positive" + + +class AliceInWonderlandDataset(ProceduralDataset): + """ + A procedural dataset inspired by the "Alice in Wonderland" paper. + + The dataset is inspired by the following paper: + @inproceedings{nezhurina2024alice, + title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and + Basic Reasoning Deficits in State-Of-the-Art Large Language Models}, + author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and + Jenia Jitsev}, + booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding + Deep Learning}, + year={2024}, + url={https://openreview.net/forum?id=Mkl7dzjYiW} + } + + """ + + def __init__(self, config: AliceInWonderlandConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.templates = { + TaskType.SIBLINGS: [ + Template( + "$female_name has $num_brothers brothers and she also has " + "$num_sisters sisters. How many sisters does " + "$female_name's brother have?" + ), + Template( + "$female_name has $num_sisters sisters and she also has " + "$num_brothers brothers. How many sisters does " + "$male_name's brother have?" + ), + ], + TaskType.FRIENDS: [ + Template( + "$female_name has $num_male male friends and she also has " + "$num_female female friends. They all are friends with each " + "other and have no other friends aside. How many female " + "friends does $male_name, a male friend of $female_name, " + "have?" + ) + ], + TaskType.COLLEAGUES: [ # New colleagues templates + Template( + "$female_name has $num_male_colleagues_alice_circle male colleagues and she also has " + "$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. " + "All these mentioned persons around $female_name are colleagues of each other. " + "$male_name has $num_male_colleagues_bob_circle male colleagues " + "and $num_female_colleagues_bob_circle female colleagues in total. " + "All these mentioned persons around $male_name are colleagues of each other. " + "The people in the circle around $male_name do not have " + "other colleagues aside - with the only exception of Matilda. " + "She is colleague of $male_name and she is also colleague of $female_name, " + "being part of $female_name's circle. How many female colleagues does Matilda have?" + ), + ], + } + + def _get_aiw(self, rng: Random) -> dict: + """Generates a single Alice in Wonderland question. + + Args: + rng (Random): Random number generator. + + Returns: + dict: A dictionary containing the generated question, the right answer + and a description of the example. + """ + task_type = rng.choice(self.config.task_types) + female_name = rng.choice(self.config.female_names) + male_name = rng.choice(self.config.male_names) + + if task_type == TaskType.SIBLINGS: + num_brothers = rng.randint(1, self.config.max_entities) + num_sisters = rng.randint(1, self.config.max_entities) + + answer = num_sisters + 1 + template = rng.choice(self.templates[TaskType.SIBLINGS]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_brothers=num_brothers, + num_sisters=num_sisters, + ) + elif task_type == TaskType.FRIENDS: + num_male = rng.randint(1, self.config.max_entities) + num_female = rng.randint(1, self.config.max_entities) + + answer = num_female + 1 + template = rng.choice(self.templates[TaskType.FRIENDS]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_male=num_male, + num_female=num_female, + ) + elif task_type == TaskType.COLLEAGUES: + num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities) + num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities) + num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities) + num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities) + + answer = num_female_colleagues_alice_circle + 1 + template = rng.choice(self.templates[TaskType.COLLEAGUES]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_male_colleagues_alice_circle=num_male_colleagues_alice_circle, + num_female_colleagues_alice_circle=num_female_colleagues_alice_circle, + num_male_colleagues_bob_circle=num_male_colleagues_bob_circle, + num_female_colleagues_bob_circle=num_female_colleagues_bob_circle, + ) + + return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}} + + def __getitem__(self, idx: int) -> dict: + rng = Random(self.seed + idx) + return self._get_aiw(rng) + + +register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py index 0af9d1b2..a5bbb219 100644 --- a/reasoning_gym/logic/syllogisms.py +++ b/reasoning_gym/logic/syllogisms.py @@ -206,6 +206,13 @@ class SyllogismDataset(ProceduralDataset): return False + def _format_quantifier_statement(self, quantifier: Quantifier, subject: Term, predicate: Term) -> str: + """Format a quantified statement in natural language""" + if quantifier == Quantifier.SOME_NOT: + return f"Some {subject.plural} are not {predicate.plural}" + else: + return f"{quantifier.value} {subject.plural} are {predicate.plural}" + def _generate_syllogism(self, rng: Random) -> dict: """Generate a single syllogism problem""" # Select three different terms @@ -226,9 +233,9 @@ class SyllogismDataset(ProceduralDataset): conclusion = (rng.choice(quantifiers), terms[0], terms[2]) # Format the syllogism as text - premise1_text = f"{premise1[0].value} {premise1[1].plural} are {premise1[2].plural}" - premise2_text = f"{premise2[0].value} {premise2[1].plural} are {premise2[2].plural}" - conclusion_text = f"{conclusion[0].value} {conclusion[1].plural} are {conclusion[2].plural}" + premise1_text = self._format_quantifier_statement(premise1[0], premise1[1], premise1[2]) + premise2_text = self._format_quantifier_statement(premise2[0], premise2[1], premise2[2]) + conclusion_text = self._format_quantifier_statement(conclusion[0], conclusion[1], conclusion[2]) question = ( f"Consider these statements:\n" diff --git a/requirements-dev.txt b/requirements-dev.txt index 8b1c25a4..b96fc1c2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ pytest>=8.3.4 +pytest-cov>=6.0.0 black>=24.10.0 isort>=5.13.2 flake8>=7.1.1 diff --git a/tests/test_advanced_geometry.py b/tests/test_advanced_geometry.py new file mode 100644 index 00000000..9eec1b36 --- /dev/null +++ b/tests/test_advanced_geometry.py @@ -0,0 +1,83 @@ +import pytest + +from reasoning_gym.geometry.advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset + + +def test_advanced_geometry_config_validation(): + """Test that invalid configs raise appropriate errors.""" + # min_coord >= max_coord + with pytest.raises(AssertionError): + config = AdvancedGeometryConfig(min_coord=5, max_coord=5) + config.validate() + + with pytest.raises(AssertionError): + config = AdvancedGeometryConfig(min_coord=10, max_coord=0) + config.validate() + + # size <= 0 + with pytest.raises(AssertionError): + config = AdvancedGeometryConfig(size=0) + config.validate() + + # Empty task_types + with pytest.raises(AssertionError): + config = AdvancedGeometryConfig(task_types=[]) + config.validate() + + +def test_advanced_geometry_dataset_deterministic(): + """Test the dataset generates the same items with the same seed.""" + config = AdvancedGeometryConfig(min_coord=-5, max_coord=5, size=5, seed=42) + dataset1 = AdvancedGeometryDataset(config) + dataset2 = AdvancedGeometryDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i], ( + f"Item mismatch at index {i} for same seed. " f"Dataset1: {dataset1[i]} vs Dataset2: {dataset2[i]}" + ) + + +def test_advanced_geometry_dataset_items(): + """Test basic properties of generated items.""" + config = AdvancedGeometryConfig(min_coord=-3, max_coord=3, size=5, seed=123) + dataset = AdvancedGeometryDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check structure + assert isinstance(item, dict), "Generated item must be a dictionary." + assert "question" in item, "Item must contain a 'question' key." + assert "answer" in item, "Item must contain an 'answer' key." + assert "metadata" in item, "Item must contain a 'metadata' key." + + # Basic metadata checks + metadata = item["metadata"] + assert ( + "A" in metadata and "B" in metadata and "C" in metadata + ), "Metadata should contain coordinates for points A, B, and C." + + # Check answer format depending on task type + # For angle measure tasks, answer should end with '°' + if "angle_measure" in item["question"].lower() or "angle at" in item["question"].lower(): + assert item["answer"].endswith("°"), f"Expected angle measure in degrees, got {item['answer']}" + + +def test_advanced_geometry_dataset_iteration(): + """Test that iteration respects dataset size and is repeatable.""" + config = AdvancedGeometryConfig(min_coord=-2, max_coord=2, size=3, seed=999) + dataset = AdvancedGeometryDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly 'size' items." + + # Test list conversion + items_list = list(dataset) + assert len(items_list) == config.size, "List conversion should yield exactly 'size' items." + + # Test multiple iterations produce the same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield the same items." diff --git a/tests/test_aiw.py b/tests/test_aiw.py new file mode 100644 index 00000000..5a2fb454 --- /dev/null +++ b/tests/test_aiw.py @@ -0,0 +1,96 @@ +import pytest + +from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType + + +def test_aiw_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(male_names=[]) # Empty male names + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(female_names=[]) # Empty female names + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(task_types=[]) # No task types + config.validate() + + +def test_aiw_deterministic(): + """Test that dataset generates same items with same seed""" + config = AliceInWonderlandConfig(seed=42, size=10) + dataset1 = AliceInWonderlandDataset(config) + dataset2 = AliceInWonderlandDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_aiw_items(): + """Test basic properties of generated items""" + config = AliceInWonderlandConfig(size=50, seed=42) + dataset = AliceInWonderlandDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Verify answer is numeric and positive + answer = int(item["answer"]) + assert answer > 0 + + # Verify question contains at least one female name + female_names = config.female_names + assert any(name in item["question"] for name in female_names) + + # Verify question task type characteristics + task_type = item["metadata"]["task_type"] + if task_type == TaskType.SIBLINGS.value: + assert any(phrase in item["question"] for phrase in ["brothers", "sisters"]) + elif task_type == TaskType.FRIENDS.value: + assert "friends" in item["question"] + elif task_type == TaskType.COLLEAGUES: + assert "colleagues" in item["question"] + + +def test_aiw_iteration(): + """Test that iteration works correctly""" + config = AliceInWonderlandConfig(size=5, seed=42) + dataset = AliceInWonderlandDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size + + # Test list conversion + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items + + +def test_aiw_random_ranges(): + """Test that generated numbers stay within expected ranges""" + config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12) + dataset = AliceInWonderlandDataset(config) + + for item in dataset: + question = item["question"] + numbers = [int(n) for n in question.split() if n.isdigit()] + + # Check all numbers are in reasonable range (1-6 as per implementation) + assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" diff --git a/tests/test_arc_1d.py b/tests/test_arc_1d.py new file mode 100644 index 00000000..1679fb50 --- /dev/null +++ b/tests/test_arc_1d.py @@ -0,0 +1,107 @@ +import pytest + +from reasoning_gym.cognition import Arc1DConfig, Arc1DDataset + + +def test_arc_1d_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = Arc1DConfig(min_size=0) + config.validate() + + with pytest.raises(AssertionError): + config = Arc1DConfig(min_size=30, max_size=20) + config.validate() + + with pytest.raises(AssertionError): + config = Arc1DConfig(num_train=0) + config.validate() + + +def test_arc_1d_deterministic(): + """Test that dataset generates same items with same seed""" + config = Arc1DConfig(seed=42, size=10) + dataset1 = Arc1DDataset(config) + dataset2 = Arc1DDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_arc_1d_items(): + """Test basic properties of generated items""" + config = Arc1DConfig(min_size=10, max_size=15, num_train=2, size=50, seed=42) + dataset = Arc1DDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contents + metadata = item["metadata"] + assert "task_name" in metadata + assert "size" in metadata + assert "train_examples" in metadata + assert "test_example" in metadata + + # Verify size constraints + assert config.min_size <= metadata["size"] <= config.max_size + + # Check training examples + train_examples = metadata["train_examples"] + assert len(train_examples) == config.num_train + for example in train_examples: + assert "input" in example + assert "output" in example + assert len(example["input"]) == metadata["size"] + assert len(example["output"]) == metadata["size"] + + # Check test example + test_example = metadata["test_example"] + assert "input" in test_example + assert "output" in test_example + assert len(test_example["input"]) == metadata["size"] + assert len(test_example["output"]) == metadata["size"] + + +def test_arc_1d_iteration(): + """Test that iteration respects dataset size""" + config = Arc1DConfig(size=5, seed=42) # Small size for testing + dataset = Arc1DDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_arc_1d_scoring(): + """Test answer scoring logic""" + config = Arc1DConfig(size=1, seed=42) + dataset = Arc1DDataset(config) + entry = dataset[0] + + # Test exact match + assert dataset.score_answer(entry["answer"], entry) == 1.0 + + # Test partial match (answer contained within response) + assert dataset.score_answer(f"The answer is: {entry['answer']}", entry) == 0.5 + + # Test incorrect answer + assert dataset.score_answer("wrong answer", entry) == 0.01 + + # Test None answer + assert dataset.score_answer(None, entry) == 0.0 diff --git a/tests/test_arc_1d_tasks.py b/tests/test_arc_1d_tasks.py new file mode 100644 index 00000000..98d9d2ef --- /dev/null +++ b/tests/test_arc_1d_tasks.py @@ -0,0 +1,122 @@ +import random + +import pytest + +from reasoning_gym.cognition.arc_1d_tasks import ( + task_block_and_noise_remove, + task_block_and_noise_remove_inside, + task_block_scale_to_dot, + task_block_touch_dot, + task_block_touch_dot_n_pix, + task_change_to_five, + task_color_left_half_blocks, + task_copy_block_to_dots, + task_copy_block_to_dots_colors, + task_duplicate_block_from_seeds, + task_fill_from_pixel, + task_fill_until_collision, + task_gravity, + task_gravity_antigravity, + task_gravity_counting, + task_gravity_one_step, + task_gravity_weighted_colors, + task_identity, + task_inverse, + task_mark_size_two_blocks, + task_mirror, + task_move_block_by_own_size, + task_move_n_pix, + task_move_n_pix_wrapped, + task_paint_biggest_block, + task_recolor_blocks_by_size, + task_recolor_blocks_from_palette, + task_reflect_block_around_dot, + task_reflect_block_with_border_pixel, + task_reflect_block_with_border_pixel_random, + task_repeat_pattern_full, + task_sort_blocks_by_size, + task_sort_complete_sequence, + task_two_points_and_fill, +) + + +def test_all_arc_1d_tasks(): + """Test that all ARC 1D task functions can be executed without exceptions.""" + rng = random.Random(42) # Fixed seed for reproducibility + size = 20 # Reasonable size for testing + + # Test all task functions + # Fixed move_pix value for testing + move_pix = 2 + + # Test task augmentation functions + base_task = task_move_n_pix(rng, size, move_pix, True) + assert base_task is not None + + mirrored = task_mirror(base_task) + assert mirrored is not None + assert mirrored["input"] == list(reversed(base_task["input"])) + assert mirrored["output"] == list(reversed(base_task["output"])) + + inversed = task_inverse(base_task) + assert inversed is not None + assert inversed["input"] == base_task["output"] + assert inversed["output"] == base_task["input"] + + identical = task_identity(base_task) + assert identical is not None + assert identical == base_task + + tasks = [ + (task_move_n_pix, {"move_pix": move_pix, "solid": True}), + (task_move_n_pix_wrapped, {"move_pix": move_pix, "solid": True}), + (task_gravity, {}), + (task_gravity_counting, {}), + (task_gravity_antigravity, {}), + (task_block_touch_dot, {}), + (task_block_touch_dot_n_pix, {"move_pix": move_pix}), + (task_block_scale_to_dot, {}), + (task_two_points_and_fill, {}), + (task_reflect_block_with_border_pixel, {}), + (task_reflect_block_with_border_pixel_random, {}), + (task_reflect_block_around_dot, {}), + (task_block_and_noise_remove, {}), + (task_block_and_noise_remove_inside, {}), + (task_copy_block_to_dots, {}), + (task_copy_block_to_dots_colors, {}), + (task_paint_biggest_block, {}), + (task_sort_blocks_by_size, {}), + (task_sort_complete_sequence, {}), + (task_recolor_blocks_by_size, {}), + (task_gravity_one_step, {}), + (task_move_block_by_own_size, {}), + (task_change_to_five, {}), + (task_recolor_blocks_from_palette, {}), + (task_duplicate_block_from_seeds, {}), + (task_fill_from_pixel, {}), + (task_mark_size_two_blocks, {}), + (task_fill_until_collision, {}), + (task_repeat_pattern_full, {}), + (task_gravity_weighted_colors, {}), + (task_color_left_half_blocks, {}), + ] + + for task_func, kwargs in tasks: + # Try multiple times as some functions might return None for certain inputs + success = False + for _ in range(10): # Try up to 10 times + try: + result = task_func(rng, size, **kwargs) + if result is not None: + success = True + # Basic structure checks + assert isinstance(result, dict) + assert "input" in result + assert "output" in result + assert len(result["input"]) == size + assert len(result["output"]) == size + break + except Exception as e: + pytest.fail(f"Task {task_func.__name__} failed with error: {str(e)}") + + assert success, f"Task {task_func.__name__} always returned None in 10 attempts" diff --git a/tests/test_base_conversion.py b/tests/test_base_conversion.py index 7c8edf1e..8017d74a 100644 --- a/tests/test_base_conversion.py +++ b/tests/test_base_conversion.py @@ -65,9 +65,20 @@ def test_base_conversion_dataset_items(): # Verify conversion correctness decimal_value = item["metadata"]["decimal_value"] target_base = item["metadata"]["target_base"] - expected = format(decimal_value, "x" if target_base == 16 else "b" if target_base == 2 else "").strip() - if target_base not in (2, 16): - expected = format(decimal_value, f"{target_base}x").lower().strip() + + # Use same conversion logic as implementation + if target_base == 16: + expected = format(decimal_value, "x") + elif target_base == 2: + expected = format(decimal_value, "b") + else: + # Manual conversion for other bases + n = decimal_value + digits = [] + while n: + digits.append(int(n % target_base)) + n //= target_base + expected = "".join(str(d) if d < 10 else chr(ord("a") + d - 10) for d in reversed(digits) or [0]) assert item["answer"] == expected @@ -83,6 +94,25 @@ def test_base_conversion_dataset_iteration(): assert items == list(dataset) +def test_base_conversion_validity(): + """Test that generated numbers are valid for their bases""" + config = BaseConversionConfig(min_base=2, max_base=36, min_value=0, max_value=1000, size=100, seed=42) + dataset = BaseConversionDataset(config) + + def is_valid_for_base(num_str: str, base: int) -> bool: + valid_chars = "0123456789abcdefghijklmnopqrstuvwxyz"[:base] + return all(c in valid_chars for c in num_str.lower()) + + for i in range(len(dataset)): + item = dataset[i] + assert is_valid_for_base( + item["metadata"]["source_repr"], item["metadata"]["source_base"] + ), f"Invalid source number {item['metadata']['source_repr']} for base {item['metadata']['source_base']}" + assert is_valid_for_base( + item["metadata"]["target_repr"], item["metadata"]["target_base"] + ), f"Invalid target number {item['metadata']['target_repr']} for base {item['metadata']['target_base']}" + + def test_base_conversion_special_bases(): """Test conversion between special bases (binary, hex)""" config = BaseConversionConfig( diff --git a/tests/test_arithmetic.py b/tests/test_basic_arithmetic.py similarity index 100% rename from tests/test_arithmetic.py rename to tests/test_basic_arithmetic.py diff --git a/tests/test_calendar_arithmetic.py b/tests/test_calendar_arithmetic.py new file mode 100644 index 00000000..87f84781 --- /dev/null +++ b/tests/test_calendar_arithmetic.py @@ -0,0 +1,198 @@ +import calendar +import math +from datetime import date + +import pytest + +from reasoning_gym.arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset + +WEEKDAYS = [ + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", +] + +WEEKDAY_TASKS = { + "weekday_offset", + "weekday_of_date_from_first_day", + "weekday_of_date", +} +NUMERIC_TASKS = { + "count_days", + "count_business_days", +} +DAY_TASKS = {"recurring_event_day"} +BOOLEAN_TASKS = {"is_leap_year"} +CALENDAR_TASKS = WEEKDAY_TASKS | NUMERIC_TASKS | DAY_TASKS | BOOLEAN_TASKS + + +def test_calendar_config_validation(): + """Test that invalid CalendarArithmeticConfig parameters raise appropriate errors.""" + with pytest.raises(ValueError): + config = CalendarArithmeticConfig(year=0) + config.validate() + + with pytest.raises(ValueError): + config = CalendarArithmeticConfig(size=0) + config.validate() + + with pytest.raises(ValueError): + config = CalendarArithmeticConfig(seed="not_an_int") + config.validate() + + with pytest.raises(ValueError): + config = CalendarArithmeticConfig(tasks=["invalid_task"]) + + +def test_calendar_deterministic(): + """Test that a dataset with a fixed seed produces the same items.""" + config = CalendarArithmeticConfig(year=2024, seed=42, size=10) + ds1 = CalendarArithmeticDataset(config) + ds2 = CalendarArithmeticDataset(config) + + for i in range(len(ds1)): + assert ds1[i] == ds2[i] + + +def test_calendar_item_structure(): + """Test that dataset items have the correct structure and fields.""" + config = CalendarArithmeticConfig(year=2024, seed=42, size=50) + dataset = CalendarArithmeticDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert all(key in item for key in ["question", "answer", "metadata"]) + + assert isinstance(item["question"], str) and len(item["question"]) > 0 + assert isinstance(item["answer"], str) and len(item["answer"]) > 0 + assert "task" in item["metadata"] + assert item["metadata"]["task"] in CALENDAR_TASKS + + +def test_calendar_answer_format(): + """Test that answers have the correct format based on task type.""" + config = CalendarArithmeticConfig(year=2024, seed=42, size=100) + dataset = CalendarArithmeticDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + task = item["metadata"]["task"] + answer = item["answer"] + + if task in WEEKDAY_TASKS: + assert answer in WEEKDAYS + + elif task in NUMERIC_TASKS: + try: + num = int(answer) + assert num >= 0, f"task {task} produced a negative count: {num}" + except ValueError: + pytest.fail(f"task {task} produced a non-integer answer: {answer}") + + elif task in BOOLEAN_TASKS: + assert answer in ["Yes", "No"] + + elif task in DAY_TASKS: + try: + num = int(answer) + year = item["metadata"]["year"] + month = item["metadata"]["month"] + _, last_day = calendar.monthrange(year, month) + assert 1 <= num <= last_day + except ValueError: + pytest.fail(f"task {task} produced a day outside expected range (1-{last_day}): {answer}") + + +def test_scoring_function(): + """Test scoring function for different answer types.""" + config = CalendarArithmeticConfig(year=2024, seed=42, size=1) + dataset = CalendarArithmeticDataset(config) + + weekday_item = {"answer": "Monday", "metadata": {"task": "weekday_offset"}} + + assert dataset.score_answer("Monday", weekday_item) == 1.0 + assert dataset.score_answer("Tuesday", weekday_item) == 0.1 + assert dataset.score_answer("It is Monday", weekday_item) == 0.0 + assert dataset.score_answer("no weekday here", weekday_item) == 0.0 + assert dataset.score_answer(None, weekday_item) == 0.0 + + numeric_item = {"answer": "10", "metadata": {"task": "count_business_days"}} + assert dataset.score_answer("10", numeric_item) == 1.0 + assert dataset.score_answer("15", numeric_item) == pytest.approx(math.exp(-5 * 0.5)) + assert dataset.score_answer("no number", numeric_item) == 0.0 + assert dataset.score_answer(None, numeric_item) == 0.0 + + boolean_item = {"answer": "Yes", "metadata": {"task": "is_leap_year"}} + assert dataset.score_answer("Yes", boolean_item) == 1.0 + assert dataset.score_answer("yes", boolean_item) == 1.0 + assert dataset.score_answer("nyes", boolean_item) == 0.0 + assert dataset.score_answer(None, boolean_item) == 0.0 + + +def test_calendar_date_consistency(): + """Test that dates in metadata are consistent with config year.""" + config = CalendarArithmeticConfig(year=2024, seed=42, size=50) + dataset = CalendarArithmeticDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + task = item["metadata"]["task"] + + if task == "weekday_offset": + start_date = date.fromisoformat(item["metadata"]["start_date"]) + assert start_date.year == config.year + + elif task in {"weekday_of_date_from_first_day", "weekday_of_date"}: + target_date = date.fromisoformat(item["metadata"]["target_date"]) + assert target_date.year == config.year + + elif task in {"count_business_days", "count_days"}: + start_date = date.fromisoformat(item["metadata"]["start_date"]) + end_date = date.fromisoformat(item["metadata"]["end_date"]) + assert start_date.year == config.year + assert end_date.year == config.year + + elif task == "recurring_event_day": + meta_year = item["metadata"]["year"] + month = item["metadata"]["month"] + answer = int(item["answer"]) + assert meta_year == config.year + assert 1 <= month <= 12 + if answer != -1: + _, last_day = calendar.monthrange(meta_year, month) + assert 1 <= answer <= last_day + + elif task == "is_leap_year": + year = item["metadata"]["year"] + assert config.year - 200 <= year <= config.year + 200 + is_leap_metadata = item["metadata"]["is_leap"] + computed_is_leap = calendar.isleap(year) + assert is_leap_metadata == computed_is_leap + + +def test_calendar_iteration(): + """Test that dataset iteration works correctly and is deterministic.""" + config = CalendarArithmeticConfig(year=2024, seed=42, size=5) + dataset = CalendarArithmeticDataset(config) + + items = [item for item in dataset] + assert len(items) == config.size + + first_iter = list(dataset) + second_iter = list(dataset) + assert first_iter == second_iter + + +def test_task_case_sensitivity(): + """Test that task names are case-insensitive.""" + tasks = ["WEEKDAY_OFFSET", "Count_Business_Days"] + config = CalendarArithmeticConfig(tasks=tasks, size=10) + dataset = CalendarArithmeticDataset(config) + + for item in dataset: + assert item["metadata"]["task"] in [t.lower() for t in tasks] diff --git a/tests/test_countdown.py b/tests/test_countdown.py index e426caf2..e78a69ab 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -64,6 +64,16 @@ def test_countdown_game_items(): # Verify expression evaluates correctly expr = item["metadata"]["expression"] + + # check score + assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 # correct answer + assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 # wrong answer but an attempt + assert ( + dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01 + ) # wrong answer but incorrectly formatted + assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 # wrong answer but empty string + assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 # no answer + try: result = eval(expr) # Safe here since we control expression generation assert result == item["metadata"]["target"] diff --git a/tests/test_intermediate_integration.py b/tests/test_intermediate_integration.py new file mode 100644 index 00000000..df62ea76 --- /dev/null +++ b/tests/test_intermediate_integration.py @@ -0,0 +1,144 @@ +"""Tests for intermediate integration task generation""" + +import pytest +import sympy +from sympy.parsing.sympy_parser import parse_expr + +from reasoning_gym.algebra.intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset + + +def test_intermediate_integration_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(problem_types=["invalid_problem_type"]) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(substitution_types=["invalid_substitution_type"]) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(by_parts_types=["invalid_by_parts_type"]) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(linear_lower_bound=2, linear_upper_bound=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(linear_lower_bound=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_linear_degree=5, max_linear_degree=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_linear_degree=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(outer_constant_min=5, outer_constant_max=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(outer_constant_min=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_poly_degree=5, max_poly_degree=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_poly_degree=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(symbols=("x", "y")) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(operators=("+", "-", "*", "/")) + config.validate() + + +def test_intermediate_integration_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = IntermediateIntegrationConfig(seed=42, size=10) + dataset1 = IntermediateIntegrationDataset(config) + dataset2 = IntermediateIntegrationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_intermediate_integration_dataset_items(): + """Test that dataset items are valid""" + config = IntermediateIntegrationConfig(seed=42, size=10) + dataset = IntermediateIntegrationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + assert "integrand" in item["metadata"] + assert "problem_type" in item["metadata"] + assert "variable" in item["metadata"] + assert "type" in item["metadata"] + + # verify answer is mathematical expression + answer = item["answer"] + answer = answer.replace(" + C", "") + assert isinstance(parse_expr(answer), sympy.Expr) + + +def test_verify_answer(): + config = IntermediateIntegrationConfig(seed=42) + dataset = IntermediateIntegrationDataset(config) + for i in range(len(dataset)): + item = dataset[i] + score = dataset.score_answer(item["answer"], item["metadata"]) + assert score == 1.0 + + +def test_score_answer_cases(): + """Test various answer scoring scenarios""" + config = IntermediateIntegrationConfig(seed=42) + dataset = IntermediateIntegrationDataset(config) + x = sympy.Symbol("x") + X = sympy.Symbol("X") + + # Test cases: (answer, metadata, expected_score) + test_cases = [ + # Correct answers + ("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0), + ("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0), + ("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Correct without explicit constant + ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), + ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), + # Incorrect but properly formatted + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + # Malformed expressions + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + # Empty answer + ("", {"variable": "x", "integrand": "2*x"}, 0.01), + # Case sensitivity + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + # Alternative constant notation + ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), + ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Simplification required + ("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0), + ("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0), + ] + + for answer, metadata, expected in test_cases: + score = dataset.score_answer(answer, metadata) + assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}" diff --git a/tests/test_n_queens.py b/tests/test_n_queens.py new file mode 100644 index 00000000..16911220 --- /dev/null +++ b/tests/test_n_queens.py @@ -0,0 +1,143 @@ +"""Tests for N Queens puzzle generation""" + +import pytest + +from reasoning_gym.games.n_queens import NQueensConfig, NQueensDataset + + +def test_nqueens_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = NQueensConfig(n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NQueensConfig(n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NQueensConfig(n=5, min_remove=5, max_remove=4) # max < min + config.validate() + + with pytest.raises(AssertionError): + config = NQueensConfig(n=5, min_remove=3, max_remove=6) # n < max + config.validate() + + +def test_nqueens_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = NQueensConfig(seed=42, size=10) + dataset1 = NQueensDataset(config) + dataset2 = NQueensDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_nqueens_dataset_items(): + """Test basic properties of generated items""" + config = NQueensConfig(n=8, min_remove=1, max_remove=7, size=10, seed=42) + dataset = NQueensDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "puzzle" in item["metadata"] + assert "solutions" in item["metadata"] + assert "num_removed" in item["metadata"] + + puzzle = item["metadata"]["puzzle"] + solutions = item["metadata"]["solutions"] + num_removed = item["metadata"]["num_removed"] + + # Verify board dimensions + assert len(puzzle) == 8 + assert all(len(row) == 8 for row in puzzle) + for board in solutions: + assert len(board) == 8 + assert all(len(row) == 8 for row in board) + + # Verify empty cell count + removed_count = len(puzzle) - sum(1 for row in puzzle for cell in row if cell == "Q") + assert config.min_remove <= removed_count <= config.max_remove + assert removed_count == num_removed + + # Verify solution validity + for board in solutions: + assert is_valid_solution(board) + + # Verify puzzle matches solution where filled + for i in range(8): + for j in range(8): + if puzzle[i][j] == "Q": + assert puzzle[i][j] == board[i][j] + + +def test_nqueens_dataset_iteration(): + """Test that iteration respects dataset size""" + config = NQueensConfig(size=5, seed=42) + dataset = NQueensDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_nqueens_board_generation(): + """Test that generated boards are valid""" + config = NQueensConfig(n=10, size=5, seed=42) + dataset = NQueensDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + for board in item["metadata"]["solutions"]: + assert is_valid_solution(board) + + +def test_nqueens_score_answer(): + """Test the score_answer method""" + config = NQueensConfig(n=8, size=10, seed=42) + dataset = NQueensDataset(config) + + # Test a few items + for i in range(len(dataset)): + item = dataset[i] + + # Test correct answer gets score 1.0 + valid_answer = item["metadata"]["valid_answers"][0] + assert dataset.score_answer(valid_answer, item) == 1.0 + + # Test invalid answer gets score 0.01 + invalid_answer = "_ _ _ _\n_ _ _ _\n_ _ _ _\n_ _ _ _" + assert dataset.score_answer(invalid_answer, item) == 0.01 + + # Test None answer gets score 0.0 + assert dataset.score_answer(None, item) == 0.0 + + +def is_valid_solution(board: list[list[str]]) -> bool: + """Helper function to verify N Queens solution validity""" + rows, cols, diags, off_diags = set(), set(), set(), set() + n = len(board) + num_queens = 0 + + for r in range(n): + for c in range(n): + if board[r][c] == "Q": + num_queens += 1 + if r in rows or c in cols or (r + c) in diags or (r - c) in off_diags: + return False + rows.add(r) + cols.add(c) + diags.add(r + c) + off_diags.add(r - c) + + return num_queens == n diff --git a/tests/test_simple_geometry.py b/tests/test_simple_geometry.py new file mode 100644 index 00000000..804cf15a --- /dev/null +++ b/tests/test_simple_geometry.py @@ -0,0 +1,80 @@ +import pytest + +from reasoning_gym.geometry.simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset + + +def test_simple_geometry_config_validation(): + """Test invalid configs raise appropriate errors.""" + # min_sides < 3 + with pytest.raises(AssertionError): + config = SimpleGeometryConfig(min_sides=2, max_sides=5) + config.validate() + + # max_sides < min_sides + with pytest.raises(AssertionError): + config = SimpleGeometryConfig(min_sides=4, max_sides=3) + config.validate() + + # Invalid angles + with pytest.raises(AssertionError): + config = SimpleGeometryConfig(min_angle=-10) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleGeometryConfig(min_angle=10, max_angle=5) + config.validate() + + +def test_simple_geometry_dataset_deterministic(): + """Test the dataset generates the same items with the same seed.""" + config = SimpleGeometryConfig(seed=42, size=5, min_sides=3, max_sides=4) + dataset1 = SimpleGeometryDataset(config) + dataset2 = SimpleGeometryDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i], ( + f"Item mismatch at index {i} for same seed. " f"Dataset1: {dataset1[i]} vs Dataset2: {dataset2[i]}" + ) + + +def test_simple_geometry_dataset_items(): + """Test basic properties of generated items.""" + config = SimpleGeometryConfig(min_sides=3, max_sides=5, min_angle=10, max_angle=120, size=10, seed=123) + dataset = SimpleGeometryDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check structure + assert isinstance(item, dict), "Generated item must be a dictionary." + assert "question" in item, "Item must contain a 'question' key." + assert "answer" in item, "Item must contain an 'answer' key." + assert "metadata" in item, "Item must contain a 'metadata' key." + + metadata = item["metadata"] + assert "n_sides" in metadata, "Metadata should contain 'n_sides'." + assert "missing_angle_rounded" in metadata, "Metadata should contain the computed 'missing_angle_rounded'." + + # Check that the missing angle is a valid float or integer + missing_angle = float(item["answer"]) + assert missing_angle > 0, f"Missing angle should be positive, found {missing_angle}" + + +def test_simple_geometry_dataset_iteration(): + """Test that iteration respects dataset size and is repeatable.""" + config = SimpleGeometryConfig(min_sides=3, max_sides=4, size=5, seed=42) + dataset = SimpleGeometryDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly 'size' items." + + # Test list conversion + items_list = list(dataset) + assert len(items_list) == config.size, "List conversion should yield exactly 'size' items." + + # Test multiple iterations produce the same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield the same items." diff --git a/tests/test_simple_integration.py b/tests/test_simple_integration.py new file mode 100644 index 00000000..0de8ab36 --- /dev/null +++ b/tests/test_simple_integration.py @@ -0,0 +1,117 @@ +import pytest +import sympy +from sympy.parsing.sympy_parser import parse_expr + +from reasoning_gym.algebra.simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset + + +def test_simple_integration_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(min_bounds=0) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(max_bounds=5, min_bounds=10) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(min_terms=-1) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(max_terms=2, min_terms=5) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(min_degree=-11) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(max_degree=3, min_degree=5) + config.validate() + + with pytest.raises(AssertionError): + config = SimpleIntegrationConfig(operators=("+", "-", "*")) + config.validate() + + +def test_simple_integration_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = SimpleIntegrationConfig(seed=42, size=10) + dataset1 = SimpleIntegrationDataset(config) + dataset2 = SimpleIntegrationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_simple_integration_dataset_items(): + """Test that dataset items are valid""" + config = SimpleIntegrationConfig(seed=42, size=10) + dataset = SimpleIntegrationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + assert "integrand" in item["metadata"] + assert "variable" in item["metadata"] + assert "expected_answer_expression" in item["metadata"] + + # Verify answer is a mathematical expression + answer = item["answer"] + answer = answer.replace(" + C", "") + assert isinstance(parse_expr(answer), sympy.Expr) + + +def test_verify_answer(): + config = SimpleIntegrationConfig(seed=42) + dataset = SimpleIntegrationDataset(config) + for i in range(len(dataset)): + item = dataset[i] + score = dataset.score_answer(item["answer"], item["metadata"]) + assert score == 1.0 + + +def test_score_answer_cases(): + """Test various answer scoring scenarios""" + config = SimpleIntegrationConfig(seed=42) + dataset = SimpleIntegrationDataset(config) + x = sympy.Symbol("x") + X = sympy.Symbol("X") + + # Test cases: (answer, metadata, expected_score) + test_cases = [ + # Correct answers + ("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0), + ("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0), + ("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Correct without explicit constant + ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), + ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), + # Incorrect but properly formatted + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + # Malformed expressions + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + # Empty answer + ("", {"variable": "x", "integrand": "2*x"}, 0.01), + # Case sensitivity + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + # Alternative constant notation + ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), + ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Simplification required + ("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0), + ("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0), + ] + + for answer, metadata, expected in test_cases: + score = dataset.score_answer(answer, metadata) + assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}" diff --git a/tests/test_time_intervals.py b/tests/test_time_intervals.py new file mode 100644 index 00000000..4e95f778 --- /dev/null +++ b/tests/test_time_intervals.py @@ -0,0 +1,113 @@ +from datetime import date, datetime + +import pytest + +from reasoning_gym.arithmetic import TimeIntervalsConfig, TimeIntervalsDataset + + +def test_time_intervals_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = TimeIntervalsConfig(size=0) + config.validate() + + with pytest.raises(AssertionError): + config = TimeIntervalsConfig(max_time_difference_seconds=0) + config.validate() + + with pytest.raises(AssertionError): + config = TimeIntervalsConfig(max_date_difference_days=0) + config.validate() + + with pytest.raises(AssertionError): + config = TimeIntervalsConfig(min_date=date(2024, 1, 1), max_date=date(2023, 1, 1)) + config.validate() + + +def test_time_intervals_deterministic(): + """Test that dataset generates same items with same seed""" + config = TimeIntervalsConfig(seed=42, size=10) + dataset1 = TimeIntervalsDataset(config) + dataset2 = TimeIntervalsDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_time_intervals_items(): + """Test basic properties of generated items""" + config = TimeIntervalsConfig( + size=100, + seed=42, + max_time_difference_seconds=3600, # 1 hour max + max_date_difference_days=10, + ) + dataset = TimeIntervalsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + assert "task_type" in item["metadata"] + assert "start_time" in item["metadata"] + assert "end_time" in item["metadata"] + + +def test_time_intervals_scoring(): + """Test the answer scoring functionality""" + config = TimeIntervalsConfig(seed=42) + dataset = TimeIntervalsDataset(config) + + # Generate a sample item + item = dataset[0] + + # Test exact match + assert dataset.score_answer(item["answer"], item) == 1.0 + + # Test empty/None answers + assert dataset.score_answer(None, item) == 0.0 + assert dataset.score_answer("", item) == 0.0 + + # Test invalid format + assert dataset.score_answer("invalid", item) == 0.0 + + # Test close but not exact answers + task_type = item["metadata"]["task_type"] + if task_type == "date": + expected = int(item["answer"]) + # Test answer off by 1 day + score = dataset.score_answer(str(expected + 1), item) + assert 0 < score < 1 + elif task_type.startswith("time"): + # Test answer off by a few minutes + if ":" in item["answer"]: + parts = item["answer"].split(":") + hours = int(parts[0]) + minutes = (int(parts[1]) + 5) % 60 # Add 5 minutes + modified = f"{hours:02d}:{minutes:02d}" + if len(parts) > 2: + modified += ":" + parts[2] + score = dataset.score_answer(modified, item) + assert 0 < score < 1 + + +def test_time_format_patterns(): + """Test that generated times match expected formats""" + config = TimeIntervalsConfig(seed=42, size=500) + dataset = TimeIntervalsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + + start_dt = item["metadata"]["start_time"] + end_dt = item["metadata"]["end_time"] + + # Verify both are datetime objects + assert isinstance(start_dt, datetime) + assert isinstance(end_dt, datetime) + + # Verify end is after start + assert end_dt >= start_dt, item["question"] + assert dataset.score_answer(item["answer"], item) == 1.0 diff --git a/tests/test_tower_of_hanoi.py b/tests/test_tower_of_hanoi.py new file mode 100644 index 00000000..a4228bc3 --- /dev/null +++ b/tests/test_tower_of_hanoi.py @@ -0,0 +1,228 @@ +"""Tests for Tower of Hanoi puzzle generation""" + +import re + +import pytest + +from reasoning_gym.games.tower_of_hanoi import HanoiConfig, HanoiDataset + + +def test_toh_config_validation(): + """Test that invalid configurations raise appropriate errors.""" + # Test negative number of disks + with pytest.raises(AssertionError): + config = HanoiConfig(min_disks=0) # At least 1 disk required + config.validate() + + # Test max_disks less than min_disks + with pytest.raises(AssertionError): + config = HanoiConfig(min_disks=5, max_disks=3) + config.validate() + + # Test min_pegs less than 3 + with pytest.raises(AssertionError): + config = HanoiConfig(min_pegs=2) + config.validate() + + # Test max_pegs less than min_pegs + with pytest.raises(AssertionError): + config = HanoiConfig(min_pegs=3, max_pegs=2) + config.validate() + + # Test invalid move configurations if any (assuming such validations exist) + # Add more tests based on the actual validation logic in HanoiConfig + + +def test_toh_dataset_deterministic(): + """Test that dataset generates the same items with the same seed.""" + config = HanoiConfig(seed=42, size=10) + dataset1 = HanoiDataset(config) + dataset2 = HanoiDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i], f"Mismatch found in instance {i} with seed 42." + + +def test_toh_dataset_items(): + """Test basic properties of generated items.""" + config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=10, seed=42) + dataset = HanoiDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + + # Check item structure + assert isinstance(item, dict), f"Item {i} is not a dictionary." + assert "question" in item, f"Item {i} missing 'question' key." + assert "answer" in item, f"Item {i} missing 'answer' key." + assert "metadata" in item, f"Item {i} missing 'metadata' key." + + # Check metadata + metadata = item["metadata"] + assert "num_disks" in metadata, f"Item {i} metadata missing 'num_disks'." + assert "num_pegs" in metadata, f"Item {i} metadata missing 'num_pegs'." + assert "start_peg" in metadata, f"Item {i} metadata missing 'start_peg'." + assert "target_peg" in metadata, f"Item {i} metadata missing 'target_peg'." + assert "auxiliary_pegs" in metadata, f"Item {i} metadata missing 'auxiliary_pegs'." + assert "solution_length" in metadata, f"Item {i} metadata missing 'solution_length'." + + num_disks = metadata["num_disks"] + num_pegs = metadata["num_pegs"] + start_peg = metadata["start_peg"] + target_peg = metadata["target_peg"] + auxiliary_pegs = metadata["auxiliary_pegs"] + solution_length = metadata["solution_length"] + + # Verify peg counts + assert num_pegs == len(metadata["auxiliary_pegs"]) + 2, f"Item {i} has inconsistent peg counts." + + # Verify solution_length consistency + assert solution_length == len( + item["answer"] + ), f"Item {i} metadata 'solution_length' does not match actual number of moves." + + # Optional: Additional checks like verifying that start and target pegs are distinct + assert start_peg != target_peg, f"Item {i} has identical start and target pegs." + + +def test_toh_move_validity(): + """Test that all moves in each problem instance are valid according to Tower of Hanoi rules.""" + config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=10, seed=42) + dataset = HanoiDataset(config) + + for idx, instance in enumerate(dataset): + num_disks = instance["metadata"]["num_disks"] + num_pegs = instance["metadata"]["num_pegs"] + start_peg = instance["metadata"]["start_peg"] + target_peg = instance["metadata"]["target_peg"] + auxiliary_pegs = instance["metadata"]["auxiliary_pegs"] + pegs = list(range(1, num_pegs + 1)) + + # Initialize pegs_state: all disks start on the start peg + pegs_state = {peg: [] for peg in pegs} + for disk in range(num_disks, 0, -1): + pegs_state[start_peg].append(disk) + + # Iterate over each move and validate + for move_num, move in enumerate(instance["answer"], start=1): + disk, from_peg, to_peg = parse_move(move) + + # Check that from_peg exists + assert from_peg in pegs, f"Move {move_num} in Instance {idx} references non-existent from_peg {from_peg}." + + # Check that to_peg exists + assert to_peg in pegs, f"Move {move_num} in Instance {idx} references non-existent to_peg {to_peg}." + + # Check that from_peg is not empty + assert pegs_state[ + from_peg + ], f"Move {move_num} in Instance {idx} attempts to move from an empty Peg {from_peg}." + + # Check that the disk to move is on top of from_peg + top_disk = pegs_state[from_peg][-1] + assert disk == top_disk, ( + f"Move {move_num} in Instance {idx} attempts to move disk {disk} " + f"which is not on top of Peg {from_peg} (top disk: {top_disk})." + ) + + # Check that moving disk to to_peg does not violate size constraints + if pegs_state[to_peg]: + top_to_disk = pegs_state[to_peg][-1] + assert top_to_disk > disk, ( + f"Move {move_num} in Instance {idx} attempts to place disk {disk} " + f"on top of smaller disk {top_to_disk} on Peg {to_peg}." + ) + + # Perform the move + pegs_state[from_peg].pop() + pegs_state[to_peg].append(disk) + + +def test_toh_final_state_correct(): + """Test that the final state of each problem instance has all disks on the target peg in correct order.""" + config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=10, seed=42) + dataset = HanoiDataset(config) + + for idx, instance in enumerate(dataset): + num_disks = instance["metadata"]["num_disks"] + num_pegs = instance["metadata"]["num_pegs"] + start_peg = instance["metadata"]["start_peg"] + target_peg = instance["metadata"]["target_peg"] + auxiliary_pegs = instance["metadata"]["auxiliary_pegs"] + pegs = list(range(1, num_pegs + 1)) + + # Initialize pegs_state: all disks start on the start peg + pegs_state = {peg: [] for peg in pegs} + for disk in range(num_disks, 0, -1): + pegs_state[start_peg].append(disk) + + # Perform all moves + for move in instance["answer"]: + disk, from_peg, to_peg = parse_move(move) + pegs_state[from_peg].pop() + pegs_state[to_peg].append(disk) + + # After all moves, all disks should be on target peg in descending order + final_pegs = pegs_state[target_peg] + assert len(final_pegs) == num_disks, f"Instance {idx} does not have all disks on the target Peg {target_peg}." + + # Verify that disks are in correct order on target peg + expected_final = list(range(num_disks, 0, -1)) + assert final_pegs == expected_final, f"Instance {idx} has disks on Peg {target_peg} in incorrect order." + + # Ensure all other pegs are empty + for peg in pegs: + if peg != target_peg: + assert ( + len(pegs_state[peg]) == 0 + ), f"Instance {idx} has disks remaining on Peg {peg}, which should be empty." + + +def test_toh_dataset_iteration(): + """Test that iteration respects dataset size and multiple iterations yield the same items.""" + config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=5, seed=42) + dataset = HanoiDataset(config) + + # Test dataset size + assert len(dataset) == config.size, f"Dataset size mismatch: expected {config.size}, got {len(dataset)}." + + # Collect items + items = list(dataset) + + # Test multiple iterations yield the same items + assert items == list(dataset), "Multiple iterations over the dataset do not yield the same items." + + +def parse_move(move_str: str) -> tuple: + """Parse a move string and extract disk number, from peg, and to peg. + + Args: + move_str (str): Move instruction, e.g., "Move disk 2 from Peg 1 to Peg 3". + + Returns: + tuple: (disk, from_peg, to_peg) + """ + pattern = r"Move disk (\d+) from Peg (\d+) to Peg (\d+)" + match = re.match(pattern, move_str) + assert match is not None, f"Move string '{move_str}' does not match the expected format." + disk = int(match.group(1)) + from_peg = int(match.group(2)) + to_peg = int(match.group(3)) + return disk, from_peg, to_peg + + +def is_valid_final_state(pegs_state: dict, target_peg: int, num_disks: int) -> bool: + """Verify that all disks are on the target peg in descending order. + + Args: + pegs_state (dict): Current state of the pegs. + target_peg (int): The target peg number. + num_disks (int): Total number of disks. + + Returns: + bool: True if valid, False otherwise. + """ + target_stack = pegs_state[target_peg] + if len(target_stack) != num_disks: + return False + return target_stack == list(range(num_disks, 0, -1))