diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 72e589dd..2cc35e61 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -31,7 +31,7 @@ from .palindrome_partitioning import ( PalindromePartitioningDataset, ) from .pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDataset -from .ransom_note import RansomNoteConfig, RansomNoteDataset +from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset @@ -90,6 +90,7 @@ __all__ = [ "SpiralMatrixCurriculum", "RansomNoteConfig", "RansomNoteDataset", + "RansomNoteCurriculum", "IsomorphicStringsConfig", "IsomorphicStringsDataset", "IsomorphicStringsCurriculum", diff --git a/reasoning_gym/algorithmic/ransom_note.py b/reasoning_gym/algorithmic/ransom_note.py index cf163467..108c5b64 100644 --- a/reasoning_gym/algorithmic/ransom_note.py +++ b/reasoning_gym/algorithmic/ransom_note.py @@ -7,13 +7,11 @@ https://leetcode.com/problems/ransom-note/description/ from collections import defaultdict from dataclasses import dataclass from random import Random -from typing import Any, Optional +from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset -MAX_NOTE_LENGTH = 100_000 -MAX_MAGAZINE_LENGTH = 100_001 - QUESTION_TEMPLATE = """Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise. Each letter in the magazine string can only be used once in your ransom note. @@ -27,7 +25,9 @@ Magazine: {magazine} class RansomNoteConfig: """Configuration for Ransom Note dataset generation""" + min_note_length: int = 1 # Minimum length of the ransom note max_note_length: int = 10 # Maximum length of the ransom note + min_magazine_length: int = 2 # Minimum length of the magazine max_magazine_length: int = 30 # Maximum length of the magazine p_solvable: float = 0.5 # Probability that the ransom note can be constructed @@ -36,10 +36,15 @@ class RansomNoteConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.max_note_length <= MAX_NOTE_LENGTH, "max_note_length must be between 1 and MAX_NOTE_LENGTH" + # assert 1 <= self.max_note_length <= MAX_NOTE_LENGTH, "max_note_length must be between 1 and MAX_NOTE_LENGTH" + assert 1 <= self.min_note_length, "min_note_length must be at least 1" assert ( - 2 <= self.max_magazine_length <= MAX_MAGAZINE_LENGTH - ), "max_magazine_length must be between 2 and MAX_MAGAZINE_LENGTH" + self.min_note_length <= self.max_note_length + ), "min_note_length must be less than or equal to max_note_length" + assert 2 <= self.min_magazine_length, "min_magazine_length must be at least 2" + assert ( + self.min_magazine_length <= self.max_magazine_length + ), "min_magazine_length must be less than or equal to max_magazine_length" assert self.max_note_length < self.max_magazine_length, "max_note_length must be less than max_magazine_length" assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1" @@ -51,22 +56,18 @@ class RansomNoteDataset(ProceduralDataset): super().__init__(config=config, seed=config.seed, size=config.size) self.letters = {chr(i) for i in range(ord("a"), ord("z") + 1)} - def _get_inputs(self, rng: Random, solvable: bool) -> tuple[str, str]: + def _get_inputs(self, rng: Random, note_length: int, magazine_length: int, solvable: bool) -> tuple[str, str]: """Generate random ransom note and magazine""" - ransom_note_len = rng.randint(1, self.config.max_note_length) - ransom_note = [rng.choice(list(self.letters)) for _ in range(ransom_note_len)] - - magazine_len = rng.randint(ransom_note_len, self.config.max_magazine_length) + ransom_note = [rng.choice(list(self.letters)) for _ in range(note_length)] magazine = ransom_note.copy() if solvable: - magazine.extend([rng.choice(list(self.letters)) for _ in range(magazine_len - ransom_note_len)]) + magazine.extend([rng.choice(list(self.letters)) for _ in range(magazine_length - note_length)]) else: remove_letter = rng.choice(magazine) magazine.remove(remove_letter) magazine.extend( - [rng.choice(list(self.letters - {remove_letter})) for _ in range(magazine_len - ransom_note_len + 1)] + [rng.choice(list(self.letters - {remove_letter})) for _ in range(magazine_length - note_length + 1)] ) - rng.shuffle(ransom_note) rng.shuffle(magazine) return "".join(ransom_note), "".join(magazine) @@ -85,35 +86,58 @@ class RansomNoteDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single Group Anagrams question""" rng = Random(self.seed + idx) + + note_length = rng.randint(self.config.min_note_length, self.config.max_note_length) + magazine_length = rng.randint( + max(note_length, self.config.min_magazine_length), self.config.max_magazine_length + ) solvable = rng.random() < self.config.p_solvable - ransom_note, magazine = self._get_inputs(rng, solvable) + ransom_note, magazine = self._get_inputs(rng, note_length, magazine_length, solvable) answer = self._can_construct(ransom_note, magazine) return { "question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine), "answer": str(answer), - "metadata": {"ransom_note": ransom_note, "magazine": magazine, "solution": answer, "solvable": solvable}, + "metadata": { + "ransom_note": ransom_note, + "magazine": magazine, + "solution": answer, + "solvable": solvable, + "difficulty": { + "note_length": note_length, + "magazine_length": magazine_length, + }, + }, } - def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - """Determine if the solution provided solves this task. - The function awards 1.0 for a correct answer. +class RansomNoteCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(RansomNoteCurriculum.__name__, RansomNoteConfig) - Args: - answer (Optional[str]): The user's answer. - entry (dict[str, Any]): The original dataset entry containing the correct answer. - - Returns: - float: The computed score between 0.0 and 1.0. - """ - - if isinstance(answer, str): - s_answer = answer.strip() - if s_answer == str(entry["answer"]): - return 1.0 - - return 0.0 + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="note_length", + levels=[10, 50, 100, 500], + default_level=0, + description="Length of the ransom note", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_note_length", + upper_field_name="max_note_length", + ), + RangeAttributeDefinition( + name="magazine_length", + levels=[50, 100, 500, 1000], + default_level=0, + description="Length of the magazine", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_magazine_length", + upper_field_name="max_magazine_length", + ), + ) -register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig) +register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum) diff --git a/tests/test_ransom_note.py b/tests/test_ransom_note.py index bf340424..609711f8 100644 --- a/tests/test_ransom_note.py +++ b/tests/test_ransom_note.py @@ -4,34 +4,34 @@ import json import pytest -from reasoning_gym.algorithmic.ransom_note import RansomNoteConfig, RansomNoteDataset +from reasoning_gym.algorithmic.ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset def test_ransom_note_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = RansomNoteConfig(max_note_length=-1) # Negative not allowed - config.validate() - - with pytest.raises(AssertionError): - config = RansomNoteConfig(max_note_length=0) # Zero not allowed - config.validate() - - with pytest.raises(AssertionError): - config = RansomNoteConfig(max_magazine_length=-1) # Negative not allowed - config.validate() - - with pytest.raises(AssertionError): - config = RansomNoteConfig(max_magazine_length=0) # Zero not allowed - config.validate() - - with pytest.raises(AssertionError): - config = RansomNoteConfig(max_magazine_length=1) # One not allowed + config = RansomNoteConfig(min_note_length=0) # min_note_length must be at least 1 config.validate() with pytest.raises(AssertionError): config = RansomNoteConfig( - max_note_length=3, max_magazine_length=2 + min_note_length=5, max_note_length=4 + ) # min_note_length must be less than or equal to max_note_length + config.validate() + + with pytest.raises(AssertionError): + config = RansomNoteConfig(min_magazine_length=1) # min_magazine_length must be at least 2 + config.validate() + + with pytest.raises(AssertionError): + config = RansomNoteConfig( + min_magazine_length=5, max_magazine_length=4 + ) # min_magazine_length must be less than or equal to max_magazine_length + config.validate() + + with pytest.raises(AssertionError): + config = RansomNoteConfig( + max_note_length=5, max_magazine_length=5 ) # max_note_length must be less than max_magazine_length config.validate() @@ -56,7 +56,9 @@ def test_ransom_note_dataset_deterministic(): def test_group_anagrams_dataset_items(): """Test basic properties of generated items""" - config = RansomNoteConfig(max_note_length=10, max_magazine_length=30, size=10, seed=42) + config = RansomNoteConfig( + min_note_length=1, max_note_length=10, min_magazine_length=2, max_magazine_length=30, size=10, seed=42 + ) dataset = RansomNoteDataset(config) for i in range(len(dataset)): @@ -114,3 +116,28 @@ def test_ransom_note_answer(): # Inorrect solution ransom_note, magazine = "az", "badhergh" assert dataset._can_construct(ransom_note, magazine) == False + + +def test_ransom_note_curriculum(): + curriculum = RansomNoteCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: RansomNoteConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_note_length == 10 and base_cfg.max_note_length == 10 + assert base_cfg.min_magazine_length == 50 and base_cfg.max_magazine_length == 50 + + # test incrementing attribute levels + curriculum.increment_attr_level("note_length") + curriculum.increment_attr_level("magazine_length") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_note_length == 10 and increased_cfg.max_note_length == 50 + assert increased_cfg.min_magazine_length == 50 and increased_cfg.max_magazine_length == 100 + + # test decrementing attribute level for note_length again + curriculum.decrement_attr_level("note_length") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_note_length == 10 and partially_decreased_cfg.max_note_length == 10 + assert partially_decreased_cfg.min_magazine_length == 50 and partially_decreased_cfg.max_magazine_length == 100