mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
sentence reordering curriculum (#326)
This commit is contained in:
parent
3c39cbda40
commit
a23c8c3d4e
3 changed files with 52 additions and 4 deletions
|
|
@ -34,7 +34,7 @@ from .pool_matrix import PoolMatrixConfig, PoolMatrixCurriculum, PoolMatrixDatas
|
||||||
from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset
|
from .ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset
|
||||||
from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset
|
from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset
|
||||||
from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset
|
from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset
|
||||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingCurriculum, SentenceReorderingDataset
|
||||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||||
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
|
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
|
||||||
from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset
|
from .string_insertion import StringInsertionConfig, StringInsertionCurriculum, StringInsertionDataset
|
||||||
|
|
@ -77,6 +77,7 @@ __all__ = [
|
||||||
"NumberSortingCurriculum",
|
"NumberSortingCurriculum",
|
||||||
"SentenceReorderingConfig",
|
"SentenceReorderingConfig",
|
||||||
"SentenceReorderingDataset",
|
"SentenceReorderingDataset",
|
||||||
|
"SentenceReorderingCurriculum",
|
||||||
"WordSequenceReversalConfig",
|
"WordSequenceReversalConfig",
|
||||||
"WordSequenceReversalDataset",
|
"WordSequenceReversalDataset",
|
||||||
"WordSequenceReversalCurriculum",
|
"WordSequenceReversalCurriculum",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
from ..data import read_data_file
|
from ..data import read_data_file
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -89,7 +90,7 @@ class SentenceReorderingDataset(ProceduralDataset):
|
||||||
return {
|
return {
|
||||||
"question": f"Restore the correct order of words in the following sentence: {question}",
|
"question": f"Restore the correct order of words in the following sentence: {question}",
|
||||||
"answer": solved_sentence,
|
"answer": solved_sentence,
|
||||||
"metadata": {"word_count": word_count},
|
"metadata": {"word_count": word_count, "difficulty": {"words_in_sentence": word_count}},
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
|
@ -114,4 +115,25 @@ class SentenceReorderingDataset(ProceduralDataset):
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)
|
class SentenceReorderingCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(SentenceReorderingCurriculum.__name__, SentenceReorderingConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="words_in_sentence",
|
||||||
|
levels=[5, 20, 50, 100],
|
||||||
|
default_level=1,
|
||||||
|
description="Number of words in the sentence",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=3,
|
||||||
|
lower_field_name="min_words_in_sentence",
|
||||||
|
upper_field_name="max_words_in_sentence",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
"sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig, SentenceReorderingCurriculum
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.algorithmic.sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
from reasoning_gym.algorithmic.sentence_reordering import (
|
||||||
|
SentenceReorderingConfig,
|
||||||
|
SentenceReorderingCurriculum,
|
||||||
|
SentenceReorderingDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -49,3 +53,24 @@ def test_key_error_in_getitem(dataset):
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
dataset[0]
|
dataset[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sentence_reordering_curriculum():
|
||||||
|
curriculum = SentenceReorderingCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: SentenceReorderingConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.min_words_in_sentence == 5 and base_cfg.max_words_in_sentence == 20
|
||||||
|
|
||||||
|
# test incrementing attribute levels
|
||||||
|
curriculum.increment_attr_level("words_in_sentence")
|
||||||
|
increased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert increased_cfg.min_words_in_sentence == 5 and increased_cfg.max_words_in_sentence == 50
|
||||||
|
|
||||||
|
# test decrementing attribute level
|
||||||
|
curriculum.decrement_attr_level("words_in_sentence")
|
||||||
|
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert partially_decreased_cfg.min_words_in_sentence == 5 and partially_decreased_cfg.max_words_in_sentence == 20
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue