mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge branch 'main' into rich/ab
This commit is contained in:
commit
27938ce13a
16 changed files with 759 additions and 12 deletions
|
|
@ -10,6 +10,7 @@ from .ab import ABConfig, ABDataset
|
|||
from .base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
|
||||
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
||||
from .count_primes import CountPrimesConfig, CountPrimesDataset
|
||||
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
|
||||
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
|
||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||
|
|
@ -69,4 +70,6 @@ __all__ = [
|
|||
"BinaryMatrixDataset",
|
||||
"ABConfig",
|
||||
"ABDataset",
|
||||
"CountPrimesConfig",
|
||||
"CountPrimesDataset",
|
||||
]
|
||||
|
|
|
|||
63
reasoning_gym/algorithmic/count_primes.py
Normal file
63
reasoning_gym/algorithmic/count_primes.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""Count prime numbers in a given interval.
|
||||
|
||||
Solution obtained with Sieve of Eratosthenes:
|
||||
https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountPrimesConfig:
|
||||
"""Configuration for Count Primes dataset generation"""
|
||||
|
||||
max_n: int = 10_000 # Upper bound for the interval
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_n, "max_n must be at least 1"
|
||||
|
||||
|
||||
class CountPrimesDataset(ProceduralDataset):
|
||||
"""Generates Count Primes exercises with configurable difficulty"""
|
||||
|
||||
def __init__(self, config: CountPrimesConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.primes = self._get_primes(config.max_n + 1)
|
||||
|
||||
def _get_primes(self, n: int) -> list[bool]:
|
||||
if n <= 1:
|
||||
return []
|
||||
primes = [True] * n
|
||||
primes[0] = primes[1] = False
|
||||
for i in range(2, int(math.sqrt(n)) + 1):
|
||||
if primes[i]:
|
||||
for j in range(2 * i, n, i):
|
||||
primes[j] = False
|
||||
return primes
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Count Primes question"""
|
||||
rng = Random(self.seed + idx)
|
||||
start = rng.randint(1, self.config.max_n)
|
||||
end = rng.randint(start, self.config.max_n)
|
||||
primes = self.primes[start : end + 1]
|
||||
answer = sum(primes)
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(start=start, end=end),
|
||||
"answer": str(answer),
|
||||
"metadata": {"start": start, "end": end, "primes": primes, "solution": answer},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig)
|
||||
|
|
@ -60,22 +60,16 @@ class RotateMatrixDataset(ProceduralDataset):
|
|||
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
||||
return matrix
|
||||
|
||||
def _rot90(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""quarter clockwise rotation"""
|
||||
return [list(row) for row in zip(*matrix[::-1])]
|
||||
|
||||
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
|
||||
"""Rotate the matrix K times by 90 degrees clockwise"""
|
||||
num_rotations %= 4
|
||||
n = len(matrix)
|
||||
output = deepcopy(matrix)
|
||||
|
||||
for _ in range(num_rotations):
|
||||
for l in range(n // 2):
|
||||
for i in range(l, n - 1 - l):
|
||||
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
|
||||
output[n - 1 - i][l],
|
||||
output[l][i],
|
||||
output[i][n - 1 - l],
|
||||
output[n - 1 - l][n - 1 - i],
|
||||
)
|
||||
|
||||
output = self._rot90(output)
|
||||
return output
|
||||
|
||||
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
|
|||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||
from .chain_sum import ChainSum, ChainSumConfig
|
||||
from .count_bits import CountBitsConfig, CountBitsDataset
|
||||
from .dice import DiceConfig, DiceDataset
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from .gcd import GCDConfig, GCDDataset
|
||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||
|
|
@ -38,4 +39,6 @@ __all__ = [
|
|||
"TimeIntervalsDataset",
|
||||
"CountBitsConfig",
|
||||
"CountBitsDataset",
|
||||
"DiceConfig",
|
||||
"DiceDataset",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import random
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -112,5 +113,36 @@ class ChainSum(ProceduralDataset):
|
|||
return expression, result
|
||||
|
||||
|
||||
class ChainSumCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(ChainSumCurriculum.__name__, ChainSumConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
default_level=0, # Start with 2 terms
|
||||
description="Maximum number of terms in the expression",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2, # Ensure at least 2 terms
|
||||
lower_field_name="min_terms",
|
||||
upper_field_name="max_terms",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="num_digits",
|
||||
levels=[1, 2, 4, 10],
|
||||
default_level=0, # Start with 1-digit numbers
|
||||
description="Number of digits in each operand",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1, # Ensure numbers are at least 1 digit
|
||||
lower_field_name="min_digits",
|
||||
upper_field_name="max_digits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
||||
|
|
|
|||
149
reasoning_gym/arithmetic/dice.py
Normal file
149
reasoning_gym/arithmetic/dice.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from math import gcd
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
def compute_probability(dice, target):
|
||||
"""
|
||||
Computes the probability of rolling a total of at least `target`
|
||||
when rolling dice specified in the list `dice`. Each element in dice
|
||||
is the number of sides on that die. The computation is done via dynamic programming.
|
||||
Returns the probability as a fraction (numerator, denominator) and as a float.
|
||||
"""
|
||||
# dp[i][s] = number of ways to get sum s using the first i dice.
|
||||
# We use only one dictionary for the current dp state.
|
||||
dp = {0: 1}
|
||||
for sides in dice:
|
||||
new_dp = {}
|
||||
for current_sum, count in dp.items():
|
||||
# Each die gives a number from 1 to sides.
|
||||
for face in range(1, sides + 1):
|
||||
new_sum = current_sum + face
|
||||
new_dp[new_sum] = new_dp.get(new_sum, 0) + count
|
||||
dp = new_dp
|
||||
|
||||
total_outcomes = reduce(lambda a, b: a * b, dice, 1)
|
||||
ways = sum(count for s, count in dp.items() if s >= target)
|
||||
|
||||
# Simplify the fraction (ways / total_outcomes)
|
||||
def simplify(n, d):
|
||||
common = gcd(n, d)
|
||||
return n // common, d // common
|
||||
|
||||
frac = simplify(ways, total_outcomes)
|
||||
return frac, ways / total_outcomes
|
||||
|
||||
|
||||
def generate_puzzle(num_dice, max_dice_size, rng):
|
||||
"""
|
||||
Generates a puzzle:
|
||||
- It forces one die to have max_dice_size.
|
||||
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
|
||||
- The dice are then shuffled.
|
||||
- The target total is chosen roughly in the middle (but you can adjust the method).
|
||||
|
||||
It then computes the probability of rolling a total at least the target.
|
||||
Finally, it prints out the puzzle statement and the answer.
|
||||
"""
|
||||
|
||||
# Guarantee one die is the maximum.
|
||||
dice = [max_dice_size]
|
||||
for _ in range(num_dice - 1):
|
||||
# Choose a die size randomly from 2 up to max_dice_size-1.
|
||||
# (If max_dice_size == 2 then all dice are 2-sided.)
|
||||
if max_dice_size > 2:
|
||||
die = rng.randint(2, max_dice_size - 1)
|
||||
else:
|
||||
die = 2
|
||||
dice.append(die)
|
||||
|
||||
# Optionally, sort dice in descending order (as is common in puzzles)
|
||||
dice.sort(reverse=True)
|
||||
|
||||
# Compute minimum and maximum possible totals.
|
||||
min_total = num_dice # each die gives at least 1
|
||||
max_total = sum(dice)
|
||||
|
||||
# Choose a target total. For an interesting puzzle,
|
||||
# we choose a target somewhere in the middle third of the range.
|
||||
low_target = min_total + (max_total - min_total) // 3
|
||||
high_target = min_total + 2 * (max_total - min_total) // 3
|
||||
target = rng.randint(low_target, high_target)
|
||||
|
||||
# Compute probability.
|
||||
(num, den), prob = compute_probability(dice, target)
|
||||
|
||||
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
|
||||
dice_str = ", ".join(f"1d{s}" for s in dice)
|
||||
|
||||
# Return the puzzle.
|
||||
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiceConfig:
|
||||
"""Configuration for dice puzzle generation"""
|
||||
|
||||
num_dice: int = 4
|
||||
max_dice_size: int = 20
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert self.num_dice >= 1, "num_dice must be gte 1"
|
||||
assert self.max_dice_size >= 2, "max_dice_size must be gte 2"
|
||||
|
||||
|
||||
class DiceDataset(ProceduralDataset):
|
||||
"""Generates Dice-based puzzles with configurable parameters"""
|
||||
|
||||
def __init__(self, config: DiceConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Dice task
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the task description
|
||||
- answer: str, a solution string
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng)
|
||||
puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]."
|
||||
answer_str = f"{puzzle['num']}/{puzzle['den']}"
|
||||
|
||||
return {
|
||||
"question": puzzle_str,
|
||||
"answer": answer_str,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Determine if the solution provided solves the Dice task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
|
||||
|
||||
register_dataset("dice", DiceDataset, DiceConfig)
|
||||
14
reasoning_gym/coaching/__init__.py
Normal file
14
reasoning_gym/coaching/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition
|
||||
from .base_curriculum import BaseCurriculum
|
||||
from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats
|
||||
|
||||
__all__ = [
|
||||
"AttributeType",
|
||||
"AttributeDefinition",
|
||||
"RangeAttributeDefinition",
|
||||
"BaseCurriculum",
|
||||
"Coach",
|
||||
"ScoreBoard",
|
||||
"GroupedScores",
|
||||
"ScoreStats",
|
||||
]
|
||||
73
reasoning_gym/coaching/attributes.py
Normal file
73
reasoning_gym/coaching/attributes.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from collections import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class AttributeType(StrEnum):
|
||||
"""Defines how attribute levels should be interpreted"""
|
||||
|
||||
STATIC = "static" # Each level is independent
|
||||
UBOUND = "ubound" # Each level is an upper bound
|
||||
APPEND = "append" # Each level includes all previous levels
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class AttributeDefinition:
|
||||
name: str
|
||||
levels: list
|
||||
default_level: int
|
||||
description: Optional[str] = None
|
||||
attr_type: AttributeType = AttributeType.STATIC # Default to static
|
||||
min_value: Optional[int | float] = None # Minimum value for numeric attributes
|
||||
|
||||
def validate_level(self, level: int, curriculum: str) -> None:
|
||||
"""
|
||||
Validate that a level is valid for an attribute.
|
||||
Args:
|
||||
level: Level to validate
|
||||
curriculum: Name of the curriculum
|
||||
Raises:
|
||||
ValueError: If level is invalid
|
||||
"""
|
||||
# TODO: if > set as [-1], if <0 set as [0]
|
||||
if not 0 <= level < len(self.levels):
|
||||
raise ValueError(
|
||||
f"Invalid level: {level} for attribute '{curriculum}.{self.name}'. "
|
||||
f"Must be between 0 and {len(self.levels)-1}"
|
||||
)
|
||||
|
||||
def get_level_value(self, level: int, curriculum: str) -> Any:
|
||||
"""
|
||||
Get the value for an attribute at a specific level based on its type.
|
||||
Args:
|
||||
attr: The attribute definition
|
||||
level: Level to get value for
|
||||
Returns:
|
||||
Value for the attribute based on its level and type
|
||||
"""
|
||||
if self.attr_type == AttributeType.STATIC:
|
||||
return self.levels[level]
|
||||
elif self.attr_type == AttributeType.UBOUND:
|
||||
return self.levels[level]
|
||||
elif self.attr_type == AttributeType.APPEND:
|
||||
return self.levels[: level + 1]
|
||||
|
||||
raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ScalarAttributeDefinition(AttributeDefinition):
|
||||
field_name: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class RangeAttributeDefinition(AttributeDefinition):
|
||||
lower_field_name: str
|
||||
upper_field_name: str
|
||||
|
||||
def get_level_value(self, level: int, curriculum: str) -> Any:
|
||||
v = super().get_level_value(level, curriculum)
|
||||
if not isinstance(v, abc.Iterable):
|
||||
return [v]
|
||||
return v
|
||||
108
reasoning_gym/coaching/base_curriculum.py
Normal file
108
reasoning_gym/coaching/base_curriculum.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
from typing import Any, Iterable, Optional
|
||||
|
||||
from ..factory import ConfigT
|
||||
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
|
||||
|
||||
class BaseCurriculum:
|
||||
def __init__(self, name: str, config_cls: ConfigT):
|
||||
self.name = name
|
||||
self._config_cls = config_cls
|
||||
self._attributes: dict[str, AttributeDefinition] = {}
|
||||
self._current_levels: dict[str, int] = {}
|
||||
|
||||
def generate_configuration(self, defaults: Optional[dict[str, any]] = None) -> ConfigT:
|
||||
config_args = defaults.copy() if defaults is not None else {}
|
||||
for attr in self._attributes.values():
|
||||
if isinstance(attr, RangeAttributeDefinition):
|
||||
vals = self.get_attr_value(attr.name)
|
||||
config_args[attr.lower_field_name] = min(vals)
|
||||
config_args[attr.upper_field_name] = max(vals)
|
||||
elif isinstance(attr, ScalarAttributeDefinition):
|
||||
val = self.get_attr_value(attr.name)
|
||||
config_args[attr.field_name] = val
|
||||
print(config_args)
|
||||
return self._config_cls(**config_args)
|
||||
|
||||
@property
|
||||
def attributes(self) -> dict[str, AttributeDefinition]:
|
||||
"""Get the curriculum's attributes"""
|
||||
return self._attributes
|
||||
|
||||
def get_attribute(self, attr_name: str) -> AttributeDefinition:
|
||||
if attr_name not in self._attributes:
|
||||
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
|
||||
return self._attributes[attr_name]
|
||||
|
||||
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None:
|
||||
for attr in attrs:
|
||||
if attr.name in self.attributes:
|
||||
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
|
||||
self.attributes[attr.name] = attr
|
||||
|
||||
def get_attr_level(self, attr_name: str) -> int:
|
||||
"""
|
||||
Get the current level for an attribute.
|
||||
Args:
|
||||
attr_name: Name of the attribute
|
||||
Returns:
|
||||
Current level index for the attribute
|
||||
"""
|
||||
attr = self.get_attribute(attr_name)
|
||||
return self._current_levels.get(attr_name, attr.default_level)
|
||||
|
||||
def get_attr_value(self, attr_name: str) -> Any:
|
||||
"""
|
||||
Get the current value for an attribute based on its level.
|
||||
Args:
|
||||
attr_name: Name of the attribute
|
||||
Returns:
|
||||
Current value for the attribute based on its level and type
|
||||
"""
|
||||
attr = self.get_attribute(attr_name)
|
||||
level = self.get_attr_level(attr_name)
|
||||
return attr.get_level_value(level, curriculum=self.name)
|
||||
|
||||
def set_attr_level(self, attr_name: str, level: int) -> None:
|
||||
"""
|
||||
Set the level for an attribute.
|
||||
Args:
|
||||
attr_name: Name of the attribute
|
||||
level: New level index
|
||||
"""
|
||||
attr = self.get_attribute(attr_name)
|
||||
attr.validate_level(level, curriculum=self.name)
|
||||
self._current_levels[attr_name] = level
|
||||
|
||||
def increment_attr_level(self, attr_name: str) -> bool:
|
||||
"""
|
||||
Increment the level of an attribute if possible.
|
||||
Args:
|
||||
attr_name: Name of the attribute to increment
|
||||
Returns:
|
||||
bool: True if level was incremented, False if already at max level
|
||||
Raises:
|
||||
KeyError: If attribute doesn't exist
|
||||
"""
|
||||
attr = self.get_attribute(attr_name)
|
||||
current_level = self.get_attr_level(attr_name)
|
||||
if current_level < len(attr.levels) - 1:
|
||||
self.set_attr_level(attr_name, current_level + 1)
|
||||
return True
|
||||
return False
|
||||
|
||||
def decrement_attr_level(self, attr_name: str) -> bool:
|
||||
"""
|
||||
Decrement the level of an attribute if possible.
|
||||
Args:
|
||||
attr_name: Name of the attribute to decrement
|
||||
Returns:
|
||||
bool: True if level was decremented, False if already at min level
|
||||
Raises:
|
||||
KeyError: If attribute doesn't exist
|
||||
"""
|
||||
current_level = self.get_attr_level(attr_name)
|
||||
if current_level > 0:
|
||||
self.set_attr_level(attr_name, current_level - 1)
|
||||
return True
|
||||
return False
|
||||
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
from statistics import mean, stdev
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .dataset import ProceduralDataset
|
||||
from ..dataset import ProceduralDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -5,6 +5,7 @@ Cognition tasks for training reasoning capabilities.
|
|||
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
|
||||
from .figlet_fonts import FigletFontConfig, FigletFontDataset
|
||||
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
|
||||
from .rectangle_count import RectangleCountConfig, RectangleCountDataset
|
||||
from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -16,4 +17,6 @@ __all__ = [
|
|||
"NumberSequenceDataset",
|
||||
"RubiksCubeConfig",
|
||||
"RubiksCubeDataset",
|
||||
"RectangleCountConfig",
|
||||
"RectangleCountDataset",
|
||||
]
|
||||
|
|
|
|||
135
reasoning_gym/cognition/rectangle_count.py
Normal file
135
reasoning_gym/cognition/rectangle_count.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
def draw_rectangles_with_overlap(n, width, height, rng):
|
||||
# Create a grid that holds a count of how many times a cell is drawn.
|
||||
grid = [[0 for _ in range(width)] for _ in range(height)]
|
||||
rectangles = []
|
||||
|
||||
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
|
||||
attempts = 0
|
||||
|
||||
while len(rectangles) < n and attempts < max_attempts:
|
||||
attempts += 1
|
||||
# Ensure minimum width and height of 3.
|
||||
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
|
||||
# Similarly, bottom must be at least top + 2.
|
||||
left = rng.randint(0, width - 3)
|
||||
right = rng.randint(left + 2, width - 1)
|
||||
top = rng.randint(0, height - 3)
|
||||
bottom = rng.randint(top + 2, height - 1)
|
||||
|
||||
# Prepare a list of all the cells that would be updated.
|
||||
cells_to_update = []
|
||||
|
||||
# Top edge:
|
||||
for col in range(left, right + 1):
|
||||
cells_to_update.append((top, col))
|
||||
# Bottom edge:
|
||||
for col in range(left, right + 1):
|
||||
cells_to_update.append((bottom, col))
|
||||
# Left edge (excluding corners already drawn):
|
||||
for row in range(top + 1, bottom):
|
||||
cells_to_update.append((row, left))
|
||||
# Right edge (excluding corners already drawn):
|
||||
for row in range(top + 1, bottom):
|
||||
cells_to_update.append((row, right))
|
||||
|
||||
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
|
||||
conflict = False
|
||||
for r, c in cells_to_update:
|
||||
if grid[r][c] >= 2:
|
||||
conflict = True
|
||||
break
|
||||
if conflict:
|
||||
continue # Skip this rectangle candidate
|
||||
|
||||
# No conflict: update the grid counts.
|
||||
for r, c in cells_to_update:
|
||||
grid[r][c] += 1
|
||||
|
||||
# Save the rectangle (stored as (left, right, top, bottom)).
|
||||
rectangles.append((left, right, top, bottom))
|
||||
|
||||
if len(rectangles) < n:
|
||||
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
|
||||
|
||||
# Print the grid.
|
||||
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
|
||||
lines = ""
|
||||
for row in grid:
|
||||
line = "".join(" " if count == 0 else ("#" if count == 1 else "█") for count in row)
|
||||
lines = lines + line + "\n"
|
||||
return lines, len(rectangles)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RectangleCountConfig:
|
||||
"""Configuration for RectangleCount puzzle generation"""
|
||||
|
||||
max_rectangles: int = 10
|
||||
width: int = 80
|
||||
height: int = 80
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert self.width >= 10, "width must be gte 10"
|
||||
assert self.height >= 10, "height must be gte 10"
|
||||
|
||||
|
||||
class RectangleCountDataset(ProceduralDataset):
|
||||
"""Generates [RectangleCount Puzzles](https://en.wikipedia.org/wiki/RectangleCount_Puzzle) with configurable parameters"""
|
||||
|
||||
def __init__(self, config: RectangleCountConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single RectangleCount task
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the task description
|
||||
- answer: str, a solution string
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
target = rng.randint(1, self.config.max_rectangles)
|
||||
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
|
||||
|
||||
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. \n\n {puzzle}"
|
||||
|
||||
return {
|
||||
"question": puzz,
|
||||
"answer": str(answer),
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Determine if the solution provided solves the RectangleCount task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
return 0.0
|
||||
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||
return 0.01
|
||||
else:
|
||||
return 1.0 # Yay
|
||||
|
||||
|
||||
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
||||
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
|
||||
|
||||
|
||||
def test_chain_sum_config_validation():
|
||||
|
|
@ -127,3 +128,30 @@ def test_chain_sum_iteration():
|
|||
first_items = list(dataset)
|
||||
second_items = list(dataset)
|
||||
assert first_items == second_items, "Multiple iterations should yield same items"
|
||||
|
||||
|
||||
def test_chain_sum_curriculum():
|
||||
curriculum = ChainSumCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: ChainSumConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
|
||||
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
|
||||
|
||||
# test incrementing attribute levels for num_terms & num_digits attributes
|
||||
curriculum.increment_attr_level("num_terms")
|
||||
curriculum.increment_attr_level("num_digits")
|
||||
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
|
||||
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
|
||||
|
||||
# test decrementing attribute level for num_digits again
|
||||
curriculum.decrement_attr_level("num_digits")
|
||||
|
||||
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
|
||||
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
|
||||
|
|
|
|||
88
tests/test_count_primes.py
Normal file
88
tests/test_count_primes.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""Tests for Count Primes questions generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.count_primes import CountPrimesConfig, CountPrimesDataset
|
||||
|
||||
|
||||
def test_count_primes_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = CountPrimesConfig(max_n=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CountPrimesConfig(max_n=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_count_primes_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CountPrimesConfig(seed=42, size=10)
|
||||
dataset1 = CountPrimesDataset(config)
|
||||
dataset2 = CountPrimesDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_count_primes_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = CountPrimesConfig(max_n=10, size=10, seed=42)
|
||||
dataset = CountPrimesDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "start" in item["metadata"]
|
||||
assert "end" in item["metadata"]
|
||||
assert "primes" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
|
||||
start = item["metadata"]["start"]
|
||||
end = item["metadata"]["end"]
|
||||
primes = item["metadata"]["primes"]
|
||||
|
||||
assert start <= end
|
||||
assert len(primes) <= end - start + 1
|
||||
|
||||
|
||||
def test_count_primes_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = CountPrimesConfig(size=5, seed=42)
|
||||
dataset = CountPrimesDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_count_primes_answer():
|
||||
"""Test the _get_primes method"""
|
||||
config = CountPrimesConfig(seed=42)
|
||||
dataset = CountPrimesDataset(config)
|
||||
|
||||
# Base cases
|
||||
assert dataset._get_primes(n=0) == []
|
||||
assert dataset._get_primes(n=1) == []
|
||||
assert dataset._get_primes(n=2) == [False, False]
|
||||
|
||||
# Test primes up to 10
|
||||
primes = dataset._get_primes(n=11)
|
||||
assert primes[2] == True
|
||||
assert primes[3] == True
|
||||
assert primes[4] == False
|
||||
assert primes[5] == True
|
||||
assert primes[6] == False
|
||||
assert primes[7] == True
|
||||
assert primes[8] == False
|
||||
assert primes[9] == False
|
||||
assert primes[10] == False
|
||||
35
tests/test_dice.py
Normal file
35
tests/test_dice.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset
|
||||
|
||||
|
||||
def test_dice():
|
||||
"""Test basic properties and solution of generated items"""
|
||||
config = DiceConfig(seed=42, size=50, num_dice=8, max_dice_size=24)
|
||||
dataset = DiceDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# Easy
|
||||
config = DiceConfig(seed=42, size=1, num_dice=1, max_dice_size=2)
|
||||
dataset = DiceDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
|
||||
# Hard
|
||||
config = DiceConfig(seed=42, size=1, num_dice=40, max_dice_size=40)
|
||||
dataset = DiceDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
19
tests/test_rectangle_count.py
Normal file
19
tests/test_rectangle_count.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.cognition.rectangle_count import RectangleCountConfig, RectangleCountDataset
|
||||
|
||||
|
||||
def test_dice():
|
||||
"""Test basic properties and solution of generated items"""
|
||||
config = RectangleCountConfig(seed=42, size=50, max_rectangles=15, width=40, height=40)
|
||||
dataset = RectangleCountDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||
Loading…
Add table
Add a link
Reference in a new issue