diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index ad76199f..dd3c6758 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -25,6 +25,7 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from .count_primes import CountPrimesConfig, CountPrimesDataset __all__ = [ "SpellBackwardConfig", @@ -66,4 +67,6 @@ __all__ = [ "ManipulateMatrixDataset", "BinaryMatrixConfig", "BinaryMatrixDataset", + "CountPrimesConfig", + "CountPrimesDataset", ] diff --git a/reasoning_gym/algorithmic/count_primes.py b/reasoning_gym/algorithmic/count_primes.py new file mode 100644 index 00000000..3335371c --- /dev/null +++ b/reasoning_gym/algorithmic/count_primes.py @@ -0,0 +1,60 @@ +"""Count prime numbers in a given interval. + +Solution obtained with Sieve of Eratosthenes: +https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes +""" + +import math +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?""" + +@dataclass +class CountPrimesConfig: + """Configuration for Count Primes dataset generation""" + + max_n: int = 10_000 # Upper bound for the interval + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.max_n, "max_n must be at least 1" + +class CountPrimesDataset(ProceduralDataset): + """Generates Count Primes exercises with configurable difficulty""" + + def __init__(self, config: CountPrimesConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.primes = self._get_primes(config.max_n + 1) + + def _get_primes(self, n: int) -> list[bool]: + if n <= 1: + return [] + primes = [True] * n + primes[0] = primes[1] = False + for i in range(2, int(math.sqrt(n))+1): + if primes[i]: + for j in range(2*i, n, i): + primes[j] = False + return primes + + def __getitem__(self, idx: int) -> dict: + """Generate a single Count Primes question""" + rng = Random(self.seed + idx) + start = rng.randint(1, self.config.max_n) + end = rng.randint(start, self.config.max_n) + primes = self.primes[start:end+1] + answer = sum(primes) + return { + "question": QUESTION_TEMPLATE.format(start=start, end=end), + "answer": str(answer), + "metadata": {"start": start, "end": end, "primes": primes, "solution": answer}, + } + +register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig) diff --git a/tests/test_count_primes.py b/tests/test_count_primes.py new file mode 100644 index 00000000..d88981e9 --- /dev/null +++ b/tests/test_count_primes.py @@ -0,0 +1,89 @@ +"""Tests for Count Primes questions generation""" + +import pytest + +from reasoning_gym.algorithmic.count_primes import CountPrimesConfig, CountPrimesDataset + + +def test_count_primes_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CountPrimesConfig(max_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CountPrimesConfig(max_n=0) # Zero not allowed + config.validate() + + + +def test_count_primes_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = CountPrimesConfig(seed=42, size=10) + dataset1 = CountPrimesDataset(config) + dataset2 = CountPrimesDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_count_primes_dataset_items(): + """Test basic properties of generated items""" + config = CountPrimesConfig(max_n=10, size=10, seed=42) + dataset = CountPrimesDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "start" in item["metadata"] + assert "end" in item["metadata"] + assert "primes" in item["metadata"] + assert "solution" in item["metadata"] + + start = item["metadata"]["start"] + end = item["metadata"]["end"] + primes = item["metadata"]["primes"] + + assert start <= end + assert len(primes) <= end - start + 1 + + +def test_count_primes_dataset_iteration(): + """Test that iteration respects dataset size""" + config = CountPrimesConfig(size=5, seed=42) + dataset = CountPrimesDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_count_primes_answer(): + """Test the _get_primes method""" + config = CountPrimesConfig(seed=42) + dataset = CountPrimesDataset(config) + + # Base cases + assert dataset._get_primes(n=0) == [] + assert dataset._get_primes(n=1) == [] + assert dataset._get_primes(n=2) == [False, False] + + # Test primes up to 10 + primes = dataset._get_primes(n=11) + assert primes[2] == True + assert primes[3] == True + assert primes[4] == False + assert primes[5] == True + assert primes[6] == False + assert primes[7] == True + assert primes[8] == False + assert primes[9] == False + assert primes[10] == False \ No newline at end of file