Refactor LetterJumble

This commit is contained in:
EduardDurech 2025-02-09 12:36:07 +00:00
parent b8ce5a8a5d
commit 18b6e71fa9
6 changed files with 550 additions and 190 deletions

View file

@ -1,103 +1,66 @@
"""Word letter jumbling task generator"""
"""Exercise definition for letter jumble exercises."""
import re
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from typing import Dict, Any
from reasoning_gym.core.template import Template
from reasoning_gym.data import read_data_file
class LetterJumbleExercise:
"""Exercise generator for word jumbling tasks."""
from ..factory import ProceduralDataset, register_dataset
def __init__(self):
self.curriculum = None
def generate(self, curriculum: Any) -> Dict[str, Any]:
"""
Generate a word jumbling problem using the curriculum.
@dataclass
class LetterJumbleConfig:
"""Configuration for letter jumbling task generation"""
Returns:
Dict containing:
- question: str (e.g. "Unscramble these words: OLHEL DLWOR")
- answer: str (the original words)
- metadata: dict with details (scrambled_words, original_words, etc.)
"""
self.curriculum = curriculum
template = curriculum.get_template(curriculum.rng)
return template.eval(self, curriculum.rng)
min_word_len: int = 1 # Minimum word length
max_word_len: int = 64 # Maximum word length
min_words: int = 3 # Minimum words per task
max_words: int = 20 # Maximum words per task
min_corruption_level: float = 0.1 # Minimum fraction of characters to swap
max_corruption_level: float = 0.9 # Maximum fraction of characters to swap
consecutive_words: bool = True # Whether to select consecutive words from text
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse the expression from the metadata.
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_word_len > 0, "min_word_len must be positive"
assert self.max_word_len >= self.min_word_len, "max_word_len must be >= min_word_len"
assert self.min_words > 0, "min_words must be positive"
assert self.max_words >= self.min_words, "max_words must be >= min_words"
assert 0 <= self.min_corruption_level <= 1, "min_corruption_level must be in [0,1]"
assert 0 <= self.max_corruption_level <= 1, "max_corruption_level must be in [0,1]"
assert (
self.max_corruption_level >= self.min_corruption_level
), "max_corruption_level must be >= min_corruption_level"
class LetterJumbleDataset(ProceduralDataset):
"""Generates word letter jumbling tasks"""
def __init__(self, config: LetterJumbleConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")
# Extract words and filter by length
self.words = [
word
for word in re.findall(r"\b\w+\b", text)
if self.config.min_word_len <= len(word) <= self.config.max_word_len and word.isalpha()
]
def _scramble_word(self, word: str, corruption_level: float, rng: Random) -> str:
"""Scramble a word by swapping random pairs of characters"""
if len(word) < 2: # Can't scramble 1-character words
return word
word = list(word)
num_swaps = max(1, int(len(word) * corruption_level)) # Ensure at least one swap
for _ in range(num_swaps):
# Pick two different random positions
pos1, pos2 = rng.sample(range(len(word)), 2)
# Swap characters
word[pos1], word[pos2] = word[pos2], word[pos1]
return "".join(word)
def __getitem__(self, idx: int) -> dict:
"""Generate a single word jumbling task"""
rng = Random(self.seed + idx)
# Select number of words and corruption level
num_words = rng.randint(self.config.min_words, self.config.max_words)
corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level)
# Select words based on configuration
if self.config.consecutive_words:
# Select consecutive words from a random starting position
start_idx = rng.randint(0, len(self.words) - num_words)
selected_words = self.words[start_idx : start_idx + num_words]
else:
# Select random words
selected_words = rng.sample(self.words, num_words)
# Scramble each word
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
return {
"question": f"Unscramble these words: {' '.join(scrambled_words)}",
"answer": " ".join(selected_words),
"metadata": {
"num_words": num_words,
"corruption_level": corruption_level,
"scrambled_words": scrambled_words,
"original_words": selected_words,
},
The metadata structure from the template system:
{
"scrambled": {
"scrambled_words": str, # Space-separated scrambled words
"original_words": List[str] # List of original words
}
}
Args:
metadata: The metadata containing the expression information.
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
Returns:
A dictionary containing:
- scrambled_words: List[str] of scrambled words
- original_words: List[str] of original words
"""
# Extract the scrambled and original words from metadata
template_data = metadata["scrambled"]
scrambled_words = template_data["scrambled_words"].split()
original_words = template_data["original_words"]
return {
"scrambled_words": scrambled_words,
"original_words": original_words
}
def _evaluate_expression(self, parsed_data: Dict[str, Any]) -> str:
"""Evaluate the expression using the parsed data.
Args:
parsed_data: Dictionary containing:
- scrambled_words: List[str] of scrambled words
- original_words: List[str] of original words
Returns:
The answer string (space-separated original words).
"""
return " ".join(parsed_data["original_words"])

View file

@ -1,8 +1,10 @@
from .base_conversion_curriculum import BaseConversionCurriculum
from .caesar_cipher_curriculum import CaesarCipherCurriculum
from .letter_counting_curriculum import LetterCountingCurriculum
from .letter_jumble_curriculum import LetterJumbleCurriculum
__all__ = [
"BaseConversionCurriculum",
"CaesarCipherCurriculum",
"LetterCountingCurriculum"
"LetterCountingCurriculum",
"LetterJumbleCurriculum"
]

View file

@ -0,0 +1,122 @@
"""
Curriculum definition for letter jumble exercises.
"""
from typing import Dict, Any
from reasoning_gym.core.base_curriculum import BaseCurriculum
from reasoning_gym.core.attributes import AttributeDefinition, AttributeType
from reasoning_gym.core.template import Template
from reasoning_gym.data import read_data_file
class LetterJumbleCurriculum(BaseCurriculum):
def __init__(self):
super().__init__("LetterJumbleCurriculum")
import re
self.words = [word for word in re.findall(r"\b\w+\b", read_data_file("in_the_year_2889.txt")) if word.isalpha()]
def _init_curriculum(self) -> None:
"""Initialize the letter jumble curriculum configuration"""
# Define valid attribute types
self._valid_types = {
AttributeType.STATIC, # For boolean flags
AttributeType.UBOUND, # For ranges like word length, num words
AttributeType.APPEND # For accumulating options
}
# Define attributes
self._attributes = {
"word_length": AttributeDefinition(
levels=[7, 12, 64], # From min_word_len/max_word_len
default_level=0,
description="Maximum word length",
attr_type=AttributeType.UBOUND,
min_value=1 # Ensure at least 2 chars for scrambling
),
"preserve_length": AttributeDefinition(
levels=[4, 2],
default_level=0,
description="Word length to preserve",
attr_type=AttributeType.STATIC
),
"num_words": AttributeDefinition(
levels=[3, 5, 20], # From min_words/max_words
default_level=0,
description="Number of words to scramble",
attr_type=AttributeType.UBOUND,
min_value=1 # Ensure at least 1 word
),
"corruption_level": AttributeDefinition(
levels=[0.1, 0.3, 0.9], # From min/max_corruption_level
default_level=0,
description="Fraction of characters to swap",
attr_type=AttributeType.UBOUND,
min_value=0.1
),
"consecutive_words": AttributeDefinition(
levels=[True, False],
default_level=0,
description="Whether to select consecutive words",
attr_type=AttributeType.APPEND
)
}
# Define templates with symbolic placeholders
self._templates = [
Template(
template="Unscramble these words: \"{scrambled}\"",
parts={"scrambled": "word_list"}
),
Template(
template="What are the original words? \"{scrambled}\"",
parts={"scrambled": "word_list"}
),
Template(
template="Rearrange the letters to find the original words: \"{scrambled}\"",
parts={"scrambled": "word_list"}
)
]
# Define symbolic structure
self._symbolic = {
# Shared variables that need to be consistent across templates
"shared_vars": {
# Selected original words that will be scrambled
"selected_words": lambda refs: (
n_words := refs["num_words"](),
pool := self.words,
refs["dataset_rng"].sample(pool, n_words) if not refs["consecutive_words"]() else
(
start := refs["dataset_rng"].randint(0, max(0, len(pool)-n_words)),
pool[start:start + n_words]
)[-1]
)[-1]
},
# Value generators for dynamic content
"generators": {
# Scramble a single word based on corruption level
"scramble_word": lambda refs: lambda lst: (
[
(i, j, lst.__setitem__(i, lst[j]), lst.__setitem__(j, temp)) # Debugging: keep track of indices and assignments
for _ in range(max(0, int(len(lst) * refs["corruption_level"]())))
for i, j in [refs["dataset_rng"].sample(range(len(lst)), 2)]
for temp in [lst[i]] # Introduce temp variable for correct swap
],
"".join(lst)
)[-1],
# Generate scrambled version of all selected words
"scramble_all": lambda refs: lambda: [
refs["scramble_word"](refs)(list(word)) if len(word) > refs["preserve_length"]() else word
for word in refs["selected_words"](refs)
]
},
# Template composition
"templates": {
"word_list": lambda refs: {
"template": "{scrambled_words}",
"parts": {
"scrambled_words": lambda refs=refs: " ".join(refs["scramble_all"](refs)()),
"original_words": lambda refs=refs: refs["selected_words"](refs)
}
}
}
}

View file

@ -9,7 +9,7 @@ Algorithmic tasks for training reasoning capabilities:
from .base_conversion import BaseConversionExercise
from .caesar_cipher import CaesarCipherExercise
from .letter_counting import LetterCountingExercise
# from .letter_jumble import LetterJumbleExercise
from .letter_jumble import LetterJumbleExercise
# from .number_filtering import NumberFilteringExercise
# from .number_sorting import NumberSortingExercise
# from .sentence_reordering import SentenceReorderingExercise
@ -23,7 +23,7 @@ __all__ = [
"BaseConversionExercise",
"CaesarCipherExercise",
"LetterCountingExercise",
# "LetterJumbleDataset",
"LetterJumbleExercise",
# "NumberFilteringDataset",
# "NumberSortingDataset",
# "SentenceReorderingDataset",