diff --git a/GALLERY.md b/GALLERY.md index c842b17f..10d020be 100644 --- a/GALLERY.md +++ b/GALLERY.md @@ -41,6 +41,9 @@ This gallery shows examples from all available datasets using their default conf - [prime_factorization](#prime_factorization) - [propositional_logic](#propositional_logic) - [quantum_lock](#quantum_lock) +- [ransom_note](#ransom_note) +- [rearc](#rearc) +- [rotate_matrix](#rotate_matrix) - [rubiks_cube](#rubiks_cube) - [self_reference](#self_reference) - [sentence_reordering](#sentence_reordering) @@ -1005,7 +1008,6 @@ Metadata: {'words': ['eagerest', 'granitite', 'helium', 'nizam', 'nazim', 'strip ```` - ### gsm_symbolic Default configuration: ```python @@ -1959,6 +1961,417 @@ Metadata: {'difficulty': 10, 'solution_path': ['B', 'B', 'B', 'B', 'B', 'B', 'B' ```` +### ransom_note +Generates Ransom Note exercises with configurable difficulty + +Default configuration: +```python +max_note_length = 10 +max_magazine_length = 30 +p_solvable = 0.5 +size = 500 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise. + +Each letter in the magazine string can only be used once in your ransom note. + +Ransom note: c +Magazine: kjjfnerbv + +Answer: False +Metadata: {'ransom_note': 'c', 'magazine': 'kjjfnerbv', 'solution': False, 'solvable': False} + +Example 2: +Question: Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise. + +Each letter in the magazine string can only be used once in your ransom note. + +Ransom note: pan +Magazine: pipmrxluyrkumtnaynmqosywf + +Answer: True +Metadata: {'ransom_note': 'pan', 'magazine': 'pipmrxluyrkumtnaynmqosywf', 'solution': True, 'solvable': True} + +Example 3: +Question: Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise. + +Each letter in the magazine string can only be used once in your ransom note. + +Ransom note: yuothygge +Magazine: gpfslbehhhhagoutvejfoytuuyy + +Answer: True +Metadata: {'ransom_note': 'yuothygge', 'magazine': 'gpfslbehhhhagoutvejfoytuuyy', 'solution': True, 'solvable': True} + +```` + +### rearc +Default configuration: +```python +min_examples = 3 +max_examples = 5 +diff_lb = 0 +diff_ub = 0.2 +board_format_opts = BoardFormattingOptions(alphabet=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], col_delimiter=' ', row_delimiter='\n', array_brackets=False) +seed = 42 +size = 500 +``` + +Example tasks: +```` +Example 1: +Question: Find the common rule that maps an input grid to an output grid, given the examples below. + +Example 1: + +Input: +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 1 +1 1 1 9 +Output: +9 9 9 9 +1 1 1 1 +9 9 9 9 +1 1 1 1 +1 9 9 9 +1 9 1 1 +1 9 1 9 + +Example 2: + +Input: +4 8 8 8 8 8 8 +8 8 8 8 8 8 8 +8 8 8 8 8 8 8 +8 8 8 8 8 8 8 +8 8 8 8 8 8 8 +Output: +4 8 4 8 4 8 4 +8 8 4 8 4 8 4 +4 4 4 8 4 8 4 +8 8 8 8 4 8 4 +4 4 4 4 4 8 4 + +Example 3: + +Input: +2 2 2 2 +2 2 2 2 +2 2 2 2 +2 2 2 2 +2 2 2 2 +2 2 2 2 +2 2 2 2 +5 2 2 2 +Output: +2 2 2 2 +5 5 5 5 +2 2 2 2 +5 5 5 5 +2 2 2 2 +5 5 5 2 +2 2 5 2 +5 2 5 2 + + + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. +Your final answer should just be the text output grid itself. + +Input: +3 3 3 3 3 3 3 9 +3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 + +Answer: ((3, 9, 3, 9, 3, 9, 3, 9), (3, 9, 3, 9, 3, 9, 3, 3), (3, 9, 3, 9, 3, 9, 9, 9), (3, 9, 3, 9, 3, 3, 3, 3), (3, 9, 3, 9, 9, 9, 9, 9)) +Metadata: {'input': ((3, 3, 3, 3, 3, 3, 3, 9), (3, 3, 3, 3, 3, 3, 3, 3), (3, 3, 3, 3, 3, 3, 3, 3), (3, 3, 3, 3, 3, 3, 3, 3), (3, 3, 3, 3, 3, 3, 3, 3)), 'output': ((3, 9, 3, 9, 3, 9, 3, 9), (3, 9, 3, 9, 3, 9, 3, 3), (3, 9, 3, 9, 3, 9, 9, 9), (3, 9, 3, 9, 3, 3, 3, 3), (3, 9, 3, 9, 9, 9, 9, 9)), 'task_id': 'd22278a0', 'difficulty': {'rng': 0.07173948707162241, 'pso': 0.12314814814814816}} + +Example 2: +Question: Find the common rule that maps an input grid to an output grid, given the examples below. + +Example 1: + +Input: +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +6 6 9 6 6 6 9 6 +6 6 6 9 6 9 6 6 +6 6 6 6 9 6 6 6 +6 6 6 9 6 9 6 6 +6 6 9 6 6 6 9 6 +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +Output: +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +6 6 9 6 6 6 9 6 +6 6 6 9 6 9 6 6 +6 6 6 6 9 6 6 6 +6 6 6 9 6 9 6 6 +6 6 9 6 6 6 9 6 +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 +6 6 6 6 6 6 6 6 + +Example 2: + +Input: +5 5 5 5 5 5 5 5 5 5 +5 5 8 5 8 5 8 5 5 5 +5 5 5 5 5 5 5 5 5 5 +5 5 8 5 2 5 8 5 5 5 +5 5 5 5 5 5 5 5 5 5 +5 5 8 5 8 5 8 5 5 5 +5 5 5 5 5 5 5 5 5 5 +Output: +5 5 5 5 5 5 5 5 5 5 +5 5 8 5 8 5 8 5 5 5 +5 5 5 5 5 5 5 5 5 5 +5 5 8 5 2 5 8 5 5 5 +5 5 5 5 5 5 5 5 5 5 +5 5 8 5 8 5 8 5 5 5 +5 5 5 5 5 5 5 5 5 5 + +Example 3: + +Input: +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 2 1 2 1 1 1 +1 1 1 1 2 1 1 1 1 +1 1 1 2 1 2 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +Output: +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 2 1 2 1 1 1 +1 1 1 1 2 1 1 1 1 +1 1 1 2 1 2 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 +1 1 1 1 1 1 1 1 1 + +Example 4: + +Input: +7 7 7 7 7 7 7 7 7 7 +7 7 7 1 7 1 7 1 7 7 +7 7 7 7 7 7 7 7 7 7 +7 7 7 1 7 1 7 1 7 7 +7 7 7 7 7 7 7 7 7 7 +7 7 7 1 7 1 7 1 7 7 +7 7 7 7 7 7 7 7 7 7 +7 7 7 7 7 7 7 7 7 7 +Output: +7 7 7 7 7 7 7 7 7 7 +7 7 7 1 7 1 7 1 7 7 +7 7 7 7 7 7 7 7 7 7 +7 7 7 1 7 1 7 1 7 7 +7 7 7 7 7 7 7 7 7 7 +7 7 7 1 7 1 7 1 7 7 +7 7 7 7 7 7 7 7 7 7 +7 7 7 7 7 7 7 7 7 7 + +Example 5: + +Input: +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 6 3 3 3 6 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 6 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 6 3 3 3 6 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +Output: +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 6 3 3 3 6 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 6 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 6 3 3 3 6 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 +3 3 3 3 3 3 3 3 3 3 3 + + + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. +Your final answer should just be the text output grid itself. + +Input: +7 7 7 7 7 7 7 7 7 7 7 +7 7 7 7 7 7 7 8 7 7 7 +7 7 7 7 7 7 8 7 8 7 7 +7 7 7 7 7 8 7 8 7 8 7 +7 7 7 7 7 7 8 7 8 7 7 +7 7 7 7 7 7 7 8 7 7 7 +7 7 7 7 7 7 7 7 7 7 7 +7 7 7 7 7 7 7 7 7 7 7 + +Answer: ((7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 7), (7, 7, 7, 7, 7, 7, 8, 7, 8, 7, 7), (7, 7, 7, 7, 7, 8, 7, 8, 7, 8, 7), (7, 7, 7, 7, 7, 7, 8, 7, 8, 7, 7), (7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)) +Metadata: {'input': ((7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 7), (7, 7, 7, 7, 7, 7, 8, 7, 8, 7, 7), (7, 7, 7, 7, 7, 8, 7, 8, 7, 8, 7), (7, 7, 7, 7, 7, 7, 8, 7, 8, 7, 7), (7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)), 'output': ((7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 7), (7, 7, 7, 7, 7, 7, 8, 7, 8, 7, 7), (7, 7, 7, 7, 7, 8, 7, 8, 7, 8, 7), (7, 7, 7, 7, 7, 7, 8, 7, 8, 7, 7), (7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7), (7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7)), 'task_id': '11852cab', 'difficulty': {'rng': 0.09651305327452808, 'pso': 0.15228956228956228}} + +Example 3: +Question: Find the common rule that maps an input grid to an output grid, given the examples below. + +Example 1: + +Input: +9 9 +9 9 +Output: +9 9 +9 9 +9 9 +9 9 + +Example 2: + +Input: +4 4 4 6 +Output: +4 4 4 6 +4 4 4 6 + +Example 3: + +Input: +4 1 1 +4 4 4 +Output: +4 1 1 +4 4 4 +4 4 4 +4 1 1 + + + +Below is a test input grid. Predict the corresponding output grid by applying the rule you found. +Your final answer should just be the text output grid itself. + +Input: +1 1 1 1 1 +1 1 1 1 1 + +Answer: ((1, 1, 1, 1, 1), (1, 1, 1, 1, 1), (1, 1, 1, 1, 1), (1, 1, 1, 1, 1)) +Metadata: {'input': ((1, 1, 1, 1, 1), (1, 1, 1, 1, 1)), 'output': ((1, 1, 1, 1, 1), (1, 1, 1, 1, 1), (1, 1, 1, 1, 1), (1, 1, 1, 1, 1)), 'task_id': '8be77c9e', 'difficulty': {'rng': 0.09322002370336528, 'pso': 0.0638888888888889}} + +```` + +### rotate_matrix +Generates Rotate Matrix exercises with configurable difficulty + +Default configuration: +```python +max_n = 10 +max_rotations = 4 +size = 500 +seed = 42 +``` + +Example tasks: +```` +Example 1: +Question: Given a square matrix, your job is to rotate it clockwise. + +Example: + +Input: Rotate the matrix below by 90 degrees clockwise: +1 2 3 +4 5 6 +7 8 9 + +Output: +7 4 1 +8 5 2 +9 6 3 + +Rotate the matrix below by 90 degrees clockwise: +3 1 +2 0 + +Answer: 2 3 +0 1 +Metadata: {'matrix': [[3, 1], [2, 0]], 'num_rotations': 1, 'solution': [[2, 3], [0, 1]]} + +Example 2: +Question: Given a square matrix, your job is to rotate it clockwise. + +Example: + +Input: Rotate the matrix below by 90 degrees clockwise: +1 2 3 +4 5 6 +7 8 9 + +Output: +7 4 1 +8 5 2 +9 6 3 + +Rotate the matrix below by 180 degrees clockwise: +0 + +Answer: 0 +Metadata: {'matrix': [[0]], 'num_rotations': 2, 'solution': [[0]]} + +Example 3: +Question: Given a square matrix, your job is to rotate it clockwise. + +Example: + +Input: Rotate the matrix below by 90 degrees clockwise: +1 2 3 +4 5 6 +7 8 9 + +Output: +7 4 1 +8 5 2 +9 6 3 + +Rotate the matrix below by 180 degrees clockwise: +28 17 38 29 8 15 26 +35 13 37 39 27 40 20 +4 30 23 16 3 5 48 +9 25 2 46 47 21 22 +31 12 41 43 19 32 10 +6 0 36 45 42 1 18 +14 24 11 7 44 34 33 + +Answer: 33 34 44 7 11 24 14 +18 1 42 45 36 0 6 +10 32 19 43 41 12 31 +22 21 47 46 2 25 9 +48 5 3 16 23 30 4 +20 40 27 39 37 13 35 +26 15 8 29 38 17 28 +Metadata: {'matrix': [[28, 17, 38, 29, 8, 15, 26], [35, 13, 37, 39, 27, 40, 20], [4, 30, 23, 16, 3, 5, 48], [9, 25, 2, 46, 47, 21, 22], [31, 12, 41, 43, 19, 32, 10], [6, 0, 36, 45, 42, 1, 18], [14, 24, 11, 7, 44, 34, 33]], 'num_rotations': 2, 'solution': [[33, 34, 44, 7, 11, 24, 14], [18, 1, 42, 45, 36, 0, 6], [10, 32, 19, 43, 41, 12, 31], [22, 21, 47, 46, 2, 25, 9], [48, 5, 3, 16, 23, 30, 4], [20, 40, 27, 39, 37, 13, 35], [26, 15, 8, 29, 38, 17, 28]]} + +```` + ### rubiks_cube Generates RubiksCube tasks @@ -2220,14 +2633,14 @@ Generates Sokoban games with configurable parameters Default configuration: ```python -seed = 42 -size = 500 min_w = 6 min_h = 6 max_w = 10 max_h = 10 min_boxes = 6 max_boxes = 10 +seed = 42 +size = 500 ``` Example tasks: @@ -2436,6 +2849,7 @@ allow_no = True allow_some = True allow_some_not = True invalid_ratio = 0.3 +inversion_probability = 0.3 seed = 42 size = 500 ``` @@ -2448,32 +2862,32 @@ Question: Consider these statements: 2. All humans are chefs Does it logically follow that: -All students are chefs? +Some chefs are humans? (Answer Yes or No) -Answer: No -Metadata: {'premise1': 'No students are humans', 'premise2': 'All humans are chefs', 'conclusion': 'All students are chefs', 'is_valid': False} +Answer: Yes +Metadata: {'premise1': 'No students are humans', 'premise2': 'All humans are chefs', 'selected_premise': 2, 'conclusion': 'Some chefs are humans', 'is_valid': True, 'type': 'inversion'} Example 2: Question: Consider these statements: 1. All children are animals -2. No animals are doctors +2. Some animals are not doctors Does it logically follow that: Some children are not doctors? (Answer Yes or No) Answer: Yes -Metadata: {'premise1': 'All children are animals', 'premise2': 'No animals are doctors', 'conclusion': 'Some children are not doctors', 'is_valid': True} +Metadata: {'premise1': 'All children are animals', 'premise2': 'Some animals are not doctors', 'conclusion': 'Some children are not doctors', 'is_valid': True, 'type': 'syllogism'} Example 3: Question: Consider these statements: -1. All butterflies are tigers +1. Some butterflies are not tigers 2. No tigers are whales Does it logically follow that: -Some butterflies are not whales? +Some butterflies are whales? (Answer Yes or No) -Answer: Yes -Metadata: {'premise1': 'All butterflies are tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are not whales', 'is_valid': True} +Answer: No +Metadata: {'premise1': 'Some butterflies are not tigers', 'premise2': 'No tigers are whales', 'conclusion': 'Some butterflies are whales', 'is_valid': False, 'type': 'syllogism'} ```` @@ -2522,7 +2936,7 @@ min_disks = 3 max_disks = 7 min_pegs = 3 max_pegs = 4 -size = 50 +size = 500 seed = 42 visualize = False ``` @@ -2584,7 +2998,7 @@ Default configuration: min_board_size = 9 max_board_size = 13 max_stones = 15 -size = 100 +size = 500 seed = 42 ``` diff --git a/reasoning_gym/arc/rearc_board_format.py b/reasoning_gym/arc/board_format.py similarity index 91% rename from reasoning_gym/arc/rearc_board_format.py rename to reasoning_gym/arc/board_format.py index 8d3f54ca..1360fcd1 100644 --- a/reasoning_gym/arc/rearc_board_format.py +++ b/reasoning_gym/arc/board_format.py @@ -1,22 +1,13 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Tuple @dataclass class BoardFormattingOptions: - alphabet: list[str] - col_delimiter: str - row_delimiter: str - array_brackets: bool - - -def default_board_format_opts() -> BoardFormattingOptions: - return BoardFormattingOptions( - alphabet=[str(i) for i in range(10)], - col_delimiter=" ", - row_delimiter="\n", - array_brackets=False, - ) + alphabet: list[str] = field(default_factory=lambda: [str(i) for i in range(10)]) + col_delimiter: str = " " + row_delimiter: str = "\n" + array_brackets: bool = False def format_arc_task( diff --git a/reasoning_gym/arc/rearc.py b/reasoning_gym/arc/rearc.py index ebb7deb0..0f0a3b27 100644 --- a/reasoning_gym/arc/rearc.py +++ b/reasoning_gym/arc/rearc.py @@ -3,69 +3,34 @@ from random import Random from typing import Any, Callable, Dict, Optional from ..factory import ProceduralDataset, register_dataset -from .rearc_board_format import ( - BoardFormattingOptions, - default_board_format_opts, - format_board, - format_board_pair, - parse_board, -) -from .rearc_utils import generators, verifiers -from .rearc_utils.dsl import * -from .rearc_utils.utils import * +from .board_format import BoardFormattingOptions, format_board, format_board_pair, parse_board -_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below +_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below. -Examples: {examples} Below is a test input grid. Predict the corresponding output grid by applying the rule you found. Your final answer should just be the text output grid itself. - -Input Grid: +Input: {input_grid} - -Output Grid:""" - -_COLOUR_MAP = ListedColormap( - ["#000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00", "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"] -) - - -def strip_prefix(string: str, prefix: str) -> str: - """ - removes prefix - """ - return string[len(prefix) :] - - -def get_generators() -> dict: - """ - returns mapper from task identifiers (keys) to example generator functions - """ - prefix = "generate_" - return {strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)} - - -def get_verifiers() -> dict: - """ - returns mapper from task identifiers (keys) to example verifier functions - """ - prefix = "verify_" - return {strip_prefix(n, prefix): getattr(verifiers, n) for n in dir(verifiers) if n.startswith(prefix)} +""" @dataclass class ReArcConfig: + min_examples: int = 3 # minimum number of board pairs shown + max_examples: int = 5 # maximum number of board pairs shown diff_lb: int = 0 - diff_ub: int = 1 - board_format_opts: BoardFormattingOptions = field(default_factory=default_board_format_opts) + diff_ub: int = 0.2 + board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions()) seed: Optional[int] = None size: int = 500 def validate(self): - assert self.diff_lb < self.diff_ub, "diff_lb must be < diff_ub." + assert self.min_examples > 0, "min_examples must be positive" + assert self.min_examples <= self.max_examples, "min_examples must be <= max_examples" + assert self.diff_lb <= self.diff_ub, "diff_lb must be <= diff_ub." assert self.size > 0, "Size of dataset must be positive." @@ -76,8 +41,13 @@ class ReArcDataset(ProceduralDataset): self._prompt_templates = _REARC_PROMPT_TEMPLATES self.diff_lb = config.diff_lb self.diff_ub = config.diff_ub - self._generators = get_generators() - self._verifiers = get_verifiers() + + # lazy import of re-arc dsl & generators + from .rearc_utils import generators + from .rearc_utils.utils import get_generators, get_pso_difficulty + + self._generators = get_generators(generators) + self.get_pso_difficulty = get_pso_difficulty @staticmethod def get_rng_difficulty(rng: Random) -> float: @@ -88,57 +58,22 @@ class ReArcDataset(ProceduralDataset): rng.difficulty_samples = [] return avg - @staticmethod - def get_pso_difficulty(example: dict) -> float: - """ - PSO-Difficulty: proxy measure for example difficulty, defined as weighted sum of #Pixels, #Symbols, #Objects - """ - i, o = example["input"], example["output"] - hwi = height(i) * width(i) - hwo = height(o) * width(o) - pix_pct = (hwi + hwo) / 1800 - col_pct = len(palette(i) | palette(o)) / 10 - obj_dens = (len(objects(i, T, F, F)) / hwi + len(objects(o, T, F, F)) / hwo) / 2 - return (pix_pct + col_pct + obj_dens) / 3 - def __len__(self) -> int: return self.size - @staticmethod - def visualise_pair(example: Dict[str, Any]) -> None: - """ - Visualise a ReArc task pair - """ - norm = Normalize(vmin=0, vmax=9) - args = {"cmap": _COLOUR_MAP, "norm": norm} - - # Change to 1 row, 2 columns - height = 1 - width = 2 - figure_size = (3 * width * 3, height * 3) - figure, axes = plt.subplots(height, width, figsize=figure_size) - - # Plot input and output side by side - axes[0].imshow(example["metadata"]["input"], **args) - axes[1].imshow(example["metadata"]["output"], **args) - - # Add titles to distinguish the plots - axes[0].set_title("Input") - axes[1].set_title("Output") - - def format_rearc_input(self, idx: int, task: dict, generator: Callable) -> str: + def format_rearc_input(self, rng: Random, task: dict, generator: Callable) -> str: """ Format a ReArc task input with multiple examples and test input. """ - example_1 = generator(Random((self.seed + idx) * 1 * self.size), self.diff_lb, self.diff_ub) - example_2 = generator(Random((self.seed + idx) * 2 * self.size), self.diff_lb, self.diff_ub) - example_3 = generator(Random((self.seed + idx) * 3 * self.size), self.diff_lb, self.diff_ub) - examples = ( - format_board_pair(1, example_1, self.board_format_opts) - + format_board_pair(2, example_2, self.board_format_opts) - + format_board_pair(3, example_3, self.board_format_opts) - ) + num_examples = rng.randint(self.config.min_examples, self.config.max_examples) + examples = [ + format_board_pair( + i + 1, generator(rng, self.diff_lb, self.diff_ub), formatting_options=self.config.board_format_opts + ) + for i in range(num_examples) + ] + examples = "".join(examples) input_grid = format_board(task["input"], self.board_format_opts) return self._prompt_templates.format(examples=examples, input_grid=input_grid) @@ -154,7 +89,7 @@ class ReArcDataset(ProceduralDataset): rng_difficulty = self.get_rng_difficulty(rng) pso_difficulty = self.get_pso_difficulty(task) - input_prompt = self.format_rearc_input(idx, task, generator) + input_prompt = self.format_rearc_input(rng, task, generator) return { "question": input_prompt, @@ -163,8 +98,10 @@ class ReArcDataset(ProceduralDataset): "input": task["input"], "output": task["output"], "task_id": task_id, - "rng": rng_difficulty, - "pso": pso_difficulty, + "difficulty": { + "rng": rng_difficulty, + "pso": pso_difficulty, + }, }, } diff --git a/reasoning_gym/arc/rearc_utils/dsl.py b/reasoning_gym/arc/rearc_utils/dsl.py index f040140d..a42b4a49 100644 --- a/reasoning_gym/arc/rearc_utils/dsl.py +++ b/reasoning_gym/arc/rearc_utils/dsl.py @@ -1,7 +1,7 @@ # types -from typing import Any, Callable, Container, FrozenSet, Iterable, List, Tuple, Union +from typing import Any, Callable, Container, FrozenSet, Tuple, Union Boolean = bool Integer = int diff --git a/reasoning_gym/arc/rearc_utils/utils.py b/reasoning_gym/arc/rearc_utils/utils.py index 39f263a1..a6fe0ff4 100644 --- a/reasoning_gym/arc/rearc_utils/utils.py +++ b/reasoning_gym/arc/rearc_utils/utils.py @@ -1,12 +1,45 @@ import random from typing import Any, List, Tuple -import matplotlib.pyplot as plt -from matplotlib.colors import ListedColormap, Normalize - from .dsl import * +def strip_prefix(string: str, prefix: str) -> str: + """ + removes prefix + """ + return string[len(prefix) :] + + +def get_generators(generators) -> dict: + """ + returns mapper from task identifiers (keys) to example generator functions + """ + prefix = "generate_" + return {strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)} + + +def get_verifiers(verifiers) -> dict: + """ + returns mapper from task identifiers (keys) to example verifier functions + """ + prefix = "verify_" + return {strip_prefix(n, prefix): getattr(verifiers, n) for n in dir(verifiers) if n.startswith(prefix)} + + +def get_pso_difficulty(example: dict) -> float: + """ + PSO-Difficulty: proxy measure for example difficulty, defined as weighted sum of #Pixels, #Symbols, #Objects + """ + i, o = example["input"], example["output"] + hwi = height(i) * width(i) + hwo = height(o) * width(o) + pix_pct = (hwi + hwo) / 1800 + col_pct = len(palette(i) | palette(o)) / 10 + obj_dens = (len(objects(i, T, F, F)) / hwi + len(objects(o, T, F, F)) / hwo) / 2 + return (pix_pct + col_pct + obj_dens) / 3 + + def unifint(rng: random.Random, diff_lb: float, diff_ub: float, bounds: Tuple[int, int]) -> int: """ rng @@ -74,30 +107,6 @@ def format_task(task: dict) -> dict: } -def plot_task(task: List[dict], title: str = None) -> None: - """ - displays a task - """ - cmap = ListedColormap( - ["#000", "#0074D9", "#FF4136", "#2ECC40", "#FFDC00", "#AAAAAA", "#F012BE", "#FF851B", "#7FDBFF", "#870C25"] - ) - norm = Normalize(vmin=0, vmax=9) - args = {"cmap": cmap, "norm": norm} - height = 2 - width = len(task) - figure_size = (width * 3, height * 3) - figure, axes = plt.subplots(height, width, figsize=figure_size) - for column, example in enumerate(task): - axes[0, column].imshow(example["input"], **args) - axes[1, column].imshow(example["output"], **args) - axes[0, column].axis("off") - axes[1, column].axis("off") - if title is not None: - figure.suptitle(title, fontsize=20) - plt.subplots_adjust(wspace=0.1, hspace=0.1) - plt.show() - - def fix_bugs(dataset: dict) -> None: """ fixes bugs in the original ARC training dataset diff --git a/tests/test_rearc.py b/tests/test_rearc.py index 3e10ee2f..aa43e64d 100644 --- a/tests/test_rearc.py +++ b/tests/test_rearc.py @@ -1,8 +1,7 @@ import pytest -from reasoning_gym import create_dataset +from reasoning_gym.arc.board_format import format_board from reasoning_gym.arc.rearc import ReArcConfig, ReArcDataset -from reasoning_gym.arc.rearc_board_format import format_board def test_rearc_config_validation(): @@ -16,7 +15,7 @@ def test_rearc_config_validation(): def test_rearc_deterministic(): """Test dataset reproducibility with fixed seed""" - config = ReArcConfig(seed=42, size=500, diff_lb=0, diff_ub=1) + config = ReArcConfig(seed=42, size=100, diff_lb=0, diff_ub=1) ds1 = ReArcDataset(config) ds2 = ReArcDataset(config) @@ -26,7 +25,7 @@ def test_rearc_deterministic(): def test_rearc_items(): """Test basic structure and metadata of generated items""" - config = ReArcConfig(seed=42, size=500, diff_lb=0, diff_ub=1) + config = ReArcConfig(seed=42, size=100, diff_lb=0, diff_ub=1) dataset = ReArcDataset(config) for item in dataset: @@ -39,17 +38,17 @@ def test_rearc_items(): assert "input" in meta assert "output" in meta assert "task_id" in meta - assert "rng" in meta - assert "pso" in meta + assert "rng" in meta["difficulty"] + assert "pso" in meta["difficulty"] # Validate difficulty bounds - assert config.diff_lb <= meta["rng"] <= config.diff_ub - assert config.diff_lb <= meta["pso"] <= config.diff_ub + assert config.diff_lb <= meta["difficulty"]["rng"] <= config.diff_ub + assert config.diff_lb <= meta["difficulty"]["pso"] <= config.diff_ub def test_rearc_solution_validation(): """Test solution verification and scoring""" - config = ReArcConfig(size=500, seed=123) + config = ReArcConfig(size=100, seed=123) dataset = ReArcDataset(config) for item in dataset: @@ -57,16 +56,22 @@ def test_rearc_solution_validation(): correct = format_board(item["metadata"]["output"], dataset.board_format_opts) assert dataset.score_answer(correct, item["metadata"]) == 1.0 - ## Test invalid format - # assert dataset.score_answer("invalid_grid", item["metadata"]) == 0.05 + # Test invalid format + invalid_grid = """ +9 9 9 +1 2 1 +7 8 7 +0 0 0 +""" + assert dataset.score_answer(invalid_grid, item["metadata"]) == 0.05 # Test empty answer - # assert dataset.score_answer(None, item["metadata"]) == 0.0 + assert dataset.score_answer(None, item["metadata"]) == 0.0 def test_rearc_scoring_edge_cases(): """Test scoring for partial and malformed answers""" - config = ReArcConfig(size=500, seed=456) + config = ReArcConfig(size=100, seed=456) dataset = ReArcDataset(config) for item in dataset: @@ -80,13 +85,3 @@ def test_rearc_scoring_edge_cases(): # Case sensitivity answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower() assert dataset.score_answer(answer, item["metadata"]) == 1.0 - - -def test_rearc_visualization(): - """Test visualization function runs without errors""" - config = ReArcConfig(size=2, seed=789) - dataset = ReArcDataset(config) - - for item in dataset: - dataset.visualise_pair(item) - # No assertion needed - just verify no exceptions