feat: add scoring cascade for reducing false negatives (#526)

* feat: add scoring cascade for reducing false negatives in answer verification

* style: fix black and isort formatting

Run black and isort to satisfy pre-commit checks.

Made-with: Cursor

* docs: add scoring cascade example to Quickstart section

Mention the experimental scoring cascade feature at the end of the
Quickstart section with a disclaimer and complete usage examples
showing both the dataset method and standalone function.

Made-with: Cursor

* docs: shorten scoring cascade section in README

Trim to a concise standalone example per review feedback.

Made-with: Cursor

* docs: simplify scoring cascade description in README

Made-with: Cursor

* update readme

---------

Co-authored-by: Zafir Stojanovski <zaf.stojano@gmail.com>
This commit is contained in:
Ritvik Rastogi 2026-04-18 01:09:15 +05:30 committed by GitHub
parent 437e0b49c4
commit 49b07130b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 477 additions and 0 deletions

View file

@ -85,6 +85,14 @@ reasoning_gym.create_dataset('composite', size=10, seed=42, datasets=specs)
For the simplest way to get started training models with Reasoning Gym, we recommend using the `verifiers` library, which directly supports RG tasks. See `examples/verifiers` for details. However, RG data can be used with any major RL training framework. For the simplest way to get started training models with Reasoning Gym, we recommend using the `verifiers` library, which directly supports RG tasks. See `examples/verifiers` for details. However, RG data can be used with any major RL training framework.
The *cascade scorer* applies progressively lenient fallback matchers — string, numeric, and symbolic math — to reduce false negatives from formatting differences (LaTeX wrappers, casing, numeric representation). Install with `pip install reasoning-gym[scoring]` for symbolic math verification.
```python
from reasoning_gym import cascade_score
assert cascade_score(answer=r"\text{42}", expected="42") == 1.0
```
## 🔍 Evaluation ## 🔍 Evaluation
Instructions for running the evaluation scripts are provided in [eval/README.md](https://github.com/open-thought/reasoning-gym/blob/main/eval/README.md). Instructions for running the evaluation scripts are provided in [eval/README.md](https://github.com/open-thought/reasoning-gym/blob/main/eval/README.md).

View file

@ -49,6 +49,9 @@ cli = [
"pyyaml>=6.0.1", "pyyaml>=6.0.1",
"httpx>=0.27.0", "httpx>=0.27.0",
] ]
scoring = [
"math-verify>=0.7.0",
]
scripts = [ scripts = [
"datasets>=3.5.0" "datasets>=3.5.0"
] ]

View file

@ -18,6 +18,7 @@ from . import (
probability, probability,
) )
from .factory import create_dataset, get_score_answer_fn, register_dataset from .factory import create_dataset, get_score_answer_fn, register_dataset
from .scoring import cascade_score, float_match, math_match, string_match, strip_latex
__version__ = "0.1.19" __version__ = "0.1.19"
__all__ = [ __all__ = [
@ -37,4 +38,9 @@ __all__ = [
"create_dataset", "create_dataset",
"register_dataset", "register_dataset",
"get_score_answer_fn", "get_score_answer_fn",
"cascade_score",
"strip_latex",
"string_match",
"float_match",
"math_match",
] ]

View file

@ -71,6 +71,21 @@ class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]):
reward = len(oracle_answer) / len(answer) reward = len(oracle_answer) / len(answer)
return reward return reward
def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score with fallback cascade (LaTeX stripping, string, float, math matching).
Runs this dataset's ``score_answer`` first, then progressively more
lenient matchers. The cascade can only upgrade, never downgrade.
Requires ``pip install reasoning-gym[scoring]`` for the ``math_match``
step (other steps work without extra dependencies).
"""
from .scoring import cascade_score
if answer is None:
return 0.0
return cascade_score(answer, entry.get("answer", ""), dataset=self, entry=entry)
T = TypeVar("T", bound="ProceduralDataset") T = TypeVar("T", bound="ProceduralDataset")
@ -127,3 +142,7 @@ class ReseedingDataset(Iterable[dict[str, Any]]):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Forward scoring to the wrapped dataset's implementation""" """Forward scoring to the wrapped dataset's implementation"""
return self.dataset.score_answer(answer, entry) return self.dataset.score_answer(answer, entry)
def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Forward cascade scoring to the wrapped dataset's implementation"""
return self.dataset.score_answer_cascade(answer, entry)

190
reasoning_gym/scoring.py Normal file
View file

@ -0,0 +1,190 @@
"""Scoring cascade utilities for reducing false negatives in answer verification.
Provides a multi-step fallback cascade that wraps any dataset's ``score_answer``
with progressively more lenient matchers:
1. ``score_answer()`` -- environment's built-in verifier
1b. ``score_answer()`` -- retry after stripping LaTeX wrappers
2. ``string_match`` -- case-insensitive exact comparison
3. ``float_match`` -- numeric comparison with tolerance
4. ``math_match`` -- symbolic math via *math-verify*
The cascade can only *upgrade* a score, never downgrade it.
``math_match`` requires the optional ``math-verify`` package. When it is not
installed the step is silently skipped (returns 0.0). Install via::
pip install reasoning-gym[scoring]
"""
import re
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from .dataset import ProceduralDataset
# ---------------------------------------------------------------------------
# LaTeX normalisation
# ---------------------------------------------------------------------------
def strip_latex(s: str) -> str:
"""Remove common LaTeX wrappers and normalise whitespace.
Handles ``\\(\\)``, ``\\text{}``, ``\\mathrm{}``, double-backslash
linebreaks, tildes, and stray backslashes.
"""
s = re.sub(r"^\\\((.*)\\\)$", r"\1", s.strip())
s = re.sub(r"\\\((.*?)\\\)", r"\1", s)
s = re.sub(r"\\(?:text|mathrm)\{([^}]*)\}", r"\1", s)
s = re.sub(r"\\\\+", " ", s)
s = re.sub(r"~", " ", s)
s = re.sub(r"\\", "", s)
return re.sub(r"\s+", " ", s).strip()
# ---------------------------------------------------------------------------
# Individual matchers
# ---------------------------------------------------------------------------
def string_match(predicted: str, expected: str) -> float:
"""Case-insensitive exact string comparison after stripping whitespace."""
try:
return 1.0 if predicted.lower().strip() == expected.lower().strip() else 0.0
except Exception:
return 0.0
def float_match(
predicted: str,
expected: str,
rel_tol: float = 0.01,
abs_tol: float = 0.01,
) -> float:
"""Numeric comparison with configurable tolerance.
Accepts if ``|a - b| <= max(rel_tol * max(|a|, |b|), abs_tol)``.
Returns 0.0 for non-numeric strings.
"""
try:
a = float(predicted)
b = float(expected)
return 1.0 if abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) else 0.0
except Exception:
return 0.0
def math_match(predicted: str, expected: str) -> float:
"""Symbolic math verification via *math-verify*, with numeric fallback.
Strips dollar signs and common display-math delimiters before parsing.
Falls back to :func:`float_match` on the parsed numeric values when
symbolic ``verify`` returns ``False``.
Returns 0.0 when ``math-verify`` is not installed.
"""
try:
from math_verify import parse, verify
except ImportError:
return 0.0
try:
a = expected.strip("$")
b = predicted.strip("$")
for delim_open, delim_close in [(r"\[", r"\]"), (r"\(", r"\)"), (r"\,", r"\,")]:
if a.startswith(delim_open) and a.endswith(delim_close):
a = a[2:-2].strip()
if b.startswith(delim_open) and b.endswith(delim_close):
b = b[2:-2].strip()
pa = parse(f"${a}$")
pb = parse(f"${b}$")
if verify(pa, pb):
return 1.0
# Numeric fallback on the first parsed element
try:
va, vb = float(pa[0]), float(pb[0])
return 1.0 if abs(va - vb) <= max(0.01 * max(abs(va), abs(vb)), 0.01) else 0.0
except Exception:
return 0.0
except Exception:
return 0.0
def _mathrm_to_text(s: str) -> str:
r"""Replace ``\mathrm{…}`` with ``\text{…}`` for a second math_match attempt."""
return re.sub(r"\\mathrm\{([^}]*)\}", r"\\text{\1}", s)
# ---------------------------------------------------------------------------
# Full cascade
# ---------------------------------------------------------------------------
def cascade_score(
answer: str,
expected: str,
dataset: Optional["ProceduralDataset"] = None,
entry: Optional[dict[str, Any]] = None,
) -> float:
"""Apply the multi-step scoring cascade.
When *dataset* and *entry* are supplied the environment's own
``score_answer`` is tried first (steps 1 & 1b). The remaining steps
use only the raw answer strings and never require a dataset instance.
The cascade can only upgrade if an earlier step already returned
a near-perfect score (>= 0.99) it is returned immediately.
Args:
answer: The model's predicted answer string.
expected: The gold / oracle answer string.
dataset: Optional :class:`ProceduralDataset` whose ``score_answer``
should be tried first.
entry: The dataset entry dict (must contain at least ``"answer"``).
Required when *dataset* is provided.
Returns:
A score in ``[0.0, 1.0]``.
"""
best = 0.0
# Step 1: environment's built-in verifier
if dataset is not None and entry is not None:
try:
score = float(dataset.score_answer(answer, entry))
if score >= 0.99:
return score
best = max(best, score)
except Exception:
pass
# Step 1b: retry after stripping LaTeX
cleaned = strip_latex(answer)
if cleaned != answer:
try:
score = float(dataset.score_answer(cleaned, entry))
if score >= 0.99:
return score
best = max(best, score)
except Exception:
pass
# Steps 2-5: string / float / math cascade
for score in (
string_match(answer, expected),
string_match(strip_latex(answer), strip_latex(expected)),
float_match(answer, expected),
math_match(answer, expected),
math_match(_mathrm_to_text(answer), _mathrm_to_text(expected)),
):
if score >= 0.99:
return score
best = max(best, score)
return best

251
tests/test_scoring.py Normal file
View file

@ -0,0 +1,251 @@
import pytest
import reasoning_gym
from reasoning_gym.scoring import _mathrm_to_text, cascade_score, float_match, math_match, string_match, strip_latex
# ---------------------------------------------------------------------------
# strip_latex
# ---------------------------------------------------------------------------
class TestStripLatex:
def test_inline_math_delimiters(self):
assert strip_latex(r"\(42\)") == "42"
def test_inline_math_mid_string(self):
assert strip_latex(r"the value is \(x + 1\) here") == "the value is x + 1 here"
def test_text_command(self):
assert strip_latex(r"\text{hello world}") == "hello world"
def test_mathrm_command(self):
assert strip_latex(r"\mathrm{cm}") == "cm"
def test_double_backslash(self):
assert strip_latex(r"a \\ b") == "a b"
def test_tilde(self):
assert strip_latex("a~b") == "a b"
def test_stray_backslashes(self):
assert strip_latex(r"\alpha + \beta") == "alpha + beta"
def test_whitespace_normalisation(self):
assert strip_latex(" a b ") == "a b"
def test_combined(self):
assert strip_latex(r"\(\text{answer}\)") == "answer"
def test_plain_string_unchanged(self):
assert strip_latex("42") == "42"
assert strip_latex("hello") == "hello"
# ---------------------------------------------------------------------------
# string_match
# ---------------------------------------------------------------------------
class TestStringMatch:
def test_exact(self):
assert string_match("42", "42") == 1.0
def test_case_insensitive(self):
assert string_match("Hello", "hello") == 1.0
assert string_match("TRUE", "true") == 1.0
def test_whitespace_stripped(self):
assert string_match(" 42 ", "42") == 1.0
def test_mismatch(self):
assert string_match("42", "43") == 0.0
def test_empty_strings(self):
assert string_match("", "") == 1.0
def test_non_string_graceful(self):
assert string_match(None, "42") == 0.0
# ---------------------------------------------------------------------------
# float_match
# ---------------------------------------------------------------------------
class TestFloatMatch:
def test_exact(self):
assert float_match("3.14", "3.14") == 1.0
def test_within_tolerance(self):
assert float_match("100", "100.5") == 1.0
def test_outside_tolerance(self):
assert float_match("100", "102") == 0.0
def test_zero_tolerance(self):
assert float_match("0", "0.005") == 1.0
def test_negative(self):
assert float_match("-5.0", "-5.0") == 1.0
assert float_match("-5.0", "5.0") == 0.0
def test_non_numeric(self):
assert float_match("abc", "42") == 0.0
assert float_match("42", "abc") == 0.0
def test_custom_tolerance(self):
assert float_match("100", "110", rel_tol=0.15) == 1.0
assert float_match("100", "110", rel_tol=0.05) == 0.0
# ---------------------------------------------------------------------------
# math_match
# ---------------------------------------------------------------------------
class TestMathMatch:
def test_returns_zero_without_math_verify(self, monkeypatch):
"""When math-verify is not importable, math_match should return 0."""
import builtins
real_import = builtins.__import__
def mock_import(name, *args, **kwargs):
if name == "math_verify":
raise ImportError("mocked")
return real_import(name, *args, **kwargs)
monkeypatch.setattr(builtins, "__import__", mock_import)
assert math_match("42", "42") == 0.0
def test_dollar_sign_stripping(self):
result = math_match("$42$", "$42$")
assert result >= 0.0 # at least doesn't crash
def test_display_math_delimiters(self):
result = math_match(r"\[42\]", r"\[42\]")
assert result >= 0.0
def test_non_parseable_returns_zero(self):
assert math_match("not math at all ???", "also not math ???") == 0.0
# ---------------------------------------------------------------------------
# _mathrm_to_text helper
# ---------------------------------------------------------------------------
class TestMathrmToText:
def test_replaces_mathrm(self):
assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}"
def test_no_mathrm_unchanged(self):
assert _mathrm_to_text("42") == "42"
def test_multiple_occurrences(self):
s = r"\mathrm{kg} \cdot \mathrm{m}"
assert _mathrm_to_text(s) == r"\text{kg} \cdot \text{m}"
# ---------------------------------------------------------------------------
# cascade_score — without dataset
# ---------------------------------------------------------------------------
class TestCascadeScoreStandalone:
def test_exact_string(self):
assert cascade_score("42", "42") >= 0.99
def test_case_insensitive_string(self):
assert cascade_score("True", "true") >= 0.99
def test_latex_wrapped_string(self):
assert cascade_score(r"\text{42}", "42") >= 0.99
def test_numeric_tolerance(self):
assert cascade_score("100.05", "100") >= 0.99
def test_mismatch(self):
assert cascade_score("42", "99") == 0.0
def test_empty_answer(self):
assert cascade_score("", "42") == 0.0
# ---------------------------------------------------------------------------
# cascade_score — with a real dataset
# ---------------------------------------------------------------------------
class TestCascadeScoreWithDataset:
def test_chain_sum_exact(self):
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
entry = ds[0]
score = cascade_score(entry["answer"], entry["answer"], dataset=ds, entry=entry)
assert score == 1.0
def test_chain_sum_latex_wrapped(self):
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
entry = ds[0]
wrapped = rf"\text{{{entry['answer']}}}"
score = cascade_score(wrapped, entry["answer"], dataset=ds, entry=entry)
assert score >= 0.99
def test_never_downgrades(self):
"""The cascade should never return less than score_answer itself."""
ds = reasoning_gym.create_dataset("chain_sum", size=10, seed=123)
for entry in ds:
base = ds.score_answer(entry["answer"], entry)
cascaded = cascade_score(entry["answer"], entry["answer"], dataset=ds, entry=entry)
assert cascaded >= base
# ---------------------------------------------------------------------------
# ProceduralDataset.score_answer_cascade convenience method
# ---------------------------------------------------------------------------
class TestScoreAnswerCascadeMethod:
def test_method_exists(self):
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
assert hasattr(ds, "score_answer_cascade")
def test_oracle_answer_scores_one(self):
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
for entry in ds:
assert ds.score_answer_cascade(entry["answer"], entry) == 1.0
def test_none_answer_scores_zero(self):
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
entry = ds[0]
assert ds.score_answer_cascade(None, entry) == 0.0
def test_latex_wrapped_answer(self):
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
entry = ds[0]
wrapped = rf"\({entry['answer']}\)"
assert ds.score_answer_cascade(wrapped, entry) >= 0.99
def test_never_less_than_score_answer(self):
ds = reasoning_gym.create_dataset("chain_sum", size=10, seed=99)
for entry in ds:
base = ds.score_answer(entry["answer"], entry)
cascaded = ds.score_answer_cascade(entry["answer"], entry)
assert cascaded >= base
# ---------------------------------------------------------------------------
# Top-level imports
# ---------------------------------------------------------------------------
class TestTopLevelImports:
def test_cascade_score_importable(self):
from reasoning_gym import cascade_score as cs
assert callable(cs)
def test_matchers_importable(self):
from reasoning_gym import float_match, math_match, string_match, strip_latex
assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])