mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
ransom note curriculum (#300)
Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
parent
bfa3a58829
commit
2fca962847
3 changed files with 108 additions and 56 deletions
|
|
@ -7,13 +7,11 @@ https://leetcode.com/problems/ransom-note/description/
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
MAX_NOTE_LENGTH = 100_000
|
||||
MAX_MAGAZINE_LENGTH = 100_001
|
||||
|
||||
QUESTION_TEMPLATE = """Given two strings representing a ransom note and a magazine, return True if you can construct the ransom note using the letters in the magazine, and False otherwise.
|
||||
|
||||
Each letter in the magazine string can only be used once in your ransom note.
|
||||
|
|
@ -27,7 +25,9 @@ Magazine: {magazine}
|
|||
class RansomNoteConfig:
|
||||
"""Configuration for Ransom Note dataset generation"""
|
||||
|
||||
min_note_length: int = 1 # Minimum length of the ransom note
|
||||
max_note_length: int = 10 # Maximum length of the ransom note
|
||||
min_magazine_length: int = 2 # Minimum length of the magazine
|
||||
max_magazine_length: int = 30 # Maximum length of the magazine
|
||||
p_solvable: float = 0.5 # Probability that the ransom note can be constructed
|
||||
|
||||
|
|
@ -36,10 +36,15 @@ class RansomNoteConfig:
|
|||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_note_length <= MAX_NOTE_LENGTH, "max_note_length must be between 1 and MAX_NOTE_LENGTH"
|
||||
# assert 1 <= self.max_note_length <= MAX_NOTE_LENGTH, "max_note_length must be between 1 and MAX_NOTE_LENGTH"
|
||||
assert 1 <= self.min_note_length, "min_note_length must be at least 1"
|
||||
assert (
|
||||
2 <= self.max_magazine_length <= MAX_MAGAZINE_LENGTH
|
||||
), "max_magazine_length must be between 2 and MAX_MAGAZINE_LENGTH"
|
||||
self.min_note_length <= self.max_note_length
|
||||
), "min_note_length must be less than or equal to max_note_length"
|
||||
assert 2 <= self.min_magazine_length, "min_magazine_length must be at least 2"
|
||||
assert (
|
||||
self.min_magazine_length <= self.max_magazine_length
|
||||
), "min_magazine_length must be less than or equal to max_magazine_length"
|
||||
assert self.max_note_length < self.max_magazine_length, "max_note_length must be less than max_magazine_length"
|
||||
assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1"
|
||||
|
||||
|
|
@ -51,22 +56,18 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.letters = {chr(i) for i in range(ord("a"), ord("z") + 1)}
|
||||
|
||||
def _get_inputs(self, rng: Random, solvable: bool) -> tuple[str, str]:
|
||||
def _get_inputs(self, rng: Random, note_length: int, magazine_length: int, solvable: bool) -> tuple[str, str]:
|
||||
"""Generate random ransom note and magazine"""
|
||||
ransom_note_len = rng.randint(1, self.config.max_note_length)
|
||||
ransom_note = [rng.choice(list(self.letters)) for _ in range(ransom_note_len)]
|
||||
|
||||
magazine_len = rng.randint(ransom_note_len, self.config.max_magazine_length)
|
||||
ransom_note = [rng.choice(list(self.letters)) for _ in range(note_length)]
|
||||
magazine = ransom_note.copy()
|
||||
if solvable:
|
||||
magazine.extend([rng.choice(list(self.letters)) for _ in range(magazine_len - ransom_note_len)])
|
||||
magazine.extend([rng.choice(list(self.letters)) for _ in range(magazine_length - note_length)])
|
||||
else:
|
||||
remove_letter = rng.choice(magazine)
|
||||
magazine.remove(remove_letter)
|
||||
magazine.extend(
|
||||
[rng.choice(list(self.letters - {remove_letter})) for _ in range(magazine_len - ransom_note_len + 1)]
|
||||
[rng.choice(list(self.letters - {remove_letter})) for _ in range(magazine_length - note_length + 1)]
|
||||
)
|
||||
|
||||
rng.shuffle(ransom_note)
|
||||
rng.shuffle(magazine)
|
||||
return "".join(ransom_note), "".join(magazine)
|
||||
|
|
@ -85,35 +86,58 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Group Anagrams question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
note_length = rng.randint(self.config.min_note_length, self.config.max_note_length)
|
||||
magazine_length = rng.randint(
|
||||
max(note_length, self.config.min_magazine_length), self.config.max_magazine_length
|
||||
)
|
||||
solvable = rng.random() < self.config.p_solvable
|
||||
ransom_note, magazine = self._get_inputs(rng, solvable)
|
||||
ransom_note, magazine = self._get_inputs(rng, note_length, magazine_length, solvable)
|
||||
answer = self._can_construct(ransom_note, magazine)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(ransom_note=ransom_note, magazine=magazine),
|
||||
"answer": str(answer),
|
||||
"metadata": {"ransom_note": ransom_note, "magazine": magazine, "solution": answer, "solvable": solvable},
|
||||
"metadata": {
|
||||
"ransom_note": ransom_note,
|
||||
"magazine": magazine,
|
||||
"solution": answer,
|
||||
"solvable": solvable,
|
||||
"difficulty": {
|
||||
"note_length": note_length,
|
||||
"magazine_length": magazine_length,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves this task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
class RansomNoteCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(RansomNoteCurriculum.__name__, RansomNoteConfig)
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
if isinstance(answer, str):
|
||||
s_answer = answer.strip()
|
||||
if s_answer == str(entry["answer"]):
|
||||
return 1.0
|
||||
|
||||
return 0.0
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="note_length",
|
||||
levels=[10, 50, 100, 500],
|
||||
default_level=0,
|
||||
description="Length of the ransom note",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_note_length",
|
||||
upper_field_name="max_note_length",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="magazine_length",
|
||||
levels=[50, 100, 500, 1000],
|
||||
default_level=0,
|
||||
description="Length of the magazine",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=2,
|
||||
lower_field_name="min_magazine_length",
|
||||
upper_field_name="max_magazine_length",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig)
|
||||
register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig, RansomNoteCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue