diff --git a/reasoning_gym/algorithmic/count_primes.py b/reasoning_gym/algorithmic/count_primes.py index 0a553c7f..e505b91d 100644 --- a/reasoning_gym/algorithmic/count_primes.py +++ b/reasoning_gym/algorithmic/count_primes.py @@ -51,8 +51,8 @@ class CountPrimesDataset(ProceduralDataset): rng = Random(self.seed + idx) start = rng.randint(1, self.config.max_n) end = rng.randint(start, self.config.max_n) - primes = self.primes[start : end + 1] - answer = sum(primes) + primes = [i for i in range(start, end + 1) if self.primes[i]] + answer = len(primes) return { "question": QUESTION_TEMPLATE.format(start=start, end=end), "answer": str(answer), diff --git a/tests/test_count_primes.py b/tests/test_count_primes.py index f131647b..68c642f6 100644 --- a/tests/test_count_primes.py +++ b/tests/test_count_primes.py @@ -86,3 +86,18 @@ def test_count_primes_answer(): assert primes[8] == False assert primes[9] == False assert primes[10] == False + + +def test_count_primes_list(): + """Test that list of primes was correctly generated""" + config = CountPrimesConfig(max_n=100, size=100, seed=42) + dataset = CountPrimesDataset(config) + + for item in dataset: + start = item["metadata"]["start"] + end = item["metadata"]["end"] + primes = item["metadata"]["primes"] + for p in primes: + assert p >= start + assert p <= end + assert dataset.primes[p] == True