sentence reordering curriculum (#326)

This commit is contained in:
Zafir Stojanovski 2025-03-11 00:21:41 +01:00 committed by GitHub
parent 3c39cbda40
commit a23c8c3d4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 4 deletions

View file

@ -34,7 +34,7 @@ from .pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDatas
from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset
from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingCurriculum, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset
@ -77,6 +77,7 @@ __all__ = [
"NumberSortingCurriculum", "NumberSortingCurriculum",
"SentenceReorderingConfig", "SentenceReorderingConfig",
"SentenceReorderingDataset", "SentenceReorderingDataset",
"SentenceReorderingCurriculum",
"WordSequenceReversalConfig", "WordSequenceReversalConfig",
"WordSequenceReversalDataset", "WordSequenceReversalDataset",
"WordSequenceReversalCurriculum", "WordSequenceReversalCurriculum",

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -89,7 +90,7 @@ class SentenceReorderingDataset(ProceduralDataset):
return { return {
"question": f"Restore the correct order of words in the following sentence: {question}", "question": f"Restore the correct order of words in the following sentence: {question}",
"answer": solved_sentence, "answer": solved_sentence,
"metadata": {"word_count": word_count}, "metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -114,4 +115,25 @@ class SentenceReorderingDataset(ProceduralDataset):
return reward return reward
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) class SentenceReorderingCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SentenceReorderingCurriculum.__name__, SentenceReorderingConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="words_in_sentence",
levels=[5, 20, 50, 100],
default_level=1,
description="Number of words in the sentence",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_words_in_sentence",
upper_field_name="max_words_in_sentence",
),
)
register_dataset(
"sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum
)

View file

@ -1,6 +1,10 @@
import pytest import pytest
from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from reasoning_gym.algorithmic.sentence_reordering import (
SentenceReorderingConfig,
SentenceReorderingCurriculum,
SentenceReorderingDataset,
)
@pytest.fixture @pytest.fixture
@ -49,3 +53,24 @@ def test_key_error_in_getitem(dataset):
with pytest.raises(KeyError): with pytest.raises(KeyError):
dataset[0] dataset[0]
def test_sentence_reordering_curriculum():
curriculum = SentenceReorderingCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: SentenceReorderingConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_words_in_sentence == 5 and base_cfg.max_words_in_sentence == 20
# test incrementing attribute levels
curriculum.increment_attr_level("words_in_sentence")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_words_in_sentence == 5 and increased_cfg.max_words_in_sentence == 50
# test decrementing attribute level
curriculum.decrement_attr_level("words_in_sentence")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_words_in_sentence == 5 and partially_decreased_cfg.max_words_in_sentence == 20