Refactor LetterCounting

This commit is contained in:
EduardDurech 2025-02-09 08:37:58 +00:00
parent ca0fb97884
commit 86215b7e5c
6 changed files with 566 additions and 117 deletions

View file

@ -1,66 +1,63 @@
"""Letter counting task generator"""
"""Letter counting exercise that generates tasks to count letter occurrences in text."""
import re
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from typing import Dict, Any
from reasoning_gym.data import read_data_file
class LetterCountingExercise:
"""Exercise generator for letter counting tasks."""
from ..factory import ProceduralDataset, register_dataset
def __init__(self):
self.curriculum = None
def generate(self, curriculum: Any) -> Dict[str, Any]:
"""
Generate a letter counting problem using the curriculum.
@dataclass
class LetterCountingConfig:
"""Configuration for letter counting task generation"""
Returns:
Dict containing:
- question: str (e.g. "How many times does 'a' appear in 'banana'?")
- answer: str (the count as a string)
- metadata: dict with details (text, target_letter, etc.)
"""
self.curriculum = curriculum
template = curriculum.get_template(curriculum.rng)
return template.eval(self, curriculum.rng)
min_words: int = 5 # Minimum words in span
max_words: int = 15 # Maximum words in span
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse the template metadata into structured data.
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_words > 0, "min_words must be positive"
assert self.max_words >= self.min_words, "max_words must be >= min_words"
class LetterCountingDataset(ProceduralDataset):
"""Generates letter counting tasks from text spans"""
def __init__(self, config: LetterCountingConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters
self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()]
def __getitem__(self, idx: int) -> dict:
"""Generate a single letter counting task"""
rng = Random(self.seed + idx)
# Select random span of words
span_length = rng.randint(self.config.min_words, self.config.max_words)
start_idx = rng.randint(0, len(self.words) - span_length)
span = self.words[start_idx : start_idx + span_length]
# Get all unique letters from span
letters = set("".join(span).lower())
if not letters:
letters = {"a"} # Fallback if span has no letters
# Select random letter that appears in the span
target_letter = rng.choice(sorted(letters))
# Count occurrences
count = sum(word.lower().count(target_letter) for word in span)
return {
"question": f'How many times does the letter "{target_letter}" appear in the text: "{" ".join(span)}"?',
"answer": str(count),
"metadata": {"span_length": span_length, "target_letter": target_letter, "span": span},
The metadata structure from the template system:
{
"text": {"text": str}, # The text span to analyze
"letter": {"letter": str}, # The letter to count
"case_sensitivity": {"sensitivity": str} # "sensitive" or "insensitive"
}
Returns:
Dictionary containing:
- text: str (the text to analyze)
- target_letter: str (the letter to count)
- case_sensitive: bool (whether to count case sensitively)
"""
return {
"text": metadata["text"]["text"],
"target_letter": metadata["letter"]["letter"],
"case_sensitive": metadata["case_sensitivity"]["sensitivity"] == "sensitive"
}
register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig)
def _evaluate_expression(self, parsed: Dict[str, Any]) -> str:
"""
Count occurrences of the target letter in the text.
Args:
parsed: Dictionary containing:
- text: str (the text to analyze)
- target_letter: str (the letter to count)
- case_sensitive: bool (whether to count case sensitively)
Returns:
String representation of the count
"""
if parsed["case_sensitive"]:
return str(parsed["text"].count(parsed["target_letter"]))
else:
return str(parsed["text"].lower().count(parsed["target_letter"].lower()))