diff --git a/README.md b/README.md index 923beb2c..d7e3d8e0 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,14 @@ reasoning_gym.create_dataset('composite', size=10, seed=42, datasets=specs) For the simplest way to get started training models with Reasoning Gym, we recommend using the `verifiers` library, which directly supports RG tasks. See `examples/verifiers` for details. However, RG data can be used with any major RL training framework. +The *cascade scorer* applies progressively lenient fallback matchers — string, numeric, and symbolic math — to reduce false negatives from formatting differences (LaTeX wrappers, casing, numeric representation). Install with `pip install reasoning-gym[scoring]` for symbolic math verification. + +```python +from reasoning_gym import cascade_score + +assert cascade_score(answer=r"\text{42}", expected="42") == 1.0 +``` + ## 🔍 Evaluation Instructions for running the evaluation scripts are provided in [eval/README.md](https://github.com/open-thought/reasoning-gym/blob/main/eval/README.md). diff --git a/pyproject.toml b/pyproject.toml index c3551afa..94513b1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ cli = [ "pyyaml>=6.0.1", "httpx>=0.27.0", ] +scoring = [ + "math-verify>=0.7.0", +] scripts = [ "datasets>=3.5.0" ] diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index f98e49af..b0dbca4f 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -18,6 +18,7 @@ from . import ( probability, ) from .factory import create_dataset, get_score_answer_fn, register_dataset +from .scoring import cascade_score, float_match, math_match, string_match, strip_latex __version__ = "0.1.19" __all__ = [ @@ -37,4 +38,9 @@ __all__ = [ "create_dataset", "register_dataset", "get_score_answer_fn", + "cascade_score", + "strip_latex", + "string_match", + "float_match", + "math_match", ] diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 29aea330..7cef17ce 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -71,6 +71,21 @@ class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]): reward = len(oracle_answer) / len(answer) return reward + def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float: + """Score with fallback cascade (LaTeX stripping, string, float, math matching). + + Runs this dataset's ``score_answer`` first, then progressively more + lenient matchers. The cascade can only upgrade, never downgrade. + + Requires ``pip install reasoning-gym[scoring]`` for the ``math_match`` + step (other steps work without extra dependencies). + """ + from .scoring import cascade_score + + if answer is None: + return 0.0 + return cascade_score(answer, entry.get("answer", ""), dataset=self, entry=entry) + T = TypeVar("T", bound="ProceduralDataset") @@ -127,3 +142,7 @@ class ReseedingDataset(Iterable[dict[str, Any]]): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Forward scoring to the wrapped dataset's implementation""" return self.dataset.score_answer(answer, entry) + + def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float: + """Forward cascade scoring to the wrapped dataset's implementation""" + return self.dataset.score_answer_cascade(answer, entry) diff --git a/reasoning_gym/scoring.py b/reasoning_gym/scoring.py new file mode 100644 index 00000000..2bd12e64 --- /dev/null +++ b/reasoning_gym/scoring.py @@ -0,0 +1,190 @@ +"""Scoring cascade utilities for reducing false negatives in answer verification. + +Provides a multi-step fallback cascade that wraps any dataset's ``score_answer`` +with progressively more lenient matchers: + + 1. ``score_answer()`` -- environment's built-in verifier + 1b. ``score_answer()`` -- retry after stripping LaTeX wrappers + 2. ``string_match`` -- case-insensitive exact comparison + 3. ``float_match`` -- numeric comparison with tolerance + 4. ``math_match`` -- symbolic math via *math-verify* + +The cascade can only *upgrade* a score, never downgrade it. + +``math_match`` requires the optional ``math-verify`` package. When it is not +installed the step is silently skipped (returns 0.0). Install via:: + + pip install reasoning-gym[scoring] +""" + +import re +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from .dataset import ProceduralDataset + + +# --------------------------------------------------------------------------- +# LaTeX normalisation +# --------------------------------------------------------------------------- + + +def strip_latex(s: str) -> str: + """Remove common LaTeX wrappers and normalise whitespace. + + Handles ``\\(…\\)``, ``\\text{}``, ``\\mathrm{}``, double-backslash + linebreaks, tildes, and stray backslashes. + """ + s = re.sub(r"^\\\((.*)\\\)$", r"\1", s.strip()) + s = re.sub(r"\\\((.*?)\\\)", r"\1", s) + s = re.sub(r"\\(?:text|mathrm)\{([^}]*)\}", r"\1", s) + s = re.sub(r"\\\\+", " ", s) + s = re.sub(r"~", " ", s) + s = re.sub(r"\\", "", s) + return re.sub(r"\s+", " ", s).strip() + + +# --------------------------------------------------------------------------- +# Individual matchers +# --------------------------------------------------------------------------- + + +def string_match(predicted: str, expected: str) -> float: + """Case-insensitive exact string comparison after stripping whitespace.""" + try: + return 1.0 if predicted.lower().strip() == expected.lower().strip() else 0.0 + except Exception: + return 0.0 + + +def float_match( + predicted: str, + expected: str, + rel_tol: float = 0.01, + abs_tol: float = 0.01, +) -> float: + """Numeric comparison with configurable tolerance. + + Accepts if ``|a - b| <= max(rel_tol * max(|a|, |b|), abs_tol)``. + Returns 0.0 for non-numeric strings. + """ + try: + a = float(predicted) + b = float(expected) + return 1.0 if abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) else 0.0 + except Exception: + return 0.0 + + +def math_match(predicted: str, expected: str) -> float: + """Symbolic math verification via *math-verify*, with numeric fallback. + + Strips dollar signs and common display-math delimiters before parsing. + Falls back to :func:`float_match` on the parsed numeric values when + symbolic ``verify`` returns ``False``. + + Returns 0.0 when ``math-verify`` is not installed. + """ + try: + from math_verify import parse, verify + except ImportError: + return 0.0 + + try: + a = expected.strip("$") + b = predicted.strip("$") + + for delim_open, delim_close in [(r"\[", r"\]"), (r"\(", r"\)"), (r"\,", r"\,")]: + if a.startswith(delim_open) and a.endswith(delim_close): + a = a[2:-2].strip() + if b.startswith(delim_open) and b.endswith(delim_close): + b = b[2:-2].strip() + + pa = parse(f"${a}$") + pb = parse(f"${b}$") + + if verify(pa, pb): + return 1.0 + + # Numeric fallback on the first parsed element + try: + va, vb = float(pa[0]), float(pb[0]) + return 1.0 if abs(va - vb) <= max(0.01 * max(abs(va), abs(vb)), 0.01) else 0.0 + except Exception: + return 0.0 + except Exception: + return 0.0 + + +def _mathrm_to_text(s: str) -> str: + r"""Replace ``\mathrm{…}`` with ``\text{…}`` for a second math_match attempt.""" + return re.sub(r"\\mathrm\{([^}]*)\}", r"\\text{\1}", s) + + +# --------------------------------------------------------------------------- +# Full cascade +# --------------------------------------------------------------------------- + + +def cascade_score( + answer: str, + expected: str, + dataset: Optional["ProceduralDataset"] = None, + entry: Optional[dict[str, Any]] = None, +) -> float: + """Apply the multi-step scoring cascade. + + When *dataset* and *entry* are supplied the environment's own + ``score_answer`` is tried first (steps 1 & 1b). The remaining steps + use only the raw answer strings and never require a dataset instance. + + The cascade can only upgrade — if an earlier step already returned + a near-perfect score (>= 0.99) it is returned immediately. + + Args: + answer: The model's predicted answer string. + expected: The gold / oracle answer string. + dataset: Optional :class:`ProceduralDataset` whose ``score_answer`` + should be tried first. + entry: The dataset entry dict (must contain at least ``"answer"``). + Required when *dataset* is provided. + + Returns: + A score in ``[0.0, 1.0]``. + """ + best = 0.0 + + # Step 1: environment's built-in verifier + if dataset is not None and entry is not None: + try: + score = float(dataset.score_answer(answer, entry)) + if score >= 0.99: + return score + best = max(best, score) + except Exception: + pass + + # Step 1b: retry after stripping LaTeX + cleaned = strip_latex(answer) + if cleaned != answer: + try: + score = float(dataset.score_answer(cleaned, entry)) + if score >= 0.99: + return score + best = max(best, score) + except Exception: + pass + + # Steps 2-5: string / float / math cascade + for score in ( + string_match(answer, expected), + string_match(strip_latex(answer), strip_latex(expected)), + float_match(answer, expected), + math_match(answer, expected), + math_match(_mathrm_to_text(answer), _mathrm_to_text(expected)), + ): + if score >= 0.99: + return score + best = max(best, score) + + return best diff --git a/tests/test_scoring.py b/tests/test_scoring.py new file mode 100644 index 00000000..c8b1540b --- /dev/null +++ b/tests/test_scoring.py @@ -0,0 +1,251 @@ +import pytest + +import reasoning_gym +from reasoning_gym.scoring import _mathrm_to_text, cascade_score, float_match, math_match, string_match, strip_latex + +# --------------------------------------------------------------------------- +# strip_latex +# --------------------------------------------------------------------------- + + +class TestStripLatex: + def test_inline_math_delimiters(self): + assert strip_latex(r"\(42\)") == "42" + + def test_inline_math_mid_string(self): + assert strip_latex(r"the value is \(x + 1\) here") == "the value is x + 1 here" + + def test_text_command(self): + assert strip_latex(r"\text{hello world}") == "hello world" + + def test_mathrm_command(self): + assert strip_latex(r"\mathrm{cm}") == "cm" + + def test_double_backslash(self): + assert strip_latex(r"a \\ b") == "a b" + + def test_tilde(self): + assert strip_latex("a~b") == "a b" + + def test_stray_backslashes(self): + assert strip_latex(r"\alpha + \beta") == "alpha + beta" + + def test_whitespace_normalisation(self): + assert strip_latex(" a b ") == "a b" + + def test_combined(self): + assert strip_latex(r"\(\text{answer}\)") == "answer" + + def test_plain_string_unchanged(self): + assert strip_latex("42") == "42" + assert strip_latex("hello") == "hello" + + +# --------------------------------------------------------------------------- +# string_match +# --------------------------------------------------------------------------- + + +class TestStringMatch: + def test_exact(self): + assert string_match("42", "42") == 1.0 + + def test_case_insensitive(self): + assert string_match("Hello", "hello") == 1.0 + assert string_match("TRUE", "true") == 1.0 + + def test_whitespace_stripped(self): + assert string_match(" 42 ", "42") == 1.0 + + def test_mismatch(self): + assert string_match("42", "43") == 0.0 + + def test_empty_strings(self): + assert string_match("", "") == 1.0 + + def test_non_string_graceful(self): + assert string_match(None, "42") == 0.0 + + +# --------------------------------------------------------------------------- +# float_match +# --------------------------------------------------------------------------- + + +class TestFloatMatch: + def test_exact(self): + assert float_match("3.14", "3.14") == 1.0 + + def test_within_tolerance(self): + assert float_match("100", "100.5") == 1.0 + + def test_outside_tolerance(self): + assert float_match("100", "102") == 0.0 + + def test_zero_tolerance(self): + assert float_match("0", "0.005") == 1.0 + + def test_negative(self): + assert float_match("-5.0", "-5.0") == 1.0 + assert float_match("-5.0", "5.0") == 0.0 + + def test_non_numeric(self): + assert float_match("abc", "42") == 0.0 + assert float_match("42", "abc") == 0.0 + + def test_custom_tolerance(self): + assert float_match("100", "110", rel_tol=0.15) == 1.0 + assert float_match("100", "110", rel_tol=0.05) == 0.0 + + +# --------------------------------------------------------------------------- +# math_match +# --------------------------------------------------------------------------- + + +class TestMathMatch: + def test_returns_zero_without_math_verify(self, monkeypatch): + """When math-verify is not importable, math_match should return 0.""" + import builtins + + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "math_verify": + raise ImportError("mocked") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + assert math_match("42", "42") == 0.0 + + def test_dollar_sign_stripping(self): + result = math_match("$42$", "$42$") + assert result >= 0.0 # at least doesn't crash + + def test_display_math_delimiters(self): + result = math_match(r"\[42\]", r"\[42\]") + assert result >= 0.0 + + def test_non_parseable_returns_zero(self): + assert math_match("not math at all ???", "also not math ???") == 0.0 + + +# --------------------------------------------------------------------------- +# _mathrm_to_text helper +# --------------------------------------------------------------------------- + + +class TestMathrmToText: + def test_replaces_mathrm(self): + assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}" + + def test_no_mathrm_unchanged(self): + assert _mathrm_to_text("42") == "42" + + def test_multiple_occurrences(self): + s = r"\mathrm{kg} \cdot \mathrm{m}" + assert _mathrm_to_text(s) == r"\text{kg} \cdot \text{m}" + + +# --------------------------------------------------------------------------- +# cascade_score — without dataset +# --------------------------------------------------------------------------- + + +class TestCascadeScoreStandalone: + def test_exact_string(self): + assert cascade_score("42", "42") >= 0.99 + + def test_case_insensitive_string(self): + assert cascade_score("True", "true") >= 0.99 + + def test_latex_wrapped_string(self): + assert cascade_score(r"\text{42}", "42") >= 0.99 + + def test_numeric_tolerance(self): + assert cascade_score("100.05", "100") >= 0.99 + + def test_mismatch(self): + assert cascade_score("42", "99") == 0.0 + + def test_empty_answer(self): + assert cascade_score("", "42") == 0.0 + + +# --------------------------------------------------------------------------- +# cascade_score — with a real dataset +# --------------------------------------------------------------------------- + + +class TestCascadeScoreWithDataset: + def test_chain_sum_exact(self): + ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42) + entry = ds[0] + score = cascade_score(entry["answer"], entry["answer"], dataset=ds, entry=entry) + assert score == 1.0 + + def test_chain_sum_latex_wrapped(self): + ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42) + entry = ds[0] + wrapped = rf"\text{{{entry['answer']}}}" + score = cascade_score(wrapped, entry["answer"], dataset=ds, entry=entry) + assert score >= 0.99 + + def test_never_downgrades(self): + """The cascade should never return less than score_answer itself.""" + ds = reasoning_gym.create_dataset("chain_sum", size=10, seed=123) + for entry in ds: + base = ds.score_answer(entry["answer"], entry) + cascaded = cascade_score(entry["answer"], entry["answer"], dataset=ds, entry=entry) + assert cascaded >= base + + +# --------------------------------------------------------------------------- +# ProceduralDataset.score_answer_cascade convenience method +# --------------------------------------------------------------------------- + + +class TestScoreAnswerCascadeMethod: + def test_method_exists(self): + ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0) + assert hasattr(ds, "score_answer_cascade") + + def test_oracle_answer_scores_one(self): + ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42) + for entry in ds: + assert ds.score_answer_cascade(entry["answer"], entry) == 1.0 + + def test_none_answer_scores_zero(self): + ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0) + entry = ds[0] + assert ds.score_answer_cascade(None, entry) == 0.0 + + def test_latex_wrapped_answer(self): + ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42) + entry = ds[0] + wrapped = rf"\({entry['answer']}\)" + assert ds.score_answer_cascade(wrapped, entry) >= 0.99 + + def test_never_less_than_score_answer(self): + ds = reasoning_gym.create_dataset("chain_sum", size=10, seed=99) + for entry in ds: + base = ds.score_answer(entry["answer"], entry) + cascaded = ds.score_answer_cascade(entry["answer"], entry) + assert cascaded >= base + + +# --------------------------------------------------------------------------- +# Top-level imports +# --------------------------------------------------------------------------- + + +class TestTopLevelImports: + def test_cascade_score_importable(self): + from reasoning_gym import cascade_score as cs + + assert callable(cs) + + def test_matchers_importable(self): + from reasoning_gym import float_match, math_match, string_match, strip_latex + + assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])