mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge branch 'main' into rich/ab
This commit is contained in:
commit
27938ce13a
16 changed files with 759 additions and 12 deletions
|
|
@ -10,6 +10,7 @@ from .ab import ABConfig, ABDataset
|
||||||
from .base_conversion import BaseConversionConfig, BaseConversionDataset
|
from .base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||||
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
|
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
|
||||||
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
||||||
|
from .count_primes import CountPrimesConfig, CountPrimesDataset
|
||||||
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
|
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
|
||||||
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
|
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
|
||||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset
|
from .letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||||
|
|
@ -69,4 +70,6 @@ __all__ = [
|
||||||
"BinaryMatrixDataset",
|
"BinaryMatrixDataset",
|
||||||
"ABConfig",
|
"ABConfig",
|
||||||
"ABDataset",
|
"ABDataset",
|
||||||
|
"CountPrimesConfig",
|
||||||
|
"CountPrimesDataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
63
reasoning_gym/algorithmic/count_primes.py
Normal file
63
reasoning_gym/algorithmic/count_primes.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Count prime numbers in a given interval.
|
||||||
|
|
||||||
|
Solution obtained with Sieve of Eratosthenes:
|
||||||
|
https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CountPrimesConfig:
|
||||||
|
"""Configuration for Count Primes dataset generation"""
|
||||||
|
|
||||||
|
max_n: int = 10_000 # Upper bound for the interval
|
||||||
|
|
||||||
|
size: int = 500 # Virtual dataset size
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert 1 <= self.max_n, "max_n must be at least 1"
|
||||||
|
|
||||||
|
|
||||||
|
class CountPrimesDataset(ProceduralDataset):
|
||||||
|
"""Generates Count Primes exercises with configurable difficulty"""
|
||||||
|
|
||||||
|
def __init__(self, config: CountPrimesConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
self.primes = self._get_primes(config.max_n + 1)
|
||||||
|
|
||||||
|
def _get_primes(self, n: int) -> list[bool]:
|
||||||
|
if n <= 1:
|
||||||
|
return []
|
||||||
|
primes = [True] * n
|
||||||
|
primes[0] = primes[1] = False
|
||||||
|
for i in range(2, int(math.sqrt(n)) + 1):
|
||||||
|
if primes[i]:
|
||||||
|
for j in range(2 * i, n, i):
|
||||||
|
primes[j] = False
|
||||||
|
return primes
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single Count Primes question"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
start = rng.randint(1, self.config.max_n)
|
||||||
|
end = rng.randint(start, self.config.max_n)
|
||||||
|
primes = self.primes[start : end + 1]
|
||||||
|
answer = sum(primes)
|
||||||
|
return {
|
||||||
|
"question": QUESTION_TEMPLATE.format(start=start, end=end),
|
||||||
|
"answer": str(answer),
|
||||||
|
"metadata": {"start": start, "end": end, "primes": primes, "solution": answer},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig)
|
||||||
|
|
@ -60,22 +60,16 @@ class RotateMatrixDataset(ProceduralDataset):
|
||||||
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
||||||
return matrix
|
return matrix
|
||||||
|
|
||||||
|
def _rot90(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||||
|
"""quarter clockwise rotation"""
|
||||||
|
return [list(row) for row in zip(*matrix[::-1])]
|
||||||
|
|
||||||
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
|
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
|
||||||
"""Rotate the matrix K times by 90 degrees clockwise"""
|
"""Rotate the matrix K times by 90 degrees clockwise"""
|
||||||
num_rotations %= 4
|
num_rotations %= 4
|
||||||
n = len(matrix)
|
|
||||||
output = deepcopy(matrix)
|
output = deepcopy(matrix)
|
||||||
|
|
||||||
for _ in range(num_rotations):
|
for _ in range(num_rotations):
|
||||||
for l in range(n // 2):
|
output = self._rot90(output)
|
||||||
for i in range(l, n - 1 - l):
|
|
||||||
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
|
|
||||||
output[n - 1 - i][l],
|
|
||||||
output[l][i],
|
|
||||||
output[i][n - 1 - l],
|
|
||||||
output[n - 1 - l][n - 1 - i],
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
|
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
|
||||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||||
from .chain_sum import ChainSum, ChainSumConfig
|
from .chain_sum import ChainSum, ChainSumConfig
|
||||||
from .count_bits import CountBitsConfig, CountBitsDataset
|
from .count_bits import CountBitsConfig, CountBitsDataset
|
||||||
|
from .dice import DiceConfig, DiceDataset
|
||||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||||
from .gcd import GCDConfig, GCDDataset
|
from .gcd import GCDConfig, GCDDataset
|
||||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||||
|
|
@ -38,4 +39,6 @@ __all__ = [
|
||||||
"TimeIntervalsDataset",
|
"TimeIntervalsDataset",
|
||||||
"CountBitsConfig",
|
"CountBitsConfig",
|
||||||
"CountBitsDataset",
|
"CountBitsDataset",
|
||||||
|
"DiceConfig",
|
||||||
|
"DiceDataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -112,5 +113,36 @@ class ChainSum(ProceduralDataset):
|
||||||
return expression, result
|
return expression, result
|
||||||
|
|
||||||
|
|
||||||
|
class ChainSumCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(ChainSumCurriculum.__name__, ChainSumConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_terms",
|
||||||
|
levels=[2, 3, 4, 5],
|
||||||
|
default_level=0, # Start with 2 terms
|
||||||
|
description="Maximum number of terms in the expression",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=2, # Ensure at least 2 terms
|
||||||
|
lower_field_name="min_terms",
|
||||||
|
upper_field_name="max_terms",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_digits",
|
||||||
|
levels=[1, 2, 4, 10],
|
||||||
|
default_level=0, # Start with 1-digit numbers
|
||||||
|
description="Number of digits in each operand",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=1, # Ensure numbers are at least 1 digit
|
||||||
|
lower_field_name="min_digits",
|
||||||
|
upper_field_name="max_digits",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Register the dataset
|
# Register the dataset
|
||||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
||||||
|
|
|
||||||
149
reasoning_gym/arithmetic/dice.py
Normal file
149
reasoning_gym/arithmetic/dice.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce
|
||||||
|
from math import gcd
|
||||||
|
from random import Random
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def compute_probability(dice, target):
|
||||||
|
"""
|
||||||
|
Computes the probability of rolling a total of at least `target`
|
||||||
|
when rolling dice specified in the list `dice`. Each element in dice
|
||||||
|
is the number of sides on that die. The computation is done via dynamic programming.
|
||||||
|
Returns the probability as a fraction (numerator, denominator) and as a float.
|
||||||
|
"""
|
||||||
|
# dp[i][s] = number of ways to get sum s using the first i dice.
|
||||||
|
# We use only one dictionary for the current dp state.
|
||||||
|
dp = {0: 1}
|
||||||
|
for sides in dice:
|
||||||
|
new_dp = {}
|
||||||
|
for current_sum, count in dp.items():
|
||||||
|
# Each die gives a number from 1 to sides.
|
||||||
|
for face in range(1, sides + 1):
|
||||||
|
new_sum = current_sum + face
|
||||||
|
new_dp[new_sum] = new_dp.get(new_sum, 0) + count
|
||||||
|
dp = new_dp
|
||||||
|
|
||||||
|
total_outcomes = reduce(lambda a, b: a * b, dice, 1)
|
||||||
|
ways = sum(count for s, count in dp.items() if s >= target)
|
||||||
|
|
||||||
|
# Simplify the fraction (ways / total_outcomes)
|
||||||
|
def simplify(n, d):
|
||||||
|
common = gcd(n, d)
|
||||||
|
return n // common, d // common
|
||||||
|
|
||||||
|
frac = simplify(ways, total_outcomes)
|
||||||
|
return frac, ways / total_outcomes
|
||||||
|
|
||||||
|
|
||||||
|
def generate_puzzle(num_dice, max_dice_size, rng):
|
||||||
|
"""
|
||||||
|
Generates a puzzle:
|
||||||
|
- It forces one die to have max_dice_size.
|
||||||
|
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
|
||||||
|
- The dice are then shuffled.
|
||||||
|
- The target total is chosen roughly in the middle (but you can adjust the method).
|
||||||
|
|
||||||
|
It then computes the probability of rolling a total at least the target.
|
||||||
|
Finally, it prints out the puzzle statement and the answer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Guarantee one die is the maximum.
|
||||||
|
dice = [max_dice_size]
|
||||||
|
for _ in range(num_dice - 1):
|
||||||
|
# Choose a die size randomly from 2 up to max_dice_size-1.
|
||||||
|
# (If max_dice_size == 2 then all dice are 2-sided.)
|
||||||
|
if max_dice_size > 2:
|
||||||
|
die = rng.randint(2, max_dice_size - 1)
|
||||||
|
else:
|
||||||
|
die = 2
|
||||||
|
dice.append(die)
|
||||||
|
|
||||||
|
# Optionally, sort dice in descending order (as is common in puzzles)
|
||||||
|
dice.sort(reverse=True)
|
||||||
|
|
||||||
|
# Compute minimum and maximum possible totals.
|
||||||
|
min_total = num_dice # each die gives at least 1
|
||||||
|
max_total = sum(dice)
|
||||||
|
|
||||||
|
# Choose a target total. For an interesting puzzle,
|
||||||
|
# we choose a target somewhere in the middle third of the range.
|
||||||
|
low_target = min_total + (max_total - min_total) // 3
|
||||||
|
high_target = min_total + 2 * (max_total - min_total) // 3
|
||||||
|
target = rng.randint(low_target, high_target)
|
||||||
|
|
||||||
|
# Compute probability.
|
||||||
|
(num, den), prob = compute_probability(dice, target)
|
||||||
|
|
||||||
|
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
|
||||||
|
dice_str = ", ".join(f"1d{s}" for s in dice)
|
||||||
|
|
||||||
|
# Return the puzzle.
|
||||||
|
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiceConfig:
|
||||||
|
"""Configuration for dice puzzle generation"""
|
||||||
|
|
||||||
|
num_dice: int = 4
|
||||||
|
max_dice_size: int = 20
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.num_dice >= 1, "num_dice must be gte 1"
|
||||||
|
assert self.max_dice_size >= 2, "max_dice_size must be gte 2"
|
||||||
|
|
||||||
|
|
||||||
|
class DiceDataset(ProceduralDataset):
|
||||||
|
"""Generates Dice-based puzzles with configurable parameters"""
|
||||||
|
|
||||||
|
def __init__(self, config: DiceConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single Dice task
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the task description
|
||||||
|
- answer: str, a solution string
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng)
|
||||||
|
puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]."
|
||||||
|
answer_str = f"{puzzle['num']}/{puzzle['den']}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": puzzle_str,
|
||||||
|
"answer": answer_str,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
"""Determine if the solution provided solves the Dice task.
|
||||||
|
|
||||||
|
The function awards 1.0 for a correct answer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
answer (Optional[str]): The user's answer.
|
||||||
|
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The computed score between 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if answer == None:
|
||||||
|
return 0.0
|
||||||
|
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||||
|
return 0.01
|
||||||
|
else:
|
||||||
|
return 1.0 # Yay
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("dice", DiceDataset, DiceConfig)
|
||||||
14
reasoning_gym/coaching/__init__.py
Normal file
14
reasoning_gym/coaching/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition
|
||||||
|
from .base_curriculum import BaseCurriculum
|
||||||
|
from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AttributeType",
|
||||||
|
"AttributeDefinition",
|
||||||
|
"RangeAttributeDefinition",
|
||||||
|
"BaseCurriculum",
|
||||||
|
"Coach",
|
||||||
|
"ScoreBoard",
|
||||||
|
"GroupedScores",
|
||||||
|
"ScoreStats",
|
||||||
|
]
|
||||||
73
reasoning_gym/coaching/attributes.py
Normal file
73
reasoning_gym/coaching/attributes.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
from collections import abc
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class AttributeType(StrEnum):
|
||||||
|
"""Defines how attribute levels should be interpreted"""
|
||||||
|
|
||||||
|
STATIC = "static" # Each level is independent
|
||||||
|
UBOUND = "ubound" # Each level is an upper bound
|
||||||
|
APPEND = "append" # Each level includes all previous levels
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class AttributeDefinition:
|
||||||
|
name: str
|
||||||
|
levels: list
|
||||||
|
default_level: int
|
||||||
|
description: Optional[str] = None
|
||||||
|
attr_type: AttributeType = AttributeType.STATIC # Default to static
|
||||||
|
min_value: Optional[int | float] = None # Minimum value for numeric attributes
|
||||||
|
|
||||||
|
def validate_level(self, level: int, curriculum: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate that a level is valid for an attribute.
|
||||||
|
Args:
|
||||||
|
level: Level to validate
|
||||||
|
curriculum: Name of the curriculum
|
||||||
|
Raises:
|
||||||
|
ValueError: If level is invalid
|
||||||
|
"""
|
||||||
|
# TODO: if > set as [-1], if <0 set as [0]
|
||||||
|
if not 0 <= level < len(self.levels):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid level: {level} for attribute '{curriculum}.{self.name}'. "
|
||||||
|
f"Must be between 0 and {len(self.levels)-1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_level_value(self, level: int, curriculum: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the value for an attribute at a specific level based on its type.
|
||||||
|
Args:
|
||||||
|
attr: The attribute definition
|
||||||
|
level: Level to get value for
|
||||||
|
Returns:
|
||||||
|
Value for the attribute based on its level and type
|
||||||
|
"""
|
||||||
|
if self.attr_type == AttributeType.STATIC:
|
||||||
|
return self.levels[level]
|
||||||
|
elif self.attr_type == AttributeType.UBOUND:
|
||||||
|
return self.levels[level]
|
||||||
|
elif self.attr_type == AttributeType.APPEND:
|
||||||
|
return self.levels[: level + 1]
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class ScalarAttributeDefinition(AttributeDefinition):
|
||||||
|
field_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class RangeAttributeDefinition(AttributeDefinition):
|
||||||
|
lower_field_name: str
|
||||||
|
upper_field_name: str
|
||||||
|
|
||||||
|
def get_level_value(self, level: int, curriculum: str) -> Any:
|
||||||
|
v = super().get_level_value(level, curriculum)
|
||||||
|
if not isinstance(v, abc.Iterable):
|
||||||
|
return [v]
|
||||||
|
return v
|
||||||
108
reasoning_gym/coaching/base_curriculum.py
Normal file
108
reasoning_gym/coaching/base_curriculum.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
from typing import Any, Iterable, Optional
|
||||||
|
|
||||||
|
from ..factory import ConfigT
|
||||||
|
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCurriculum:
|
||||||
|
def __init__(self, name: str, config_cls: ConfigT):
|
||||||
|
self.name = name
|
||||||
|
self._config_cls = config_cls
|
||||||
|
self._attributes: dict[str, AttributeDefinition] = {}
|
||||||
|
self._current_levels: dict[str, int] = {}
|
||||||
|
|
||||||
|
def generate_configuration(self, defaults: Optional[dict[str, any]] = None) -> ConfigT:
|
||||||
|
config_args = defaults.copy() if defaults is not None else {}
|
||||||
|
for attr in self._attributes.values():
|
||||||
|
if isinstance(attr, RangeAttributeDefinition):
|
||||||
|
vals = self.get_attr_value(attr.name)
|
||||||
|
config_args[attr.lower_field_name] = min(vals)
|
||||||
|
config_args[attr.upper_field_name] = max(vals)
|
||||||
|
elif isinstance(attr, ScalarAttributeDefinition):
|
||||||
|
val = self.get_attr_value(attr.name)
|
||||||
|
config_args[attr.field_name] = val
|
||||||
|
print(config_args)
|
||||||
|
return self._config_cls(**config_args)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attributes(self) -> dict[str, AttributeDefinition]:
|
||||||
|
"""Get the curriculum's attributes"""
|
||||||
|
return self._attributes
|
||||||
|
|
||||||
|
def get_attribute(self, attr_name: str) -> AttributeDefinition:
|
||||||
|
if attr_name not in self._attributes:
|
||||||
|
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
|
||||||
|
return self._attributes[attr_name]
|
||||||
|
|
||||||
|
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None:
|
||||||
|
for attr in attrs:
|
||||||
|
if attr.name in self.attributes:
|
||||||
|
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
|
||||||
|
self.attributes[attr.name] = attr
|
||||||
|
|
||||||
|
def get_attr_level(self, attr_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the current level for an attribute.
|
||||||
|
Args:
|
||||||
|
attr_name: Name of the attribute
|
||||||
|
Returns:
|
||||||
|
Current level index for the attribute
|
||||||
|
"""
|
||||||
|
attr = self.get_attribute(attr_name)
|
||||||
|
return self._current_levels.get(attr_name, attr.default_level)
|
||||||
|
|
||||||
|
def get_attr_value(self, attr_name: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the current value for an attribute based on its level.
|
||||||
|
Args:
|
||||||
|
attr_name: Name of the attribute
|
||||||
|
Returns:
|
||||||
|
Current value for the attribute based on its level and type
|
||||||
|
"""
|
||||||
|
attr = self.get_attribute(attr_name)
|
||||||
|
level = self.get_attr_level(attr_name)
|
||||||
|
return attr.get_level_value(level, curriculum=self.name)
|
||||||
|
|
||||||
|
def set_attr_level(self, attr_name: str, level: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the level for an attribute.
|
||||||
|
Args:
|
||||||
|
attr_name: Name of the attribute
|
||||||
|
level: New level index
|
||||||
|
"""
|
||||||
|
attr = self.get_attribute(attr_name)
|
||||||
|
attr.validate_level(level, curriculum=self.name)
|
||||||
|
self._current_levels[attr_name] = level
|
||||||
|
|
||||||
|
def increment_attr_level(self, attr_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Increment the level of an attribute if possible.
|
||||||
|
Args:
|
||||||
|
attr_name: Name of the attribute to increment
|
||||||
|
Returns:
|
||||||
|
bool: True if level was incremented, False if already at max level
|
||||||
|
Raises:
|
||||||
|
KeyError: If attribute doesn't exist
|
||||||
|
"""
|
||||||
|
attr = self.get_attribute(attr_name)
|
||||||
|
current_level = self.get_attr_level(attr_name)
|
||||||
|
if current_level < len(attr.levels) - 1:
|
||||||
|
self.set_attr_level(attr_name, current_level + 1)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def decrement_attr_level(self, attr_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Decrement the level of an attribute if possible.
|
||||||
|
Args:
|
||||||
|
attr_name: Name of the attribute to decrement
|
||||||
|
Returns:
|
||||||
|
bool: True if level was decremented, False if already at min level
|
||||||
|
Raises:
|
||||||
|
KeyError: If attribute doesn't exist
|
||||||
|
"""
|
||||||
|
current_level = self.get_attr_level(attr_name)
|
||||||
|
if current_level > 0:
|
||||||
|
self.set_attr_level(attr_name, current_level - 1)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||||
from statistics import mean, stdev
|
from statistics import mean, stdev
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .dataset import ProceduralDataset
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -5,6 +5,7 @@ Cognition tasks for training reasoning capabilities.
|
||||||
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
|
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
|
||||||
from .figlet_fonts import FigletFontConfig, FigletFontDataset
|
from .figlet_fonts import FigletFontConfig, FigletFontDataset
|
||||||
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
|
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
|
||||||
|
from .rectangle_count import RectangleCountConfig, RectangleCountDataset
|
||||||
from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset
|
from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -16,4 +17,6 @@ __all__ = [
|
||||||
"NumberSequenceDataset",
|
"NumberSequenceDataset",
|
||||||
"RubiksCubeConfig",
|
"RubiksCubeConfig",
|
||||||
"RubiksCubeDataset",
|
"RubiksCubeDataset",
|
||||||
|
"RectangleCountConfig",
|
||||||
|
"RectangleCountDataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
135
reasoning_gym/cognition/rectangle_count.py
Normal file
135
reasoning_gym/cognition/rectangle_count.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def draw_rectangles_with_overlap(n, width, height, rng):
|
||||||
|
# Create a grid that holds a count of how many times a cell is drawn.
|
||||||
|
grid = [[0 for _ in range(width)] for _ in range(height)]
|
||||||
|
rectangles = []
|
||||||
|
|
||||||
|
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
while len(rectangles) < n and attempts < max_attempts:
|
||||||
|
attempts += 1
|
||||||
|
# Ensure minimum width and height of 3.
|
||||||
|
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
|
||||||
|
# Similarly, bottom must be at least top + 2.
|
||||||
|
left = rng.randint(0, width - 3)
|
||||||
|
right = rng.randint(left + 2, width - 1)
|
||||||
|
top = rng.randint(0, height - 3)
|
||||||
|
bottom = rng.randint(top + 2, height - 1)
|
||||||
|
|
||||||
|
# Prepare a list of all the cells that would be updated.
|
||||||
|
cells_to_update = []
|
||||||
|
|
||||||
|
# Top edge:
|
||||||
|
for col in range(left, right + 1):
|
||||||
|
cells_to_update.append((top, col))
|
||||||
|
# Bottom edge:
|
||||||
|
for col in range(left, right + 1):
|
||||||
|
cells_to_update.append((bottom, col))
|
||||||
|
# Left edge (excluding corners already drawn):
|
||||||
|
for row in range(top + 1, bottom):
|
||||||
|
cells_to_update.append((row, left))
|
||||||
|
# Right edge (excluding corners already drawn):
|
||||||
|
for row in range(top + 1, bottom):
|
||||||
|
cells_to_update.append((row, right))
|
||||||
|
|
||||||
|
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
|
||||||
|
conflict = False
|
||||||
|
for r, c in cells_to_update:
|
||||||
|
if grid[r][c] >= 2:
|
||||||
|
conflict = True
|
||||||
|
break
|
||||||
|
if conflict:
|
||||||
|
continue # Skip this rectangle candidate
|
||||||
|
|
||||||
|
# No conflict: update the grid counts.
|
||||||
|
for r, c in cells_to_update:
|
||||||
|
grid[r][c] += 1
|
||||||
|
|
||||||
|
# Save the rectangle (stored as (left, right, top, bottom)).
|
||||||
|
rectangles.append((left, right, top, bottom))
|
||||||
|
|
||||||
|
if len(rectangles) < n:
|
||||||
|
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
|
||||||
|
|
||||||
|
# Print the grid.
|
||||||
|
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
|
||||||
|
lines = ""
|
||||||
|
for row in grid:
|
||||||
|
line = "".join(" " if count == 0 else ("#" if count == 1 else "█") for count in row)
|
||||||
|
lines = lines + line + "\n"
|
||||||
|
return lines, len(rectangles)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RectangleCountConfig:
|
||||||
|
"""Configuration for RectangleCount puzzle generation"""
|
||||||
|
|
||||||
|
max_rectangles: int = 10
|
||||||
|
width: int = 80
|
||||||
|
height: int = 80
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.width >= 10, "width must be gte 10"
|
||||||
|
assert self.height >= 10, "height must be gte 10"
|
||||||
|
|
||||||
|
|
||||||
|
class RectangleCountDataset(ProceduralDataset):
|
||||||
|
"""Generates [RectangleCount Puzzles](https://en.wikipedia.org/wiki/RectangleCount_Puzzle) with configurable parameters"""
|
||||||
|
|
||||||
|
def __init__(self, config: RectangleCountConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single RectangleCount task
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the task description
|
||||||
|
- answer: str, a solution string
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
|
target = rng.randint(1, self.config.max_rectangles)
|
||||||
|
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
|
||||||
|
|
||||||
|
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. \n\n {puzzle}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": puzz,
|
||||||
|
"answer": str(answer),
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
"""Determine if the solution provided solves the RectangleCount task.
|
||||||
|
|
||||||
|
The function awards 1.0 for a correct answer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
answer (Optional[str]): The user's answer.
|
||||||
|
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The computed score between 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if answer == None:
|
||||||
|
return 0.0
|
||||||
|
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
|
||||||
|
return 0.01
|
||||||
|
else:
|
||||||
|
return 1.0 # Yay
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
||||||
|
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
|
||||||
|
|
||||||
|
|
||||||
def test_chain_sum_config_validation():
|
def test_chain_sum_config_validation():
|
||||||
|
|
@ -127,3 +128,30 @@ def test_chain_sum_iteration():
|
||||||
first_items = list(dataset)
|
first_items = list(dataset)
|
||||||
second_items = list(dataset)
|
second_items = list(dataset)
|
||||||
assert first_items == second_items, "Multiple iterations should yield same items"
|
assert first_items == second_items, "Multiple iterations should yield same items"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_sum_curriculum():
|
||||||
|
curriculum = ChainSumCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: ChainSumConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
|
||||||
|
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
|
||||||
|
|
||||||
|
# test incrementing attribute levels for num_terms & num_digits attributes
|
||||||
|
curriculum.increment_attr_level("num_terms")
|
||||||
|
curriculum.increment_attr_level("num_digits")
|
||||||
|
|
||||||
|
increased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
|
||||||
|
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
|
||||||
|
|
||||||
|
# test decrementing attribute level for num_digits again
|
||||||
|
curriculum.decrement_attr_level("num_digits")
|
||||||
|
|
||||||
|
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
|
||||||
|
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
|
||||||
|
|
|
||||||
88
tests/test_count_primes.py
Normal file
88
tests/test_count_primes.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
"""Tests for Count Primes questions generation"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.algorithmic.count_primes import CountPrimesConfig, CountPrimesDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_primes_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CountPrimesConfig(max_n=-1) # Negative not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CountPrimesConfig(max_n=0) # Zero not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_primes_dataset_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = CountPrimesConfig(seed=42, size=10)
|
||||||
|
dataset1 = CountPrimesDataset(config)
|
||||||
|
dataset2 = CountPrimesDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_primes_dataset_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = CountPrimesConfig(max_n=10, size=10, seed=42)
|
||||||
|
dataset = CountPrimesDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
# Check item structure
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata
|
||||||
|
assert "start" in item["metadata"]
|
||||||
|
assert "end" in item["metadata"]
|
||||||
|
assert "primes" in item["metadata"]
|
||||||
|
assert "solution" in item["metadata"]
|
||||||
|
|
||||||
|
start = item["metadata"]["start"]
|
||||||
|
end = item["metadata"]["end"]
|
||||||
|
primes = item["metadata"]["primes"]
|
||||||
|
|
||||||
|
assert start <= end
|
||||||
|
assert len(primes) <= end - start + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_primes_dataset_iteration():
|
||||||
|
"""Test that iteration respects dataset size"""
|
||||||
|
config = CountPrimesConfig(size=5, seed=42)
|
||||||
|
dataset = CountPrimesDataset(config)
|
||||||
|
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test multiple iterations yield same items
|
||||||
|
assert items == list(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_primes_answer():
|
||||||
|
"""Test the _get_primes method"""
|
||||||
|
config = CountPrimesConfig(seed=42)
|
||||||
|
dataset = CountPrimesDataset(config)
|
||||||
|
|
||||||
|
# Base cases
|
||||||
|
assert dataset._get_primes(n=0) == []
|
||||||
|
assert dataset._get_primes(n=1) == []
|
||||||
|
assert dataset._get_primes(n=2) == [False, False]
|
||||||
|
|
||||||
|
# Test primes up to 10
|
||||||
|
primes = dataset._get_primes(n=11)
|
||||||
|
assert primes[2] == True
|
||||||
|
assert primes[3] == True
|
||||||
|
assert primes[4] == False
|
||||||
|
assert primes[5] == True
|
||||||
|
assert primes[6] == False
|
||||||
|
assert primes[7] == True
|
||||||
|
assert primes[8] == False
|
||||||
|
assert primes[9] == False
|
||||||
|
assert primes[10] == False
|
||||||
35
tests/test_dice.py
Normal file
35
tests/test_dice.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_dice():
|
||||||
|
"""Test basic properties and solution of generated items"""
|
||||||
|
config = DiceConfig(seed=42, size=50, num_dice=8, max_dice_size=24)
|
||||||
|
dataset = DiceDataset(config)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Test the scoring
|
||||||
|
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||||
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
|
|
||||||
|
# Easy
|
||||||
|
config = DiceConfig(seed=42, size=1, num_dice=1, max_dice_size=2)
|
||||||
|
dataset = DiceDataset(config)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||||
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
|
|
||||||
|
# Hard
|
||||||
|
config = DiceConfig(seed=42, size=1, num_dice=40, max_dice_size=40)
|
||||||
|
dataset = DiceDataset(config)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||||
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
19
tests/test_rectangle_count.py
Normal file
19
tests/test_rectangle_count.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.cognition.rectangle_count import RectangleCountConfig, RectangleCountDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_dice():
|
||||||
|
"""Test basic properties and solution of generated items"""
|
||||||
|
config = RectangleCountConfig(seed=42, size=50, max_rectangles=15, width=40, height=40)
|
||||||
|
dataset = RectangleCountDataset(config)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Test the scoring
|
||||||
|
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||||
|
assert dataset.score_answer(answer=None, entry=item) == 0.0
|
||||||
Loading…
Add table
Add a link
Reference in a new issue