Merge branch 'main' into rich/ab

This commit is contained in:
Andreas Köpf 2025-02-11 23:34:48 +01:00 committed by GitHub
commit 27938ce13a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 759 additions and 12 deletions

View file

@ -10,6 +10,7 @@ from .ab import ABConfig, ABDataset
from .base_conversion import BaseConversionConfig, BaseConversionDataset
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
from .count_primes import CountPrimesConfig, CountPrimesDataset
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset
@ -69,4 +70,6 @@ __all__ = [
"BinaryMatrixDataset",
"ABConfig",
"ABDataset",
"CountPrimesConfig",
"CountPrimesDataset",
]

View file

@ -0,0 +1,63 @@
"""Count prime numbers in a given interval.
Solution obtained with Sieve of Eratosthenes:
https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
"""
import math
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?"""
@dataclass
class CountPrimesConfig:
"""Configuration for Count Primes dataset generation"""
max_n: int = 10_000 # Upper bound for the interval
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
class CountPrimesDataset(ProceduralDataset):
"""Generates Count Primes exercises with configurable difficulty"""
def __init__(self, config: CountPrimesConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.primes = self._get_primes(config.max_n + 1)
def _get_primes(self, n: int) -> list[bool]:
if n <= 1:
return []
primes = [True] * n
primes[0] = primes[1] = False
for i in range(2, int(math.sqrt(n)) + 1):
if primes[i]:
for j in range(2 * i, n, i):
primes[j] = False
return primes
def __getitem__(self, idx: int) -> dict:
"""Generate a single Count Primes question"""
rng = Random(self.seed + idx)
start = rng.randint(1, self.config.max_n)
end = rng.randint(start, self.config.max_n)
primes = self.primes[start : end + 1]
answer = sum(primes)
return {
"question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": str(answer),
"metadata": {"start": start, "end": end, "primes": primes, "solution": answer},
}
register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig)

View file

@ -60,22 +60,16 @@ class RotateMatrixDataset(ProceduralDataset):
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
return matrix
def _rot90(self, matrix: list[list[int]]) -> list[list[int]]:
"""quarter clockwise rotation"""
return [list(row) for row in zip(*matrix[::-1])]
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
"""Rotate the matrix K times by 90 degrees clockwise"""
num_rotations %= 4
n = len(matrix)
output = deepcopy(matrix)
for _ in range(num_rotations):
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
output[n - 1 - i][l],
output[l][i],
output[i][n - 1 - l],
output[n - 1 - l][n - 1 - i],
)
output = self._rot90(output)
return output
def _matrix_to_str(self, matrix: list[list[int]]) -> str:

View file

@ -6,6 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSum, ChainSumConfig
from .count_bits import CountBitsConfig, CountBitsDataset
from .dice import DiceConfig, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
@ -38,4 +39,6 @@ __all__ = [
"TimeIntervalsDataset",
"CountBitsConfig",
"CountBitsDataset",
"DiceConfig",
"DiceDataset",
]

View file

@ -2,6 +2,7 @@ import random
from dataclasses import dataclass
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -112,5 +113,36 @@ class ChainSum(ProceduralDataset):
return expression, result
class ChainSumCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ChainSumCurriculum.__name__, ChainSumConfig)
# Define attributes
self._define_attributes(
(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2, # Ensure at least 2 terms
lower_field_name="min_terms",
upper_field_name="max_terms",
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 4, 10],
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure numbers are at least 1 digit
lower_field_name="min_digits",
upper_field_name="max_digits",
),
)
)
# Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig)

View file

@ -0,0 +1,149 @@
from dataclasses import dataclass
from functools import reduce
from math import gcd
from random import Random
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
def compute_probability(dice, target):
"""
Computes the probability of rolling a total of at least `target`
when rolling dice specified in the list `dice`. Each element in dice
is the number of sides on that die. The computation is done via dynamic programming.
Returns the probability as a fraction (numerator, denominator) and as a float.
"""
# dp[i][s] = number of ways to get sum s using the first i dice.
# We use only one dictionary for the current dp state.
dp = {0: 1}
for sides in dice:
new_dp = {}
for current_sum, count in dp.items():
# Each die gives a number from 1 to sides.
for face in range(1, sides + 1):
new_sum = current_sum + face
new_dp[new_sum] = new_dp.get(new_sum, 0) + count
dp = new_dp
total_outcomes = reduce(lambda a, b: a * b, dice, 1)
ways = sum(count for s, count in dp.items() if s >= target)
# Simplify the fraction (ways / total_outcomes)
def simplify(n, d):
common = gcd(n, d)
return n // common, d // common
frac = simplify(ways, total_outcomes)
return frac, ways / total_outcomes
def generate_puzzle(num_dice, max_dice_size, rng):
"""
Generates a puzzle:
- It forces one die to have max_dice_size.
- The other (num_dice-1) dice are chosen randomly between 2 and max_dice_size-1.
- The dice are then shuffled.
- The target total is chosen roughly in the middle (but you can adjust the method).
It then computes the probability of rolling a total at least the target.
Finally, it prints out the puzzle statement and the answer.
"""
# Guarantee one die is the maximum.
dice = [max_dice_size]
for _ in range(num_dice - 1):
# Choose a die size randomly from 2 up to max_dice_size-1.
# (If max_dice_size == 2 then all dice are 2-sided.)
if max_dice_size > 2:
die = rng.randint(2, max_dice_size - 1)
else:
die = 2
dice.append(die)
# Optionally, sort dice in descending order (as is common in puzzles)
dice.sort(reverse=True)
# Compute minimum and maximum possible totals.
min_total = num_dice # each die gives at least 1
max_total = sum(dice)
# Choose a target total. For an interesting puzzle,
# we choose a target somewhere in the middle third of the range.
low_target = min_total + (max_total - min_total) // 3
high_target = min_total + 2 * (max_total - min_total) // 3
target = rng.randint(low_target, high_target)
# Compute probability.
(num, den), prob = compute_probability(dice, target)
# Create a string representing the dice, e.g., "1d20, 1d17, 1d6" etc.
dice_str = ", ".join(f"1d{s}" for s in dice)
# Return the puzzle.
return {"dice_str": dice_str, "target": target, "num": num, "den": den}
@dataclass
class DiceConfig:
"""Configuration for dice puzzle generation"""
num_dice: int = 4
max_dice_size: int = 20
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert self.num_dice >= 1, "num_dice must be gte 1"
assert self.max_dice_size >= 2, "max_dice_size must be gte 2"
class DiceDataset(ProceduralDataset):
"""Generates Dice-based puzzles with configurable parameters"""
def __init__(self, config: DiceConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single Dice task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
puzzle = generate_puzzle(self.config.num_dice, self.config.max_dice_size, rng)
puzzle_str = f"I have these dice: {puzzle['dice_str']}. What are the odds of rolling {puzzle['target']} or higher? (Assume that all dice are rolled at once, and that '1d6' represents one roll of a 6-sided dice.) Please respond with a reduced fraction representing the probability [ex., 1/60]."
answer_str = f"{puzzle['num']}/{puzzle['den']}"
return {
"question": puzzle_str,
"answer": answer_str,
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves the Dice task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
register_dataset("dice", DiceDataset, DiceConfig)

View file

@ -0,0 +1,14 @@
from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition
from .base_curriculum import BaseCurriculum
from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats
__all__ = [
"AttributeType",
"AttributeDefinition",
"RangeAttributeDefinition",
"BaseCurriculum",
"Coach",
"ScoreBoard",
"GroupedScores",
"ScoreStats",
]

View file

@ -0,0 +1,73 @@
from collections import abc
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Optional
class AttributeType(StrEnum):
"""Defines how attribute levels should be interpreted"""
STATIC = "static" # Each level is independent
UBOUND = "ubound" # Each level is an upper bound
APPEND = "append" # Each level includes all previous levels
@dataclass(kw_only=True)
class AttributeDefinition:
name: str
levels: list
default_level: int
description: Optional[str] = None
attr_type: AttributeType = AttributeType.STATIC # Default to static
min_value: Optional[int | float] = None # Minimum value for numeric attributes
def validate_level(self, level: int, curriculum: str) -> None:
"""
Validate that a level is valid for an attribute.
Args:
level: Level to validate
curriculum: Name of the curriculum
Raises:
ValueError: If level is invalid
"""
# TODO: if > set as [-1], if <0 set as [0]
if not 0 <= level < len(self.levels):
raise ValueError(
f"Invalid level: {level} for attribute '{curriculum}.{self.name}'. "
f"Must be between 0 and {len(self.levels)-1}"
)
def get_level_value(self, level: int, curriculum: str) -> Any:
"""
Get the value for an attribute at a specific level based on its type.
Args:
attr: The attribute definition
level: Level to get value for
Returns:
Value for the attribute based on its level and type
"""
if self.attr_type == AttributeType.STATIC:
return self.levels[level]
elif self.attr_type == AttributeType.UBOUND:
return self.levels[level]
elif self.attr_type == AttributeType.APPEND:
return self.levels[: level + 1]
raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'")
@dataclass(kw_only=True)
class ScalarAttributeDefinition(AttributeDefinition):
field_name: str
@dataclass(kw_only=True)
class RangeAttributeDefinition(AttributeDefinition):
lower_field_name: str
upper_field_name: str
def get_level_value(self, level: int, curriculum: str) -> Any:
v = super().get_level_value(level, curriculum)
if not isinstance(v, abc.Iterable):
return [v]
return v

View file

@ -0,0 +1,108 @@
from typing import Any, Iterable, Optional
from ..factory import ConfigT
from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition
class BaseCurriculum:
def __init__(self, name: str, config_cls: ConfigT):
self.name = name
self._config_cls = config_cls
self._attributes: dict[str, AttributeDefinition] = {}
self._current_levels: dict[str, int] = {}
def generate_configuration(self, defaults: Optional[dict[str, any]] = None) -> ConfigT:
config_args = defaults.copy() if defaults is not None else {}
for attr in self._attributes.values():
if isinstance(attr, RangeAttributeDefinition):
vals = self.get_attr_value(attr.name)
config_args[attr.lower_field_name] = min(vals)
config_args[attr.upper_field_name] = max(vals)
elif isinstance(attr, ScalarAttributeDefinition):
val = self.get_attr_value(attr.name)
config_args[attr.field_name] = val
print(config_args)
return self._config_cls(**config_args)
@property
def attributes(self) -> dict[str, AttributeDefinition]:
"""Get the curriculum's attributes"""
return self._attributes
def get_attribute(self, attr_name: str) -> AttributeDefinition:
if attr_name not in self._attributes:
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
return self._attributes[attr_name]
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None:
for attr in attrs:
if attr.name in self.attributes:
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
self.attributes[attr.name] = attr
def get_attr_level(self, attr_name: str) -> int:
"""
Get the current level for an attribute.
Args:
attr_name: Name of the attribute
Returns:
Current level index for the attribute
"""
attr = self.get_attribute(attr_name)
return self._current_levels.get(attr_name, attr.default_level)
def get_attr_value(self, attr_name: str) -> Any:
"""
Get the current value for an attribute based on its level.
Args:
attr_name: Name of the attribute
Returns:
Current value for the attribute based on its level and type
"""
attr = self.get_attribute(attr_name)
level = self.get_attr_level(attr_name)
return attr.get_level_value(level, curriculum=self.name)
def set_attr_level(self, attr_name: str, level: int) -> None:
"""
Set the level for an attribute.
Args:
attr_name: Name of the attribute
level: New level index
"""
attr = self.get_attribute(attr_name)
attr.validate_level(level, curriculum=self.name)
self._current_levels[attr_name] = level
def increment_attr_level(self, attr_name: str) -> bool:
"""
Increment the level of an attribute if possible.
Args:
attr_name: Name of the attribute to increment
Returns:
bool: True if level was incremented, False if already at max level
Raises:
KeyError: If attribute doesn't exist
"""
attr = self.get_attribute(attr_name)
current_level = self.get_attr_level(attr_name)
if current_level < len(attr.levels) - 1:
self.set_attr_level(attr_name, current_level + 1)
return True
return False
def decrement_attr_level(self, attr_name: str) -> bool:
"""
Decrement the level of an attribute if possible.
Args:
attr_name: Name of the attribute to decrement
Returns:
bool: True if level was decremented, False if already at min level
Raises:
KeyError: If attribute doesn't exist
"""
current_level = self.get_attr_level(attr_name)
if current_level > 0:
self.set_attr_level(attr_name, current_level - 1)
return True
return False

View file

@ -8,7 +8,7 @@ from pathlib import Path
from statistics import mean, stdev
from typing import Any, Dict, List, Optional, Tuple, Union
from .dataset import ProceduralDataset
from ..dataset import ProceduralDataset
@dataclass

View file

@ -5,6 +5,7 @@ Cognition tasks for training reasoning capabilities.
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
from .figlet_fonts import FigletFontConfig, FigletFontDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
from .rectangle_count import RectangleCountConfig, RectangleCountDataset
from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset
__all__ = [
@ -16,4 +17,6 @@ __all__ = [
"NumberSequenceDataset",
"RubiksCubeConfig",
"RubiksCubeDataset",
"RectangleCountConfig",
"RectangleCountDataset",
]

View file

@ -0,0 +1,135 @@
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
def draw_rectangles_with_overlap(n, width, height, rng):
# Create a grid that holds a count of how many times a cell is drawn.
grid = [[0 for _ in range(width)] for _ in range(height)]
rectangles = []
max_attempts = 100000 # Prevent infinite loops in case of a crowded grid
attempts = 0
while len(rectangles) < n and attempts < max_attempts:
attempts += 1
# Ensure minimum width and height of 3.
# For a rectangle to be at least 3 cells wide, right must be at least left + 2.
# Similarly, bottom must be at least top + 2.
left = rng.randint(0, width - 3)
right = rng.randint(left + 2, width - 1)
top = rng.randint(0, height - 3)
bottom = rng.randint(top + 2, height - 1)
# Prepare a list of all the cells that would be updated.
cells_to_update = []
# Top edge:
for col in range(left, right + 1):
cells_to_update.append((top, col))
# Bottom edge:
for col in range(left, right + 1):
cells_to_update.append((bottom, col))
# Left edge (excluding corners already drawn):
for row in range(top + 1, bottom):
cells_to_update.append((row, left))
# Right edge (excluding corners already drawn):
for row in range(top + 1, bottom):
cells_to_update.append((row, right))
# Check if drawing this rectangle would cause any cell to exceed a count of 2.
conflict = False
for r, c in cells_to_update:
if grid[r][c] >= 2:
conflict = True
break
if conflict:
continue # Skip this rectangle candidate
# No conflict: update the grid counts.
for r, c in cells_to_update:
grid[r][c] += 1
# Save the rectangle (stored as (left, right, top, bottom)).
rectangles.append((left, right, top, bottom))
if len(rectangles) < n:
print(f"Only placed {len(rectangles)} rectangles after {attempts} attempts.")
# Print the grid.
# Use ' ' for an untouched cell, '#' for a single hit, and '█' for exactly two hits.
lines = ""
for row in grid:
line = "".join(" " if count == 0 else ("#" if count == 1 else "") for count in row)
lines = lines + line + "\n"
return lines, len(rectangles)
@dataclass
class RectangleCountConfig:
"""Configuration for RectangleCount puzzle generation"""
max_rectangles: int = 10
width: int = 80
height: int = 80
seed: Optional[int] = None
size: int = 500
def validate(self):
"""Validate configuration parameters"""
assert self.width >= 10, "width must be gte 10"
assert self.height >= 10, "height must be gte 10"
class RectangleCountDataset(ProceduralDataset):
"""Generates [RectangleCount Puzzles](https://en.wikipedia.org/wiki/RectangleCount_Puzzle) with configurable parameters"""
def __init__(self, config: RectangleCountConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single RectangleCount task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
target = rng.randint(1, self.config.max_rectangles)
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with ''. \n\n {puzzle}"
return {
"question": puzz,
"answer": str(answer),
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves the RectangleCount task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
return 0.01
else:
return 1.0 # Yay
register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig)

View file

@ -1,6 +1,7 @@
import pytest
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
def test_chain_sum_config_validation():
@ -127,3 +128,30 @@ def test_chain_sum_iteration():
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items"
def test_chain_sum_curriculum():
curriculum = ChainSumCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ChainSumConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
# test incrementing attribute levels for num_terms & num_digits attributes
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
# test decrementing attribute level for num_digits again
curriculum.decrement_attr_level("num_digits")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3

View file

@ -0,0 +1,88 @@
"""Tests for Count Primes questions generation"""
import pytest
from reasoning_gym.algorithmic.count_primes import CountPrimesConfig, CountPrimesDataset
def test_count_primes_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CountPrimesConfig(max_n=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = CountPrimesConfig(max_n=0) # Zero not allowed
config.validate()
def test_count_primes_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CountPrimesConfig(seed=42, size=10)
dataset1 = CountPrimesDataset(config)
dataset2 = CountPrimesDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_count_primes_dataset_items():
"""Test basic properties of generated items"""
config = CountPrimesConfig(max_n=10, size=10, seed=42)
dataset = CountPrimesDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "start" in item["metadata"]
assert "end" in item["metadata"]
assert "primes" in item["metadata"]
assert "solution" in item["metadata"]
start = item["metadata"]["start"]
end = item["metadata"]["end"]
primes = item["metadata"]["primes"]
assert start <= end
assert len(primes) <= end - start + 1
def test_count_primes_dataset_iteration():
"""Test that iteration respects dataset size"""
config = CountPrimesConfig(size=5, seed=42)
dataset = CountPrimesDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_count_primes_answer():
"""Test the _get_primes method"""
config = CountPrimesConfig(seed=42)
dataset = CountPrimesDataset(config)
# Base cases
assert dataset._get_primes(n=0) == []
assert dataset._get_primes(n=1) == []
assert dataset._get_primes(n=2) == [False, False]
# Test primes up to 10
primes = dataset._get_primes(n=11)
assert primes[2] == True
assert primes[3] == True
assert primes[4] == False
assert primes[5] == True
assert primes[6] == False
assert primes[7] == True
assert primes[8] == False
assert primes[9] == False
assert primes[10] == False

35
tests/test_dice.py Normal file
View file

@ -0,0 +1,35 @@
import pytest
from reasoning_gym.arithmetic.dice import DiceConfig, DiceDataset
def test_dice():
"""Test basic properties and solution of generated items"""
config = DiceConfig(seed=42, size=50, num_dice=8, max_dice_size=24)
dataset = DiceDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# Easy
config = DiceConfig(seed=42, size=1, num_dice=1, max_dice_size=2)
dataset = DiceDataset(config)
for item in dataset:
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0
# Hard
config = DiceConfig(seed=42, size=1, num_dice=40, max_dice_size=40)
dataset = DiceDataset(config)
for item in dataset:
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

View file

@ -0,0 +1,19 @@
import pytest
from reasoning_gym.cognition.rectangle_count import RectangleCountConfig, RectangleCountDataset
def test_dice():
"""Test basic properties and solution of generated items"""
config = RectangleCountConfig(seed=42, size=50, max_rectangles=15, width=40, height=40)
dataset = RectangleCountDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0