diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index 3b708d47..b4e23932 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -5,7 +5,7 @@ Cognition tasks for training reasoning capabilities. from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset from .figlet_fonts import FigletFontConfig, FigletFontDataset from .modulo_grid import ModuloGridConfig, ModuloGridDataset -from .needle_haystack import NeedleHaystackConfig, NeedleHaystackDataset +from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset from .rectangle_count import RectangleCountConfig, RectangleCountCurriculum, RectangleCountDataset from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset @@ -26,6 +26,7 @@ __all__ = [ "RectangleCountDataset", "NeedleHaystackConfig", "NeedleHaystackDataset", + "NeedleHaystackCurriculum", "ModuloGridConfig", "ModuloGridDataset", ] diff --git a/reasoning_gym/cognition/needle_haystack.py b/reasoning_gym/cognition/needle_haystack.py index e5adf741..6fe4f55a 100644 --- a/reasoning_gym/cognition/needle_haystack.py +++ b/reasoning_gym/cognition/needle_haystack.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -9,14 +10,18 @@ from ..factory import ProceduralDataset, register_dataset class NeedleHaystackConfig: """Configuration for NeedleHaystack task generation""" - num_statements: int = 50 + min_num_statements: int = 10 + max_num_statements: int = 100 seed: Optional[int] = None size: int = 500 def validate(self) -> None: """Validate configuration parameters""" - assert self.num_statements > 0, "num_statements must be greater than 0" - assert self.num_statements < 168387000, f"num_statements must be less than {168387000}" + assert self.min_num_statements > 0, "min_num_statements must be greater than 0" + assert ( + self.max_num_statements >= self.min_num_statements + ), "max_num_statements must be greater than min_num_statements" + assert self.max_num_statements < 168387000, f"max_num_statements must be less than {168387000}" def generate_unique_triplets(names: list[str], verbs: list[str], subjects: list[str], n: int, rng) -> dict[str, Any]: @@ -85,7 +90,8 @@ class NeedleHaystackDataset(ProceduralDataset): rng = Random(self.seed + idx) - stack = generate_unique_triplets(NAMES, VERBS, SUBJECTS, self.config.num_statements, rng) + num_statements = rng.randint(self.config.min_num_statements, self.config.max_num_statements) + stack = generate_unique_triplets(NAMES, VERBS, SUBJECTS, num_statements, rng) stack_text = "" for triplet in stack["triplets"]: @@ -97,7 +103,7 @@ class NeedleHaystackDataset(ProceduralDataset): return { "question": full_text, "answer": stack["needle"][0], - "metadata": {"question": question}, + "metadata": {"question": question, "difficulty": {"num_statements": num_statements}}, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -123,5 +129,24 @@ class NeedleHaystackDataset(ProceduralDataset): return 0.0 +class NeedleHaystackCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(NeedleHaystackCurriculum.__name__, NeedleHaystackConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_statements", + levels=[10, 100, 1_000, 10_000, 100_000, 1_000_000, 168_386_000], + default_level=1, + description="Number of statements in the haystack", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_num_statements", + upper_field_name="max_num_statements", + ), + ) + + # Register the dataset -register_dataset("needle_haystack", NeedleHaystackDataset, NeedleHaystackConfig) +register_dataset("needle_haystack", NeedleHaystackDataset, NeedleHaystackConfig, NeedleHaystackCurriculum) diff --git a/tests/test_needle_haystack.py b/tests/test_needle_haystack.py index 672d16b3..845b7cd2 100644 --- a/tests/test_needle_haystack.py +++ b/tests/test_needle_haystack.py @@ -1,11 +1,15 @@ import pytest -from reasoning_gym.cognition.needle_haystack import NeedleHaystackConfig, NeedleHaystackDataset +from reasoning_gym.cognition.needle_haystack import ( + NeedleHaystackConfig, + NeedleHaystackCurriculum, + NeedleHaystackDataset, +) def test_needle_haystack(): """Test basic properties and solution of generated items""" - config = NeedleHaystackConfig(seed=42, size=50, num_statements=50) + config = NeedleHaystackConfig(seed=42, size=50, min_num_statements=50, max_num_statements=50) dataset = NeedleHaystackDataset(config) for item in dataset: @@ -19,7 +23,7 @@ def test_needle_haystack(): assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - config = NeedleHaystackConfig(seed=42, size=1, num_statements=500) + config = NeedleHaystackConfig(seed=42, size=1, min_num_statements=500, max_num_statements=500) dataset = NeedleHaystackDataset(config) for item in dataset: @@ -32,7 +36,7 @@ def test_needle_haystack(): assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - config = NeedleHaystackConfig(seed=42, size=1, num_statements=5000) + config = NeedleHaystackConfig(seed=42, size=1, min_num_statements=5000, max_num_statements=5000) dataset = NeedleHaystackDataset(config) for item in dataset: @@ -45,7 +49,7 @@ def test_needle_haystack(): assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - config = NeedleHaystackConfig(seed=42, size=1, num_statements=50000) + config = NeedleHaystackConfig(seed=42, size=1, min_num_statements=50000, max_num_statements=50000) dataset = NeedleHaystackDataset(config) for item in dataset: @@ -58,7 +62,7 @@ def test_needle_haystack(): assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - config = NeedleHaystackConfig(seed=42, size=1, num_statements=500000) + config = NeedleHaystackConfig(seed=42, size=1, min_num_statements=500000, max_num_statements=500000) dataset = NeedleHaystackDataset(config) for item in dataset: @@ -70,3 +74,24 @@ def test_needle_haystack(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + +def test_needle_haystack_curriculum(): + curriculum = NeedleHaystackCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: NeedleHaystackConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_num_statements == 10 and base_cfg.max_num_statements == 100 + + # test incrementing attribute levels + curriculum.increment_attr_level("num_statements") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_num_statements == 10 and increased_cfg.max_num_statements == 1000 + + # test decrementing attribute level for num_statements again + curriculum.decrement_attr_level("num_statements") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_num_statements == 10 and partially_decreased_cfg.max_num_statements == 100