mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
"""List functions generators"""
|
|
|
|
from dataclasses import dataclass
|
|
from random import Random
|
|
from typing import Any, Callable, Optional
|
|
|
|
from reasoning_gym.factory import ProceduralDataset, register_dataset
|
|
|
|
DATASET_NAME = "list_functions"
|
|
|
|
|
|
@dataclass
|
|
class ListFunctionsDatasetConfig:
|
|
"""Configuration for List function generators."""
|
|
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration parameters"""
|
|
assert self.size > 0, "size must be positive"
|
|
|
|
|
|
tasks = list(range(17))
|
|
|
|
|
|
class ListFunctionsDataset(ProceduralDataset):
|
|
|
|
def __init__(self, config: ListFunctionsDatasetConfig):
|
|
super().__init__(config, config.seed, config.size)
|
|
self._generators: dict[int, Callable[[Random, float], dict[str, Any]]] = None # initially None, lazy loading
|
|
self.task_indices = Random(self.seed).choices(tasks, k=self.size)
|
|
self.prompt_template = """You are an expert at inductive reasoning. Generate an output corresponding to the given input.
|
|
The output is generated by applying the same rule that maps input to output for the examples provided. Your answer should be a list of element/elements
|
|
Examples:
|
|
{examples}
|
|
|
|
Input: {input}
|
|
Output:
|
|
"""
|
|
|
|
@property
|
|
def generators(self) -> dict[int, Callable[[Random, float], dict[str, Any]]]:
|
|
"""Lazy load generators only when first accessed"""
|
|
if self._generators is None:
|
|
self._generators = self._load_generators()
|
|
return self._generators
|
|
|
|
def _load_generators(self):
|
|
"""
|
|
Generates mapper from task identifiers (keys) to example generator functions
|
|
"""
|
|
from . import generators
|
|
|
|
def strip_prefix(s: str, prefix: str) -> str:
|
|
return s[len(prefix) :]
|
|
|
|
prefix = "generate_"
|
|
gs = {}
|
|
for n in dir(generators):
|
|
if n.startswith(prefix):
|
|
gs[int(strip_prefix(n, prefix))] = getattr(generators, n)
|
|
return gs
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single induction-based list function dataset"""
|
|
rng = Random(self.seed + idx)
|
|
generator_idx = self.task_indices[idx]
|
|
generator = self.generators[generator_idx]
|
|
examples = generator(rng)
|
|
entry = examples.popitem()
|
|
input = entry[0]
|
|
output = entry[1]
|
|
formatted_examples = ""
|
|
for index, key in enumerate(examples):
|
|
formatted_examples += f"""Input {index + 1}: {key}
|
|
Output {index + 1}: {examples[key]}
|
|
"""
|
|
question = self.prompt_template.format(examples=formatted_examples, input=input)
|
|
return {
|
|
"question": question,
|
|
"answer": output,
|
|
"metadata": {
|
|
"source_dataset": DATASET_NAME,
|
|
"source_index": idx,
|
|
},
|
|
}
|
|
|
|
|
|
register_dataset(DATASET_NAME, ListFunctionsDataset, ListFunctionsDatasetConfig)
|