Merge branch 'main' into rich/ab

This commit is contained in:
Andreas Köpf 2025-02-11 23:34:48 +01:00 committed by GitHub
commit 27938ce13a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 759 additions and 12 deletions

View file

@ -10,6 +10,7 @@ from .ab import ABConfig, ABDataset
from .base_conversion import BaseConversionConfig, BaseConversionDataset from .base_conversion import BaseConversionConfig, BaseConversionDataset
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
from .count_primes import CountPrimesConfig, CountPrimesDataset
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset
@ -69,4 +70,6 @@ __all__ = [
"BinaryMatrixDataset", "BinaryMatrixDataset",
"ABConfig", "ABConfig",
"ABDataset", "ABDataset",
"CountPrimesConfig",
"CountPrimesDataset",
] ]

View file

@ -0,0 +1,63 @@
"""Count prime numbers in a given interval.
Solution obtained with Sieve of Eratosthenes:
https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
"""
import math
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?"""
@dataclass
class CountPrimesConfig:
"""Configuration for Count Primes dataset generation"""
max_n: int = 10_000 # Upper bound for the interval
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
class CountPrimesDataset(ProceduralDataset):
"""Generates Count Primes exercises with configurable difficulty"""
def __init__(self, config: CountPrimesConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.primes = self._get_primes(config.max_n + 1)
def _get_primes(self, n: int) -> list[bool]:
if n <= 1:
return []
primes = [True] * n
primes[0] = primes[1] = False
for i in range(2, int(math.sqrt(n)) + 1):
if primes[i]:
for j in range(2 * i, n, i):
primes[j] = False
return primes
def __getitem__(self, idx: int) -> dict:
"""Generate a single Count Primes question"""
rng = Random(self.seed + idx)
start = rng.randint(1, self.config.max_n)
end = rng.randint(start, self.config.max_n)
primes = self.primes[start : end + 1]
answer = sum(primes)
return {
"question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": str(answer),
"metadata": {"start": start, "end": end, "primes": primes, "solution": answer},
}
register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig)

View file

@ -60,22 +60,16 @@ class RotateMatrixDataset(ProceduralDataset):
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
return matrix return matrix
def _rot90(self, matrix: list[list[int]]) -> list[list[int]]:
"""quarter clockwise rotation"""
return [list(row) for row in zip(*matrix[::-1])]
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]: def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
"""Rotate the matrix K times by 90 degrees clockwise""" """Rotate the matrix K times by 90 degrees clockwise"""
num_rotations %= 4 num_rotations %= 4
n = len(matrix)
output = deepcopy(matrix) output = deepcopy(matrix)
for _ in range(num_rotations): for _ in range(num_rotations):
for l in range(n // 2): output = self._rot90(output)
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
output[n - 1 - i][l],
output[l][i],
output[i][n - 1 - l],
output[n - 1 - l][n - 1 - i],
)
return output return output
def _matrix_to_str(self, matrix: list[list[int]]) -> str: def _matrix_to_str(self, matrix: list[list[int]]) -> str:

View file

@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSum, ChainSumConfig from .chain_sum import ChainSum, ChainSumConfig
from .count_bits import CountBitsConfig, CountBitsDataset from .count_bits import CountBitsConfig, CountBitsDataset
from .dice import DiceConfig, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset from .gcd import GCDConfig, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
@ -38,4 +39,6 @@ __all__ = [
"TimeIntervalsDataset", "TimeIntervalsDataset",
"CountBitsConfig", "CountBitsConfig",
"CountBitsDataset", "CountBitsDataset",
"DiceConfig",
"DiceDataset",
] ]

View file

@ -2,6 +2,7 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -112,5 +113,36 @@ class ChainSum(ProceduralDataset):
return expression, result return expression, result
class ChainSumCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ChainSumCurriculum.__name__, ChainSumConfig)
# Define attributes
self._define_attributes(
(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2, # Ensure at least 2 terms
lower_field_name="min_terms",
upper_field_name="max_terms",
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 4, 10],
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure numbers are at least 1 digit
lower_field_name="min_digits",
upper_field_name="max_digits",
),
)
)
# Register the dataset # Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig) register_dataset("chain_sum", ChainSum, ChainSumConfig)

View file

@ -0,0 +1,149 @@
from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
def compute_probability(dice, target):
"""
Computes the probability of rolling a total of at least `target`
when rolling dice specified in the list `dice`. Each element in dice
is the number of sides on that die. The computation is done via dynamic programming.
Returns the probability as a fraction (numerator, denominator) and as a float.
"""
# dp[i][s] = number of ways to get sum s using the first i dice.
# We use only one dictionary for the current dp state.
dp = {0: 1}
for sides in dice:
new_dp = {}
for current_sum, count in dp.items():
# Each die gives a number from 1 to sides.
for face in range(1, sides + 1):
new_sum = current_sum + face
new_dp[new_sum] = new_dp.get(new_sum, 0) + count
dp = new_dp
total_outcomes = reduce(lambda a, b: a * b, dice, 1)
ways = sum(count for s, count in dp.items() if s >= target)
# Simplify the fraction (ways / total_outcomes)
def simplify(n, d):
common = gcd(n, d)
return n // common, d // common
frac = simplify(ways, total_outcomes)
return frac, ways / total_outcomes
def generate_puzzle(num_dice, max_dice_size, rng):
"""
Generates a puzzle:
- It forces one die to have max_dice_size.
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
- The dice are then shuffled.
- The target total is chosen roughly in the middle (but you can adjust the method).
It then computes the probability of rolling a total at least the target.
Finally, it prints out the puzzle statement and the answer.
"""
# Guarantee one die is the maximum.
dice = [max_dice_size]
for _ in range(num_dice - 1):
# Choose a die size randomly from 2 up to max_dice_size-1.
# (If max_dice_size == 2 then all dice are 2-sided.)
if max_dice_size > 2:
die = rng.randint(2, max_dice_size - 1)
else:
die = 2
dice.append(die)
# Optionally, sort dice in descending order (as is common in puzzles)
dice.sort(reverse=True)
# Compute minimum and maximum possible totals.
min_total = num_dice # each die gives at least 1
max_total = sum(dice)
# Choose a target total. For an interesting puzzle,
# we choose a target somewhere in the middle third of the range.
low_target = min_total + (max_total - min_total) // 3
high_target = min_total + 2 * (max_total - min_total) // 3
target = rng.randint(low_target, high_target)
# Compute probability.
(num, den), prob = compute_probability(dice, target)
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
dice_str = ", ".join(f"1d{s}" for s in dice)
# Return the puzzle.
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
@dataclass
class DiceConfig:
"""Configuration for dice puzzle generation"""
num_dice: int = 4
max_dice_size: int = 20
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert self.num_dice >= 1, "num_dice must be gte 1"
assert self.max_dice_size >= 2, "max_dice_size must be gte 2"
class DiceDataset(ProceduralDataset):
"""Generates Dice-based puzzles with configurable parameters"""
def __init__(self, config: DiceConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single Dice task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng)
puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]."
answer_str = f"{puzzle['num']}/{puzzle['den']}"
return {
"question": puzzle_str,
"answer": answer_str,
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves the Dice task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
register_dataset("dice", DiceDataset, DiceConfig)

View file

@ -0,0 +1,14 @@
from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition
from .base_curriculum import BaseCurriculum
from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats
__all__ = [
"AttributeType",
"AttributeDefinition",
"RangeAttributeDefinition",
"BaseCurriculum",
"Coach",
"ScoreBoard",
"GroupedScores",
"ScoreStats",
]

View file

@ -0,0 +1,73 @@
from collections import abc
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Optional
class AttributeType(StrEnum):
"""Defines how attribute levels should be interpreted"""
STATIC = "static" # Each level is independent
UBOUND = "ubound" # Each level is an upper bound
APPEND = "append" # Each level includes all previous levels
@dataclass(kw_only=True)
class AttributeDefinition:
name: str
levels: list
default_level: int
description: Optional[str] = None
attr_type: AttributeType = AttributeType.STATIC # Default to static
min_value: Optional[int | float] = None # Minimum value for numeric attributes
def validate_level(self, level: int, curriculum: str) -> None:
"""
Validate that a level is valid for an attribute.
Args:
level: Level to validate
curriculum: Name of the curriculum
Raises:
ValueError: If level is invalid
"""
# TODO: if > set as [-1], if <0 set as [0]
if not 0 <= level < len(self.levels):
raise ValueError(
f"Invalid level: {level} for attribute '{curriculum}.{self.name}'. "
f"Must be between 0 and {len(self.levels)-1}"
)
def get_level_value(self, level: int, curriculum: str) -> Any:
"""
Get the value for an attribute at a specific level based on its type.
Args:
attr: The attribute definition
level: Level to get value for
Returns:
Value for the attribute based on its level and type
"""
if self.attr_type == AttributeType.STATIC:
return self.levels[level]
elif self.attr_type == AttributeType.UBOUND:
return self.levels[level]
elif self.attr_type == AttributeType.APPEND:
return self.levels[: level + 1]
raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'")
@dataclass(kw_only=True)
class ScalarAttributeDefinition(AttributeDefinition):
field_name: str
@dataclass(kw_only=True)
class RangeAttributeDefinition(AttributeDefinition):
lower_field_name: str
upper_field_name: str
def get_level_value(self, level: int, curriculum: str) -> Any:
v = super().get_level_value(level, curriculum)
if not isinstance(v, abc.Iterable):
return [v]
return v

View file

@ -0,0 +1,108 @@
from typing import Any, Iterable, Optional
from ..factory import ConfigT
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
class BaseCurriculum:
def __init__(self, name: str, config_cls: ConfigT):
self.name = name
self._config_cls = config_cls
self._attributes: dict[str, AttributeDefinition] = {}
self._current_levels: dict[str, int] = {}
def generate_configuration(self, defaults: Optional[dict[str, any]] = None) -> ConfigT:
config_args = defaults.copy() if defaults is not None else {}
for attr in self._attributes.values():
if isinstance(attr, RangeAttributeDefinition):
vals = self.get_attr_value(attr.name)
config_args[attr.lower_field_name] = min(vals)
config_args[attr.upper_field_name] = max(vals)
elif isinstance(attr, ScalarAttributeDefinition):
val = self.get_attr_value(attr.name)
config_args[attr.field_name] = val
print(config_args)
return self._config_cls(**config_args)
@property
def attributes(self) -> dict[str, AttributeDefinition]:
"""Get the curriculum's attributes"""
return self._attributes
def get_attribute(self, attr_name: str) -> AttributeDefinition:
if attr_name not in self._attributes:
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
return self._attributes[attr_name]
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None:
for attr in attrs:
if attr.name in self.attributes:
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
self.attributes[attr.name] = attr
def get_attr_level(self, attr_name: str) -> int:
"""
Get the current level for an attribute.
Args:
attr_name: Name of the attribute
Returns:
Current level index for the attribute
"""
attr = self.get_attribute(attr_name)
return self._current_levels.get(attr_name, attr.default_level)
def get_attr_value(self, attr_name: str) -> Any:
"""
Get the current value for an attribute based on its level.
Args:
attr_name: Name of the attribute
Returns:
Current value for the attribute based on its level and type
"""
attr = self.get_attribute(attr_name)
level = self.get_attr_level(attr_name)
return attr.get_level_value(level, curriculum=self.name)
def set_attr_level(self, attr_name: str, level: int) -> None:
"""
Set the level for an attribute.
Args:
attr_name: Name of the attribute
level: New level index
"""
attr = self.get_attribute(attr_name)
attr.validate_level(level, curriculum=self.name)
self._current_levels[attr_name] = level
def increment_attr_level(self, attr_name: str) -> bool:
"""
Increment the level of an attribute if possible.
Args:
attr_name: Name of the attribute to increment
Returns:
bool: True if level was incremented, False if already at max level
Raises:
KeyError: If attribute doesn't exist
"""
attr = self.get_attribute(attr_name)
current_level = self.get_attr_level(attr_name)
if current_level < len(attr.levels) - 1:
self.set_attr_level(attr_name, current_level + 1)
return True
return False
def decrement_attr_level(self, attr_name: str) -> bool:
"""
Decrement the level of an attribute if possible.
Args:
attr_name: Name of the attribute to decrement
Returns:
bool: True if level was decremented, False if already at min level
Raises:
KeyError: If attribute doesn't exist
"""
current_level = self.get_attr_level(attr_name)
if current_level > 0:
self.set_attr_level(attr_name, current_level - 1)
return True
return False

View file

@ -8,7 +8,7 @@ from pathlib import Path
from statistics import mean, stdev from statistics import mean, stdev
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from .dataset import ProceduralDataset from ..dataset import ProceduralDataset
@dataclass @dataclass

View file

@ -5,6 +5,7 @@ Cognition tasks for training reasoning capabilities.
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
from .figlet_fonts import FigletFontConfig, FigletFontDataset from .figlet_fonts import FigletFontConfig, FigletFontDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
from .rectangle_count import RectangleCountConfig, RectangleCountDataset
from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset
__all__ = [ __all__ = [
@ -16,4 +17,6 @@ __all__ = [
"NumberSequenceDataset", "NumberSequenceDataset",
"RubiksCubeConfig", "RubiksCubeConfig",
"RubiksCubeDataset", "RubiksCubeDataset",
"RectangleCountConfig",
"RectangleCountDataset",
] ]

View file

@ -0,0 +1,135 @@
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
def draw_rectangles_with_overlap(n, width, height, rng):
# Create a grid that holds a count of how many times a cell is drawn.
grid = [[0 for _ in range(width)] for _ in range(height)]
rectangles = []
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
attempts = 0
while len(rectangles) < n and attempts < max_attempts:
attempts += 1
# Ensure minimum width and height of 3.
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
# Similarly, bottom must be at least top + 2.
left = rng.randint(0, width - 3)
right = rng.randint(left + 2, width - 1)
top = rng.randint(0, height - 3)
bottom = rng.randint(top + 2, height - 1)
# Prepare a list of all the cells that would be updated.
cells_to_update = []
# Top edge:
for col in range(left, right + 1):
cells_to_update.append((top, col))
# Bottom edge:
for col in range(left, right + 1):
cells_to_update.append((bottom, col))
# Left edge (excluding corners already drawn):
for row in range(top + 1, bottom):
cells_to_update.append((row, left))
# Right edge (excluding corners already drawn):
for row in range(top + 1, bottom):
cells_to_update.append((row, right))
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
conflict = False
for r, c in cells_to_update:
if grid[r][c] >= 2:
conflict = True
break
if conflict:
continue # Skip this rectangle candidate
# No conflict: update the grid counts.
for r, c in cells_to_update:
grid[r][c] += 1
# Save the rectangle (stored as (left, right, top, bottom)).
rectangles.append((left, right, top, bottom))
if len(rectangles) < n:
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
# Print the grid.
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
lines = ""
for row in grid:
line = "".join(" " if count == 0 else ("#" if count == 1 else "") for count in row)
lines = lines + line + "\n"
return lines, len(rectangles)
@dataclass
class RectangleCountConfig:
"""Configuration for RectangleCount puzzle generation"""
max_rectangles: int = 10
width: int = 80
height: int = 80
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert self.width >= 10, "width must be gte 10"
assert self.height >= 10, "height must be gte 10"
class RectangleCountDataset(ProceduralDataset):
"""Generates [RectangleCount Puzzles](https://en.wikipedia.org/wiki/RectangleCount_Puzzle) with configurable parameters"""
def __init__(self, config: RectangleCountConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single RectangleCount task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
target = rng.randint(1, self.config.max_rectangles)
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with ''. \n\n {puzzle}"
return {
"question": puzz,
"answer": str(answer),
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves the RectangleCount task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)

View file

@ -1,6 +1,7 @@
import pytest import pytest
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
def test_chain_sum_config_validation(): def test_chain_sum_config_validation():
@ -127,3 +128,30 @@ def test_chain_sum_iteration():
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items" assert first_items == second_items, "Multiple iterations should yield same items"
def test_chain_sum_curriculum():
curriculum = ChainSumCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ChainSumConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
# test incrementing attribute levels for num_terms & num_digits attributes
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
# test decrementing attribute level for num_digits again
curriculum.decrement_attr_level("num_digits")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3

View file

@ -0,0 +1,88 @@
"""Tests for Count Primes questions generation"""
import pytest
from reasoning_gym.algorithmic.count_primes import CountPrimesConfig, CountPrimesDataset
def test_count_primes_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CountPrimesConfig(max_n=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = CountPrimesConfig(max_n=0) # Zero not allowed
config.validate()
def test_count_primes_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CountPrimesConfig(seed=42, size=10)
dataset1 = CountPrimesDataset(config)
dataset2 = CountPrimesDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_count_primes_dataset_items():
"""Test basic properties of generated items"""
config = CountPrimesConfig(max_n=10, size=10, seed=42)
dataset = CountPrimesDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "start" in item["metadata"]
assert "end" in item["metadata"]
assert "primes" in item["metadata"]
assert "solution" in item["metadata"]
start = item["metadata"]["start"]
end = item["metadata"]["end"]
primes = item["metadata"]["primes"]
assert start <= end
assert len(primes) <= end - start + 1
def test_count_primes_dataset_iteration():
"""Test that iteration respects dataset size"""
config = CountPrimesConfig(size=5, seed=42)
dataset = CountPrimesDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_count_primes_answer():
"""Test the _get_primes method"""
config = CountPrimesConfig(seed=42)
dataset = CountPrimesDataset(config)
# Base cases
assert dataset._get_primes(n=0) == []
assert dataset._get_primes(n=1) == []
assert dataset._get_primes(n=2) == [False, False]
# Test primes up to 10
primes = dataset._get_primes(n=11)
assert primes[2] == True
assert primes[3] == True
assert primes[4] == False
assert primes[5] == True
assert primes[6] == False
assert primes[7] == True
assert primes[8] == False
assert primes[9] == False
assert primes[10] == False

35
tests/test_dice.py Normal file
View file

@ -0,0 +1,35 @@
import pytest
from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset
def test_dice():
"""Test basic properties and solution of generated items"""
config = DiceConfig(seed=42, size=50, num_dice=8, max_dice_size=24)
dataset = DiceDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# Easy
config = DiceConfig(seed=42, size=1, num_dice=1, max_dice_size=2)
dataset = DiceDataset(config)
for item in dataset:
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# Hard
config = DiceConfig(seed=42, size=1, num_dice=40, max_dice_size=40)
dataset = DiceDataset(config)
for item in dataset:
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

View file

@ -0,0 +1,19 @@
import pytest
from reasoning_gym.cognition.rectangle_count import RectangleCountConfig, RectangleCountDataset
def test_dice():
"""Test basic properties and solution of generated items"""
config = RectangleCountConfig(seed=42, size=50, max_rectangles=15, width=40, height=40)
dataset = RectangleCountDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0