mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Add 13 new procedural datasets across 7 categories
New dataset categories: combinatorics, statistics, optimization, and formal languages. Extended existing algebra, arithmetic, probability, logic, and graphs packages with complex_advanced, linear_algebra, limits, number_theory, conditional_probability, set_operations, and job_scheduling. Each dataset includes config validation, deterministic seeding, custom scoring, curriculum support, and comprehensive unit tests (92 new tests).
This commit is contained in:
parent
49b07130b3
commit
6eb252ae32
36 changed files with 3705 additions and 1 deletions
|
|
@ -9,13 +9,17 @@ from . import (
|
|||
arithmetic,
|
||||
code,
|
||||
cognition,
|
||||
combinatorics,
|
||||
data,
|
||||
games,
|
||||
geometry,
|
||||
graphs,
|
||||
induction,
|
||||
languages,
|
||||
logic,
|
||||
optimization,
|
||||
probability,
|
||||
statistics,
|
||||
)
|
||||
from .factory import create_dataset, get_score_answer_fn, register_dataset
|
||||
from .scoring import cascade_score, float_match, math_match, string_match, strip_latex
|
||||
|
|
@ -28,13 +32,17 @@ __all__ = [
|
|||
"arithmetic",
|
||||
"code",
|
||||
"cognition",
|
||||
"combinatorics",
|
||||
"data",
|
||||
"games",
|
||||
"geometry",
|
||||
"graphs",
|
||||
"languages",
|
||||
"logic",
|
||||
"induction",
|
||||
"optimization",
|
||||
"probability",
|
||||
"statistics",
|
||||
"create_dataset",
|
||||
"register_dataset",
|
||||
"get_score_answer_fn",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from .complex_advanced import ComplexAdvancedConfig, ComplexAdvancedCurriculum, ComplexAdvancedDataset
|
||||
from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticCurriculum, ComplexArithmeticDataset
|
||||
from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset
|
||||
from .limits import LimitsConfig, LimitsCurriculum, LimitsDataset
|
||||
from .linear_algebra import LinearAlgebraConfig, LinearAlgebraCurriculum, LinearAlgebraDataset
|
||||
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsCurriculum, PolynomialEquationsDataset
|
||||
from .polynomial_multiplication import (
|
||||
PolynomialMultiplicationConfig,
|
||||
|
|
@ -10,6 +13,9 @@ from .simple_equations import SimpleEquationsConfig, SimpleEquationsCurriculum,
|
|||
from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationCurriculum, SimpleIntegrationDataset
|
||||
|
||||
__all__ = [
|
||||
"ComplexAdvancedConfig",
|
||||
"ComplexAdvancedDataset",
|
||||
"ComplexAdvancedCurriculum",
|
||||
"ComplexArithmeticConfig",
|
||||
"ComplexArithmeticDataset",
|
||||
"ComplexArithmeticCurriculum",
|
||||
|
|
@ -27,4 +33,10 @@ __all__ = [
|
|||
"PolynomialMultiplicationConfig",
|
||||
"PolynomialMultiplicationDataset",
|
||||
"PolynomialMultiplicationCurriculum",
|
||||
"LinearAlgebraConfig",
|
||||
"LinearAlgebraDataset",
|
||||
"LinearAlgebraCurriculum",
|
||||
"LimitsConfig",
|
||||
"LimitsDataset",
|
||||
"LimitsCurriculum",
|
||||
]
|
||||
|
|
|
|||
287
reasoning_gym/algebra/complex_advanced.py
Normal file
287
reasoning_gym/algebra/complex_advanced.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
import cmath
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "complex_advanced"
|
||||
|
||||
TASK_TYPES = ("polar", "euler", "inverse", "sqrt", "quadratic")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplexAdvancedConfig:
|
||||
min_real: int = 1
|
||||
max_real: int = 10
|
||||
min_imag: int = 1
|
||||
max_imag: int = 10
|
||||
decimal_places: int = 4
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.2, 0.2, 0.2, 0.2, 0.2])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.max_real >= self.min_real, "max_real must be >= min_real"
|
||||
assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag"
|
||||
assert self.min_real >= 1, "min_real must be >= 1"
|
||||
assert self.min_imag >= 1, "min_imag must be >= 1"
|
||||
assert self.decimal_places >= 1, "decimal_places must be >= 1"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type, must be in {TASK_TYPES}"
|
||||
assert len(self.task_weights) == len(self.task_types), "task_weights must match task_types length"
|
||||
assert self.size > 0, "size must be positive"
|
||||
|
||||
|
||||
def _fmt(val: float, dp: int) -> str:
|
||||
return f"{val:.{dp}f}"
|
||||
|
||||
|
||||
def _fmt_complex(z: complex, dp: int) -> str:
|
||||
r, i = round(z.real, dp), round(z.imag, dp)
|
||||
if abs(i) < 10 ** (-(dp + 1)):
|
||||
return _fmt(r, dp)
|
||||
if abs(r) < 10 ** (-(dp + 1)):
|
||||
return f"{_fmt(i, dp)}i"
|
||||
sign = "+" if i >= 0 else "-"
|
||||
return f"{_fmt(r, dp)} {sign} {_fmt(abs(i), dp)}i"
|
||||
|
||||
|
||||
class ComplexAdvancedDataset(ProceduralDataset):
|
||||
def __init__(self, config: ComplexAdvancedConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _rand_complex(self, rng: random.Random, allow_neg: bool = True) -> complex:
|
||||
r = rng.randint(self.config.min_real, self.config.max_real)
|
||||
i = rng.randint(self.config.min_imag, self.config.max_imag)
|
||||
if allow_neg:
|
||||
r *= rng.choice([1, -1])
|
||||
i *= rng.choice([1, -1])
|
||||
return complex(r, i)
|
||||
|
||||
def _make_polar(self, rng: random.Random) -> dict:
|
||||
z = self._rand_complex(rng)
|
||||
dp = self.config.decimal_places
|
||||
r, theta = cmath.polar(z)
|
||||
answer = f"modulus={_fmt(r, dp)}, argument={_fmt(theta, dp)}"
|
||||
a, b = int(z.real), int(z.imag)
|
||||
sign = "+" if b >= 0 else "-"
|
||||
question = (
|
||||
f"Convert the complex number {a} {sign} {abs(b)}i to polar form. "
|
||||
f"Give the modulus and argument (in radians), both rounded to {dp} decimal places. "
|
||||
f"Format: modulus=<value>, argument=<value>"
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "polar", "z": (a, b)}
|
||||
|
||||
def _make_euler(self, rng: random.Random) -> dict:
|
||||
z = self._rand_complex(rng)
|
||||
dp = self.config.decimal_places
|
||||
r, theta = cmath.polar(z)
|
||||
rect = cmath.rect(r, theta)
|
||||
answer = _fmt_complex(rect, dp)
|
||||
question = (
|
||||
f"Express {_fmt(r, dp)}(cos({_fmt(theta, dp)}) + i*sin({_fmt(theta, dp)})) "
|
||||
f"in rectangular form a + bi, rounded to {dp} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "euler", "r": r, "theta": theta}
|
||||
|
||||
def _make_inverse(self, rng: random.Random) -> dict:
|
||||
z = self._rand_complex(rng)
|
||||
dp = self.config.decimal_places
|
||||
inv = 1.0 / z
|
||||
answer = _fmt_complex(inv, dp)
|
||||
a, b = int(z.real), int(z.imag)
|
||||
sign = "+" if b >= 0 else "-"
|
||||
question = (
|
||||
f"Find the multiplicative inverse of {a} {sign} {abs(b)}i. "
|
||||
f"Express your answer in the form a + bi, rounded to {dp} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "inverse", "z": (a, b)}
|
||||
|
||||
def _make_sqrt(self, rng: random.Random) -> dict:
|
||||
w = self._rand_complex(rng)
|
||||
z = w * w
|
||||
dp = self.config.decimal_places
|
||||
root1 = _fmt_complex(w, dp)
|
||||
root2 = _fmt_complex(-w, dp)
|
||||
answer = f"{root1}, {root2}"
|
||||
zr, zi = round(z.real, dp), round(z.imag, dp)
|
||||
sign = "+" if zi >= 0 else "-"
|
||||
question = (
|
||||
f"Find the two square roots of {_fmt(zr, dp)} {sign} {_fmt(abs(zi), dp)}i. "
|
||||
f"Give both roots rounded to {dp} decimal places, separated by a comma."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "sqrt", "w": (int(w.real), int(w.imag))}
|
||||
|
||||
def _make_quadratic(self, rng: random.Random) -> dict:
|
||||
dp = self.config.decimal_places
|
||||
use_complex = rng.choice([True, False])
|
||||
if use_complex:
|
||||
p = rng.randint(self.config.min_real, self.config.max_real) * rng.choice([1, -1])
|
||||
q = rng.randint(self.config.min_imag, self.config.max_imag)
|
||||
r1 = complex(p, q)
|
||||
r2 = complex(p, -q)
|
||||
else:
|
||||
r1 = complex(rng.randint(-self.config.max_real, self.config.max_real), 0)
|
||||
r2 = complex(rng.randint(-self.config.max_real, self.config.max_real), 0)
|
||||
|
||||
a_coeff = 1
|
||||
b_coeff = -a_coeff * (r1 + r2)
|
||||
c_coeff = a_coeff * r1 * r2
|
||||
b_int, c_int = round(b_coeff.real), round(c_coeff.real)
|
||||
|
||||
terms = [f"x^2"]
|
||||
if b_int > 0:
|
||||
terms.append(f"+ {b_int}x")
|
||||
elif b_int < 0:
|
||||
terms.append(f"- {abs(b_int)}x")
|
||||
if c_int > 0:
|
||||
terms.append(f"+ {c_int}")
|
||||
elif c_int < 0:
|
||||
terms.append(f"- {abs(c_int)}")
|
||||
eq_str = " ".join(terms)
|
||||
|
||||
ans1 = _fmt_complex(r1, dp)
|
||||
ans2 = _fmt_complex(r2, dp)
|
||||
answer = f"{ans1}, {ans2}"
|
||||
|
||||
question = (
|
||||
f"Solve the quadratic equation {eq_str} = 0. "
|
||||
f"Give both solutions rounded to {dp} decimal places, separated by a comma. "
|
||||
f"For complex solutions, use the form a + bi."
|
||||
)
|
||||
return {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"task_type": "quadratic",
|
||||
"roots": [(r1.real, r1.imag), (r2.real, r2.imag)],
|
||||
}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"polar": self._make_polar,
|
||||
"euler": self._make_euler,
|
||||
"inverse": self._make_inverse,
|
||||
"sqrt": self._make_sqrt,
|
||||
"quadratic": self._make_quadratic,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_real": self.config.min_real,
|
||||
"max_real": self.config.max_real,
|
||||
"min_imag": self.config.min_imag,
|
||||
"max_imag": self.config.max_imag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
try:
|
||||
if task_type == "polar":
|
||||
return self._score_polar(answer, oracle)
|
||||
elif task_type in ("sqrt", "quadratic"):
|
||||
return self._score_pair(answer, oracle)
|
||||
else:
|
||||
return self._score_single_complex(answer, oracle)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _score_polar(self, answer: str, oracle: str) -> float:
|
||||
def parse_polar(s: str) -> tuple[float, float]:
|
||||
parts = {}
|
||||
for part in s.split(","):
|
||||
k, v = part.split("=")
|
||||
parts[k.strip()] = float(v.strip())
|
||||
return parts["modulus"], parts["argument"]
|
||||
|
||||
am, aa = parse_polar(answer)
|
||||
om, oa = parse_polar(oracle)
|
||||
mod_err = abs(am - om)
|
||||
arg_err = abs(aa - oa)
|
||||
return min(1.0, math.exp(-(mod_err + arg_err)))
|
||||
|
||||
def _score_single_complex(self, answer: str, oracle: str) -> float:
|
||||
az = self._parse_complex(answer)
|
||||
oz = self._parse_complex(oracle)
|
||||
if az is None or oz is None:
|
||||
return 0.0
|
||||
return min(1.0, math.exp(-abs(az - oz)))
|
||||
|
||||
def _score_pair(self, answer: str, oracle: str) -> float:
|
||||
a_parts = [s.strip() for s in answer.split(",")]
|
||||
o_parts = [s.strip() for s in oracle.split(",")]
|
||||
if len(a_parts) < 2 or len(o_parts) < 2:
|
||||
return 0.0
|
||||
a_vals = [self._parse_complex(a_parts[0] + ("" if "i" in a_parts[0] else "") ),
|
||||
self._parse_complex(a_parts[1] + ("" if "i" in a_parts[1] else ""))]
|
||||
o_vals = [self._parse_complex(o_parts[0]), self._parse_complex(o_parts[1])]
|
||||
if any(v is None for v in a_vals + o_vals):
|
||||
return 0.0
|
||||
d1 = min(abs(a_vals[0] - o_vals[0]) + abs(a_vals[1] - o_vals[1]),
|
||||
abs(a_vals[0] - o_vals[1]) + abs(a_vals[1] - o_vals[0]))
|
||||
return min(1.0, math.exp(-d1))
|
||||
|
||||
@staticmethod
|
||||
def _parse_complex(s: str) -> Optional[complex]:
|
||||
try:
|
||||
s = s.strip().replace(" ", "").replace("i", "j")
|
||||
if "j" not in s:
|
||||
return complex(float(s), 0)
|
||||
if s == "j":
|
||||
return 1j
|
||||
if s == "-j":
|
||||
return -1j
|
||||
if s.endswith("j") and "+" not in s[1:] and "-" not in s[1:]:
|
||||
coef = s[:-1] or "1"
|
||||
if coef == "-":
|
||||
coef = "-1"
|
||||
return complex(0, float(coef))
|
||||
if "+j" in s:
|
||||
s = s.replace("+j", "+1j")
|
||||
if "-j" in s:
|
||||
s = s.replace("-j", "-1j")
|
||||
return complex(s)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
class ComplexAdvancedCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(ComplexAdvancedCurriculum.__name__, ComplexAdvancedConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="max_real",
|
||||
field_name="max_real",
|
||||
levels=[5, 10, 50, 100],
|
||||
description="Maximum real part magnitude",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_imag",
|
||||
field_name="max_imag",
|
||||
levels=[5, 10, 50, 100],
|
||||
description="Maximum imaginary part magnitude",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, ComplexAdvancedDataset, ComplexAdvancedConfig, ComplexAdvancedCurriculum)
|
||||
206
reasoning_gym/algebra/limits.py
Normal file
206
reasoning_gym/algebra/limits.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "limits"
|
||||
|
||||
TASK_TYPES = ("polynomial_cancel", "rational_infinity", "direct_sub", "squeeze")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LimitsConfig:
|
||||
max_coeff: int = 10
|
||||
max_degree: int = 3
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.3, 0.3, 0.2, 0.2])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.max_coeff >= 1, "max_coeff must be >= 1"
|
||||
assert self.max_degree >= 1, "max_degree must be >= 1"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
def _poly_str(coeffs: list[int], var: str = "x") -> str:
|
||||
"""coeffs[i] is coefficient of x^i"""
|
||||
parts = []
|
||||
for i in range(len(coeffs) - 1, -1, -1):
|
||||
c = coeffs[i]
|
||||
if c == 0:
|
||||
continue
|
||||
if i == 0:
|
||||
parts.append(str(c))
|
||||
elif i == 1:
|
||||
if c == 1:
|
||||
parts.append(var)
|
||||
elif c == -1:
|
||||
parts.append(f"-{var}")
|
||||
else:
|
||||
parts.append(f"{c}*{var}")
|
||||
else:
|
||||
if c == 1:
|
||||
parts.append(f"{var}^{i}")
|
||||
elif c == -1:
|
||||
parts.append(f"-{var}^{i}")
|
||||
else:
|
||||
parts.append(f"{c}*{var}^{i}")
|
||||
if not parts:
|
||||
return "0"
|
||||
result = parts[0]
|
||||
for p in parts[1:]:
|
||||
if p.startswith("-"):
|
||||
result += " - " + p[1:]
|
||||
else:
|
||||
result += " + " + p
|
||||
return result
|
||||
|
||||
|
||||
class LimitsDataset(ProceduralDataset):
|
||||
def __init__(self, config: LimitsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_polynomial_cancel(self, rng: random.Random) -> dict:
|
||||
a = rng.randint(-self.config.max_coeff, self.config.max_coeff)
|
||||
if a == 0:
|
||||
a = 1
|
||||
b = rng.randint(1, self.config.max_coeff)
|
||||
c = rng.randint(1, self.config.max_coeff) * rng.choice([1, -1])
|
||||
|
||||
num_val = b * a + c
|
||||
answer_frac = Fraction(num_val, 1)
|
||||
answer = str(answer_frac)
|
||||
|
||||
num_expr = f"{b}*x" if b != 1 else "x"
|
||||
if c > 0:
|
||||
num_expr += f" + {c}"
|
||||
elif c < 0:
|
||||
num_expr += f" - {abs(c)}"
|
||||
|
||||
denom_expr = f"x - {a}" if a >= 0 else f"x + {abs(a)}"
|
||||
full_num = f"({num_expr}) * ({denom_expr})"
|
||||
|
||||
question = (
|
||||
f"Find the limit as x approaches {a} of {full_num} / ({denom_expr}). "
|
||||
f"Give your answer as an integer or simplified fraction."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "polynomial_cancel"}
|
||||
|
||||
def _make_rational_infinity(self, rng: random.Random) -> dict:
|
||||
deg = rng.randint(1, self.config.max_degree)
|
||||
num_lead = rng.randint(1, self.config.max_coeff) * rng.choice([1, -1])
|
||||
den_lead = rng.randint(1, self.config.max_coeff) * rng.choice([1, -1])
|
||||
|
||||
num_coeffs = [rng.randint(-self.config.max_coeff, self.config.max_coeff) for _ in range(deg)]
|
||||
num_coeffs.append(num_lead)
|
||||
den_coeffs = [rng.randint(-self.config.max_coeff, self.config.max_coeff) for _ in range(deg)]
|
||||
den_coeffs.append(den_lead)
|
||||
|
||||
answer_frac = Fraction(num_lead, den_lead)
|
||||
answer = str(answer_frac)
|
||||
|
||||
num_str = _poly_str(num_coeffs)
|
||||
den_str = _poly_str(den_coeffs)
|
||||
question = (
|
||||
f"Find the limit as x approaches infinity of ({num_str}) / ({den_str}). "
|
||||
f"Give your answer as an integer or simplified fraction."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "rational_infinity"}
|
||||
|
||||
def _make_direct_sub(self, rng: random.Random) -> dict:
|
||||
a = rng.randint(1, 5)
|
||||
deg = rng.randint(1, self.config.max_degree)
|
||||
coeffs = [rng.randint(-self.config.max_coeff, self.config.max_coeff) for _ in range(deg + 1)]
|
||||
if coeffs[-1] == 0:
|
||||
coeffs[-1] = 1
|
||||
|
||||
val = sum(coeffs[i] * (a ** i) for i in range(len(coeffs)))
|
||||
answer = str(val)
|
||||
poly = _poly_str(coeffs)
|
||||
question = (
|
||||
f"Find the limit as x approaches {a} of ({poly}). "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "direct_sub"}
|
||||
|
||||
def _make_squeeze(self, rng: random.Random) -> dict:
|
||||
L = rng.randint(-5, 5)
|
||||
a = rng.randint(-3, 3)
|
||||
k = rng.randint(1, 3)
|
||||
|
||||
question = (
|
||||
f"Suppose that for all x near {a}, we have:\n"
|
||||
f" {L} - (x - {a})^{2 * k} ≤ f(x) ≤ {L} + (x - {a})^{2 * k}\n"
|
||||
f"Find the limit of f(x) as x approaches {a}. Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(L), "task_type": "squeeze"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"polynomial_cancel": self._make_polynomial_cancel,
|
||||
"rational_infinity": self._make_rational_infinity,
|
||||
"direct_sub": self._make_direct_sub,
|
||||
"squeeze": self._make_squeeze,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"max_coeff": self.config.max_coeff,
|
||||
"max_degree": self.config.max_degree,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
try:
|
||||
a_frac = Fraction(answer.strip())
|
||||
o_frac = Fraction(oracle.strip())
|
||||
return 1.0 if a_frac == o_frac else 0.0
|
||||
except (ValueError, ZeroDivisionError):
|
||||
try:
|
||||
return 1.0 if float(answer.strip()) == float(oracle.strip()) else 0.0
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
|
||||
class LimitsCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(LimitsCurriculum.__name__, LimitsConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="max_coeff",
|
||||
field_name="max_coeff",
|
||||
levels=[5, 10, 20, 50],
|
||||
description="Maximum coefficient magnitude",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_degree",
|
||||
field_name="max_degree",
|
||||
levels=[1, 2, 3, 4],
|
||||
description="Maximum polynomial degree",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, LimitsDataset, LimitsConfig, LimitsCurriculum)
|
||||
267
reasoning_gym/algebra/linear_algebra.py
Normal file
267
reasoning_gym/algebra/linear_algebra.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
import json
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "linear_algebra"
|
||||
|
||||
TASK_TYPES = ("matrix_multiply", "determinant", "inverse", "solve_system", "eigenvalues")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearAlgebraConfig:
|
||||
min_dim: int = 2
|
||||
max_dim: int = 3
|
||||
min_value: int = -5
|
||||
max_value: int = 5
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.25, 0.2, 0.2, 0.2, 0.15])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_dim >= 2, "min_dim must be >= 2"
|
||||
assert self.max_dim >= self.min_dim, "max_dim must be >= min_dim"
|
||||
assert self.max_dim <= 4, "max_dim must be <= 4"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
def _mat_str(m: list[list[int]]) -> str:
|
||||
rows = ["[" + ", ".join(str(x) for x in row) + "]" for row in m]
|
||||
return "[" + ", ".join(rows) + "]"
|
||||
|
||||
|
||||
def _mat_mult(a: list[list[int]], b: list[list[int]]) -> list[list[int]]:
|
||||
n, m, p = len(a), len(a[0]), len(b[0])
|
||||
result = [[0] * p for _ in range(n)]
|
||||
for i in range(n):
|
||||
for j in range(p):
|
||||
for k in range(m):
|
||||
result[i][j] += a[i][k] * b[k][j]
|
||||
return result
|
||||
|
||||
|
||||
def _det(m: list[list[int]]) -> int:
|
||||
n = len(m)
|
||||
if n == 1:
|
||||
return m[0][0]
|
||||
if n == 2:
|
||||
return m[0][0] * m[1][1] - m[0][1] * m[1][0]
|
||||
result = 0
|
||||
for j in range(n):
|
||||
sub = [[m[i][k] for k in range(n) if k != j] for i in range(1, n)]
|
||||
result += ((-1) ** j) * m[0][j] * _det(sub)
|
||||
return result
|
||||
|
||||
|
||||
def _adjugate_2x2(m: list[list[int]]) -> list[list[int]]:
|
||||
return [[m[1][1], -m[0][1]], [-m[1][0], m[0][0]]]
|
||||
|
||||
|
||||
class LinearAlgebraDataset(ProceduralDataset):
|
||||
def __init__(self, config: LinearAlgebraConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _gen_matrix(self, rng: random.Random, rows: int, cols: int) -> list[list[int]]:
|
||||
return [
|
||||
[rng.randint(self.config.min_value, self.config.max_value) for _ in range(cols)]
|
||||
for _ in range(rows)
|
||||
]
|
||||
|
||||
def _make_matrix_multiply(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_dim, self.config.max_dim)
|
||||
m = rng.randint(self.config.min_dim, self.config.max_dim)
|
||||
p = rng.randint(self.config.min_dim, self.config.max_dim)
|
||||
a = self._gen_matrix(rng, n, m)
|
||||
b = self._gen_matrix(rng, m, p)
|
||||
result = _mat_mult(a, b)
|
||||
question = (
|
||||
f"Multiply the matrices A = {_mat_str(a)} and B = {_mat_str(b)}. "
|
||||
f"Give the result as a nested list, e.g. [[1, 2], [3, 4]]."
|
||||
)
|
||||
return {"question": question, "answer": _mat_str(result), "task_type": "matrix_multiply"}
|
||||
|
||||
def _make_determinant(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_dim, min(self.config.max_dim, 3))
|
||||
m = self._gen_matrix(rng, n, n)
|
||||
result = _det(m)
|
||||
question = (
|
||||
f"Find the determinant of the matrix {_mat_str(m)}. "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(result), "task_type": "determinant"}
|
||||
|
||||
def _make_inverse(self, rng: random.Random) -> dict:
|
||||
for _ in range(100):
|
||||
m = self._gen_matrix(rng, 2, 2)
|
||||
d = _det(m)
|
||||
if d != 0 and all(x % d == 0 for row in _adjugate_2x2(m) for x in row):
|
||||
adj = _adjugate_2x2(m)
|
||||
inv = [[x // d for x in row] for row in adj]
|
||||
question = (
|
||||
f"Find the inverse of the 2x2 matrix {_mat_str(m)}. "
|
||||
f"Give the result as a nested list of integers, e.g. [[1, 2], [3, 4]]."
|
||||
)
|
||||
return {"question": question, "answer": _mat_str(inv), "task_type": "inverse"}
|
||||
m = [[1, 0], [0, 1]]
|
||||
question = f"Find the inverse of the 2x2 matrix {_mat_str(m)}. Give the result as a nested list."
|
||||
return {"question": question, "answer": _mat_str(m), "task_type": "inverse"}
|
||||
|
||||
def _make_solve_system(self, rng: random.Random) -> dict:
|
||||
n = 2
|
||||
x_sol = [rng.randint(-5, 5) for _ in range(n)]
|
||||
for _ in range(100):
|
||||
a = self._gen_matrix(rng, n, n)
|
||||
d = _det(a)
|
||||
if d != 0:
|
||||
break
|
||||
else:
|
||||
a = [[1, 0], [0, 1]]
|
||||
|
||||
b = [sum(a[i][j] * x_sol[j] for j in range(n)) for i in range(n)]
|
||||
vars_ = ["x", "y"]
|
||||
eqs = []
|
||||
for i in range(n):
|
||||
parts = []
|
||||
for j in range(n):
|
||||
coef = a[i][j]
|
||||
if coef == 0:
|
||||
continue
|
||||
if coef == 1:
|
||||
parts.append(f"{vars_[j]}")
|
||||
elif coef == -1:
|
||||
parts.append(f"-{vars_[j]}")
|
||||
else:
|
||||
parts.append(f"{coef}{vars_[j]}")
|
||||
eq = " + ".join(parts).replace("+ -", "- ")
|
||||
eqs.append(f" {eq} = {b[i]}")
|
||||
eq_str = "\n".join(eqs)
|
||||
answer = ", ".join(f"{vars_[i]}={x_sol[i]}" for i in range(n))
|
||||
question = (
|
||||
f"Solve the following system of linear equations:\n{eq_str}\n"
|
||||
f"Give your answer in the format: x=<value>, y=<value>"
|
||||
)
|
||||
return {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"task_type": "solve_system",
|
||||
"matrix": a,
|
||||
"b": b,
|
||||
}
|
||||
|
||||
def _make_eigenvalues(self, rng: random.Random) -> dict:
|
||||
e1 = rng.randint(-5, 5)
|
||||
e2 = rng.randint(-5, 5)
|
||||
m = [[e1, 0], [0, e2]]
|
||||
p_det = 1
|
||||
for _ in range(3):
|
||||
shear = self._gen_matrix(rng, 2, 2)
|
||||
d = _det(shear)
|
||||
if abs(d) == 1:
|
||||
inv_d = d
|
||||
adj = _adjugate_2x2(shear)
|
||||
shear_inv = [[x * inv_d for x in row] for row in adj]
|
||||
temp = _mat_mult(shear, m)
|
||||
m = _mat_mult(temp, shear_inv)
|
||||
break
|
||||
|
||||
eigenvals = sorted([e1, e2])
|
||||
answer = ", ".join(str(e) for e in eigenvals)
|
||||
question = (
|
||||
f"Find the eigenvalues of the 2x2 matrix {_mat_str(m)}. "
|
||||
f"List them separated by commas in ascending order."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "eigenvalues"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"matrix_multiply": self._make_matrix_multiply,
|
||||
"determinant": self._make_determinant,
|
||||
"inverse": self._make_inverse,
|
||||
"solve_system": self._make_solve_system,
|
||||
"eigenvalues": self._make_eigenvalues,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_dim": self.config.min_dim,
|
||||
"max_dim": self.config.max_dim,
|
||||
},
|
||||
**({"matrix": result["matrix"], "b": result["b"]} if "matrix" in result else {}),
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
|
||||
if task_type == "solve_system":
|
||||
try:
|
||||
parts = {}
|
||||
for part in answer.strip().split(","):
|
||||
k, v = part.split("=")
|
||||
parts[k.strip()] = int(v.strip())
|
||||
x, y = parts["x"], parts["y"]
|
||||
mat = entry["metadata"]["matrix"]
|
||||
b = entry["metadata"]["b"]
|
||||
if mat[0][0] * x + mat[0][1] * y == b[0] and mat[1][0] * x + mat[1][1] * y == b[1]:
|
||||
return 1.0
|
||||
return 0.0
|
||||
except (ValueError, KeyError, TypeError):
|
||||
return 0.0
|
||||
|
||||
if task_type in ("determinant",):
|
||||
try:
|
||||
return 1.0 if int(answer.strip()) == int(oracle.strip()) else 0.0
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
if task_type == "eigenvalues":
|
||||
try:
|
||||
a_vals = sorted(int(x.strip()) for x in answer.split(","))
|
||||
o_vals = sorted(int(x.strip()) for x in oracle.split(","))
|
||||
return 1.0 if a_vals == o_vals else 0.0
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
a_mat = json.loads(answer.strip())
|
||||
o_mat = json.loads(oracle.strip())
|
||||
return 1.0 if a_mat == o_mat else 0.0
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
|
||||
class LinearAlgebraCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(LinearAlgebraCurriculum.__name__, LinearAlgebraConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="max_dim",
|
||||
field_name="max_dim",
|
||||
levels=[2, 3, 4],
|
||||
description="Maximum matrix dimension",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, LinearAlgebraDataset, LinearAlgebraConfig, LinearAlgebraCurriculum)
|
||||
|
|
@ -20,6 +20,7 @@ from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetCon
|
|||
from .lcm import LCMConfig, LCMCurriculum, LCMDataset
|
||||
from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset
|
||||
from .number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFormatDataset
|
||||
from .number_theory import NumberTheoryConfig, NumberTheoryCurriculum, NumberTheoryDataset
|
||||
from .power_function import PowerFunctionConfig, PowerFunctionCurriculum, PowerFunctionDataset
|
||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationCurriculum, PrimeFactorizationDataset
|
||||
from .products import ProductsConfig, ProductsDataset
|
||||
|
|
@ -77,4 +78,7 @@ __all__ = [
|
|||
"BitwiseArithmeticConfig",
|
||||
"BitwiseArithmeticDataset",
|
||||
"BitwiseArithmeticCurriculum",
|
||||
"NumberTheoryConfig",
|
||||
"NumberTheoryDataset",
|
||||
"NumberTheoryCurriculum",
|
||||
]
|
||||
|
|
|
|||
203
reasoning_gym/arithmetic/number_theory.py
Normal file
203
reasoning_gym/arithmetic/number_theory.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "number_theory"
|
||||
|
||||
TASK_TYPES = ("mod_arith", "mod_exp", "totient", "crt", "mod_inverse", "diophantine")
|
||||
|
||||
|
||||
def euler_totient(n: int) -> int:
|
||||
result = n
|
||||
p = 2
|
||||
temp = n
|
||||
while p * p <= temp:
|
||||
if temp % p == 0:
|
||||
while temp % p == 0:
|
||||
temp //= p
|
||||
result -= result // p
|
||||
p += 1
|
||||
if temp > 1:
|
||||
result -= result // temp
|
||||
return result
|
||||
|
||||
|
||||
def extended_gcd(a: int, b: int) -> tuple[int, int, int]:
|
||||
if a == 0:
|
||||
return b, 0, 1
|
||||
g, x1, y1 = extended_gcd(b % a, a)
|
||||
return g, y1 - (b // a) * x1, x1
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberTheoryConfig:
|
||||
min_value: int = 2
|
||||
max_value: int = 50
|
||||
max_exp: int = 20
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.2, 0.2, 0.15, 0.15, 0.15, 0.15])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_value >= 2, "min_value must be >= 2"
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value"
|
||||
assert self.max_exp >= 2, "max_exp must be >= 2"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
class NumberTheoryDataset(ProceduralDataset):
|
||||
def __init__(self, config: NumberTheoryConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_mod_arith(self, rng: random.Random) -> dict:
|
||||
a = rng.randint(self.config.min_value, self.config.max_value * 5)
|
||||
m = rng.randint(self.config.min_value, self.config.max_value)
|
||||
answer = a % m
|
||||
question = f"What is {a} mod {m}? Give your answer as a single integer."
|
||||
return {"question": question, "answer": str(answer), "task_type": "mod_arith"}
|
||||
|
||||
def _make_mod_exp(self, rng: random.Random) -> dict:
|
||||
base = rng.randint(2, self.config.max_value)
|
||||
exp = rng.randint(2, self.config.max_exp)
|
||||
mod = rng.randint(self.config.min_value, self.config.max_value)
|
||||
answer = pow(base, exp, mod)
|
||||
question = f"What is {base}^{exp} mod {mod}? Give your answer as a single integer."
|
||||
return {"question": question, "answer": str(answer), "task_type": "mod_exp"}
|
||||
|
||||
def _make_totient(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_value, self.config.max_value)
|
||||
answer = euler_totient(n)
|
||||
question = (
|
||||
f"Compute Euler's totient function φ({n}), i.e., the count of integers "
|
||||
f"from 1 to {n} that are coprime to {n}. Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "totient"}
|
||||
|
||||
def _make_crt(self, rng: random.Random) -> dict:
|
||||
m1 = rng.randint(2, 10)
|
||||
m2 = rng.randint(2, 10)
|
||||
while math.gcd(m1, m2) != 1:
|
||||
m2 = rng.randint(2, 10)
|
||||
r1 = rng.randint(0, m1 - 1)
|
||||
r2 = rng.randint(0, m2 - 1)
|
||||
|
||||
for x in range(m1 * m2):
|
||||
if x % m1 == r1 and x % m2 == r2:
|
||||
answer = x
|
||||
break
|
||||
|
||||
question = (
|
||||
f"Find the smallest non-negative integer x such that:\n"
|
||||
f" x ≡ {r1} (mod {m1})\n"
|
||||
f" x ≡ {r2} (mod {m2})\n"
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "crt"}
|
||||
|
||||
def _make_mod_inverse(self, rng: random.Random) -> dict:
|
||||
m = rng.randint(3, self.config.max_value)
|
||||
a = rng.randint(2, m - 1)
|
||||
while math.gcd(a, m) != 1:
|
||||
a = rng.randint(2, m - 1)
|
||||
answer = pow(a, -1, m)
|
||||
question = (
|
||||
f"Find the modular inverse of {a} modulo {m}, i.e., find x such that "
|
||||
f"{a} * x ≡ 1 (mod {m}). Give x as a single integer (0 ≤ x < {m})."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "mod_inverse"}
|
||||
|
||||
def _make_diophantine(self, rng: random.Random) -> dict:
|
||||
a = rng.randint(2, self.config.max_value)
|
||||
b = rng.randint(2, self.config.max_value)
|
||||
g = math.gcd(a, b)
|
||||
c = g * rng.randint(1, 5)
|
||||
|
||||
_, x0, y0 = extended_gcd(a, b)
|
||||
x0 *= c // g
|
||||
y0 *= c // g
|
||||
answer = f"x={x0}, y={y0}"
|
||||
question = (
|
||||
f"Find one integer solution (x, y) to the equation {a}x + {b}y = {c}. "
|
||||
f"Format your answer as: x=<value>, y=<value>"
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "diophantine", "a": a, "b": b, "c": c}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"mod_arith": self._make_mod_arith,
|
||||
"mod_exp": self._make_mod_exp,
|
||||
"totient": self._make_totient,
|
||||
"crt": self._make_crt,
|
||||
"mod_inverse": self._make_mod_inverse,
|
||||
"diophantine": self._make_diophantine,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_value": self.config.min_value,
|
||||
"max_value": self.config.max_value,
|
||||
},
|
||||
**({"a": result["a"], "b": result["b"], "c": result["c"]} if "a" in result else {}),
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
if task_type == "diophantine":
|
||||
try:
|
||||
parts = {}
|
||||
for part in answer.strip().split(","):
|
||||
k, v = part.split("=")
|
||||
parts[k.strip()] = int(v.strip())
|
||||
x, y = parts["x"], parts["y"]
|
||||
a = entry["metadata"]["a"]
|
||||
b = entry["metadata"]["b"]
|
||||
c = entry["metadata"]["c"]
|
||||
if a * x + b * y == c:
|
||||
return 1.0
|
||||
return 0.0
|
||||
except (ValueError, KeyError, TypeError):
|
||||
return 0.0
|
||||
try:
|
||||
return 1.0 if int(answer.strip()) == int(oracle.strip()) else 0.0
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
|
||||
class NumberTheoryCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(NumberTheoryCurriculum.__name__, NumberTheoryConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="value_range",
|
||||
levels=[10, 50, 100, 500],
|
||||
lower_field_name="min_value",
|
||||
upper_field_name="max_value",
|
||||
description="Range for numbers in problems",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, NumberTheoryDataset, NumberTheoryConfig, NumberTheoryCurriculum)
|
||||
7
reasoning_gym/combinatorics/__init__.py
Normal file
7
reasoning_gym/combinatorics/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""
|
||||
Combinatorics reasoning tasks.
|
||||
"""
|
||||
|
||||
from .combinatorics import CombinatoricsConfig, CombinatoricsCurriculum, CombinatoricsDataset
|
||||
|
||||
__all__ = ["CombinatoricsDataset", "CombinatoricsConfig", "CombinatoricsCurriculum"]
|
||||
164
reasoning_gym/combinatorics/combinatorics.py
Normal file
164
reasoning_gym/combinatorics/combinatorics.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "combinatorics"
|
||||
|
||||
TASK_TYPES = ("ncr", "npr", "permutations_repetition", "inclusion_exclusion", "stars_and_bars", "pigeonhole")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombinatoricsConfig:
|
||||
min_n: int = 5
|
||||
max_n: int = 15
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(
|
||||
default_factory=lambda: [0.2, 0.15, 0.2, 0.2, 0.15, 0.1]
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_n >= 2, "min_n must be >= 2"
|
||||
assert self.max_n >= self.min_n, "max_n must be >= min_n"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
class CombinatoricsDataset(ProceduralDataset):
|
||||
def __init__(self, config: CombinatoricsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_ncr(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_n, self.config.max_n)
|
||||
k = rng.randint(1, n - 1)
|
||||
answer = math.comb(n, k)
|
||||
question = (
|
||||
f"How many ways can you choose {k} items from a set of {n} items? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "ncr"}
|
||||
|
||||
def _make_npr(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_n, self.config.max_n)
|
||||
k = rng.randint(1, min(n - 1, 6))
|
||||
answer = math.perm(n, k)
|
||||
question = (
|
||||
f"How many ways can you arrange {k} items chosen from {n} distinct items "
|
||||
f"(order matters)? Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "npr"}
|
||||
|
||||
def _make_permutations_repetition(self, rng: random.Random) -> dict:
|
||||
letters = []
|
||||
num_distinct = rng.randint(2, 4)
|
||||
pool = "ABCDEFGH"
|
||||
chosen = rng.sample(pool, num_distinct)
|
||||
counts = {}
|
||||
for ch in chosen:
|
||||
c = rng.randint(1, 4)
|
||||
counts[ch] = c
|
||||
letters.extend([ch] * c)
|
||||
rng.shuffle(letters)
|
||||
word = "".join(letters)
|
||||
|
||||
numerator = math.factorial(len(word))
|
||||
denominator = 1
|
||||
for c in counts.values():
|
||||
denominator *= math.factorial(c)
|
||||
answer = numerator // denominator
|
||||
|
||||
count_desc = ", ".join(f"'{k}' appears {v} time(s)" for k, v in sorted(counts.items()))
|
||||
question = (
|
||||
f"How many distinct arrangements can be made from the letters of '{word}'? "
|
||||
f"({count_desc}) Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "permutations_repetition"}
|
||||
|
||||
def _make_inclusion_exclusion(self, rng: random.Random) -> dict:
|
||||
total = rng.randint(50, 200)
|
||||
a_count = rng.randint(total // 4, total * 3 // 4)
|
||||
b_count = rng.randint(total // 4, total * 3 // 4)
|
||||
max_both = min(a_count, b_count, total)
|
||||
min_both = max(0, a_count + b_count - total)
|
||||
both = rng.randint(min_both, max_both)
|
||||
neither = total - (a_count + b_count - both)
|
||||
|
||||
activity_a = rng.choice(["play soccer", "like tea", "study math", "read fiction"])
|
||||
activity_b = rng.choice(["play chess", "like coffee", "study science", "read poetry"])
|
||||
|
||||
question = (
|
||||
f"In a group of {total} people, {a_count} {activity_a}, {b_count} {activity_b}, "
|
||||
f"and {both} do both. How many people do neither? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(neither), "task_type": "inclusion_exclusion"}
|
||||
|
||||
def _make_stars_and_bars(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_n, self.config.max_n)
|
||||
k = rng.randint(2, 5)
|
||||
answer = math.comb(n + k - 1, k - 1)
|
||||
question = (
|
||||
f"How many ways can you distribute {n} identical balls into {k} distinct boxes "
|
||||
f"(each box can hold any number of balls)? Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "stars_and_bars"}
|
||||
|
||||
def _make_pigeonhole(self, rng: random.Random) -> dict:
|
||||
boxes = rng.randint(3, 20)
|
||||
extra = rng.randint(1, 10)
|
||||
items = boxes * extra + rng.randint(1, boxes - 1)
|
||||
answer = (items + boxes - 1) // boxes # ceiling division
|
||||
|
||||
question = (
|
||||
f"If {items} items are placed into {boxes} boxes, what is the minimum number of items "
|
||||
f"that must be in at least one box? Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(answer), "task_type": "pigeonhole"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"ncr": self._make_ncr,
|
||||
"npr": self._make_npr,
|
||||
"permutations_repetition": self._make_permutations_repetition,
|
||||
"inclusion_exclusion": self._make_inclusion_exclusion,
|
||||
"stars_and_bars": self._make_stars_and_bars,
|
||||
"pigeonhole": self._make_pigeonhole,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {"min_n": self.config.min_n, "max_n": self.config.max_n},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CombinatoricsCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(CombinatoricsCurriculum.__name__, CombinatoricsConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="n_range",
|
||||
levels=[5, 10, 20, 30],
|
||||
lower_field_name="min_n",
|
||||
upper_field_name="max_n",
|
||||
description="Range for n in combinatorial problems",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, CombinatoricsDataset, CombinatoricsConfig, CombinatoricsCurriculum)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
|
||||
from .job_scheduling import JobSchedulingConfig, JobSchedulingCurriculum, JobSchedulingDataset
|
||||
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsCurriculum, FamilyRelationshipsDataset
|
||||
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
|
||||
from .path_star import PathStarConfig, PathStarCurriculum, PathStarDataset
|
||||
|
|
@ -24,4 +25,7 @@ __all__ = [
|
|||
"ShortestPathConfig",
|
||||
"ShortestPathDataset",
|
||||
"ShortestPathCurriculum",
|
||||
"JobSchedulingConfig",
|
||||
"JobSchedulingDataset",
|
||||
"JobSchedulingCurriculum",
|
||||
]
|
||||
|
|
|
|||
224
reasoning_gym/graphs/job_scheduling.py
Normal file
224
reasoning_gym/graphs/job_scheduling.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
import random
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "job_scheduling"
|
||||
|
||||
TASK_TYPES = ("critical_path", "interval_scheduling", "task_ordering")
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobSchedulingConfig:
|
||||
min_jobs: int = 4
|
||||
max_jobs: int = 7
|
||||
min_duration: int = 1
|
||||
max_duration: int = 10
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.34, 0.33, 0.33])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_jobs >= 3, "min_jobs must be >= 3"
|
||||
assert self.max_jobs >= self.min_jobs, "max_jobs must be >= min_jobs"
|
||||
assert self.min_duration >= 1, "min_duration must be >= 1"
|
||||
assert self.max_duration >= self.min_duration, "max_duration must be >= min_duration"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
def _critical_path(jobs: dict, deps: dict) -> int:
|
||||
earliest = {}
|
||||
in_deg = defaultdict(int)
|
||||
for j in jobs:
|
||||
in_deg[j] = 0
|
||||
for j, prereqs in deps.items():
|
||||
in_deg[j] = len(prereqs)
|
||||
|
||||
queue = deque()
|
||||
for j in jobs:
|
||||
if in_deg[j] == 0:
|
||||
earliest[j] = 0
|
||||
queue.append(j)
|
||||
|
||||
adj = defaultdict(list)
|
||||
for j, prereqs in deps.items():
|
||||
for p in prereqs:
|
||||
adj[p].append(j)
|
||||
|
||||
while queue:
|
||||
j = queue.popleft()
|
||||
for nxt in adj[j]:
|
||||
start = earliest[j] + jobs[j]
|
||||
earliest[nxt] = max(earliest.get(nxt, 0), start)
|
||||
in_deg[nxt] -= 1
|
||||
if in_deg[nxt] == 0:
|
||||
queue.append(nxt)
|
||||
|
||||
return max(earliest[j] + jobs[j] for j in jobs)
|
||||
|
||||
|
||||
def _topo_sort(jobs: list, deps: dict) -> list:
|
||||
in_deg = {j: 0 for j in jobs}
|
||||
adj = defaultdict(list)
|
||||
for j, prereqs in deps.items():
|
||||
in_deg[j] = len(prereqs)
|
||||
for p in prereqs:
|
||||
adj[p].append(j)
|
||||
|
||||
queue = deque(sorted(j for j in jobs if in_deg[j] == 0))
|
||||
result = []
|
||||
while queue:
|
||||
j = queue.popleft()
|
||||
result.append(j)
|
||||
for nxt in sorted(adj[j]):
|
||||
in_deg[nxt] -= 1
|
||||
if in_deg[nxt] == 0:
|
||||
queue.append(nxt)
|
||||
return result
|
||||
|
||||
|
||||
class JobSchedulingDataset(ProceduralDataset):
|
||||
def __init__(self, config: JobSchedulingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_critical_path(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_jobs, self.config.max_jobs)
|
||||
names = [chr(65 + i) for i in range(n)]
|
||||
jobs = {name: rng.randint(self.config.min_duration, self.config.max_duration) for name in names}
|
||||
deps = {}
|
||||
for i in range(1, n):
|
||||
num_deps = rng.randint(0, min(2, i))
|
||||
if num_deps > 0:
|
||||
deps[names[i]] = rng.sample(names[:i], num_deps)
|
||||
|
||||
cp = _critical_path(jobs, deps)
|
||||
job_desc = ", ".join(f"{name}(duration={d})" for name, d in jobs.items())
|
||||
dep_desc = "; ".join(f"{j} depends on {', '.join(p)}" for j, p in deps.items()) or "no dependencies"
|
||||
question = (
|
||||
f"Given jobs: {job_desc}. Dependencies: {dep_desc}. "
|
||||
f"All jobs without dependencies can start immediately and run in parallel. "
|
||||
f"What is the minimum total time to complete all jobs? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(cp), "task_type": "critical_path"}
|
||||
|
||||
def _make_interval_scheduling(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_jobs, self.config.max_jobs + 3)
|
||||
intervals = []
|
||||
for _ in range(n):
|
||||
start = rng.randint(0, 20)
|
||||
end = start + rng.randint(1, 8)
|
||||
intervals.append((start, end))
|
||||
intervals.sort(key=lambda x: x[1])
|
||||
|
||||
selected = []
|
||||
last_end = -1
|
||||
for s, e in intervals:
|
||||
if s >= last_end:
|
||||
selected.append((s, e))
|
||||
last_end = e
|
||||
|
||||
rng.shuffle(intervals)
|
||||
intervals_str = ", ".join(f"({s}, {e})" for s, e in intervals)
|
||||
question = (
|
||||
f"Given the following intervals (start, end): [{intervals_str}]. "
|
||||
f"What is the maximum number of non-overlapping intervals you can select? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(len(selected)), "task_type": "interval_scheduling"}
|
||||
|
||||
def _make_task_ordering(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_jobs, self.config.max_jobs)
|
||||
names = [chr(65 + i) for i in range(n)]
|
||||
deps = {}
|
||||
for i in range(1, n):
|
||||
num_deps = rng.randint(0, min(2, i))
|
||||
if num_deps > 0:
|
||||
deps[names[i]] = rng.sample(names[:i], num_deps)
|
||||
|
||||
order = _topo_sort(names, deps)
|
||||
dep_desc = "; ".join(f"{j} depends on {', '.join(p)}" for j, p in deps.items()) or "no dependencies"
|
||||
answer = ", ".join(order)
|
||||
question = (
|
||||
f"Given tasks: {', '.join(names)}. Dependencies: {dep_desc}. "
|
||||
f"Give a valid execution order that respects all dependencies. "
|
||||
f"List the tasks separated by commas."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "task_ordering", "deps": deps, "names": names}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"critical_path": self._make_critical_path,
|
||||
"interval_scheduling": self._make_interval_scheduling,
|
||||
"task_ordering": self._make_task_ordering,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_jobs": self.config.min_jobs,
|
||||
"max_jobs": self.config.max_jobs,
|
||||
},
|
||||
**({"deps": result["deps"], "names": result["names"]} if "deps" in result else {}),
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
|
||||
if task_type == "task_ordering":
|
||||
try:
|
||||
order = [x.strip() for x in answer.split(",")]
|
||||
deps = entry["metadata"]["deps"]
|
||||
names = entry["metadata"]["names"]
|
||||
if set(order) != set(names):
|
||||
return 0.0
|
||||
pos = {name: i for i, name in enumerate(order)}
|
||||
for j, prereqs in deps.items():
|
||||
for p in prereqs:
|
||||
if pos.get(p, float("inf")) >= pos.get(j, -1):
|
||||
return 0.0
|
||||
return 1.0
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
return 1.0 if int(answer.strip()) == int(oracle.strip()) else 0.0
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
|
||||
class JobSchedulingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(JobSchedulingCurriculum.__name__, JobSchedulingConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="max_jobs",
|
||||
field_name="max_jobs",
|
||||
levels=[4, 7, 10, 15],
|
||||
description="Maximum number of jobs",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, JobSchedulingDataset, JobSchedulingConfig, JobSchedulingCurriculum)
|
||||
7
reasoning_gym/languages/__init__.py
Normal file
7
reasoning_gym/languages/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""
|
||||
Formal languages and regex reasoning tasks.
|
||||
"""
|
||||
|
||||
from .regex_puzzles import RegexPuzzlesConfig, RegexPuzzlesCurriculum, RegexPuzzlesDataset
|
||||
|
||||
__all__ = ["RegexPuzzlesDataset", "RegexPuzzlesConfig", "RegexPuzzlesCurriculum"]
|
||||
257
reasoning_gym/languages/regex_puzzles.py
Normal file
257
reasoning_gym/languages/regex_puzzles.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
import random
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "regex_puzzles"
|
||||
|
||||
TASK_TYPES = ("string_generation", "extraction", "dfa_state", "dfa_prefix")
|
||||
|
||||
REGEX_PATTERNS = [
|
||||
(r"[a-c]{2}[0-9]{3}", "two lowercase letters (a-c) followed by three digits"),
|
||||
(r"[A-Z]{3}[0-9]{2}", "three uppercase letters followed by two digits"),
|
||||
(r"[0-9]{2}-[0-9]{2}-[0-9]{4}", "a date in DD-MM-YYYY format (digits only)"),
|
||||
(r"[a-z]+@[a-z]+\.[a-z]{2,3}", "a simple email like name@domain.com"),
|
||||
(r"[01]{4}", "a 4-digit binary string"),
|
||||
(r"[A-Z][a-z]{2,5}", "a capitalized word (3-6 letters)"),
|
||||
(r"[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}", "an IP-address-like pattern"),
|
||||
(r"#[0-9a-f]{6}", "a hex color code like #a1b2c3"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegexPuzzlesConfig:
|
||||
min_dfa_states: int = 3
|
||||
max_dfa_states: int = 5
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.3, 0.25, 0.25, 0.2])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_dfa_states >= 2, "min_dfa_states must be >= 2"
|
||||
assert self.max_dfa_states >= self.min_dfa_states, "max_dfa_states must be >= min_dfa_states"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
def _gen_matching_string(pattern: str, rng: random.Random) -> str:
|
||||
"""Generate a string matching a simple regex pattern via character-level generation."""
|
||||
import string
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
while i < len(pattern):
|
||||
if pattern[i] == "[":
|
||||
end = pattern.index("]", i)
|
||||
char_class = pattern[i + 1 : end]
|
||||
i = end + 1
|
||||
reps = 1
|
||||
if i < len(pattern) and pattern[i] == "{":
|
||||
end_brace = pattern.index("}", i)
|
||||
quant = pattern[i + 1 : end_brace]
|
||||
if "," in quant:
|
||||
lo, hi = quant.split(",")
|
||||
reps = rng.randint(int(lo), int(hi))
|
||||
else:
|
||||
reps = int(quant)
|
||||
i = end_brace + 1
|
||||
elif i < len(pattern) and pattern[i] == "+":
|
||||
reps = rng.randint(1, 5)
|
||||
i += 1
|
||||
|
||||
chars = []
|
||||
j = 0
|
||||
while j < len(char_class):
|
||||
if j + 2 < len(char_class) and char_class[j + 1] == "-":
|
||||
chars.extend(chr(c) for c in range(ord(char_class[j]), ord(char_class[j + 2]) + 1))
|
||||
j += 3
|
||||
else:
|
||||
chars.append(char_class[j])
|
||||
j += 1
|
||||
for _ in range(reps):
|
||||
result.append(rng.choice(chars))
|
||||
elif pattern[i] == "\\":
|
||||
i += 1
|
||||
if pattern[i] == "d":
|
||||
result.append(str(rng.randint(0, 9)))
|
||||
elif pattern[i] == ".":
|
||||
result.append(".")
|
||||
elif pattern[i] == "$":
|
||||
result.append("$")
|
||||
i += 1
|
||||
else:
|
||||
result.append(pattern[i])
|
||||
i += 1
|
||||
return "".join(result)
|
||||
|
||||
|
||||
class RegexPuzzlesDataset(ProceduralDataset):
|
||||
def __init__(self, config: RegexPuzzlesConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_string_generation(self, rng: random.Random) -> dict:
|
||||
pattern, desc = rng.choice(REGEX_PATTERNS)
|
||||
answer = _gen_matching_string(pattern, rng)
|
||||
question = (
|
||||
f"Generate a string that matches the regex pattern `{pattern}` "
|
||||
f"(i.e., {desc}). Give only the string, nothing else."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "string_generation", "pattern": pattern}
|
||||
|
||||
def _make_extraction(self, rng: random.Random) -> dict:
|
||||
pattern_str = r"\$\d+\.\d{2}"
|
||||
n = rng.randint(2, 4)
|
||||
prices = [f"${rng.randint(1, 999)}.{rng.randint(10, 99):02d}" for _ in range(n)]
|
||||
words = ["The price is", "costs", "for", "and", "total", "you pay", "item at"]
|
||||
text_parts = []
|
||||
for p in prices:
|
||||
text_parts.append(rng.choice(words))
|
||||
text_parts.append(p)
|
||||
text_parts.append(rng.choice(["today", "now", "in total"]))
|
||||
text = " ".join(text_parts)
|
||||
matches = re.findall(pattern_str, text)
|
||||
answer = ", ".join(matches)
|
||||
question = (
|
||||
f"Extract all dollar amounts (matching the pattern $X.XX) from the following text:\n"
|
||||
f"'{text}'\n"
|
||||
f"List them separated by commas in the order they appear."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "extraction"}
|
||||
|
||||
def _make_dfa(self, rng: random.Random) -> tuple[dict, list, str, list, str]:
|
||||
n = rng.randint(self.config.min_dfa_states, self.config.max_dfa_states)
|
||||
states = [f"q{i}" for i in range(n)]
|
||||
alphabet = ["a", "b"]
|
||||
transitions = {}
|
||||
for s in states:
|
||||
for c in alphabet:
|
||||
transitions[(s, c)] = rng.choice(states)
|
||||
accept = rng.sample(states, rng.randint(1, max(1, n // 2)))
|
||||
return transitions, states, states[0], accept, alphabet
|
||||
|
||||
def _run_dfa(self, transitions: dict, start: str, input_str: str) -> str:
|
||||
state = start
|
||||
for c in input_str:
|
||||
state = transitions.get((state, c), state)
|
||||
return state
|
||||
|
||||
def _make_dfa_state(self, rng: random.Random) -> dict:
|
||||
transitions, states, start, accept, alphabet = self._make_dfa(rng)
|
||||
input_len = rng.randint(3, 6)
|
||||
input_str = "".join(rng.choice(alphabet) for _ in range(input_len))
|
||||
final_state = self._run_dfa(transitions, start, input_str)
|
||||
|
||||
trans_str = ", ".join(f"δ({s},{c})={transitions[(s,c)]}" for s in states for c in alphabet)
|
||||
question = (
|
||||
f"A DFA has states {{{', '.join(states)}}}, alphabet {{a, b}}, start state {start}.\n"
|
||||
f"Transitions: {trans_str}\n"
|
||||
f"After processing the input '{input_str}', what state is the DFA in? "
|
||||
f"Give only the state name."
|
||||
)
|
||||
return {"question": question, "answer": final_state, "task_type": "dfa_state"}
|
||||
|
||||
def _make_dfa_prefix(self, rng: random.Random) -> dict:
|
||||
transitions, states, start, accept, alphabet = self._make_dfa(rng)
|
||||
input_len = rng.randint(4, 8)
|
||||
input_str = "".join(rng.choice(alphabet) for _ in range(input_len))
|
||||
|
||||
longest_prefix = ""
|
||||
state = start
|
||||
for i, c in enumerate(input_str):
|
||||
state = transitions.get((state, c), state)
|
||||
if state in accept:
|
||||
longest_prefix = input_str[: i + 1]
|
||||
|
||||
if not longest_prefix:
|
||||
if start in accept:
|
||||
longest_prefix = ""
|
||||
else:
|
||||
longest_prefix = "NONE"
|
||||
|
||||
trans_str = ", ".join(f"δ({s},{c})={transitions[(s,c)]}" for s in states for c in alphabet)
|
||||
accept_str = ", ".join(accept)
|
||||
question = (
|
||||
f"A DFA has states {{{', '.join(states)}}}, alphabet {{a, b}}, "
|
||||
f"start state {start}, accept states {{{accept_str}}}.\n"
|
||||
f"Transitions: {trans_str}\n"
|
||||
f"What is the longest prefix of '{input_str}' that is accepted by this DFA? "
|
||||
f"If no prefix is accepted, answer 'NONE'."
|
||||
)
|
||||
return {"question": question, "answer": longest_prefix, "task_type": "dfa_prefix"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"string_generation": self._make_string_generation,
|
||||
"extraction": self._make_extraction,
|
||||
"dfa_state": self._make_dfa_state,
|
||||
"dfa_prefix": self._make_dfa_prefix,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_dfa_states": self.config.min_dfa_states,
|
||||
"max_dfa_states": self.config.max_dfa_states,
|
||||
},
|
||||
**({"pattern": result["pattern"]} if "pattern" in result else {}),
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
|
||||
if task_type == "string_generation":
|
||||
pattern = entry["metadata"]["pattern"]
|
||||
try:
|
||||
if re.fullmatch(pattern, answer.strip()):
|
||||
return 1.0
|
||||
except re.error:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
if task_type == "extraction":
|
||||
try:
|
||||
a_parts = [x.strip() for x in answer.split(",")]
|
||||
o_parts = [x.strip() for x in oracle.split(",")]
|
||||
if a_parts == o_parts:
|
||||
return 1.0
|
||||
return 0.0
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class RegexPuzzlesCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(RegexPuzzlesCurriculum.__name__, RegexPuzzlesConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="max_dfa_states",
|
||||
field_name="max_dfa_states",
|
||||
levels=[3, 5, 7, 10],
|
||||
description="Maximum DFA states",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, RegexPuzzlesDataset, RegexPuzzlesConfig, RegexPuzzlesCurriculum)
|
||||
|
|
@ -7,6 +7,7 @@ from .circuit_logic import CircuitLogicConfig, CircuitLogicCurriculum, CircuitLo
|
|||
from .knights_knaves import KnightsKnavesConfig, KnightsKnavesCurriculum, KnightsKnavesDataset
|
||||
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicCurriculum, PropositionalLogicDataset
|
||||
from .self_reference import SelfReferenceConfig, SelfReferenceCurriculum, SelfReferenceDataset
|
||||
from .set_operations import SetOperationsConfig, SetOperationsCurriculum, SetOperationsDataset
|
||||
from .syllogisms import SyllogismConfig, SyllogismCurriculum, SyllogismDataset
|
||||
from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset
|
||||
|
||||
|
|
@ -33,4 +34,7 @@ __all__ = [
|
|||
"KnightsKnavesConfig",
|
||||
"KnightsKnavesDataset",
|
||||
"KnightsKnavesCurriculum",
|
||||
"SetOperationsConfig",
|
||||
"SetOperationsDataset",
|
||||
"SetOperationsCurriculum",
|
||||
]
|
||||
|
|
|
|||
191
reasoning_gym/logic/set_operations.py
Normal file
191
reasoning_gym/logic/set_operations.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "set_operations"
|
||||
|
||||
TASK_TYPES = ("union", "intersection", "difference", "symmetric_difference", "cardinality", "power_set_size", "complement", "chained")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetOperationsConfig:
|
||||
min_set_size: int = 3
|
||||
max_set_size: int = 8
|
||||
min_value: int = 1
|
||||
max_value: int = 20
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.15, 0.15, 0.12, 0.12, 0.12, 0.1, 0.12, 0.12])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_set_size >= 1, "min_set_size must be >= 1"
|
||||
assert self.max_set_size >= self.min_set_size, "max_set_size must be >= min_set_size"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
def _fmt_set(s: set) -> str:
|
||||
return "{" + ", ".join(str(x) for x in sorted(s)) + "}"
|
||||
|
||||
|
||||
class SetOperationsDataset(ProceduralDataset):
|
||||
def __init__(self, config: SetOperationsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _gen_set(self, rng: random.Random) -> set:
|
||||
n = rng.randint(self.config.min_set_size, self.config.max_set_size)
|
||||
return set(rng.sample(range(self.config.min_value, self.config.max_value + 1), n))
|
||||
|
||||
def _make_union(self, rng: random.Random) -> dict:
|
||||
a, b = self._gen_set(rng), self._gen_set(rng)
|
||||
result = a | b
|
||||
question = f"Given A = {_fmt_set(a)} and B = {_fmt_set(b)}, find A ∪ B (the union)."
|
||||
return {"question": question, "answer": _fmt_set(result), "task_type": "union"}
|
||||
|
||||
def _make_intersection(self, rng: random.Random) -> dict:
|
||||
a, b = self._gen_set(rng), self._gen_set(rng)
|
||||
result = a & b
|
||||
question = f"Given A = {_fmt_set(a)} and B = {_fmt_set(b)}, find A ∩ B (the intersection)."
|
||||
return {"question": question, "answer": _fmt_set(result), "task_type": "intersection"}
|
||||
|
||||
def _make_difference(self, rng: random.Random) -> dict:
|
||||
a, b = self._gen_set(rng), self._gen_set(rng)
|
||||
result = a - b
|
||||
question = f"Given A = {_fmt_set(a)} and B = {_fmt_set(b)}, find A \\ B (elements in A but not in B)."
|
||||
return {"question": question, "answer": _fmt_set(result), "task_type": "difference"}
|
||||
|
||||
def _make_symmetric_difference(self, rng: random.Random) -> dict:
|
||||
a, b = self._gen_set(rng), self._gen_set(rng)
|
||||
result = a ^ b
|
||||
question = f"Given A = {_fmt_set(a)} and B = {_fmt_set(b)}, find A △ B (the symmetric difference)."
|
||||
return {"question": question, "answer": _fmt_set(result), "task_type": "symmetric_difference"}
|
||||
|
||||
def _make_cardinality(self, rng: random.Random) -> dict:
|
||||
a_size = rng.randint(5, 30)
|
||||
b_size = rng.randint(5, 30)
|
||||
both = rng.randint(0, min(a_size, b_size))
|
||||
union_size = a_size + b_size - both
|
||||
question = (
|
||||
f"If |A| = {a_size}, |B| = {b_size}, and |A ∩ B| = {both}, what is |A ∪ B|? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(union_size), "task_type": "cardinality"}
|
||||
|
||||
def _make_power_set_size(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(2, 8)
|
||||
answer = 2 ** n
|
||||
question = f"How many subsets does a set with {n} elements have? Give your answer as a single integer."
|
||||
return {"question": question, "answer": str(answer), "task_type": "power_set_size"}
|
||||
|
||||
def _make_complement(self, rng: random.Random) -> dict:
|
||||
u_max = rng.randint(8, 15)
|
||||
universe = set(range(1, u_max + 1))
|
||||
a = set(rng.sample(sorted(universe), rng.randint(2, u_max - 2)))
|
||||
result = universe - a
|
||||
question = (
|
||||
f"If the universal set U = {_fmt_set(universe)} and A = {_fmt_set(a)}, "
|
||||
f"find A' (the complement of A in U)."
|
||||
)
|
||||
return {"question": question, "answer": _fmt_set(result), "task_type": "complement"}
|
||||
|
||||
def _make_chained(self, rng: random.Random) -> dict:
|
||||
a, b, c = self._gen_set(rng), self._gen_set(rng), self._gen_set(rng)
|
||||
op1 = rng.choice(["union", "intersection"])
|
||||
op2 = rng.choice(["union", "intersection"])
|
||||
op1_sym = "∪" if op1 == "union" else "∩"
|
||||
op2_sym = "∪" if op2 == "union" else "∩"
|
||||
|
||||
if op1 == "union":
|
||||
intermediate = a | b
|
||||
else:
|
||||
intermediate = a & b
|
||||
if op2 == "union":
|
||||
result = intermediate | c
|
||||
else:
|
||||
result = intermediate & c
|
||||
|
||||
question = (
|
||||
f"Given A = {_fmt_set(a)}, B = {_fmt_set(b)}, C = {_fmt_set(c)}, "
|
||||
f"find (A {op1_sym} B) {op2_sym} C."
|
||||
)
|
||||
return {"question": question, "answer": _fmt_set(result), "task_type": "chained"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"union": self._make_union,
|
||||
"intersection": self._make_intersection,
|
||||
"difference": self._make_difference,
|
||||
"symmetric_difference": self._make_symmetric_difference,
|
||||
"cardinality": self._make_cardinality,
|
||||
"power_set_size": self._make_power_set_size,
|
||||
"complement": self._make_complement,
|
||||
"chained": self._make_chained,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_set_size": self.config.min_set_size,
|
||||
"max_set_size": self.config.max_set_size,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
task_type = entry["metadata"]["task_type"]
|
||||
if task_type in ("cardinality", "power_set_size"):
|
||||
try:
|
||||
return 1.0 if int(answer.strip()) == int(oracle.strip()) else 0.0
|
||||
except ValueError:
|
||||
return 0.0
|
||||
try:
|
||||
parsed = set()
|
||||
inner = answer.strip().strip("{}")
|
||||
if inner:
|
||||
for x in inner.split(","):
|
||||
parsed.add(int(x.strip()))
|
||||
oracle_set = set()
|
||||
oracle_inner = oracle.strip().strip("{}")
|
||||
if oracle_inner:
|
||||
for x in oracle_inner.split(","):
|
||||
oracle_set.add(int(x.strip()))
|
||||
return 1.0 if parsed == oracle_set else 0.0
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
|
||||
class SetOperationsCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(SetOperationsCurriculum.__name__, SetOperationsConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="set_size",
|
||||
levels=[3, 6, 10, 15],
|
||||
lower_field_name="min_set_size",
|
||||
upper_field_name="max_set_size",
|
||||
description="Size of generated sets",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, SetOperationsDataset, SetOperationsConfig, SetOperationsCurriculum)
|
||||
19
reasoning_gym/optimization/__init__.py
Normal file
19
reasoning_gym/optimization/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Optimization reasoning tasks.
|
||||
"""
|
||||
|
||||
from .dynamic_programming import DynamicProgrammingConfig, DynamicProgrammingCurriculum, DynamicProgrammingDataset
|
||||
from .knapsack import KnapsackConfig, KnapsackCurriculum, KnapsackDataset
|
||||
from .linear_programming import LinearProgrammingConfig, LinearProgrammingCurriculum, LinearProgrammingDataset
|
||||
|
||||
__all__ = [
|
||||
"LinearProgrammingDataset",
|
||||
"LinearProgrammingConfig",
|
||||
"LinearProgrammingCurriculum",
|
||||
"KnapsackDataset",
|
||||
"KnapsackConfig",
|
||||
"KnapsackCurriculum",
|
||||
"DynamicProgrammingDataset",
|
||||
"DynamicProgrammingConfig",
|
||||
"DynamicProgrammingCurriculum",
|
||||
]
|
||||
199
reasoning_gym/optimization/dynamic_programming.py
Normal file
199
reasoning_gym/optimization/dynamic_programming.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "dynamic_programming"
|
||||
|
||||
TASK_TYPES = ("lcs", "coin_change", "lis", "edit_distance", "staircase")
|
||||
|
||||
|
||||
def _lcs_length(s1: str, s2: str) -> tuple[int, str]:
|
||||
m, n = len(s1), len(s2)
|
||||
dp = [[""] * (n + 1) for _ in range(m + 1)]
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if s1[i - 1] == s2[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1] + s1[i - 1]
|
||||
else:
|
||||
dp[i][j] = dp[i - 1][j] if len(dp[i - 1][j]) >= len(dp[i][j - 1]) else dp[i][j - 1]
|
||||
return len(dp[m][n]), dp[m][n]
|
||||
|
||||
|
||||
def _coin_change(coins: list[int], amount: int) -> int:
|
||||
dp = [float("inf")] * (amount + 1)
|
||||
dp[0] = 0
|
||||
for c in coins:
|
||||
for a in range(c, amount + 1):
|
||||
dp[a] = min(dp[a], dp[a - c] + 1)
|
||||
return dp[amount] if dp[amount] != float("inf") else -1
|
||||
|
||||
|
||||
def _lis_length(arr: list[int]) -> int:
|
||||
if not arr:
|
||||
return 0
|
||||
dp = [1] * len(arr)
|
||||
for i in range(1, len(arr)):
|
||||
for j in range(i):
|
||||
if arr[j] < arr[i]:
|
||||
dp[i] = max(dp[i], dp[j] + 1)
|
||||
return max(dp)
|
||||
|
||||
|
||||
def _edit_distance(s1: str, s2: str) -> int:
|
||||
m, n = len(s1), len(s2)
|
||||
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
for i in range(m + 1):
|
||||
dp[i][0] = i
|
||||
for j in range(n + 1):
|
||||
dp[0][j] = j
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if s1[i - 1] == s2[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1]
|
||||
else:
|
||||
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
|
||||
return dp[m][n]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicProgrammingConfig:
|
||||
min_str_len: int = 4
|
||||
max_str_len: int = 8
|
||||
min_arr_len: int = 5
|
||||
max_arr_len: int = 10
|
||||
max_staircase: int = 15
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.2, 0.2, 0.2, 0.2, 0.2])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_str_len >= 2, "min_str_len must be >= 2"
|
||||
assert self.max_str_len >= self.min_str_len, "max_str_len must be >= min_str_len"
|
||||
assert self.min_arr_len >= 3, "min_arr_len must be >= 3"
|
||||
assert self.max_arr_len >= self.min_arr_len, "max_arr_len must be >= min_arr_len"
|
||||
assert self.max_staircase >= 2, "max_staircase must be >= 2"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
class DynamicProgrammingDataset(ProceduralDataset):
|
||||
def __init__(self, config: DynamicProgrammingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_lcs(self, rng: random.Random) -> dict:
|
||||
l1 = rng.randint(self.config.min_str_len, self.config.max_str_len)
|
||||
l2 = rng.randint(self.config.min_str_len, self.config.max_str_len)
|
||||
chars = "ABCDEFGH"
|
||||
s1 = "".join(rng.choice(chars) for _ in range(l1))
|
||||
s2 = "".join(rng.choice(chars) for _ in range(l2))
|
||||
length, _ = _lcs_length(s1, s2)
|
||||
question = (
|
||||
f"Find the length of the longest common subsequence (LCS) of '{s1}' and '{s2}'. "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(length), "task_type": "lcs"}
|
||||
|
||||
def _make_coin_change(self, rng: random.Random) -> dict:
|
||||
num_coins = rng.randint(2, 4)
|
||||
coins = sorted(set(rng.randint(1, 10) for _ in range(num_coins + 2)))[:num_coins]
|
||||
if 1 not in coins:
|
||||
coins = [1] + coins
|
||||
amount = rng.randint(5, 25)
|
||||
result = _coin_change(coins, amount)
|
||||
question = (
|
||||
f"What is the minimum number of coins needed to make {amount} "
|
||||
f"using coins of denominations {coins}? Each denomination can be used unlimited times. "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(result), "task_type": "coin_change"}
|
||||
|
||||
def _make_lis(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(self.config.min_arr_len, self.config.max_arr_len)
|
||||
arr = [rng.randint(1, 50) for _ in range(n)]
|
||||
length = _lis_length(arr)
|
||||
question = (
|
||||
f"Find the length of the longest strictly increasing subsequence in {arr}. "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(length), "task_type": "lis"}
|
||||
|
||||
def _make_edit_distance(self, rng: random.Random) -> dict:
|
||||
l1 = rng.randint(self.config.min_str_len, self.config.max_str_len)
|
||||
l2 = rng.randint(self.config.min_str_len, self.config.max_str_len)
|
||||
chars = "abcdefgh"
|
||||
s1 = "".join(rng.choice(chars) for _ in range(l1))
|
||||
s2 = "".join(rng.choice(chars) for _ in range(l2))
|
||||
dist = _edit_distance(s1, s2)
|
||||
question = (
|
||||
f"What is the minimum edit distance (Levenshtein distance) between "
|
||||
f"'{s1}' and '{s2}'? Operations: insert, delete, or substitute a character. "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(dist), "task_type": "edit_distance"}
|
||||
|
||||
def _make_staircase(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(3, self.config.max_staircase)
|
||||
ways = [0] * (n + 1)
|
||||
ways[0] = 1
|
||||
ways[1] = 1
|
||||
for i in range(2, n + 1):
|
||||
ways[i] = ways[i - 1] + ways[i - 2]
|
||||
question = (
|
||||
f"You are climbing a staircase with {n} steps. Each time you can climb 1 or 2 steps. "
|
||||
f"How many distinct ways can you reach the top? Give your answer as a single integer."
|
||||
)
|
||||
return {"question": question, "answer": str(ways[n]), "task_type": "staircase"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"lcs": self._make_lcs,
|
||||
"coin_change": self._make_coin_change,
|
||||
"lis": self._make_lis,
|
||||
"edit_distance": self._make_edit_distance,
|
||||
"staircase": self._make_staircase,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"max_str_len": self.config.max_str_len,
|
||||
"max_arr_len": self.config.max_arr_len,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DynamicProgrammingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(DynamicProgrammingCurriculum.__name__, DynamicProgrammingConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="max_str_len",
|
||||
field_name="max_str_len",
|
||||
levels=[5, 8, 12, 15],
|
||||
description="Maximum string length",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_arr_len",
|
||||
field_name="max_arr_len",
|
||||
levels=[5, 10, 15, 20],
|
||||
description="Maximum array length",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, DynamicProgrammingDataset, DynamicProgrammingConfig, DynamicProgrammingCurriculum)
|
||||
108
reasoning_gym/optimization/knapsack.py
Normal file
108
reasoning_gym/optimization/knapsack.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "knapsack"
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnapsackConfig:
|
||||
min_items: int = 3
|
||||
max_items: int = 6
|
||||
min_weight: int = 1
|
||||
max_weight: int = 15
|
||||
min_value: int = 5
|
||||
max_value: int = 50
|
||||
min_capacity: int = 10
|
||||
max_capacity: int = 30
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_items >= 2, "min_items must be >= 2"
|
||||
assert self.max_items >= self.min_items, "max_items must be >= min_items"
|
||||
assert self.min_weight >= 1, "min_weight must be >= 1"
|
||||
assert self.max_weight >= self.min_weight, "max_weight must be >= min_weight"
|
||||
assert self.min_value >= 1, "min_value must be >= 1"
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value"
|
||||
assert self.min_capacity >= 1, "min_capacity must be >= 1"
|
||||
assert self.max_capacity >= self.min_capacity, "max_capacity must be >= min_capacity"
|
||||
|
||||
|
||||
def _solve_knapsack(weights: list[int], values: list[int], capacity: int) -> int:
|
||||
n = len(weights)
|
||||
dp = [[0] * (capacity + 1) for _ in range(n + 1)]
|
||||
for i in range(1, n + 1):
|
||||
for w in range(capacity + 1):
|
||||
dp[i][w] = dp[i - 1][w]
|
||||
if weights[i - 1] <= w:
|
||||
dp[i][w] = max(dp[i][w], dp[i - 1][w - weights[i - 1]] + values[i - 1])
|
||||
return dp[n][capacity]
|
||||
|
||||
|
||||
class KnapsackDataset(ProceduralDataset):
|
||||
def __init__(self, config: KnapsackConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
n = rng.randint(self.config.min_items, self.config.max_items)
|
||||
weights = [rng.randint(self.config.min_weight, self.config.max_weight) for _ in range(n)]
|
||||
values = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(n)]
|
||||
capacity = rng.randint(self.config.min_capacity, self.config.max_capacity)
|
||||
|
||||
opt_val = _solve_knapsack(weights, values, capacity)
|
||||
|
||||
items_str = ", ".join(f"(weight={w}, value={v})" for w, v in zip(weights, values))
|
||||
question = (
|
||||
f"You have a knapsack with capacity {capacity}. "
|
||||
f"You have the following items: {items_str}. "
|
||||
f"Each item can be used at most once. "
|
||||
f"What is the maximum total value you can carry? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(opt_val),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"weights": weights,
|
||||
"values": values,
|
||||
"capacity": capacity,
|
||||
"difficulty": {
|
||||
"min_items": self.config.min_items,
|
||||
"max_items": self.config.max_items,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class KnapsackCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(KnapsackCurriculum.__name__, KnapsackConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="item_count",
|
||||
levels=[3, 6, 10, 15],
|
||||
lower_field_name="min_items",
|
||||
upper_field_name="max_items",
|
||||
description="Number of items",
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="capacity",
|
||||
levels=[10, 30, 50, 100],
|
||||
lower_field_name="min_capacity",
|
||||
upper_field_name="max_capacity",
|
||||
description="Knapsack capacity range",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, KnapsackDataset, KnapsackConfig, KnapsackCurriculum)
|
||||
108
reasoning_gym/optimization/linear_programming.py
Normal file
108
reasoning_gym/optimization/linear_programming.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "linear_programming"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearProgrammingConfig:
|
||||
min_coeff: int = 1
|
||||
max_coeff: int = 10
|
||||
num_constraints: int = 3
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_coeff >= 1, "min_coeff must be >= 1"
|
||||
assert self.max_coeff >= self.min_coeff, "max_coeff must be >= min_coeff"
|
||||
assert 2 <= self.num_constraints <= 6, "num_constraints must be between 2 and 6"
|
||||
|
||||
|
||||
class LinearProgrammingDataset(ProceduralDataset):
|
||||
"""2-variable LP problems with backward construction from a known optimal vertex."""
|
||||
|
||||
def __init__(self, config: LinearProgrammingConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
x_opt = rng.randint(1, self.config.max_coeff)
|
||||
y_opt = rng.randint(1, self.config.max_coeff)
|
||||
c1 = rng.randint(self.config.min_coeff, self.config.max_coeff)
|
||||
c2 = rng.randint(self.config.min_coeff, self.config.max_coeff)
|
||||
opt_val = c1 * x_opt + c2 * y_opt
|
||||
|
||||
constraints = []
|
||||
for _ in range(self.config.num_constraints):
|
||||
a = rng.randint(self.config.min_coeff, self.config.max_coeff)
|
||||
b = rng.randint(self.config.min_coeff, self.config.max_coeff)
|
||||
rhs = a * x_opt + b * y_opt + rng.randint(0, 5)
|
||||
constraints.append((a, b, rhs))
|
||||
|
||||
tight_a = rng.randint(self.config.min_coeff, self.config.max_coeff)
|
||||
tight_b = rng.randint(self.config.min_coeff, self.config.max_coeff)
|
||||
tight_rhs = tight_a * x_opt + tight_b * y_opt
|
||||
constraints[0] = (tight_a, tight_b, tight_rhs)
|
||||
|
||||
constraint_strs = []
|
||||
for a, b, rhs in constraints:
|
||||
constraint_strs.append(f" {a}x + {b}y <= {rhs}")
|
||||
constraint_strs.append(" x >= 0")
|
||||
constraint_strs.append(" y >= 0")
|
||||
constraints_text = "\n".join(constraint_strs)
|
||||
|
||||
question = (
|
||||
f"Maximize {c1}x + {c2}y subject to:\n{constraints_text}\n"
|
||||
f"What is the maximum value of the objective function? "
|
||||
f"Give your answer as a single integer."
|
||||
)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(opt_val),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"optimal_point": (x_opt, y_opt),
|
||||
"difficulty": {
|
||||
"num_constraints": self.config.num_constraints,
|
||||
"max_coeff": self.config.max_coeff,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
try:
|
||||
return 1.0 if int(answer.strip()) == int(entry["answer"]) else 0.0
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
|
||||
class LinearProgrammingCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(LinearProgrammingCurriculum.__name__, LinearProgrammingConfig)
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="num_constraints",
|
||||
field_name="num_constraints",
|
||||
levels=[2, 3, 4, 5],
|
||||
description="Number of inequality constraints",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_coeff",
|
||||
field_name="max_coeff",
|
||||
levels=[5, 10, 20, 50],
|
||||
description="Maximum coefficient value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, LinearProgrammingDataset, LinearProgrammingConfig, LinearProgrammingCurriculum)
|
||||
|
|
@ -3,5 +3,17 @@ Probability reasoning tasks.
|
|||
"""
|
||||
|
||||
from .coin_flip import CoinFlipConfig, CoinFlipCurriculum, CoinFlipDataset
|
||||
from .conditional_probability import (
|
||||
ConditionalProbabilityConfig,
|
||||
ConditionalProbabilityCurriculum,
|
||||
ConditionalProbabilityDataset,
|
||||
)
|
||||
|
||||
__all__ = ["CoinFlipDataset", "CoinFlipConfig", "CoinFlipCurriculum"]
|
||||
__all__ = [
|
||||
"CoinFlipDataset",
|
||||
"CoinFlipConfig",
|
||||
"CoinFlipCurriculum",
|
||||
"ConditionalProbabilityDataset",
|
||||
"ConditionalProbabilityConfig",
|
||||
"ConditionalProbabilityCurriculum",
|
||||
]
|
||||
|
|
|
|||
158
reasoning_gym/probability/conditional_probability.py
Normal file
158
reasoning_gym/probability/conditional_probability.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "conditional_probability"
|
||||
|
||||
TASK_TYPES = ("bayes", "dependent_draws", "contingency_table")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditionalProbabilityConfig:
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.34, 0.33, 0.33])
|
||||
min_total_items: int = 5
|
||||
max_total_items: int = 20
|
||||
min_table_cell: int = 5
|
||||
max_table_cell: int = 50
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
assert self.min_total_items >= 2, "min_total_items must be >= 2"
|
||||
assert self.max_total_items >= self.min_total_items, "max_total_items must be >= min_total_items"
|
||||
|
||||
|
||||
class ConditionalProbabilityDataset(ProceduralDataset):
|
||||
def __init__(self, config: ConditionalProbabilityConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _make_bayes(self, rng: random.Random) -> dict:
|
||||
sensitivity = Fraction(rng.randint(70, 99), 100)
|
||||
specificity = Fraction(rng.randint(70, 99), 100)
|
||||
prevalence = Fraction(rng.randint(1, 15), 100)
|
||||
|
||||
p_pos = sensitivity * prevalence + (1 - specificity) * (1 - prevalence)
|
||||
p_disease_given_pos = (sensitivity * prevalence) / p_pos
|
||||
|
||||
question = (
|
||||
f"A medical test has a sensitivity (true positive rate) of {sensitivity} "
|
||||
f"and a specificity (true negative rate) of {specificity}. "
|
||||
f"The prevalence of the disease in the population is {prevalence}. "
|
||||
f"If a person tests positive, what is the probability they actually have the disease? "
|
||||
f"Give your answer as a simplified fraction."
|
||||
)
|
||||
return {"question": question, "answer": str(p_disease_given_pos), "task_type": "bayes"}
|
||||
|
||||
def _make_dependent_draws(self, rng: random.Random) -> dict:
|
||||
total = rng.randint(self.config.min_total_items, self.config.max_total_items)
|
||||
color_a_count = rng.randint(2, total - 1)
|
||||
color_b_count = total - color_a_count
|
||||
draws = rng.randint(2, min(3, color_a_count))
|
||||
|
||||
color_a = rng.choice(["red", "blue", "green", "white", "black"])
|
||||
color_b = rng.choice([c for c in ["red", "blue", "green", "white", "black"] if c != color_a])
|
||||
|
||||
prob = Fraction(1, 1)
|
||||
for i in range(draws):
|
||||
prob *= Fraction(color_a_count - i, total - i)
|
||||
|
||||
question = (
|
||||
f"A bag contains {color_a_count} {color_a} balls and {color_b_count} {color_b} balls. "
|
||||
f"You draw {draws} balls without replacement. "
|
||||
f"What is the probability that all {draws} balls are {color_a}? "
|
||||
f"Give your answer as a simplified fraction."
|
||||
)
|
||||
return {"question": question, "answer": str(prob), "task_type": "dependent_draws"}
|
||||
|
||||
def _make_contingency(self, rng: random.Random) -> dict:
|
||||
a = rng.randint(self.config.min_table_cell, self.config.max_table_cell)
|
||||
b = rng.randint(self.config.min_table_cell, self.config.max_table_cell)
|
||||
c = rng.randint(self.config.min_table_cell, self.config.max_table_cell)
|
||||
d = rng.randint(self.config.min_table_cell, self.config.max_table_cell)
|
||||
|
||||
total = a + b + c + d
|
||||
row1_total = a + b
|
||||
prob = Fraction(a, row1_total)
|
||||
|
||||
question = (
|
||||
f"Consider the following contingency table:\n\n"
|
||||
f" | Event B | Not B | Total\n"
|
||||
f" Event A | {a:>4} | {b:>4} | {row1_total:>4}\n"
|
||||
f" Not A | {c:>4} | {d:>4} | {c + d:>4}\n"
|
||||
f" Total | {a + c:>4} | {b + d:>4} | {total:>4}\n\n"
|
||||
f"Given that Event A occurred, what is the probability of Event B? "
|
||||
f"Give your answer as a simplified fraction."
|
||||
)
|
||||
return {"question": question, "answer": str(prob), "task_type": "contingency_table"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"bayes": self._make_bayes,
|
||||
"dependent_draws": self._make_dependent_draws,
|
||||
"contingency_table": self._make_contingency,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_total_items": self.config.min_total_items,
|
||||
"max_total_items": self.config.max_total_items,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
try:
|
||||
ans_frac = Fraction(answer.strip())
|
||||
oracle_frac = Fraction(oracle.strip())
|
||||
if ans_frac == oracle_frac:
|
||||
return 1.0
|
||||
diff = abs(float(ans_frac) - float(oracle_frac))
|
||||
if diff < 1e-4:
|
||||
return 0.9
|
||||
if diff < 1e-2:
|
||||
return 0.5
|
||||
return 0.0
|
||||
except (ValueError, ZeroDivisionError):
|
||||
return 0.0
|
||||
|
||||
|
||||
class ConditionalProbabilityCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(ConditionalProbabilityCurriculum.__name__, ConditionalProbabilityConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="total_items",
|
||||
levels=[5, 10, 20, 50],
|
||||
lower_field_name="min_total_items",
|
||||
upper_field_name="max_total_items",
|
||||
description="Total items for draw problems",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
DATASET_NAME, ConditionalProbabilityDataset, ConditionalProbabilityConfig, ConditionalProbabilityCurriculum
|
||||
)
|
||||
7
reasoning_gym/statistics/__init__.py
Normal file
7
reasoning_gym/statistics/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""
|
||||
Statistics reasoning tasks.
|
||||
"""
|
||||
|
||||
from .descriptive_stats import DescriptiveStatsConfig, DescriptiveStatsCurriculum, DescriptiveStatsDataset
|
||||
|
||||
__all__ = ["DescriptiveStatsDataset", "DescriptiveStatsConfig", "DescriptiveStatsCurriculum"]
|
||||
208
reasoning_gym/statistics/descriptive_stats.py
Normal file
208
reasoning_gym/statistics/descriptive_stats.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
import math
|
||||
import random
|
||||
import statistics as stats_module
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "descriptive_stats"
|
||||
|
||||
TASK_TYPES = ("mean", "median", "mode", "weighted_mean", "std_dev", "percentile", "z_score")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DescriptiveStatsConfig:
|
||||
min_data_size: int = 5
|
||||
max_data_size: int = 10
|
||||
min_value: int = 1
|
||||
max_value: int = 100
|
||||
decimal_places: int = 2
|
||||
task_types: tuple[str, ...] = TASK_TYPES
|
||||
task_weights: list[float] = field(default_factory=lambda: [0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.1])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.min_data_size >= 3, "min_data_size must be >= 3"
|
||||
assert self.max_data_size >= self.min_data_size, "max_data_size must be >= min_data_size"
|
||||
assert self.min_value < self.max_value, "min_value must be < max_value"
|
||||
assert len(self.task_types) > 0, "must have at least one task type"
|
||||
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
|
||||
assert len(self.task_weights) == len(self.task_types), "weights must match types"
|
||||
|
||||
|
||||
class DescriptiveStatsDataset(ProceduralDataset):
|
||||
def __init__(self, config: DescriptiveStatsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _gen_data(self, rng: random.Random) -> list[int]:
|
||||
n = rng.randint(self.config.min_data_size, self.config.max_data_size)
|
||||
return [rng.randint(self.config.min_value, self.config.max_value) for _ in range(n)]
|
||||
|
||||
def _fmt(self, val: float) -> str:
|
||||
return f"{val:.{self.config.decimal_places}f}"
|
||||
|
||||
def _make_mean(self, rng: random.Random) -> dict:
|
||||
data = self._gen_data(rng)
|
||||
answer = self._fmt(stats_module.mean(data))
|
||||
question = (
|
||||
f"Find the mean (average) of the following numbers: {data}. "
|
||||
f"Round your answer to {self.config.decimal_places} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "mean"}
|
||||
|
||||
def _make_median(self, rng: random.Random) -> dict:
|
||||
data = self._gen_data(rng)
|
||||
answer = self._fmt(stats_module.median(data))
|
||||
question = (
|
||||
f"Find the median of the following numbers: {data}. "
|
||||
f"Round your answer to {self.config.decimal_places} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "median"}
|
||||
|
||||
def _make_mode(self, rng: random.Random) -> dict:
|
||||
data = self._gen_data(rng)
|
||||
val = rng.choice(data)
|
||||
data.append(val)
|
||||
rng.shuffle(data)
|
||||
counts = Counter(data)
|
||||
max_count = max(counts.values())
|
||||
modes = sorted([k for k, v in counts.items() if v == max_count])
|
||||
answer = ", ".join(str(m) for m in modes)
|
||||
question = (
|
||||
f"Find the mode(s) of the following numbers: {data}. "
|
||||
f"If there are multiple modes, list them separated by commas in ascending order."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "mode"}
|
||||
|
||||
def _make_weighted_mean(self, rng: random.Random) -> dict:
|
||||
n = rng.randint(3, 5)
|
||||
values = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(n)]
|
||||
raw_weights = [rng.randint(1, 10) for _ in range(n)]
|
||||
total_w = sum(raw_weights)
|
||||
weights = [w / total_w for w in raw_weights]
|
||||
|
||||
result = sum(v * w for v, w in zip(values, weights))
|
||||
answer = self._fmt(result)
|
||||
|
||||
pairs = ", ".join(f"value={v} weight={w:.2f}" for v, w in zip(values, weights))
|
||||
question = (
|
||||
f"Calculate the weighted mean of the following: {pairs}. "
|
||||
f"Round your answer to {self.config.decimal_places} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "weighted_mean"}
|
||||
|
||||
def _make_std_dev(self, rng: random.Random) -> dict:
|
||||
data = self._gen_data(rng)
|
||||
answer = self._fmt(stats_module.pstdev(data))
|
||||
question = (
|
||||
f"Find the population standard deviation of the following numbers: {data}. "
|
||||
f"Round your answer to {self.config.decimal_places} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "std_dev"}
|
||||
|
||||
def _make_percentile(self, rng: random.Random) -> dict:
|
||||
data = sorted(self._gen_data(rng))
|
||||
p = rng.choice([25, 50, 75, 90])
|
||||
n = len(data)
|
||||
rank = (p / 100) * (n - 1)
|
||||
lower = int(rank)
|
||||
frac = rank - lower
|
||||
if lower + 1 < n:
|
||||
val = data[lower] + frac * (data[lower + 1] - data[lower])
|
||||
else:
|
||||
val = data[lower]
|
||||
answer = self._fmt(val)
|
||||
question = (
|
||||
f"Find the {p}th percentile of the following numbers: {data}. "
|
||||
f"Use linear interpolation. Round to {self.config.decimal_places} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "percentile"}
|
||||
|
||||
def _make_z_score(self, rng: random.Random) -> dict:
|
||||
mean = rng.randint(50, 150)
|
||||
std = rng.randint(5, 30)
|
||||
x = mean + rng.randint(-3, 3) * std + rng.randint(-std, std)
|
||||
z = (x - mean) / std
|
||||
answer = self._fmt(z)
|
||||
question = (
|
||||
f"A dataset has a mean of {mean} and a standard deviation of {std}. "
|
||||
f"What is the z-score of the value {x}? "
|
||||
f"Round your answer to {self.config.decimal_places} decimal places."
|
||||
)
|
||||
return {"question": question, "answer": answer, "task_type": "z_score"}
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
|
||||
|
||||
generators = {
|
||||
"mean": self._make_mean,
|
||||
"median": self._make_median,
|
||||
"mode": self._make_mode,
|
||||
"weighted_mean": self._make_weighted_mean,
|
||||
"std_dev": self._make_std_dev,
|
||||
"percentile": self._make_percentile,
|
||||
"z_score": self._make_z_score,
|
||||
}
|
||||
result = generators[task_type](rng)
|
||||
return {
|
||||
"question": result["question"],
|
||||
"answer": result["answer"],
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": result["task_type"],
|
||||
"difficulty": {
|
||||
"min_data_size": self.config.min_data_size,
|
||||
"max_data_size": self.config.max_data_size,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
oracle = entry["answer"]
|
||||
if answer.strip() == oracle.strip():
|
||||
return 1.0
|
||||
try:
|
||||
a_parts = [float(x.strip()) for x in answer.split(",")]
|
||||
o_parts = [float(x.strip()) for x in oracle.split(",")]
|
||||
if len(a_parts) != len(o_parts):
|
||||
return 0.0
|
||||
max_err = max(abs(a - o) for a, o in zip(a_parts, o_parts))
|
||||
if max_err < 10 ** (-(self.config.decimal_places)):
|
||||
return 1.0
|
||||
if max_err < 0.1:
|
||||
return 0.5
|
||||
return 0.0
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
|
||||
class DescriptiveStatsCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(DescriptiveStatsCurriculum.__name__, DescriptiveStatsConfig)
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="data_size",
|
||||
levels=[5, 10, 20, 50],
|
||||
lower_field_name="min_data_size",
|
||||
upper_field_name="max_data_size",
|
||||
description="Size of data sets",
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="decimal_places",
|
||||
field_name="decimal_places",
|
||||
levels=[1, 2, 3, 4],
|
||||
description="Decimal precision required",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset(DATASET_NAME, DescriptiveStatsDataset, DescriptiveStatsConfig, DescriptiveStatsCurriculum)
|
||||
76
tests/test_combinatorics.py
Normal file
76
tests/test_combinatorics.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.combinatorics.combinatorics import (
|
||||
CombinatoricsConfig,
|
||||
CombinatoricsCurriculum,
|
||||
CombinatoricsDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = CombinatoricsConfig(min_n=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CombinatoricsConfig(min_n=10, max_n=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = CombinatoricsConfig(seed=42, size=10)
|
||||
ds1 = CombinatoricsDataset(config)
|
||||
ds2 = CombinatoricsDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = CombinatoricsConfig(seed=42, size=50)
|
||||
ds = CombinatoricsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "combinatorics"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = CombinatoricsConfig(seed=42, size=100)
|
||||
ds = CombinatoricsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
answer = item["answer"]
|
||||
assert answer.lstrip("-").isdigit(), f"Item {i}: answer '{answer}' is not an integer"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = CombinatoricsConfig(seed=42, size=10)
|
||||
ds = CombinatoricsDataset(config)
|
||||
item = ds[0]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = CombinatoricsCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("n_range")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_n >= base_cfg.max_n
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("ncr", "npr", "permutations_repetition", "inclusion_exclusion", "stars_and_bars", "pigeonhole"):
|
||||
config = CombinatoricsConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = CombinatoricsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
87
tests/test_complex_advanced.py
Normal file
87
tests/test_complex_advanced.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.algebra.complex_advanced import (
|
||||
ComplexAdvancedConfig,
|
||||
ComplexAdvancedCurriculum,
|
||||
ComplexAdvancedDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = ComplexAdvancedConfig(min_real=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ComplexAdvancedConfig(min_real=10, max_real=5)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ComplexAdvancedConfig(task_types=("invalid",))
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = ComplexAdvancedConfig(seed=42, size=10)
|
||||
ds1 = ComplexAdvancedDataset(config)
|
||||
ds2 = ComplexAdvancedDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = ComplexAdvancedConfig(seed=42, size=50)
|
||||
ds = ComplexAdvancedDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "complex_advanced"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = ComplexAdvancedConfig(seed=42, size=50)
|
||||
ds = ComplexAdvancedDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}, expected 1.0"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = ComplexAdvancedConfig(seed=42, size=10)
|
||||
ds = ComplexAdvancedDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("completely wrong", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = ComplexAdvancedCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 50
|
||||
|
||||
curriculum.increment_attr_level("max_real")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_real >= base_cfg.max_real
|
||||
|
||||
curriculum.decrement_attr_level("max_real")
|
||||
restored_cfg = curriculum.generate_configuration(base_value)
|
||||
assert restored_cfg.max_real == base_cfg.max_real
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("polar", "euler", "inverse", "sqrt", "quadratic"):
|
||||
config = ComplexAdvancedConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = ComplexAdvancedDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
79
tests/test_conditional_probability.py
Normal file
79
tests/test_conditional_probability.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.probability.conditional_probability import (
|
||||
ConditionalProbabilityConfig,
|
||||
ConditionalProbabilityCurriculum,
|
||||
ConditionalProbabilityDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = ConditionalProbabilityConfig(min_total_items=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = ConditionalProbabilityConfig(min_total_items=20, max_total_items=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = ConditionalProbabilityConfig(seed=42, size=10)
|
||||
ds1 = ConditionalProbabilityDataset(config)
|
||||
ds2 = ConditionalProbabilityDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = ConditionalProbabilityConfig(seed=42, size=50)
|
||||
ds = ConditionalProbabilityDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "conditional_probability"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = ConditionalProbabilityConfig(seed=42, size=50)
|
||||
ds = ConditionalProbabilityDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = ConditionalProbabilityConfig(seed=42, size=10)
|
||||
ds = ConditionalProbabilityDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("not a fraction", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = ConditionalProbabilityCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 50
|
||||
|
||||
curriculum.increment_attr_level("total_items")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_total_items >= base_cfg.max_total_items
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("bayes", "dependent_draws", "contingency_table"):
|
||||
config = ConditionalProbabilityConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = ConditionalProbabilityDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
78
tests/test_descriptive_stats.py
Normal file
78
tests/test_descriptive_stats.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.statistics.descriptive_stats import (
|
||||
DescriptiveStatsConfig,
|
||||
DescriptiveStatsCurriculum,
|
||||
DescriptiveStatsDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = DescriptiveStatsConfig(min_data_size=2)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = DescriptiveStatsConfig(min_data_size=10, max_data_size=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = DescriptiveStatsConfig(seed=42, size=10)
|
||||
ds1 = DescriptiveStatsDataset(config)
|
||||
ds2 = DescriptiveStatsDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = DescriptiveStatsConfig(seed=42, size=50)
|
||||
ds = DescriptiveStatsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "descriptive_stats"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = DescriptiveStatsConfig(seed=42, size=50)
|
||||
ds = DescriptiveStatsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = DescriptiveStatsConfig(seed=42, size=10)
|
||||
ds = DescriptiveStatsDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("definitely wrong", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = DescriptiveStatsCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("data_size")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_data_size >= base_cfg.max_data_size
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("mean", "median", "mode", "weighted_mean", "std_dev", "percentile", "z_score"):
|
||||
config = DescriptiveStatsConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = DescriptiveStatsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
77
tests/test_dynamic_programming.py
Normal file
77
tests/test_dynamic_programming.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.optimization.dynamic_programming import (
|
||||
DynamicProgrammingConfig,
|
||||
DynamicProgrammingCurriculum,
|
||||
DynamicProgrammingDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = DynamicProgrammingConfig(min_str_len=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = DynamicProgrammingConfig(min_arr_len=2)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = DynamicProgrammingConfig(seed=42, size=10)
|
||||
ds1 = DynamicProgrammingDataset(config)
|
||||
ds2 = DynamicProgrammingDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = DynamicProgrammingConfig(seed=42, size=50)
|
||||
ds = DynamicProgrammingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "dynamic_programming"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = DynamicProgrammingConfig(seed=42, size=50)
|
||||
ds = DynamicProgrammingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
answer = item["answer"]
|
||||
assert answer.lstrip("-").isdigit(), f"Item {i}: answer '{answer}' is not an integer"
|
||||
|
||||
|
||||
def test_score_answer():
|
||||
config = DynamicProgrammingConfig(seed=42, size=50)
|
||||
ds = DynamicProgrammingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = DynamicProgrammingCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("max_str_len")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_str_len >= base_cfg.max_str_len
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("lcs", "coin_change", "lis", "edit_distance", "staircase"):
|
||||
config = DynamicProgrammingConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = DynamicProgrammingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
92
tests/test_job_scheduling.py
Normal file
92
tests/test_job_scheduling.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.graphs.job_scheduling import (
|
||||
JobSchedulingConfig,
|
||||
JobSchedulingCurriculum,
|
||||
JobSchedulingDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = JobSchedulingConfig(min_jobs=2)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = JobSchedulingConfig(min_jobs=10, max_jobs=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = JobSchedulingConfig(seed=42, size=10)
|
||||
ds1 = JobSchedulingDataset(config)
|
||||
ds2 = JobSchedulingDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = JobSchedulingConfig(seed=42, size=50)
|
||||
ds = JobSchedulingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "job_scheduling"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = JobSchedulingConfig(seed=42, size=50)
|
||||
ds = JobSchedulingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_task_ordering_verification():
|
||||
config = JobSchedulingConfig(
|
||||
seed=42, size=20, task_types=("task_ordering",), task_weights=[1.0]
|
||||
)
|
||||
ds = JobSchedulingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
order = item["answer"].split(", ")
|
||||
deps = item["metadata"]["deps"]
|
||||
pos = {name: j for j, name in enumerate(order)}
|
||||
for j, prereqs in deps.items():
|
||||
for p in prereqs:
|
||||
assert pos[p] < pos[j], f"Dependency {p} -> {j} violated"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = JobSchedulingConfig(seed=42, size=10)
|
||||
ds = JobSchedulingDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = JobSchedulingCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("max_jobs")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_jobs >= base_cfg.max_jobs
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("critical_path", "interval_scheduling", "task_ordering"):
|
||||
config = JobSchedulingConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = JobSchedulingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
70
tests/test_knapsack.py
Normal file
70
tests/test_knapsack.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.optimization.knapsack import (
|
||||
KnapsackConfig,
|
||||
KnapsackCurriculum,
|
||||
KnapsackDataset,
|
||||
_solve_knapsack,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = KnapsackConfig(min_items=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = KnapsackConfig(min_items=10, max_items=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = KnapsackConfig(seed=42, size=10)
|
||||
ds1 = KnapsackDataset(config)
|
||||
ds2 = KnapsackDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = KnapsackConfig(seed=42, size=50)
|
||||
ds = KnapsackDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "knapsack"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = KnapsackConfig(seed=42, size=50)
|
||||
ds = KnapsackDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
weights = item["metadata"]["weights"]
|
||||
values = item["metadata"]["values"]
|
||||
capacity = item["metadata"]["capacity"]
|
||||
expected = _solve_knapsack(weights, values, capacity)
|
||||
assert int(item["answer"]) == expected, f"Item {i}: answer mismatch"
|
||||
|
||||
|
||||
def test_score_answer():
|
||||
config = KnapsackConfig(seed=42, size=10)
|
||||
ds = KnapsackDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = KnapsackCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("item_count")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_items >= base_cfg.max_items
|
||||
78
tests/test_limits.py
Normal file
78
tests/test_limits.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.algebra.limits import (
|
||||
LimitsConfig,
|
||||
LimitsCurriculum,
|
||||
LimitsDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = LimitsConfig(max_coeff=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LimitsConfig(max_degree=0)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = LimitsConfig(seed=42, size=10)
|
||||
ds1 = LimitsDataset(config)
|
||||
ds2 = LimitsDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = LimitsConfig(seed=42, size=50)
|
||||
ds = LimitsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "limits"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = LimitsConfig(seed=42, size=50)
|
||||
ds = LimitsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = LimitsConfig(seed=42, size=10)
|
||||
ds = LimitsDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("wrong", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = LimitsCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("max_coeff")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_coeff >= base_cfg.max_coeff
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("polynomial_cancel", "rational_infinity", "direct_sub", "squeeze"):
|
||||
config = LimitsConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = LimitsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
89
tests/test_linear_algebra.py
Normal file
89
tests/test_linear_algebra.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.algebra.linear_algebra import (
|
||||
LinearAlgebraConfig,
|
||||
LinearAlgebraCurriculum,
|
||||
LinearAlgebraDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = LinearAlgebraConfig(min_dim=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LinearAlgebraConfig(max_dim=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = LinearAlgebraConfig(seed=42, size=10)
|
||||
ds1 = LinearAlgebraDataset(config)
|
||||
ds2 = LinearAlgebraDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = LinearAlgebraConfig(seed=42, size=50)
|
||||
ds = LinearAlgebraDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "linear_algebra"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = LinearAlgebraConfig(seed=42, size=50)
|
||||
ds = LinearAlgebraDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_solve_system_verification():
|
||||
config = LinearAlgebraConfig(
|
||||
seed=42, size=20, task_types=("solve_system",), task_weights=[1.0]
|
||||
)
|
||||
ds = LinearAlgebraDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = LinearAlgebraConfig(seed=42, size=10)
|
||||
ds = LinearAlgebraDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("totally wrong", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = LinearAlgebraCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("max_dim")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_dim >= base_cfg.max_dim
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("matrix_multiply", "determinant", "inverse", "solve_system", "eigenvalues"):
|
||||
config = LinearAlgebraConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = LinearAlgebraDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
65
tests/test_linear_programming.py
Normal file
65
tests/test_linear_programming.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.optimization.linear_programming import (
|
||||
LinearProgrammingConfig,
|
||||
LinearProgrammingCurriculum,
|
||||
LinearProgrammingDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = LinearProgrammingConfig(min_coeff=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = LinearProgrammingConfig(num_constraints=1)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = LinearProgrammingConfig(seed=42, size=10)
|
||||
ds1 = LinearProgrammingDataset(config)
|
||||
ds2 = LinearProgrammingDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = LinearProgrammingConfig(seed=42, size=50)
|
||||
ds = LinearProgrammingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "linear_programming"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = LinearProgrammingConfig(seed=42, size=50)
|
||||
ds = LinearProgrammingDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = LinearProgrammingConfig(seed=42, size=10)
|
||||
ds = LinearProgrammingDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("not a number", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = LinearProgrammingCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("num_constraints")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.num_constraints >= base_cfg.num_constraints
|
||||
94
tests/test_number_theory.py
Normal file
94
tests/test_number_theory.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.number_theory import (
|
||||
NumberTheoryConfig,
|
||||
NumberTheoryCurriculum,
|
||||
NumberTheoryDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberTheoryConfig(min_value=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberTheoryConfig(min_value=50, max_value=10)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = NumberTheoryConfig(seed=42, size=10)
|
||||
ds1 = NumberTheoryDataset(config)
|
||||
ds2 = NumberTheoryDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = NumberTheoryConfig(seed=42, size=50)
|
||||
ds = NumberTheoryDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "number_theory"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = NumberTheoryConfig(seed=42, size=50)
|
||||
ds = NumberTheoryDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_diophantine_verification():
|
||||
config = NumberTheoryConfig(
|
||||
seed=42, size=20, task_types=("diophantine",), task_weights=[1.0]
|
||||
)
|
||||
ds = NumberTheoryDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
a = item["metadata"]["a"]
|
||||
b = item["metadata"]["b"]
|
||||
c = item["metadata"]["c"]
|
||||
parts = {}
|
||||
for part in item["answer"].split(","):
|
||||
k, v = part.split("=")
|
||||
parts[k.strip()] = int(v.strip())
|
||||
assert a * parts["x"] + b * parts["y"] == c
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = NumberTheoryConfig(seed=42, size=10)
|
||||
ds = NumberTheoryDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = NumberTheoryCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("value_range")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_value >= base_cfg.max_value
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("mod_arith", "mod_exp", "totient", "crt", "mod_inverse", "diophantine"):
|
||||
config = NumberTheoryConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = NumberTheoryDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
77
tests/test_regex_puzzles.py
Normal file
77
tests/test_regex_puzzles.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.languages.regex_puzzles import (
|
||||
RegexPuzzlesConfig,
|
||||
RegexPuzzlesCurriculum,
|
||||
RegexPuzzlesDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = RegexPuzzlesConfig(min_dfa_states=1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RegexPuzzlesConfig(min_dfa_states=10, max_dfa_states=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = RegexPuzzlesConfig(seed=42, size=10)
|
||||
ds1 = RegexPuzzlesDataset(config)
|
||||
ds2 = RegexPuzzlesDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = RegexPuzzlesConfig(seed=42, size=50)
|
||||
ds = RegexPuzzlesDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "regex_puzzles"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = RegexPuzzlesConfig(seed=42, size=50)
|
||||
ds = RegexPuzzlesDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = RegexPuzzlesConfig(seed=42, size=10)
|
||||
ds = RegexPuzzlesDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = RegexPuzzlesCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("max_dfa_states")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_dfa_states >= base_cfg.max_dfa_states
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("string_generation", "extraction", "dfa_state", "dfa_prefix"):
|
||||
config = RegexPuzzlesConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = RegexPuzzlesDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
78
tests/test_set_operations.py
Normal file
78
tests/test_set_operations.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.logic.set_operations import (
|
||||
SetOperationsConfig,
|
||||
SetOperationsCurriculum,
|
||||
SetOperationsDataset,
|
||||
)
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
with pytest.raises(AssertionError):
|
||||
config = SetOperationsConfig(min_set_size=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = SetOperationsConfig(min_set_size=10, max_set_size=5)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_deterministic():
|
||||
config = SetOperationsConfig(seed=42, size=10)
|
||||
ds1 = SetOperationsDataset(config)
|
||||
ds2 = SetOperationsDataset(config)
|
||||
for i in range(len(ds1)):
|
||||
assert ds1[i] == ds2[i]
|
||||
|
||||
|
||||
def test_item_structure():
|
||||
config = SetOperationsConfig(seed=42, size=50)
|
||||
ds = SetOperationsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
assert item["metadata"]["source_dataset"] == "set_operations"
|
||||
|
||||
|
||||
def test_answer_correctness():
|
||||
config = SetOperationsConfig(seed=42, size=50)
|
||||
ds = SetOperationsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Item {i}: oracle answer scored {score}"
|
||||
|
||||
|
||||
def test_score_wrong_answer():
|
||||
config = SetOperationsConfig(seed=42, size=10)
|
||||
ds = SetOperationsDataset(config)
|
||||
item = ds[0]
|
||||
assert ds.score_answer(None, item) == 0.0
|
||||
assert ds.score_answer("{999, 998, 997}", item) == 0.0
|
||||
|
||||
|
||||
def test_curriculum():
|
||||
curriculum = SetOperationsCurriculum()
|
||||
base_value = {"size": 50, "seed": 1}
|
||||
base_cfg = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
|
||||
curriculum.increment_attr_level("set_size")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.max_set_size >= base_cfg.max_set_size
|
||||
|
||||
|
||||
def test_task_types():
|
||||
for task_type in ("union", "intersection", "difference", "symmetric_difference", "cardinality", "power_set_size", "complement", "chained"):
|
||||
config = SetOperationsConfig(
|
||||
seed=42, size=10, task_types=(task_type,), task_weights=[1.0]
|
||||
)
|
||||
ds = SetOperationsDataset(config)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert item["metadata"]["task_type"] == task_type
|
||||
score = ds.score_answer(item["answer"], item)
|
||||
assert score >= 1.0, f"Task {task_type}, item {i}: oracle scored {score}"
|
||||
Loading…
Add table
Add a link
Reference in a new issue