diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index dd3c6758..bf17ebe8 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -9,6 +9,7 @@ Algorithmic tasks for training reasoning capabilities: from .base_conversion import BaseConversionConfig, BaseConversionDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset +from .count_primes import CountPrimesConfig, CountPrimesDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset @@ -25,7 +26,6 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .count_primes import CountPrimesConfig, CountPrimesDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/count_primes.py b/reasoning_gym/algorithmic/count_primes.py index 3335371c..0a553c7f 100644 --- a/reasoning_gym/algorithmic/count_primes.py +++ b/reasoning_gym/algorithmic/count_primes.py @@ -13,6 +13,7 @@ from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Count how many prime numbers there are between {start} and {end} (inclusive) ?""" + @dataclass class CountPrimesConfig: """Configuration for Count Primes dataset generation""" @@ -26,21 +27,22 @@ class CountPrimesConfig: """Validate configuration parameters""" assert 1 <= self.max_n, "max_n must be at least 1" + class CountPrimesDataset(ProceduralDataset): """Generates Count Primes exercises with configurable difficulty""" def __init__(self, config: CountPrimesConfig): super().__init__(config=config, seed=config.seed, size=config.size) self.primes = self._get_primes(config.max_n + 1) - + def _get_primes(self, n: int) -> list[bool]: if n <= 1: return [] primes = [True] * n primes[0] = primes[1] = False - for i in range(2, int(math.sqrt(n))+1): + for i in range(2, int(math.sqrt(n)) + 1): if primes[i]: - for j in range(2*i, n, i): + for j in range(2 * i, n, i): primes[j] = False return primes @@ -49,7 +51,7 @@ class CountPrimesDataset(ProceduralDataset): rng = Random(self.seed + idx) start = rng.randint(1, self.config.max_n) end = rng.randint(start, self.config.max_n) - primes = self.primes[start:end+1] + primes = self.primes[start : end + 1] answer = sum(primes) return { "question": QUESTION_TEMPLATE.format(start=start, end=end), @@ -57,4 +59,5 @@ class CountPrimesDataset(ProceduralDataset): "metadata": {"start": start, "end": end, "primes": primes, "solution": answer}, } + register_dataset("count_primes", CountPrimesDataset, CountPrimesConfig) diff --git a/tests/test_count_primes.py b/tests/test_count_primes.py index d88981e9..f131647b 100644 --- a/tests/test_count_primes.py +++ b/tests/test_count_primes.py @@ -16,7 +16,6 @@ def test_count_primes_config_validation(): config.validate() - def test_count_primes_dataset_deterministic(): """Test that dataset generates same items with same seed""" config = CountPrimesConfig(seed=42, size=10) @@ -86,4 +85,4 @@ def test_count_primes_answer(): assert primes[7] == True assert primes[8] == False assert primes[9] == False - assert primes[10] == False \ No newline at end of file + assert primes[10] == False