reasoning-gym/reasoning_gym/algorithmic/palindrome_generation.py
Zafir Stojanovski ce0a6c4878
fix(envs): Add source dataset and index to metadata (#388)
* add source dataset and index to metadata

* fix typo

* fix coach class and its test
2025-03-20 11:12:14 +00:00

146 lines
5.1 KiB
Python

import random
import string
from dataclasses import dataclass
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome.
A palindrome is a phrase that reads the same forwards and backwards.
If there are multiple possible answers, only respond with one of them. You must use all the letters provided.
Your output should be a single string, with no spaces or punctuation.
Now, form a valid palindrome using the following letters: {letters}
"""
DATASET_NAME = "palindrome_generation"
@dataclass
class PalindromeConfig:
"""
Configuration for the palindrome task.
- min_length: Minimum length of the palindrome.
- max_length: Maximum length of the palindrome.
- seed: Optional seed for reproducibility.
- size: Number of palindrome samples in the virtual dataset.
"""
min_length: int = 3
max_length: int = 10
seed: Optional[int] = None
size: int = 50
def validate(self) -> None:
"""Validate configuration parameters."""
assert self.min_length >= 1, "min_length must be >= 1"
assert self.max_length >= self.min_length, "max_length must be >= min_length"
class PalindromeDataset(ProceduralDataset):
"""
Generates a set of letters that can be assembled into a palindrome.
"""
def __init__(self, config: PalindromeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""
Generate a single palindrome task.
Returns:
dict with:
- "question": Set of letters to form a palindrome.
- "answer": A correct palindrome.
- "metadata": Includes letter set and generated palindrome.
"""
rng = random.Random(self.seed + idx)
length = rng.randint(self.config.min_length, self.config.max_length)
letters = self._generate_palindrome_letters(rng, length)
scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order
palindrome = self._assemble_palindrome(letters)
return {
"question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)),
"answer": palindrome,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"letters": scrambled_letters,
"generated_palindrome": palindrome,
"length": length,
"difficulty": {
"length": (self.config.min_length, self.config.max_length),
},
},
}
def _generate_palindrome_letters(self, rng: random.Random, length: int) -> list[str]:
"""Generate a set of letters that can form a palindrome."""
half_length = length // 2
letters = rng.choices(string.ascii_lowercase, k=half_length)
if length % 2 == 1:
middle_letter = rng.choice(string.ascii_lowercase)
return letters + [middle_letter] + letters[::-1]
return letters + letters[::-1]
def _assemble_palindrome(self, letters: list[str]) -> str:
"""Return the palindrome string from the letter set."""
return "".join(letters)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided is a valid palindrome.
The answer is expected to be a single string
Expected behavior:
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
- An answer that is a string, but not a palindrome gives 0.02
- An empty string gives 0.0
- None gives 0.0.
"""
if answer is None or not isinstance(answer, str):
return 0.0 # No answer given
if answer == "":
return 0.0
metadata = entry["metadata"]
answer = answer.strip().lower()
expected_letters = metadata["letters"]
# Check if the answer is a palindrome
if answer != answer[::-1]:
return 0.02
# Check if answer contains the same letters as provided (ignoring order)
if sorted(answer) != sorted(expected_letters):
return 0.05
return 1.0 # Correct solution
class PalindromeCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(PalindromeCurriculum.__name__, PalindromeConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="length",
levels=[10, 50, 100, 500],
description="Length of the generated palindrome.",
lower_field_name="min_length",
upper_field_name="max_length",
ensure_interval=True,
)
)
register_dataset(DATASET_NAME, PalindromeDataset, PalindromeConfig, PalindromeCurriculum)