mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
update generation of input string
This commit is contained in:
parent
ed606631bb
commit
fd5c47d634
1 changed files with 53 additions and 37 deletions
|
|
@ -5,12 +5,11 @@ https://leetcode.com/problems/palindrome-partitioning/description/
|
|||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Given a string, partition it such that every substring is a palindrome.
|
||||
|
|
@ -20,11 +19,14 @@ A palindrome is a word that reads the same backward as forward.
|
|||
You may return all possible palindrome partitioning in any order.
|
||||
|
||||
Example:
|
||||
Input: "aab"
|
||||
Output: [["a","a","b"],["aa","b"]]
|
||||
- Input: Partition the following string into palindromes: aab
|
||||
- Output: [["a","a","b"],["aa","b"]]
|
||||
- Explanation:
|
||||
- One way to partition the string is "a" | "a" | "b", where each substring is a palindrome.
|
||||
- Another way to partition the string is "aa" | "b", where again each substring is a palindrome.
|
||||
- Therefore, the final result is a list of the two palindrome partitions.
|
||||
|
||||
Partition the following string into palindromes:
|
||||
{string}
|
||||
Partition the following string into palindromes: {string}
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -32,12 +34,19 @@ Partition the following string into palindromes:
|
|||
class PalindromePartitioningConfig:
|
||||
"""Configuration for Palindrome Partitioning dataset generation"""
|
||||
|
||||
min_string_len: int = 5
|
||||
max_string_len: int = 15
|
||||
max_substring_palindome_len: int = 5
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
pass
|
||||
assert 1 <= self.min_string_len, "Minimum string length must be at least 1"
|
||||
assert self.min_string_len <= self.max_string_len, "Minimum string length must be less than or equal to maximum"
|
||||
assert 1 <= self.max_substring_palindome_len, "Maximum substring palindrome length must be at least 1"
|
||||
assert self.max_substring_palindome_len <= self.max_string_len, "Maximum substring palindrome length must be less than or equal to maximum string length"
|
||||
|
||||
|
||||
class PalindromePartitioningDataset(ProceduralDataset):
|
||||
|
|
@ -45,27 +54,14 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
|
||||
def __init__(self, config: PalindromePartitioningConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.words = [
|
||||
re.sub(r"\W+", "", word.strip()) for word in read_data_file("in_the_year_2889.txt").split() if word.strip()
|
||||
]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.config.size
|
||||
|
||||
def __iter__(self):
|
||||
self._current_idx = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._current_idx >= self.config.size:
|
||||
raise StopIteration
|
||||
item = self[self._current_idx]
|
||||
self._current_idx += 1
|
||||
return item
|
||||
|
||||
def _sort_list(self, lst: list[list[str]]) -> list[list[str]]:
|
||||
"""Sort the list of palindrome partitions"""
|
||||
return sorted([sublist for sublist in lst], key=lambda x: x[0] if x else "")
|
||||
return sorted(lst, key=lambda x: x[0] if x else "")
|
||||
|
||||
def to_set_of_tuples(self, list_of_lists: list[list[str]]) -> set[tuple[str]]:
|
||||
"""Convert a list of lists to a set of tuples"""
|
||||
return {tuple(lst) for lst in list_of_lists}
|
||||
|
||||
def _palindrome_partitioning(self, string: str) -> list[list[str]]:
|
||||
"""Return all possible palindrome partitions of a string"""
|
||||
|
|
@ -97,25 +93,45 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Score a single Palindrome Partitioning question"""
|
||||
reward = 0
|
||||
if answer is not None:
|
||||
try:
|
||||
answer = json.loads(answer)
|
||||
oracle = entry["metadata"]["solution"]
|
||||
answer_str = json.dumps(self._sort_list(answer))
|
||||
oracle_str = json.dumps(self._sort_list(oracle))
|
||||
if answer_str == oracle_str:
|
||||
reward = 1
|
||||
else:
|
||||
reward = 0.01
|
||||
answer = self.to_set_of_tuples(json.loads(answer))
|
||||
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
|
||||
if answer == oracle:
|
||||
return 1.0
|
||||
return 0.01
|
||||
except Exception:
|
||||
reward = 0
|
||||
return reward
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:
|
||||
"""Generate a set of letters that can form a palindrome."""
|
||||
half_length = length // 2
|
||||
letters = rng.choices(string.ascii_lowercase, k=half_length)
|
||||
if length % 2 == 1:
|
||||
middle_letter = rng.choice(string.ascii_lowercase)
|
||||
return letters + [middle_letter] + letters[::-1]
|
||||
return letters + letters[::-1]
|
||||
|
||||
def _get_string(self, rng: Random) -> str:
|
||||
"""Generate a random string"""
|
||||
size = rng.randint(self.config.min_string_len, self.config.max_string_len)
|
||||
output = ""
|
||||
|
||||
while len(output) < size:
|
||||
palindrome_len = rng.randint(
|
||||
1,
|
||||
min(self.config.max_substring_palindome_len, size - len(output))
|
||||
)
|
||||
substring = "".join(self._generate_palindrome_letters(rng, palindrome_len))
|
||||
output += substring
|
||||
|
||||
return output
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Palindrome Partitioning question"""
|
||||
rng = Random(self.seed + idx)
|
||||
string = rng.choice(self.words)
|
||||
string = self._get_string(rng)
|
||||
answer = self._palindrome_partitioning(string)
|
||||
answer_str = json.dumps(answer)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue