mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* update difficulty metadata for logic datasets * update difficulty metadata for graph datasets * update difficulty metadata for geometry datasets * update difficulty metadata for games datasets * update difficulty metadata for cognition datasets * update difficulty metadata for arithmetic datasets * update difficulty metadata for arc datasets * update difficulty metadata for algorithmic datasets * update difficulty metadata for algebra datasets * use tuples * update tests * update tests
156 lines
5.6 KiB
Python
156 lines
5.6 KiB
Python
from dataclasses import dataclass
|
|
from random import Random
|
|
from typing import Any, Optional
|
|
|
|
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
|
from ..factory import ProceduralDataset, register_dataset
|
|
|
|
|
|
@dataclass
|
|
class NeedleHaystackConfig:
|
|
"""Configuration for NeedleHaystack task generation"""
|
|
|
|
min_num_statements: int = 10
|
|
max_num_statements: int = 100
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration parameters"""
|
|
assert self.min_num_statements > 0, "min_num_statements must be greater than 0"
|
|
assert (
|
|
self.max_num_statements >= self.min_num_statements
|
|
), "max_num_statements must be greater than min_num_statements"
|
|
assert self.max_num_statements < 168387000, f"max_num_statements must be less than {168387000}"
|
|
|
|
|
|
def generate_unique_triplets(names: list[str], verbs: list[str], subjects: list[str], n: int, rng) -> dict[str, Any]:
|
|
"""
|
|
Generate n unique random triplets (name, verb, subject) without generating the full Cartesian product in memory.
|
|
|
|
Each triplet is selected based on a unique index derived from a range of
|
|
total possible combinations. Additionally, one of the generated triplets is
|
|
randomly chosen as the 'needle'.
|
|
|
|
Args:
|
|
names (list[str]): List of names.
|
|
verbs (list[str]): List of verbs.
|
|
subjects (list[str]): List of subjects.
|
|
n (int): Number of unique triplets to generate.
|
|
rng (random.Random): A pre-seeded random number generator.
|
|
|
|
Returns:
|
|
dict[str, Any]: A dictionary with:
|
|
- "triplets": a list of n unique triplets (tuples of (name, verb, subject)),
|
|
- "needle": one triplet randomly chosen from the list.
|
|
|
|
Raises:
|
|
ValueError: If n exceeds the total number of unique triplets possible.
|
|
"""
|
|
total_possible = len(names) * len(verbs) * len(subjects)
|
|
|
|
# Use a range for memory efficiency and sample n unique indices.
|
|
indices = rng.sample(range(total_possible), n)
|
|
triplets: list[tuple[str, str, str]] = []
|
|
|
|
num_verbs = len(verbs)
|
|
num_subjects = len(subjects)
|
|
|
|
for idx in indices:
|
|
# Compute the corresponding indices for names, verbs, and subjects.
|
|
name_index = idx // (num_verbs * num_subjects)
|
|
remainder = idx % (num_verbs * num_subjects)
|
|
verb_index = remainder // num_subjects
|
|
subject_index = remainder % num_subjects
|
|
|
|
triplet = (names[name_index], verbs[verb_index], subjects[subject_index])
|
|
triplets.append(triplet)
|
|
|
|
# Select one random triplet as the needle.
|
|
needle = rng.choice(triplets)
|
|
return {"triplets": triplets, "needle": needle}
|
|
|
|
|
|
class NeedleHaystackDataset(ProceduralDataset):
|
|
"""Generates "Needle in a Haystack tasks"""
|
|
|
|
def __init__(self, config: NeedleHaystackConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single NeedleHaystack task
|
|
|
|
Returns:
|
|
dict with keys:
|
|
- question: str, the task description with cube string
|
|
- answer: None, indicating to use the dynamic evaluator
|
|
- metadata: dict with generation parameters and example solution
|
|
"""
|
|
from .needle_data import NAMES, SUBJECTS, VERBS
|
|
|
|
rng = Random(self.seed + idx)
|
|
|
|
num_statements = rng.randint(self.config.min_num_statements, self.config.max_num_statements)
|
|
stack = generate_unique_triplets(NAMES, VERBS, SUBJECTS, num_statements, rng)
|
|
|
|
stack_text = ""
|
|
for triplet in stack["triplets"]:
|
|
stack_text = stack_text + f"{triplet[0]} {triplet[1]} {triplet[2]}. "
|
|
question = f"Who {stack['needle'][1]} {stack['needle'][2]}? Reply only with a name."
|
|
|
|
full_text = stack_text + "\n" + question
|
|
|
|
return {
|
|
"question": full_text,
|
|
"answer": stack["needle"][0],
|
|
"metadata": {
|
|
"question": question,
|
|
"num_statements": num_statements,
|
|
"difficulty": {
|
|
"num_statements": (self.config.min_num_statements, self.config.max_num_statements),
|
|
},
|
|
},
|
|
}
|
|
|
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
|
"""Determine if the solution provided solves the task.
|
|
|
|
Args:
|
|
answer (Optional[str]): The user's answer.
|
|
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
|
|
|
Returns:
|
|
float: The computed score between 0.0 and 1.0.
|
|
"""
|
|
if isinstance(answer, str):
|
|
correct_word = entry["answer"]
|
|
|
|
# Normalize case
|
|
answer = answer.replace(" ", "").strip().lower()
|
|
correct_word = correct_word.strip().lower()
|
|
|
|
if answer == correct_word:
|
|
return 1.0 # Correct!
|
|
|
|
return 0.0
|
|
|
|
|
|
class NeedleHaystackCurriculum(BaseCurriculum):
|
|
def __init__(self):
|
|
super().__init__(NeedleHaystackCurriculum.__name__, NeedleHaystackConfig)
|
|
|
|
# Define attributes
|
|
self._define_attributes(
|
|
RangeAttributeDefinition(
|
|
name="num_statements",
|
|
levels=[10, 100, 500, 1_000, 5_000, 10_000, 50_000, 100_000],
|
|
description="Number of statements in the haystack",
|
|
lower_field_name="min_num_statements",
|
|
upper_field_name="max_num_statements",
|
|
ensure_interval=True,
|
|
),
|
|
)
|
|
|
|
|
|
# Register the dataset
|
|
register_dataset("needle_haystack", NeedleHaystackDataset, NeedleHaystackConfig, NeedleHaystackCurriculum)
|