reasoning-gym/reasoning_gym/graphs/quantum_lock.py
Zafir Stojanovski ce0a6c4878
fix(envs): Add source dataset and index to metadata (#388)
* add source dataset and index to metadata

* fix typo

* fix coach class and its test
2025-03-20 11:12:14 +00:00

256 lines
9.1 KiB
Python

import re
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "quantum_lock"
@dataclass
class QuantumLockConfig:
"""Configuration for QuantumLock task generation"""
difficulty: int = 10
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.difficulty > 0, "difficulty must be positive"
assert self.size > 0, "size must be positive"
class QuantumLockDataset(ProceduralDataset):
"""Generates QuantumLock tasks"""
def __init__(self, config: QuantumLockConfig):
self._prompt_templates = [
"""\
In front of you are some buttons, a light, and a number. The light will toggle between red and green whenever you press a button. Each button performs a mathematical operation to the number, but the operation may depend on the state of the light.
You must press the shortest correct sequence of buttons to reach the target value. Your answer should be a sequence of buttons separated by '', for example: A → B → C
Start: {initial_value} ({initial_state})
Target: {target_value}
Buttons:
{buttons}"""
]
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single QuantumLock task
Returns:
dict with keys:
- question: str, the task description
- answer: str, a solution string
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
difficulty = rng.randint(1, self.config.difficulty)
puzzle_data = self.generate_quantum_puzzle(rng, difficulty)
return {
"question": self.format_puzzle(rng.choice(self._prompt_templates), puzzle=puzzle_data),
"answer": "".join(puzzle_data["solution"]),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"solution_path": puzzle_data["solution"],
"target_value": puzzle_data["target_value"],
"buttons": puzzle_data["buttons"],
"initial_state": puzzle_data["initial_state"],
"initial_value": puzzle_data["initial_value"],
"difficulty": {"difficulty": difficulty},
},
}
def generate_quantum_puzzle(self, rng: Random, difficulty: int = 1) -> dict[str, Any]:
"""
Generates a Quantum Lock puzzle with configurable difficulty.
Returns a dictionary containing puzzle parameters and solution.
"""
# Define operation parameters based on difficulty
base_values = {
"add": [2, 3] if difficulty >= 5 else [1, 2],
"subtract": [2, 3] if difficulty >= 5 else [1, 2],
"multiply": [2, 3] if difficulty >= 7 else [2],
}
operations = [
{"type": "add", "values": base_values["add"]},
{"type": "subtract", "values": base_values["subtract"]},
{"type": "multiply", "values": base_values["multiply"]},
]
# Generate unique buttons with collision protection
buttons = []
used_combinations = set()
while len(buttons) < 3:
op = rng.choice(operations)
btn_value = rng.choice(op["values"])
# State selection with weighted probabilities
state_weights = {"any": 4, "green": 2, "red": 1}
active_state = rng.choices(list(state_weights.keys()), weights=state_weights.values(), k=1)[0]
# Create unique combination check
combo = (op["type"], btn_value, active_state)
if combo in used_combinations:
continue
# Prevent duplicate button effects
if any(
b["type"] == op["type"] and b["value"] == btn_value and b["active_state"] == active_state
for b in buttons
):
continue
buttons.append(
{"name": chr(65 + len(buttons)), "type": op["type"], "value": btn_value, "active_state": active_state}
)
used_combinations.add(combo)
# Dynamic target scaling with non-linear progression
base_target = 5 + (difficulty**1.5)
variance = rng.randint(-int(base_target * 0.2), int(base_target * 0.3))
target = max(8, int(base_target + variance))
# Create puzzle structure
puzzle = {
"initial_value": 0,
"initial_state": "red",
"target_value": target,
"buttons": buttons,
"max_steps": min(15, 6 + int(difficulty * 1.5)),
"solution": None,
}
# Find shortest solution using BFS
queue = deque([(0, "red", [])])
visited = set()
while queue:
val, state, path = queue.popleft()
if val == puzzle["target_value"]:
puzzle["solution"] = path
return puzzle
if len(path) >= puzzle["max_steps"] or (val, state) in visited:
continue
visited.add((val, state))
for btn in buttons:
next_state = "green" if state == "red" else "red"
# Check if button is usable
if btn["active_state"] not in [state, "any"]:
continue
# Calculate new value
try:
if btn["type"] == "add":
new_val = val + btn["value"]
elif btn["type"] == "subtract":
new_val = val - btn["value"]
elif btn["type"] == "multiply":
new_val = val * btn["value"]
except:
continue # Handle overflows if needed
queue.append((new_val, next_state, path + [btn["name"]]))
# If no solution found, regenerate
return self.generate_quantum_puzzle(rng, difficulty)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the task.
The function awards 1.0 for a correct answer and less otherwise.
"""
if not isinstance(answer, str):
return 0.0
# Normalize both answers
def normalize_seq(seq: str) -> list[str]:
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]
user_sequence = normalize_seq(answer)
target_sequence = normalize_seq(entry["answer"])
# Exact sequence match required
if user_sequence == target_sequence:
return 1.0
# Partial credit for reaching target (optional)
final_state = self.simulate_sequence(entry["metadata"], user_sequence)
if final_state == entry["metadata"]["target_value"]:
if len(user_sequence) == len(target_sequence):
return 1.0 # Different answer, but qually correct
return 0.5 # Alternative scoring - you're correct, but not optimal
return 0.0
def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int:
"""Simulate button presses to verify solutions"""
state = metadata["initial_value"]
current_color = metadata["initial_state"]
buttons = {btn["name"]: btn for btn in metadata["buttons"]}
for btn_char in sequence:
btn = buttons.get(btn_char.upper())
if not btn:
continue
# Check button availability
if btn["active_state"] not in [current_color, "any"]:
continue
# Apply operation
if btn["type"] == "add":
state += btn["value"]
elif btn["type"] == "subtract":
state -= btn["value"]
elif btn["type"] == "multiply":
state *= btn["value"]
# Toggle color state
current_color = "green" if current_color == "red" else "red"
return state
def format_puzzle(self, template, puzzle: dict) -> str:
return template.format(
initial_value=puzzle["initial_value"],
initial_state=puzzle["initial_state"],
target_value=puzzle["target_value"],
buttons="\n".join(
f"{btn['name']}: {btn['type'].title()} {btn['value']} (when {btn['active_state']})"
for btn in puzzle["buttons"]
),
)
class QuantumLockCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(QuantumLockCurriculum.__name__, QuantumLockConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="difficulty",
field_name="difficulty",
levels=list(range(1, 11)),
description="The difficulty of the puzzle",
)
)
# Register the dataset
register_dataset(DATASET_NAME, QuantumLockDataset, QuantumLockConfig, QuantumLockCurriculum)