diff --git a/reasoning_gym/scoring.py b/reasoning_gym/scoring.py index ac2d6972..2bd12e64 100644 --- a/reasoning_gym/scoring.py +++ b/reasoning_gym/scoring.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: # LaTeX normalisation # --------------------------------------------------------------------------- + def strip_latex(s: str) -> str: """Remove common LaTeX wrappers and normalise whitespace. @@ -47,6 +48,7 @@ def strip_latex(s: str) -> str: # Individual matchers # --------------------------------------------------------------------------- + def string_match(predicted: str, expected: str) -> float: """Case-insensitive exact string comparison after stripping whitespace.""" try: @@ -123,6 +125,7 @@ def _mathrm_to_text(s: str) -> str: # Full cascade # --------------------------------------------------------------------------- + def cascade_score( answer: str, expected: str, diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 1c7968fb..c8b1540b 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -1,20 +1,13 @@ import pytest import reasoning_gym -from reasoning_gym.scoring import ( - cascade_score, - float_match, - math_match, - string_match, - strip_latex, - _mathrm_to_text, -) - +from reasoning_gym.scoring import _mathrm_to_text, cascade_score, float_match, math_match, string_match, strip_latex # --------------------------------------------------------------------------- # strip_latex # --------------------------------------------------------------------------- + class TestStripLatex: def test_inline_math_delimiters(self): assert strip_latex(r"\(42\)") == "42" @@ -52,6 +45,7 @@ class TestStripLatex: # string_match # --------------------------------------------------------------------------- + class TestStringMatch: def test_exact(self): assert string_match("42", "42") == 1.0 @@ -77,6 +71,7 @@ class TestStringMatch: # float_match # --------------------------------------------------------------------------- + class TestFloatMatch: def test_exact(self): assert float_match("3.14", "3.14") == 1.0 @@ -107,6 +102,7 @@ class TestFloatMatch: # math_match # --------------------------------------------------------------------------- + class TestMathMatch: def test_returns_zero_without_math_verify(self, monkeypatch): """When math-verify is not importable, math_match should return 0.""" @@ -138,6 +134,7 @@ class TestMathMatch: # _mathrm_to_text helper # --------------------------------------------------------------------------- + class TestMathrmToText: def test_replaces_mathrm(self): assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}" @@ -154,6 +151,7 @@ class TestMathrmToText: # cascade_score — without dataset # --------------------------------------------------------------------------- + class TestCascadeScoreStandalone: def test_exact_string(self): assert cascade_score("42", "42") >= 0.99 @@ -178,6 +176,7 @@ class TestCascadeScoreStandalone: # cascade_score — with a real dataset # --------------------------------------------------------------------------- + class TestCascadeScoreWithDataset: def test_chain_sum_exact(self): ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42) @@ -205,6 +204,7 @@ class TestCascadeScoreWithDataset: # ProceduralDataset.score_answer_cascade convenience method # --------------------------------------------------------------------------- + class TestScoreAnswerCascadeMethod: def test_method_exists(self): ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0) @@ -238,11 +238,14 @@ class TestScoreAnswerCascadeMethod: # Top-level imports # --------------------------------------------------------------------------- + class TestTopLevelImports: def test_cascade_score_importable(self): from reasoning_gym import cascade_score as cs + assert callable(cs) def test_matchers_importable(self): - from reasoning_gym import string_match, float_match, math_match, strip_latex + from reasoning_gym import float_match, math_match, string_match, strip_latex + assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])