mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: add scoring cascade for reducing false negatives (#526)
* feat: add scoring cascade for reducing false negatives in answer verification * style: fix black and isort formatting Run black and isort to satisfy pre-commit checks. Made-with: Cursor * docs: add scoring cascade example to Quickstart section Mention the experimental scoring cascade feature at the end of the Quickstart section with a disclaimer and complete usage examples showing both the dataset method and standalone function. Made-with: Cursor * docs: shorten scoring cascade section in README Trim to a concise standalone example per review feedback. Made-with: Cursor * docs: simplify scoring cascade description in README Made-with: Cursor * update readme --------- Co-authored-by: Zafir Stojanovski <zaf.stojano@gmail.com>
This commit is contained in:
parent
437e0b49c4
commit
49b07130b3
6 changed files with 477 additions and 0 deletions
|
|
@ -85,6 +85,14 @@ reasoning_gym.create_dataset('composite', size=10, seed=42, datasets=specs)
|
||||||
|
|
||||||
For the simplest way to get started training models with Reasoning Gym, we recommend using the `verifiers` library, which directly supports RG tasks. See `examples/verifiers` for details. However, RG data can be used with any major RL training framework.
|
For the simplest way to get started training models with Reasoning Gym, we recommend using the `verifiers` library, which directly supports RG tasks. See `examples/verifiers` for details. However, RG data can be used with any major RL training framework.
|
||||||
|
|
||||||
|
The *cascade scorer* applies progressively lenient fallback matchers — string, numeric, and symbolic math — to reduce false negatives from formatting differences (LaTeX wrappers, casing, numeric representation). Install with `pip install reasoning-gym[scoring]` for symbolic math verification.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from reasoning_gym import cascade_score
|
||||||
|
|
||||||
|
assert cascade_score(answer=r"\text{42}", expected="42") == 1.0
|
||||||
|
```
|
||||||
|
|
||||||
## 🔍 Evaluation
|
## 🔍 Evaluation
|
||||||
|
|
||||||
Instructions for running the evaluation scripts are provided in [eval/README.md](https://github.com/open-thought/reasoning-gym/blob/main/eval/README.md).
|
Instructions for running the evaluation scripts are provided in [eval/README.md](https://github.com/open-thought/reasoning-gym/blob/main/eval/README.md).
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,9 @@ cli = [
|
||||||
"pyyaml>=6.0.1",
|
"pyyaml>=6.0.1",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
]
|
]
|
||||||
|
scoring = [
|
||||||
|
"math-verify>=0.7.0",
|
||||||
|
]
|
||||||
scripts = [
|
scripts = [
|
||||||
"datasets>=3.5.0"
|
"datasets>=3.5.0"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from . import (
|
||||||
probability,
|
probability,
|
||||||
)
|
)
|
||||||
from .factory import create_dataset, get_score_answer_fn, register_dataset
|
from .factory import create_dataset, get_score_answer_fn, register_dataset
|
||||||
|
from .scoring import cascade_score, float_match, math_match, string_match, strip_latex
|
||||||
|
|
||||||
__version__ = "0.1.19"
|
__version__ = "0.1.19"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -37,4 +38,9 @@ __all__ = [
|
||||||
"create_dataset",
|
"create_dataset",
|
||||||
"register_dataset",
|
"register_dataset",
|
||||||
"get_score_answer_fn",
|
"get_score_answer_fn",
|
||||||
|
"cascade_score",
|
||||||
|
"strip_latex",
|
||||||
|
"string_match",
|
||||||
|
"float_match",
|
||||||
|
"math_match",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,21 @@ class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]):
|
||||||
reward = len(oracle_answer) / len(answer)
|
reward = len(oracle_answer) / len(answer)
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
"""Score with fallback cascade (LaTeX stripping, string, float, math matching).
|
||||||
|
|
||||||
|
Runs this dataset's ``score_answer`` first, then progressively more
|
||||||
|
lenient matchers. The cascade can only upgrade, never downgrade.
|
||||||
|
|
||||||
|
Requires ``pip install reasoning-gym[scoring]`` for the ``math_match``
|
||||||
|
step (other steps work without extra dependencies).
|
||||||
|
"""
|
||||||
|
from .scoring import cascade_score
|
||||||
|
|
||||||
|
if answer is None:
|
||||||
|
return 0.0
|
||||||
|
return cascade_score(answer, entry.get("answer", ""), dataset=self, entry=entry)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="ProceduralDataset")
|
T = TypeVar("T", bound="ProceduralDataset")
|
||||||
|
|
||||||
|
|
@ -127,3 +142,7 @@ class ReseedingDataset(Iterable[dict[str, Any]]):
|
||||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
"""Forward scoring to the wrapped dataset's implementation"""
|
"""Forward scoring to the wrapped dataset's implementation"""
|
||||||
return self.dataset.score_answer(answer, entry)
|
return self.dataset.score_answer(answer, entry)
|
||||||
|
|
||||||
|
def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
"""Forward cascade scoring to the wrapped dataset's implementation"""
|
||||||
|
return self.dataset.score_answer_cascade(answer, entry)
|
||||||
|
|
|
||||||
190
reasoning_gym/scoring.py
Normal file
190
reasoning_gym/scoring.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
"""Scoring cascade utilities for reducing false negatives in answer verification.
|
||||||
|
|
||||||
|
Provides a multi-step fallback cascade that wraps any dataset's ``score_answer``
|
||||||
|
with progressively more lenient matchers:
|
||||||
|
|
||||||
|
1. ``score_answer()`` -- environment's built-in verifier
|
||||||
|
1b. ``score_answer()`` -- retry after stripping LaTeX wrappers
|
||||||
|
2. ``string_match`` -- case-insensitive exact comparison
|
||||||
|
3. ``float_match`` -- numeric comparison with tolerance
|
||||||
|
4. ``math_match`` -- symbolic math via *math-verify*
|
||||||
|
|
||||||
|
The cascade can only *upgrade* a score, never downgrade it.
|
||||||
|
|
||||||
|
``math_match`` requires the optional ``math-verify`` package. When it is not
|
||||||
|
installed the step is silently skipped (returns 0.0). Install via::
|
||||||
|
|
||||||
|
pip install reasoning-gym[scoring]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LaTeX normalisation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def strip_latex(s: str) -> str:
|
||||||
|
"""Remove common LaTeX wrappers and normalise whitespace.
|
||||||
|
|
||||||
|
Handles ``\\(…\\)``, ``\\text{}``, ``\\mathrm{}``, double-backslash
|
||||||
|
linebreaks, tildes, and stray backslashes.
|
||||||
|
"""
|
||||||
|
s = re.sub(r"^\\\((.*)\\\)$", r"\1", s.strip())
|
||||||
|
s = re.sub(r"\\\((.*?)\\\)", r"\1", s)
|
||||||
|
s = re.sub(r"\\(?:text|mathrm)\{([^}]*)\}", r"\1", s)
|
||||||
|
s = re.sub(r"\\\\+", " ", s)
|
||||||
|
s = re.sub(r"~", " ", s)
|
||||||
|
s = re.sub(r"\\", "", s)
|
||||||
|
return re.sub(r"\s+", " ", s).strip()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Individual matchers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def string_match(predicted: str, expected: str) -> float:
|
||||||
|
"""Case-insensitive exact string comparison after stripping whitespace."""
|
||||||
|
try:
|
||||||
|
return 1.0 if predicted.lower().strip() == expected.lower().strip() else 0.0
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def float_match(
|
||||||
|
predicted: str,
|
||||||
|
expected: str,
|
||||||
|
rel_tol: float = 0.01,
|
||||||
|
abs_tol: float = 0.01,
|
||||||
|
) -> float:
|
||||||
|
"""Numeric comparison with configurable tolerance.
|
||||||
|
|
||||||
|
Accepts if ``|a - b| <= max(rel_tol * max(|a|, |b|), abs_tol)``.
|
||||||
|
Returns 0.0 for non-numeric strings.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
a = float(predicted)
|
||||||
|
b = float(expected)
|
||||||
|
return 1.0 if abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) else 0.0
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def math_match(predicted: str, expected: str) -> float:
|
||||||
|
"""Symbolic math verification via *math-verify*, with numeric fallback.
|
||||||
|
|
||||||
|
Strips dollar signs and common display-math delimiters before parsing.
|
||||||
|
Falls back to :func:`float_match` on the parsed numeric values when
|
||||||
|
symbolic ``verify`` returns ``False``.
|
||||||
|
|
||||||
|
Returns 0.0 when ``math-verify`` is not installed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from math_verify import parse, verify
|
||||||
|
except ImportError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
a = expected.strip("$")
|
||||||
|
b = predicted.strip("$")
|
||||||
|
|
||||||
|
for delim_open, delim_close in [(r"\[", r"\]"), (r"\(", r"\)"), (r"\,", r"\,")]:
|
||||||
|
if a.startswith(delim_open) and a.endswith(delim_close):
|
||||||
|
a = a[2:-2].strip()
|
||||||
|
if b.startswith(delim_open) and b.endswith(delim_close):
|
||||||
|
b = b[2:-2].strip()
|
||||||
|
|
||||||
|
pa = parse(f"${a}$")
|
||||||
|
pb = parse(f"${b}$")
|
||||||
|
|
||||||
|
if verify(pa, pb):
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# Numeric fallback on the first parsed element
|
||||||
|
try:
|
||||||
|
va, vb = float(pa[0]), float(pb[0])
|
||||||
|
return 1.0 if abs(va - vb) <= max(0.01 * max(abs(va), abs(vb)), 0.01) else 0.0
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _mathrm_to_text(s: str) -> str:
|
||||||
|
r"""Replace ``\mathrm{…}`` with ``\text{…}`` for a second math_match attempt."""
|
||||||
|
return re.sub(r"\\mathrm\{([^}]*)\}", r"\\text{\1}", s)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Full cascade
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def cascade_score(
|
||||||
|
answer: str,
|
||||||
|
expected: str,
|
||||||
|
dataset: Optional["ProceduralDataset"] = None,
|
||||||
|
entry: Optional[dict[str, Any]] = None,
|
||||||
|
) -> float:
|
||||||
|
"""Apply the multi-step scoring cascade.
|
||||||
|
|
||||||
|
When *dataset* and *entry* are supplied the environment's own
|
||||||
|
``score_answer`` is tried first (steps 1 & 1b). The remaining steps
|
||||||
|
use only the raw answer strings and never require a dataset instance.
|
||||||
|
|
||||||
|
The cascade can only upgrade — if an earlier step already returned
|
||||||
|
a near-perfect score (>= 0.99) it is returned immediately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
answer: The model's predicted answer string.
|
||||||
|
expected: The gold / oracle answer string.
|
||||||
|
dataset: Optional :class:`ProceduralDataset` whose ``score_answer``
|
||||||
|
should be tried first.
|
||||||
|
entry: The dataset entry dict (must contain at least ``"answer"``).
|
||||||
|
Required when *dataset* is provided.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A score in ``[0.0, 1.0]``.
|
||||||
|
"""
|
||||||
|
best = 0.0
|
||||||
|
|
||||||
|
# Step 1: environment's built-in verifier
|
||||||
|
if dataset is not None and entry is not None:
|
||||||
|
try:
|
||||||
|
score = float(dataset.score_answer(answer, entry))
|
||||||
|
if score >= 0.99:
|
||||||
|
return score
|
||||||
|
best = max(best, score)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Step 1b: retry after stripping LaTeX
|
||||||
|
cleaned = strip_latex(answer)
|
||||||
|
if cleaned != answer:
|
||||||
|
try:
|
||||||
|
score = float(dataset.score_answer(cleaned, entry))
|
||||||
|
if score >= 0.99:
|
||||||
|
return score
|
||||||
|
best = max(best, score)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Steps 2-5: string / float / math cascade
|
||||||
|
for score in (
|
||||||
|
string_match(answer, expected),
|
||||||
|
string_match(strip_latex(answer), strip_latex(expected)),
|
||||||
|
float_match(answer, expected),
|
||||||
|
math_match(answer, expected),
|
||||||
|
math_match(_mathrm_to_text(answer), _mathrm_to_text(expected)),
|
||||||
|
):
|
||||||
|
if score >= 0.99:
|
||||||
|
return score
|
||||||
|
best = max(best, score)
|
||||||
|
|
||||||
|
return best
|
||||||
251
tests/test_scoring.py
Normal file
251
tests/test_scoring.py
Normal file
|
|
@ -0,0 +1,251 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import reasoning_gym
|
||||||
|
from reasoning_gym.scoring import _mathrm_to_text, cascade_score, float_match, math_match, string_match, strip_latex
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# strip_latex
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStripLatex:
|
||||||
|
def test_inline_math_delimiters(self):
|
||||||
|
assert strip_latex(r"\(42\)") == "42"
|
||||||
|
|
||||||
|
def test_inline_math_mid_string(self):
|
||||||
|
assert strip_latex(r"the value is \(x + 1\) here") == "the value is x + 1 here"
|
||||||
|
|
||||||
|
def test_text_command(self):
|
||||||
|
assert strip_latex(r"\text{hello world}") == "hello world"
|
||||||
|
|
||||||
|
def test_mathrm_command(self):
|
||||||
|
assert strip_latex(r"\mathrm{cm}") == "cm"
|
||||||
|
|
||||||
|
def test_double_backslash(self):
|
||||||
|
assert strip_latex(r"a \\ b") == "a b"
|
||||||
|
|
||||||
|
def test_tilde(self):
|
||||||
|
assert strip_latex("a~b") == "a b"
|
||||||
|
|
||||||
|
def test_stray_backslashes(self):
|
||||||
|
assert strip_latex(r"\alpha + \beta") == "alpha + beta"
|
||||||
|
|
||||||
|
def test_whitespace_normalisation(self):
|
||||||
|
assert strip_latex(" a b ") == "a b"
|
||||||
|
|
||||||
|
def test_combined(self):
|
||||||
|
assert strip_latex(r"\(\text{answer}\)") == "answer"
|
||||||
|
|
||||||
|
def test_plain_string_unchanged(self):
|
||||||
|
assert strip_latex("42") == "42"
|
||||||
|
assert strip_latex("hello") == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# string_match
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStringMatch:
|
||||||
|
def test_exact(self):
|
||||||
|
assert string_match("42", "42") == 1.0
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
assert string_match("Hello", "hello") == 1.0
|
||||||
|
assert string_match("TRUE", "true") == 1.0
|
||||||
|
|
||||||
|
def test_whitespace_stripped(self):
|
||||||
|
assert string_match(" 42 ", "42") == 1.0
|
||||||
|
|
||||||
|
def test_mismatch(self):
|
||||||
|
assert string_match("42", "43") == 0.0
|
||||||
|
|
||||||
|
def test_empty_strings(self):
|
||||||
|
assert string_match("", "") == 1.0
|
||||||
|
|
||||||
|
def test_non_string_graceful(self):
|
||||||
|
assert string_match(None, "42") == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# float_match
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFloatMatch:
|
||||||
|
def test_exact(self):
|
||||||
|
assert float_match("3.14", "3.14") == 1.0
|
||||||
|
|
||||||
|
def test_within_tolerance(self):
|
||||||
|
assert float_match("100", "100.5") == 1.0
|
||||||
|
|
||||||
|
def test_outside_tolerance(self):
|
||||||
|
assert float_match("100", "102") == 0.0
|
||||||
|
|
||||||
|
def test_zero_tolerance(self):
|
||||||
|
assert float_match("0", "0.005") == 1.0
|
||||||
|
|
||||||
|
def test_negative(self):
|
||||||
|
assert float_match("-5.0", "-5.0") == 1.0
|
||||||
|
assert float_match("-5.0", "5.0") == 0.0
|
||||||
|
|
||||||
|
def test_non_numeric(self):
|
||||||
|
assert float_match("abc", "42") == 0.0
|
||||||
|
assert float_match("42", "abc") == 0.0
|
||||||
|
|
||||||
|
def test_custom_tolerance(self):
|
||||||
|
assert float_match("100", "110", rel_tol=0.15) == 1.0
|
||||||
|
assert float_match("100", "110", rel_tol=0.05) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# math_match
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMathMatch:
|
||||||
|
def test_returns_zero_without_math_verify(self, monkeypatch):
|
||||||
|
"""When math-verify is not importable, math_match should return 0."""
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
def mock_import(name, *args, **kwargs):
|
||||||
|
if name == "math_verify":
|
||||||
|
raise ImportError("mocked")
|
||||||
|
return real_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(builtins, "__import__", mock_import)
|
||||||
|
assert math_match("42", "42") == 0.0
|
||||||
|
|
||||||
|
def test_dollar_sign_stripping(self):
|
||||||
|
result = math_match("$42$", "$42$")
|
||||||
|
assert result >= 0.0 # at least doesn't crash
|
||||||
|
|
||||||
|
def test_display_math_delimiters(self):
|
||||||
|
result = math_match(r"\[42\]", r"\[42\]")
|
||||||
|
assert result >= 0.0
|
||||||
|
|
||||||
|
def test_non_parseable_returns_zero(self):
|
||||||
|
assert math_match("not math at all ???", "also not math ???") == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _mathrm_to_text helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMathrmToText:
|
||||||
|
def test_replaces_mathrm(self):
|
||||||
|
assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}"
|
||||||
|
|
||||||
|
def test_no_mathrm_unchanged(self):
|
||||||
|
assert _mathrm_to_text("42") == "42"
|
||||||
|
|
||||||
|
def test_multiple_occurrences(self):
|
||||||
|
s = r"\mathrm{kg} \cdot \mathrm{m}"
|
||||||
|
assert _mathrm_to_text(s) == r"\text{kg} \cdot \text{m}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cascade_score — without dataset
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCascadeScoreStandalone:
|
||||||
|
def test_exact_string(self):
|
||||||
|
assert cascade_score("42", "42") >= 0.99
|
||||||
|
|
||||||
|
def test_case_insensitive_string(self):
|
||||||
|
assert cascade_score("True", "true") >= 0.99
|
||||||
|
|
||||||
|
def test_latex_wrapped_string(self):
|
||||||
|
assert cascade_score(r"\text{42}", "42") >= 0.99
|
||||||
|
|
||||||
|
def test_numeric_tolerance(self):
|
||||||
|
assert cascade_score("100.05", "100") >= 0.99
|
||||||
|
|
||||||
|
def test_mismatch(self):
|
||||||
|
assert cascade_score("42", "99") == 0.0
|
||||||
|
|
||||||
|
def test_empty_answer(self):
|
||||||
|
assert cascade_score("", "42") == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cascade_score — with a real dataset
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCascadeScoreWithDataset:
|
||||||
|
def test_chain_sum_exact(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
|
||||||
|
entry = ds[0]
|
||||||
|
score = cascade_score(entry["answer"], entry["answer"], dataset=ds, entry=entry)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
def test_chain_sum_latex_wrapped(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
|
||||||
|
entry = ds[0]
|
||||||
|
wrapped = rf"\text{{{entry['answer']}}}"
|
||||||
|
score = cascade_score(wrapped, entry["answer"], dataset=ds, entry=entry)
|
||||||
|
assert score >= 0.99
|
||||||
|
|
||||||
|
def test_never_downgrades(self):
|
||||||
|
"""The cascade should never return less than score_answer itself."""
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=10, seed=123)
|
||||||
|
for entry in ds:
|
||||||
|
base = ds.score_answer(entry["answer"], entry)
|
||||||
|
cascaded = cascade_score(entry["answer"], entry["answer"], dataset=ds, entry=entry)
|
||||||
|
assert cascaded >= base
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ProceduralDataset.score_answer_cascade convenience method
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreAnswerCascadeMethod:
|
||||||
|
def test_method_exists(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
|
||||||
|
assert hasattr(ds, "score_answer_cascade")
|
||||||
|
|
||||||
|
def test_oracle_answer_scores_one(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
|
||||||
|
for entry in ds:
|
||||||
|
assert ds.score_answer_cascade(entry["answer"], entry) == 1.0
|
||||||
|
|
||||||
|
def test_none_answer_scores_zero(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
|
||||||
|
entry = ds[0]
|
||||||
|
assert ds.score_answer_cascade(None, entry) == 0.0
|
||||||
|
|
||||||
|
def test_latex_wrapped_answer(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
|
||||||
|
entry = ds[0]
|
||||||
|
wrapped = rf"\({entry['answer']}\)"
|
||||||
|
assert ds.score_answer_cascade(wrapped, entry) >= 0.99
|
||||||
|
|
||||||
|
def test_never_less_than_score_answer(self):
|
||||||
|
ds = reasoning_gym.create_dataset("chain_sum", size=10, seed=99)
|
||||||
|
for entry in ds:
|
||||||
|
base = ds.score_answer(entry["answer"], entry)
|
||||||
|
cascaded = ds.score_answer_cascade(entry["answer"], entry)
|
||||||
|
assert cascaded >= base
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level imports
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopLevelImports:
|
||||||
|
def test_cascade_score_importable(self):
|
||||||
|
from reasoning_gym import cascade_score as cs
|
||||||
|
|
||||||
|
assert callable(cs)
|
||||||
|
|
||||||
|
def test_matchers_importable(self):
|
||||||
|
from reasoning_gym import float_match, math_match, string_match, strip_latex
|
||||||
|
|
||||||
|
assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])
|
||||||
Loading…
Add table
Add a link
Reference in a new issue