diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index c7c1cb45..8f167f09 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -16,7 +16,7 @@ from .cryptarithm import CryptarithmConfig, CryptarithmDataset from .game_of_life import GameOfLifeConfig, GameOfLifeDataset from .game_of_life_halting import GameOfLifeHaltingConfig, GameOfLifeHaltingDataset from .graph_color import GraphColorConfig, GraphColorDataset -from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset +from .group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset from .jugs import JugsConfig, JugsDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset @@ -76,6 +76,7 @@ __all__ = [ "PalindromeDataset", "GroupAnagramsConfig", "GroupAnagramsDataset", + "GroupAnagramsCurriculum", "PalindromePartitioningConfig", "PalindromePartitioningDataset", "SpiralMatrixConfig", diff --git a/reasoning_gym/algorithmic/group_anagrams.py b/reasoning_gym/algorithmic/group_anagrams.py index 6c17eae5..7fe200cc 100644 --- a/reasoning_gym/algorithmic/group_anagrams.py +++ b/reasoning_gym/algorithmic/group_anagrams.py @@ -12,11 +12,10 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset -MAX_ANAGRAM_GROUPS = 500 - QUESTION_TEMPLATE = """An anagram is a word formed by rearranging the letters of a different word, using all the original letters exactly once. Your job is to group the anagrams together. You can return the answer in any order. @@ -32,7 +31,9 @@ Group the following list of words into anagrams: class GroupAnagramsConfig: """Configuration for Group Anagrams dataset generation""" - anagram_groups: int = 10 # Groups of anagrams present in the input + min_anagram_groups: int = 2 # Minimum number of anagram groups present in the input + max_anagram_groups: int = 10 # Maximum number of anagram groups present in the input + min_words_per_group: int = 2 # Minimum number of words in a single anagram group max_words_per_group: int = 5 # Maximum number of words in a single anagram group size: int = 500 # Virtual dataset size @@ -40,10 +41,8 @@ class GroupAnagramsConfig: def validate(self): """Validate configuration parameters""" - assert ( - 1 <= self.anagram_groups <= MAX_ANAGRAM_GROUPS - ), f"anagram_groups must be between 1 and {MAX_ANAGRAM_GROUPS}" - assert 1 <= self.max_words_per_group, "max_words_per_group must be at least 1" + assert 2 <= self.min_anagram_groups <= self.max_anagram_groups, "Invalid number of anagram groups" + assert 2 <= self.min_words_per_group <= self.max_words_per_group, "Invalid number of words per group" class GroupAnagramsDataset(ProceduralDataset): @@ -54,11 +53,12 @@ class GroupAnagramsDataset(ProceduralDataset): with get_data_file_path("anagrams.jsonl").open() as f: self.anagrams = [json.loads(line)["words"] for line in f] - def _get_anagram_words(self, rng: Random) -> list[str]: + def _get_anagram_words(self, rng: Random, num_groups: int) -> list[str]: """Generate a list of words with anagrams""" words = [] - for sample in rng.sample(self.anagrams, self.config.anagram_groups): - anagrams = rng.sample(sample, rng.randint(1, min(len(sample), self.config.max_words_per_group))) + for sample in rng.sample(self.anagrams, num_groups): + num_words = min(len(sample), rng.randint(self.config.min_words_per_group, self.config.max_words_per_group)) + anagrams = rng.sample(sample, num_words) words.extend(anagrams) return words @@ -103,15 +103,54 @@ class GroupAnagramsDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single Group Anagrams question""" rng = Random(self.seed + idx) - words = self._get_anagram_words(rng) + + anagram_groups = min( + len(self.anagrams), rng.randint(self.config.min_anagram_groups, self.config.max_anagram_groups) + ) + words = self._get_anagram_words(rng, num_groups=anagram_groups) answer = self._group_anagrams(words) answer_str = json.dumps(answer) return { "question": QUESTION_TEMPLATE.format(words=json.dumps(words)), "answer": answer_str, - "metadata": {"words": words, "solution": answer}, + "metadata": { + "words": words, + "solution": answer, + "difficulty": { + "anagram_groups": anagram_groups, + }, + }, } -register_dataset("group_anagrams", GroupAnagramsDataset, GroupAnagramsConfig) +class GroupAnagramsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(GroupAnagramsCurriculum.__name__, GroupAnagramsConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="anagram_groups", + levels=[10, 100, 1_000, 10_000], + default_level=0, + description="Number of anagram groups in the input", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_anagram_groups", + upper_field_name="max_anagram_groups", + ), + RangeAttributeDefinition( + name="words_per_group", + levels=[2, 5, 10, 20], + default_level=0, + description="Number of words in a single anagram group", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_words_per_group", + upper_field_name="max_words_per_group", + ), + ) + + +register_dataset("group_anagrams", GroupAnagramsDataset, GroupAnagramsConfig, GroupAnagramsCurriculum) diff --git a/tests/test_group_anagrams.py b/tests/test_group_anagrams.py index bac22412..076184dd 100644 --- a/tests/test_group_anagrams.py +++ b/tests/test_group_anagrams.py @@ -4,27 +4,35 @@ import json import pytest -from reasoning_gym.algorithmic.group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset +from reasoning_gym.algorithmic.group_anagrams import GroupAnagramsConfig, GroupAnagramsCurriculum, GroupAnagramsDataset def test_group_anagrams_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = GroupAnagramsConfig(anagram_groups=-1) # Negative not allowed + config = GroupAnagramsConfig(min_anagram_groups=-1) # Negative not allowed config.validate() with pytest.raises(AssertionError): - config = GroupAnagramsConfig(anagram_groups=0) # Zero not allowed + config = GroupAnagramsConfig(min_anagram_groups=0) # Zero not allowed config.validate() with pytest.raises(AssertionError): - config = GroupAnagramsConfig(max_words_per_group=-1) # Negative not allowed + config = GroupAnagramsConfig(min_anagram_groups=5, max_anagram_groups=4) # Min > Max not allowed config.validate() with pytest.raises(AssertionError): - config = GroupAnagramsConfig(max_words_per_group=0) # Zero not allowed + config = GroupAnagramsConfig(min_words_per_group=-1) # Negative not allowed config.validate() + with pytest.raises(AssertionError): + config = GroupAnagramsConfig(min_words_per_group=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = GroupAnagramsConfig(min_words_per_group=5, max_words_per_group=4) # Min > Max not allowed + config.validate() # Min > Max not allowed + def test_group_anagrams_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -38,7 +46,7 @@ def test_group_anagrams_dataset_deterministic(): def test_group_anagrams_dataset_items(): """Test basic properties of generated items""" - config = GroupAnagramsConfig(anagram_groups=5, max_words_per_group=3, size=10, seed=42) + config = GroupAnagramsConfig(max_anagram_groups=5, max_words_per_group=3, size=10, seed=42) dataset = GroupAnagramsDataset(config) for i in range(len(dataset)): @@ -57,8 +65,8 @@ def test_group_anagrams_dataset_items(): solution = item["metadata"]["solution"] # Verify list dimensions - assert len(words) > 5 - assert len(solution) == 5 + assert len(words) >= len(solution) + assert len(solution) <= 5 assert all(len(group) <= 3 for group in solution) @@ -119,3 +127,28 @@ def test_group_anagrams_score_answer(): answer = '["ate", "eat", "tea"], ["bat"], ["nat", "tan"]' item = {"metadata": {"solution": [["ate", "eat", "tea"], ["bat"], ["nat", "tan"]]}} assert dataset.score_answer(answer, item) == 0 + + +def test_group_anagrams_curriculum(): + curriculum = GroupAnagramsCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: GroupAnagramsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_anagram_groups == 10 and base_cfg.max_anagram_groups == 10 + assert base_cfg.min_words_per_group == 2 and base_cfg.max_words_per_group == 2 + + # test incrementing attribute levels + curriculum.increment_attr_level("anagram_groups") + curriculum.increment_attr_level("words_per_group") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_anagram_groups == 10 and increased_cfg.max_anagram_groups == 100 + assert increased_cfg.min_words_per_group == 2 and increased_cfg.max_words_per_group == 5 + + # test decrementing attribute level partially + curriculum.decrement_attr_level("anagram_groups") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_anagram_groups == 10 and partially_decreased_cfg.max_anagram_groups == 10 + assert partially_decreased_cfg.min_words_per_group == 2 and partially_decreased_cfg.max_words_per_group == 5 diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py index e363f655..ed6b22bc 100644 --- a/tests/test_shortest_path.py +++ b/tests/test_shortest_path.py @@ -181,7 +181,7 @@ def test_shortest_path_answer(): assert dataset.score_answer(None, entry) == 0.0 -def test_chain_sum_curriculum(): +def test_shortest_path_curriculum(): curriculum = ShortestPathCurriculum() base_value = {"size": 150, "seed": 1}