sentence reordering curriculum (#326)

This commit is contained in:
Zafir Stojanovski 2025-03-11 00:21:41 +01:00 committed by GitHub
parent 3c39cbda40
commit a23c8c3d4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 4 deletions

View file

@ -34,7 +34,7 @@ from .pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDatas
from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset
from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingCurriculum, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset
@ -77,6 +77,7 @@ __all__ = [
"NumberSortingCurriculum",
"SentenceReorderingConfig",
"SentenceReorderingDataset",
"SentenceReorderingCurriculum",
"WordSequenceReversalConfig",
"WordSequenceReversalDataset",
"WordSequenceReversalCurriculum",

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@ -89,7 +90,7 @@ class SentenceReorderingDataset(ProceduralDataset):
return {
"question": f"Restore the correct order of words in the following sentence: {question}",
"answer": solved_sentence,
"metadata": {"word_count": word_count},
"metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -114,4 +115,25 @@ class SentenceReorderingDataset(ProceduralDataset):
return reward
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)
class SentenceReorderingCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SentenceReorderingCurriculum.__name__, SentenceReorderingConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="words_in_sentence",
levels=[5, 20, 50, 100],
default_level=1,
description="Number of words in the sentence",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_words_in_sentence",
upper_field_name="max_words_in_sentence",
),
)
register_dataset(
"sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum
)

View file

@ -1,6 +1,10 @@
import pytest
from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from reasoning_gym.algorithmic.sentence_reordering import (
SentenceReorderingConfig,
SentenceReorderingCurriculum,
SentenceReorderingDataset,
)
@pytest.fixture
@ -49,3 +53,24 @@ def test_key_error_in_getitem(dataset):
with pytest.raises(KeyError):
dataset[0]
def test_sentence_reordering_curriculum():
curriculum = SentenceReorderingCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: SentenceReorderingConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_words_in_sentence == 5 and base_cfg.max_words_in_sentence == 20
# test incrementing attribute levels
curriculum.increment_attr_level("words_in_sentence")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_words_in_sentence == 5 and increased_cfg.max_words_in_sentence == 50
# test decrementing attribute level
curriculum.decrement_attr_level("words_in_sentence")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_words_in_sentence == 5 and partially_decreased_cfg.max_words_in_sentence == 20