mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
* math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test
514 lines
19 KiB
Python
514 lines
19 KiB
Python
import copy
|
|
import itertools
|
|
from dataclasses import dataclass
|
|
from random import Random
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
|
|
from reasoning_gym.factory import ProceduralDataset, register_dataset
|
|
|
|
COMMON_NAMES = [
|
|
"Emma",
|
|
"Liam",
|
|
"Olivia",
|
|
"Noah",
|
|
"Ava",
|
|
"Ethan",
|
|
"Sophia",
|
|
"Mason",
|
|
"Isabella",
|
|
"William",
|
|
"Mia",
|
|
"James",
|
|
"Charlotte",
|
|
"Benjamin",
|
|
"Amelia",
|
|
"Lucas",
|
|
"Harper",
|
|
"Henry",
|
|
"Evelyn",
|
|
"Alexander",
|
|
"Abigail",
|
|
"Michael",
|
|
"Emily",
|
|
"Daniel",
|
|
"Elizabeth",
|
|
"Jacob",
|
|
"Sofia",
|
|
"Logan",
|
|
"Avery",
|
|
"Jackson",
|
|
"Ella",
|
|
"Sebastian",
|
|
"Scarlett",
|
|
"Jack",
|
|
"Grace",
|
|
"Aiden",
|
|
"Chloe",
|
|
"Owen",
|
|
"Victoria",
|
|
"Samuel",
|
|
"Riley",
|
|
"Matthew",
|
|
"Aria",
|
|
"Joseph",
|
|
"Lily",
|
|
"Luke",
|
|
"Aurora",
|
|
"David",
|
|
"Zoey",
|
|
"Oliver",
|
|
"Penelope",
|
|
]
|
|
|
|
KNIGHT_KNAVE_PAIRS = [
|
|
["a knight", "a knave"],
|
|
["a pioneer", "a laggard"],
|
|
["a saint", "a sinner"],
|
|
["a hero", "a villain"],
|
|
["an angel", "a devil"],
|
|
["an altruist", "an egoist"],
|
|
["a sage", "a fool"],
|
|
]
|
|
|
|
VALID_ROLES = {pair[i].split()[1] for pair in KNIGHT_KNAVE_PAIRS for i in range(2)}
|
|
|
|
PREFIX = (
|
|
"A very special island is inhabited only by {knight}s and {knave}s. "
|
|
+ "{Knight}s always tell the truth, and {knave}s always lie. "
|
|
)
|
|
|
|
POSTFIX = "So who is {a_knight} and who is {a_knave}?"
|
|
|
|
TEMPLATES = [
|
|
"{name} said that {content}.",
|
|
"{name} told you that {content}.",
|
|
'{name} said, "{content}."',
|
|
'{name} stated, "{content}".',
|
|
'According to {name}, "{content}".',
|
|
'In {name}\'s words: "{content}".',
|
|
'{name} remarked, "{content}".',
|
|
'"{content}," {name} declared.',
|
|
'{name} was heard saying, "{content}".',
|
|
"{name} expressed that {content}.",
|
|
'"{content}" - {name}.',
|
|
'As {name} put it, "{content}".',
|
|
'{name} asserted: "{content}".',
|
|
'"{content}," {name} mentioned.',
|
|
'{name} commented, "{content}".',
|
|
'In a statement by {name}: "{content}".',
|
|
'{name} noted, "{content}".',
|
|
'"{content}," {name} claimed.',
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class KnightsKnavesConfig:
|
|
"""
|
|
Configuration for knights and knaves task generation.
|
|
|
|
:param n_people: Number of people in the problem
|
|
:param depth_constraint: Maximum depth of each person's statement
|
|
:param width_constraint: Maximum width (number of branches) of each person's statement
|
|
:param size: Virtual size of dataset
|
|
:param seed: Random seed
|
|
"""
|
|
|
|
n_people: int = 2
|
|
depth_constraint: int = 2
|
|
width_constraint: int = 2
|
|
size: int = 500
|
|
seed: Optional[int] = None
|
|
|
|
def validate(self):
|
|
assert 1 <= self.n_people, "Number of people must be at least 1"
|
|
assert 1 <= self.depth_constraint, "Depth constraint must be at least 1"
|
|
assert 1 <= self.width_constraint, "Width constraint must be at least 1"
|
|
|
|
|
|
class KKProblemSampler:
|
|
def __init__(self, rand_seed: int, n_people: int, depth_constraint: int = 2, width_constraint: int = 2):
|
|
self.rng = np.random.default_rng(rand_seed)
|
|
self.n_people = n_people
|
|
self.depth_constraint = depth_constraint
|
|
self.width_constraint = width_constraint
|
|
|
|
def sample(self):
|
|
statements = tuple(
|
|
self._sample_statement(person_id, self.depth_constraint) for person_id in range(self.n_people)
|
|
)
|
|
return self._immutable_statements(statements)
|
|
|
|
def sample_valid_problems(
|
|
self,
|
|
n_problems: int,
|
|
max_retry: int = 1000,
|
|
skip_no_solution: bool = True,
|
|
skip_multiple_solutions: bool = True,
|
|
):
|
|
problems = []
|
|
unique_statements = set()
|
|
for _ in range(n_problems):
|
|
for _ in range(max_retry):
|
|
statements = self.sample()
|
|
if statements in unique_statements:
|
|
continue
|
|
solutions = KnightsKnavesDataset.find_solution(statements)
|
|
if len(solutions) == 0 and skip_no_solution:
|
|
continue
|
|
if len(solutions) > 1 and skip_multiple_solutions:
|
|
continue
|
|
sol = solutions[0] if len(solutions) > 0 else None
|
|
problems.append({"statements": statements, "solution": sol, "all_solutions": solutions})
|
|
unique_statements.add(statements)
|
|
break
|
|
return problems
|
|
|
|
def _sample_statement(self, person_id: int, depth_constraint: int):
|
|
dice = self.rng.integers(0, 6)
|
|
if depth_constraint == 1 or dice == 0:
|
|
while True:
|
|
knight_or_knave = self.rng.choice(["telling-truth", "lying"])
|
|
person = self.rng.integers(0, self.n_people)
|
|
# prevent the contradiction "I am lying"
|
|
if not (knight_or_knave == "lying" and person == person_id):
|
|
return (knight_or_knave, person)
|
|
if dice == 1:
|
|
return ("not", self._sample_statement(person_id, depth_constraint - 1))
|
|
if dice in [2, 3]:
|
|
operator = ["and", "or"][dice - 2]
|
|
n_substatements = self.rng.integers(2, self.width_constraint + 1)
|
|
return (operator,) + self._sample_substatements(person_id, depth_constraint, n_substatements)
|
|
if dice in [4, 5]:
|
|
operator = ["->", "<=>"][dice - 4]
|
|
return (operator,) + self._sample_substatements(person_id, depth_constraint, 2)
|
|
|
|
def _sample_substatements(self, person_id: int, depth_constraint: int, count: int, dedup: bool = True):
|
|
sub_statements = []
|
|
dedup_set = set()
|
|
while True:
|
|
stmt = self._sample_statement(person_id, depth_constraint - 1)
|
|
if dedup:
|
|
if stmt in dedup_set:
|
|
continue
|
|
dedup_set.add(stmt)
|
|
sub_statements.append(stmt)
|
|
if len(sub_statements) == count:
|
|
break
|
|
return tuple(sub_statements)
|
|
|
|
def _immutable_statements(self, mutable_statements):
|
|
def _make_immutable(x):
|
|
if isinstance(x, (list, tuple)):
|
|
return tuple(_make_immutable(child) for child in x)
|
|
if isinstance(x, np.str_):
|
|
return str(x)
|
|
if isinstance(x, np.int64):
|
|
return int(x)
|
|
return x
|
|
|
|
return tuple(_make_immutable(s) for s in mutable_statements)
|
|
|
|
|
|
class KKProblemFormatter:
|
|
def __init__(self, rand_seed, problem):
|
|
self.rng = np.random.default_rng(rand_seed)
|
|
self.problem = problem
|
|
|
|
def format_problem(self):
|
|
statements = copy.deepcopy(self.problem["statements"])
|
|
n_people = len(statements)
|
|
names = list(self.rng.choice(COMMON_NAMES, size=n_people, replace=False))
|
|
knight_knave = self.rng.choice(KNIGHT_KNAVE_PAIRS)
|
|
knight_knave = {
|
|
"knight": knight_knave[0].split()[1],
|
|
"knave": knight_knave[1].split()[1],
|
|
"a_knight": knight_knave[0],
|
|
"a_knave": knight_knave[1],
|
|
}
|
|
knight_knave["Knight"] = knight_knave["knight"].capitalize()
|
|
knight_knave["Knave"] = knight_knave["knave"].capitalize()
|
|
|
|
text = PREFIX.format(**knight_knave)
|
|
text += f"You meet {n_people} inhabitants: "
|
|
text += ", ".join(names[:-1]) + ", and " + names[-1] + "."
|
|
|
|
text_statements = []
|
|
for i, stmt in enumerate(statements):
|
|
tmpl = self.rng.choice(TEMPLATES)
|
|
content = self._format_statement(names, knight_knave, stmt)
|
|
text_statements.append(" " + tmpl.format(name=names[i], content=content))
|
|
|
|
text += "".join(text_statements)
|
|
text += " " + POSTFIX.format(**knight_knave)
|
|
format = ", ".join(f"{name} is a {knight_knave['knight']}/{knight_knave['knave']}" for name in names[:-1])
|
|
if len(names) > 1:
|
|
format += f", and {names[-1]} is a {knight_knave['knight']}/{knight_knave['knave']}"
|
|
else:
|
|
format = f"{names[0]} is a {knight_knave['knight']}/{knight_knave['knave']}"
|
|
|
|
text += f' (Format your answer like: "{format}")'
|
|
|
|
if self.problem["solution"] is None:
|
|
solution_text = "No valid solution exists."
|
|
else:
|
|
solution_stmts = []
|
|
for name, indicator in zip(names, self.problem["solution"]):
|
|
if indicator:
|
|
solution_stmts.append(name + " is " + knight_knave["a_knight"])
|
|
else:
|
|
solution_stmts.append(name + " is " + knight_knave["a_knave"])
|
|
solution_text = ", ".join(solution_stmts[:-1]) + ", and " + solution_stmts[-1] + "."
|
|
return {
|
|
"quiz": text,
|
|
"names": names,
|
|
"knight_knave": knight_knave,
|
|
"solution": self.problem["solution"],
|
|
"solution_text": solution_text,
|
|
}
|
|
|
|
def _format_statement(self, names, knight_knave, statement, depth=0):
|
|
"""
|
|
Recursively format a logical statement with appropriate parentheses based on depth.
|
|
|
|
Args:
|
|
names: List of people's names
|
|
knight_knave: Dictionary with knight/knave terminology
|
|
statement: Logical statement tuple to format
|
|
depth: Current nesting depth (0 = top level)
|
|
"""
|
|
# Base case: this is a primitive statement
|
|
if statement[0] in ("telling-truth", "lying"):
|
|
return self._format_knight_knave(names, knight_knave, statement)
|
|
|
|
# Handle negation
|
|
if statement[0] == "not":
|
|
|
|
# Special case: If negating a primitive statement, use the complementary term directly
|
|
if statement[1][0] in ("telling-truth", "lying"):
|
|
# Map "telling-truth" to "lying" and vice versa
|
|
complementary_statement = (
|
|
"lying" if statement[1][0] == "telling-truth" else "telling-truth",
|
|
statement[1][1],
|
|
)
|
|
return self._format_knight_knave(names, knight_knave, complementary_statement)
|
|
else:
|
|
# For complex statements, use the verbose "it is not the case that" format
|
|
inner_content = self._format_statement(names, knight_knave, statement[1], depth + 1)
|
|
if statement[1][0] not in ("telling-truth", "lying"):
|
|
inner_content = f"({inner_content})"
|
|
return f"it is not the case that {inner_content}"
|
|
|
|
# Handle AND/OR
|
|
if statement[0] in ["and", "or"]:
|
|
formatted_substmts = []
|
|
for sub_stmt in statement[1:]:
|
|
sub_content = self._format_statement(names, knight_knave, sub_stmt, depth + 1)
|
|
# Always add parentheses for complex subexpressions in AND/OR
|
|
if sub_stmt[0] not in ("telling-truth", "lying"):
|
|
sub_content = f"({sub_content})"
|
|
formatted_substmts.append(sub_content)
|
|
connector = f" {statement[0]} "
|
|
return connector.join(formatted_substmts)
|
|
|
|
# Handle implication
|
|
if statement[0] == "->":
|
|
antecedent = self._format_statement(names, knight_knave, statement[1], depth + 1)
|
|
consequent = self._format_statement(names, knight_knave, statement[2], depth + 1)
|
|
|
|
# Always add parentheses for complex expressions in implications
|
|
if statement[1][0] not in ("telling-truth", "lying"):
|
|
antecedent = f"({antecedent})"
|
|
if statement[2][0] not in ("telling-truth", "lying"):
|
|
consequent = f"({consequent})"
|
|
|
|
return f"if {antecedent} then {consequent}"
|
|
|
|
# Handle biconditional
|
|
if statement[0] == "<=>":
|
|
left = self._format_statement(names, knight_knave, statement[1], depth + 1)
|
|
right = self._format_statement(names, knight_knave, statement[2], depth + 1)
|
|
|
|
# Always add parentheses for complex expressions in biconditionals
|
|
if statement[1][0] not in ("telling-truth", "lying"):
|
|
left = f"({left})"
|
|
if statement[2][0] not in ("telling-truth", "lying"):
|
|
right = f"({right})"
|
|
|
|
return f"{left} if and only if {right}"
|
|
|
|
# This should not happen with well-formed statements
|
|
raise ValueError(f"Unknown statement type: {statement[0]}")
|
|
|
|
def _format_knight_knave(self, names, knight_knave, statement, negation=False):
|
|
assert statement[0] in ("telling-truth", "lying")
|
|
text = names[statement[1]] + " is "
|
|
if negation:
|
|
text += "not "
|
|
text += {"telling-truth": knight_knave["a_knight"], "lying": knight_knave["a_knave"]}[statement[0]]
|
|
return text
|
|
|
|
|
|
class KnightsKnavesDataset(ProceduralDataset):
|
|
"""
|
|
Generates random knights and knaves problems.
|
|
|
|
This implementation is adapted from the Knights and Knaves problem generator in:
|
|
https://github.com/AlphaPav/mem-kk-logic
|
|
|
|
As described in the paper:
|
|
@article{xie2024memorization,
|
|
title={On Memorization of Large Language Models in Logical Reasoning},
|
|
author={Chulin Xie and Yangsibo Huang and Chiyuan Zhang and Da Yu and Xinyun Chen and Bill Yuchen Lin and Bo Li and Badih Ghazi and Ravi Kumar},
|
|
year={2024},
|
|
eprint={2410.23123},
|
|
archivePrefix={arXiv},
|
|
primaryClass={cs.CL},
|
|
url={https://arxiv.org/abs/2410.23123},
|
|
}
|
|
"""
|
|
|
|
def __init__(self, config: KnightsKnavesConfig):
|
|
super().__init__(config, seed=config.seed, size=config.size)
|
|
|
|
@staticmethod
|
|
def find_solution(statements):
|
|
"""Find solutions given a list of statements."""
|
|
n_people = len(statements)
|
|
single_statement = ("and",) + tuple(
|
|
("<=>", ("telling-truth", i), statements[i]) for i in range(len(statements))
|
|
)
|
|
# Brute force
|
|
solutions = []
|
|
for assignments in itertools.product([True, False], repeat=n_people):
|
|
# if KnightsKnavesDataset.test_satisfiability(single_statement, assignments):
|
|
if KnightsKnavesDataset.test_satisfiability(single_statement, assignments):
|
|
solutions.append(assignments)
|
|
return solutions
|
|
|
|
@staticmethod
|
|
def test_satisfiability(statement, assignments):
|
|
"""Recursively test if a statement is satisfied under given assignments."""
|
|
if statement[0] == "telling-truth":
|
|
return assignments[statement[1]]
|
|
if statement[0] == "lying":
|
|
return not assignments[statement[1]]
|
|
if statement[0] == "not":
|
|
return not KnightsKnavesDataset.test_satisfiability(statement[1], assignments)
|
|
if statement[0] == "and":
|
|
return np.all(
|
|
[KnightsKnavesDataset.test_satisfiability(statement[i], assignments) for i in range(1, len(statement))]
|
|
)
|
|
if statement[0] == "or":
|
|
return np.any(
|
|
[KnightsKnavesDataset.test_satisfiability(statement[i], assignments) for i in range(1, len(statement))]
|
|
)
|
|
if statement[0] == "->":
|
|
val1 = KnightsKnavesDataset.test_satisfiability(statement[1], assignments)
|
|
val2 = KnightsKnavesDataset.test_satisfiability(statement[2], assignments)
|
|
return (not val1) or val2
|
|
if statement[0] == "<=>":
|
|
val1 = KnightsKnavesDataset.test_satisfiability(statement[1], assignments)
|
|
val2 = KnightsKnavesDataset.test_satisfiability(statement[2], assignments)
|
|
return (val1 and val2) or ((not val1) and (not val2))
|
|
raise KeyError(f"Unknown statement: {statement}")
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
"""
|
|
Generate a single knights and knaves problem item.
|
|
|
|
Args:
|
|
idx: Index of the item to generate
|
|
|
|
Returns:
|
|
dict containing at least:
|
|
- question: str (the puzzle in natural language)
|
|
- answer: str (the solution in text)
|
|
- metadata: dict (additional problem details)
|
|
"""
|
|
rng = Random(self.seed + idx if self.seed is not None else None)
|
|
return self.__generate_problem(rng)
|
|
|
|
def __generate_problem(self, rng: Random) -> dict[str, Any]:
|
|
"""
|
|
Generate a single knights and knaves problem with a unique solution.
|
|
"""
|
|
|
|
# Sample a valid problem using the original KKProblemSampler logic
|
|
# Use the sampler to generate a valid problem
|
|
sampler = KKProblemSampler(
|
|
rand_seed=rng.randint(0, 2**32),
|
|
n_people=self.config.n_people,
|
|
depth_constraint=self.config.depth_constraint,
|
|
width_constraint=self.config.width_constraint,
|
|
)
|
|
problems = sampler.sample_valid_problems(1, skip_no_solution=True, skip_multiple_solutions=True)
|
|
problem = problems[0]
|
|
|
|
# Format the problem
|
|
formatter = KKProblemFormatter(rand_seed=rng.randint(0, 2**32), problem=problem)
|
|
formatted = formatter.format_problem()
|
|
|
|
# Prepare the return dictionary
|
|
question = formatted["quiz"]
|
|
answer = formatted["solution_text"]
|
|
metadata = {
|
|
"statements": problem["statements"],
|
|
"solution": problem["solution"],
|
|
"names": formatted["names"],
|
|
"knight_knave_terms": formatted["knight_knave"],
|
|
}
|
|
|
|
return {"question": question, "answer": answer, "metadata": metadata}
|
|
|
|
@staticmethod
|
|
def _normalize_answer(answer: str) -> set[tuple[str, str]]:
|
|
"""Convert answer string into normalized set of (name, role) tuples"""
|
|
# Remove common punctuation and standardize spacing
|
|
answer = answer.lower().strip().replace(".", " ").replace(",", " ").replace(")", " ").replace("(", " ")
|
|
|
|
# Split on 'and' or spaces for different formats
|
|
parts = [p.strip() for p in answer.replace(" and ", " ").split()]
|
|
|
|
# Extract name-role pairs
|
|
assignments = set()
|
|
current_name = None
|
|
|
|
for part in parts:
|
|
if part in ["is", "a", "an"]:
|
|
continue
|
|
if part in VALID_ROLES:
|
|
if current_name:
|
|
assignments.add((current_name, part))
|
|
current_name = None
|
|
else:
|
|
current_name = part
|
|
|
|
return assignments
|
|
|
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
|
"""Score an answer against the oracle answer."""
|
|
if not isinstance(answer, str) or len(answer) == 0:
|
|
return 0.0
|
|
|
|
try:
|
|
oracle_assignments = self._normalize_answer(entry["answer"])
|
|
answer_assignments = self._normalize_answer(answer)
|
|
|
|
# Full credit for exact assignments regardless of order
|
|
if oracle_assignments == answer_assignments:
|
|
return 1.0
|
|
|
|
# Partial credit if all names are present and some assignments match
|
|
if len(oracle_assignments) == len(answer_assignments):
|
|
matching = len(oracle_assignments.intersection(answer_assignments))
|
|
if matching > 0:
|
|
return 0.3 + (0.7 * matching / len(oracle_assignments))
|
|
|
|
except Exception:
|
|
pass
|
|
return 0.0
|
|
|
|
|
|
register_dataset("knights_knaves", KnightsKnavesDataset, KnightsKnavesConfig)
|