style: fix black and isort formatting

Run black and isort to satisfy pre-commit checks.

Made-with: Cursor
This commit is contained in:
Ritvik19 2026-04-17 14:32:25 +00:00
parent 83fcceb317
commit d6a5a8a9f1
2 changed files with 16 additions and 10 deletions

View file

@ -28,6 +28,7 @@ if TYPE_CHECKING:
# LaTeX normalisation # LaTeX normalisation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def strip_latex(s: str) -> str: def strip_latex(s: str) -> str:
"""Remove common LaTeX wrappers and normalise whitespace. """Remove common LaTeX wrappers and normalise whitespace.
@ -47,6 +48,7 @@ def strip_latex(s: str) -> str:
# Individual matchers # Individual matchers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def string_match(predicted: str, expected: str) -> float: def string_match(predicted: str, expected: str) -> float:
"""Case-insensitive exact string comparison after stripping whitespace.""" """Case-insensitive exact string comparison after stripping whitespace."""
try: try:
@ -123,6 +125,7 @@ def _mathrm_to_text(s: str) -> str:
# Full cascade # Full cascade
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def cascade_score( def cascade_score(
answer: str, answer: str,
expected: str, expected: str,

View file

@ -1,20 +1,13 @@
import pytest import pytest
import reasoning_gym import reasoning_gym
from reasoning_gym.scoring import ( from reasoning_gym.scoring import _mathrm_to_text, cascade_score, float_match, math_match, string_match, strip_latex
cascade_score,
float_match,
math_match,
string_match,
strip_latex,
_mathrm_to_text,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# strip_latex # strip_latex
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestStripLatex: class TestStripLatex:
def test_inline_math_delimiters(self): def test_inline_math_delimiters(self):
assert strip_latex(r"\(42\)") == "42" assert strip_latex(r"\(42\)") == "42"
@ -52,6 +45,7 @@ class TestStripLatex:
# string_match # string_match
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestStringMatch: class TestStringMatch:
def test_exact(self): def test_exact(self):
assert string_match("42", "42") == 1.0 assert string_match("42", "42") == 1.0
@ -77,6 +71,7 @@ class TestStringMatch:
# float_match # float_match
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestFloatMatch: class TestFloatMatch:
def test_exact(self): def test_exact(self):
assert float_match("3.14", "3.14") == 1.0 assert float_match("3.14", "3.14") == 1.0
@ -107,6 +102,7 @@ class TestFloatMatch:
# math_match # math_match
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestMathMatch: class TestMathMatch:
def test_returns_zero_without_math_verify(self, monkeypatch): def test_returns_zero_without_math_verify(self, monkeypatch):
"""When math-verify is not importable, math_match should return 0.""" """When math-verify is not importable, math_match should return 0."""
@ -138,6 +134,7 @@ class TestMathMatch:
# _mathrm_to_text helper # _mathrm_to_text helper
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestMathrmToText: class TestMathrmToText:
def test_replaces_mathrm(self): def test_replaces_mathrm(self):
assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}" assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}"
@ -154,6 +151,7 @@ class TestMathrmToText:
# cascade_score — without dataset # cascade_score — without dataset
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestCascadeScoreStandalone: class TestCascadeScoreStandalone:
def test_exact_string(self): def test_exact_string(self):
assert cascade_score("42", "42") >= 0.99 assert cascade_score("42", "42") >= 0.99
@ -178,6 +176,7 @@ class TestCascadeScoreStandalone:
# cascade_score — with a real dataset # cascade_score — with a real dataset
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestCascadeScoreWithDataset: class TestCascadeScoreWithDataset:
def test_chain_sum_exact(self): def test_chain_sum_exact(self):
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42) ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
@ -205,6 +204,7 @@ class TestCascadeScoreWithDataset:
# ProceduralDataset.score_answer_cascade convenience method # ProceduralDataset.score_answer_cascade convenience method
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestScoreAnswerCascadeMethod: class TestScoreAnswerCascadeMethod:
def test_method_exists(self): def test_method_exists(self):
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0) ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
@ -238,11 +238,14 @@ class TestScoreAnswerCascadeMethod:
# Top-level imports # Top-level imports
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestTopLevelImports: class TestTopLevelImports:
def test_cascade_score_importable(self): def test_cascade_score_importable(self):
from reasoning_gym import cascade_score as cs from reasoning_gym import cascade_score as cs
assert callable(cs) assert callable(cs)
def test_matchers_importable(self): def test_matchers_importable(self):
from reasoning_gym import string_match, float_match, math_match, strip_latex from reasoning_gym import float_match, math_match, string_match, strip_latex
assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex]) assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])