mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
import gzip
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from random import Random
|
|
from typing import Any, Optional
|
|
|
|
from ..factory import ProceduralDataset, register_dataset
|
|
|
|
OUTPUT_PREDICTION_PROMPT_TEMPLATE = """
|
|
You are given a question that requires some input and output variables as follows:
|
|
|
|
{0}
|
|
|
|
The input and output requirements are as follows:
|
|
|
|
{1}
|
|
|
|
Given the following input:
|
|
|
|
{2}
|
|
|
|
Can you predict the output without writing any code? Please think and then provide the exact output in the form of a JSON object as your final answer. The keys and values of the object should strictly match the output requirement as specified.
|
|
|
|
Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly.
|
|
|
|
{3}
|
|
"""
|
|
|
|
INPUT_PREDICTION_PROMPT_TEMPLATE = """
|
|
You are given a question that requires some input and output variables as follows:
|
|
|
|
{0}
|
|
|
|
The input and output requirements are as follows:
|
|
|
|
{1}
|
|
|
|
Given the following output:
|
|
|
|
{2}
|
|
|
|
Can you predict a feasible input without writing any code? Please reason and put your final answer in the form of a JSON object, even if the there is only one input variable, with keys strictly matching the input variables' names as specified.
|
|
|
|
Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly.
|
|
|
|
{3}
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class CodeIOConfig:
|
|
"""Configuration for CodeI/O reasoning task generation"""
|
|
|
|
seed: Optional[int] = None
|
|
size: int = 500
|
|
input_prediction_probability: float = 0.5
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration parameters"""
|
|
assert 0.0 <= self.input_prediction_probability <= 1.0, "input_prediction_probability must be in [0, 1]"
|
|
|
|
|
|
class CodeIODataset(ProceduralDataset):
|
|
def __init__(self, config: CodeIOConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
|
|
self._data_path = Path(__file__).parent / "contrib/codeio/data.jsonl.gz"
|
|
self._offsets = []
|
|
|
|
# Index line byte offsets in the CodeI/O data file for fast random access
|
|
with gzip.open(self._data_path, "rt", encoding="utf-8") as f:
|
|
while True:
|
|
offset, line = f.tell(), f.readline()
|
|
if not line:
|
|
break
|
|
self._offsets.append(offset)
|
|
|
|
def __len__(self) -> int:
|
|
return self.config.size
|
|
|
|
def __iter__(self):
|
|
self._current_idx = 0
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self._current_idx >= self.config.size:
|
|
raise StopIteration
|
|
item = self[self._current_idx]
|
|
self._current_idx += 1
|
|
return item
|
|
|
|
def _generate_io_pairs(self, main_code: str, input_generator_code: str, num_pairs: int = 1):
|
|
local_vars = {}
|
|
exec(main_code, {}, local_vars)
|
|
exec(input_generator_code, {}, local_vars)
|
|
io_pairs = []
|
|
for _ in range(num_pairs):
|
|
inputs = local_vars["input_generator"]()
|
|
outputs = local_vars["main"](**inputs)
|
|
io_pairs.append((inputs, outputs))
|
|
return io_pairs
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single CodeI/O reasoning task"""
|
|
rng = Random(self.seed + idx)
|
|
|
|
random_offset = rng.choice(self._offsets)
|
|
with gzip.open(self._data_path, "rt", encoding="utf-8") as f:
|
|
f.seek(random_offset)
|
|
json_data = json.loads(f.readline().strip())
|
|
|
|
query = json_data["query"]
|
|
parameters = json_data["parameters"]
|
|
reference_code = json_data["reference_code"]
|
|
input_generator_code = json_data["input_generator"]
|
|
|
|
input_data, output_data = self._generate_io_pairs(reference_code, input_generator_code, num_pairs=1)[0]
|
|
|
|
if rng.random() < self.config.input_prediction_probability:
|
|
question = OUTPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, input_data, reference_code)
|
|
solution = json.dumps(output_data)
|
|
else:
|
|
question = INPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, output_data, reference_code)
|
|
solution = json.dumps(input_data)
|
|
|
|
return {
|
|
"question": question,
|
|
"answer": solution,
|
|
"metadata": {},
|
|
}
|
|
|
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
|
# TODO: this scoring could definitely be refined
|
|
oracle_answer = entry["answer"].strip()
|
|
reward = 0.0
|
|
if answer is not None and len(answer) > 0:
|
|
answer = answer.strip()
|
|
if answer == oracle_answer:
|
|
reward = 1.0
|
|
elif "{" in answer and "}" in answer:
|
|
# Check if the answer contains a correct format JSON object somewhere
|
|
# But penalise for length & accuracy
|
|
ans_first_open, ans_last_close = answer.index("{"), answer.rindex("}")
|
|
extra_chars = len(answer[:ans_first_open]) + len(answer[ans_last_close + 1 :])
|
|
|
|
try:
|
|
answer_dict = json.loads(answer[ans_first_open : ans_last_close + 1])
|
|
oracle_dict = json.loads(oracle_answer)
|
|
if answer_dict == oracle_dict:
|
|
# 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much
|
|
# e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}"
|
|
reward = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2)
|
|
elif answer_dict.keys() == oracle_dict.keys():
|
|
# Wrong answer, but at least the right format
|
|
reward = 0.1
|
|
else:
|
|
# At least we got a JSON object, I guess?
|
|
reward = 0.05
|
|
except json.JSONDecodeError:
|
|
if oracle_answer in answer:
|
|
reward = len(oracle_answer) / len(answer)
|
|
elif oracle_answer in answer:
|
|
# max() to avoid penalising too heavily, since correct answers are short here
|
|
reward = max(len(oracle_answer) / len(answer), 0.2)
|
|
else:
|
|
reward = 0.01
|
|
|
|
return reward
|
|
|
|
|
|
# Register the dataset
|
|
# register_dataset("codeio", CodeIODataset, CodeIOConfig)
|