added word sort curriculum (#289)

This commit is contained in:
joesharratt1229 2025-03-08 01:50:13 +01:00 committed by GitHub
parent e8601a63b4
commit 88a3d065bd
3 changed files with 85 additions and 6 deletions

View file

@ -39,7 +39,7 @@ from .string_splitting import StringSplittingConfig, StringSplittingDataset
from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset
from .word_ladder import WordLadderConfig, WordLadderDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingCurriculum, WordSortingDataset
__all__ = [
"SpellBackwardConfig",
@ -67,6 +67,7 @@ __all__ = [
"SentenceReorderingDataset",
"WordSequenceReversalConfig",
"WordSequenceReversalDataset",
"WordSortingCurriculum",
"WordSortingConfig",
"WordSortingDataset",
"TextTransformation",

View file

@ -6,6 +6,7 @@ from enum import StrEnum
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@ -105,11 +106,14 @@ class WordSortingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
"answer": ", ".join(answer),
"metadata": {
"difficulty": {
"num_words": len(original_words),
"word_length": max(len(word) for word in original_words),
},
"original_words": original_words,
"sorted_words": answer,
"transformed_words": transformed_words,
"direction": direction,
"transformation": self.config.transformation,
"sorted_words": answer,
},
}
@ -125,4 +129,32 @@ class WordSortingDataset(ProceduralDataset):
return 0.0
class WordSortingCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(WordSortingCurriculum.__name__, WordSortingConfig)
self._define_attributes(
RangeAttributeDefinition(
name="num_words",
levels=[5, 10, 20, 30],
default_level=0,
description="Number of words to sort",
attr_type=AttributeType.APPEND,
min_value=5,
lower_field_name="min_words",
upper_field_name="max_words",
),
RangeAttributeDefinition(
name="word_length",
levels=[3, 6, 9, 12],
default_level=0,
description="Length of words to sort",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_word_length",
upper_field_name="max_word_length",
),
)
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)