mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
style: fix black and isort formatting
Run black and isort to satisfy pre-commit checks. Made-with: Cursor
This commit is contained in:
parent
83fcceb317
commit
d6a5a8a9f1
2 changed files with 16 additions and 10 deletions
|
|
@ -28,6 +28,7 @@ if TYPE_CHECKING:
|
||||||
# LaTeX normalisation
|
# LaTeX normalisation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def strip_latex(s: str) -> str:
|
def strip_latex(s: str) -> str:
|
||||||
"""Remove common LaTeX wrappers and normalise whitespace.
|
"""Remove common LaTeX wrappers and normalise whitespace.
|
||||||
|
|
||||||
|
|
@ -47,6 +48,7 @@ def strip_latex(s: str) -> str:
|
||||||
# Individual matchers
|
# Individual matchers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def string_match(predicted: str, expected: str) -> float:
|
def string_match(predicted: str, expected: str) -> float:
|
||||||
"""Case-insensitive exact string comparison after stripping whitespace."""
|
"""Case-insensitive exact string comparison after stripping whitespace."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -123,6 +125,7 @@ def _mathrm_to_text(s: str) -> str:
|
||||||
# Full cascade
|
# Full cascade
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def cascade_score(
|
def cascade_score(
|
||||||
answer: str,
|
answer: str,
|
||||||
expected: str,
|
expected: str,
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import reasoning_gym
|
import reasoning_gym
|
||||||
from reasoning_gym.scoring import (
|
from reasoning_gym.scoring import _mathrm_to_text, cascade_score, float_match, math_match, string_match, strip_latex
|
||||||
cascade_score,
|
|
||||||
float_match,
|
|
||||||
math_match,
|
|
||||||
string_match,
|
|
||||||
strip_latex,
|
|
||||||
_mathrm_to_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# strip_latex
|
# strip_latex
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestStripLatex:
|
class TestStripLatex:
|
||||||
def test_inline_math_delimiters(self):
|
def test_inline_math_delimiters(self):
|
||||||
assert strip_latex(r"\(42\)") == "42"
|
assert strip_latex(r"\(42\)") == "42"
|
||||||
|
|
@ -52,6 +45,7 @@ class TestStripLatex:
|
||||||
# string_match
|
# string_match
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestStringMatch:
|
class TestStringMatch:
|
||||||
def test_exact(self):
|
def test_exact(self):
|
||||||
assert string_match("42", "42") == 1.0
|
assert string_match("42", "42") == 1.0
|
||||||
|
|
@ -77,6 +71,7 @@ class TestStringMatch:
|
||||||
# float_match
|
# float_match
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestFloatMatch:
|
class TestFloatMatch:
|
||||||
def test_exact(self):
|
def test_exact(self):
|
||||||
assert float_match("3.14", "3.14") == 1.0
|
assert float_match("3.14", "3.14") == 1.0
|
||||||
|
|
@ -107,6 +102,7 @@ class TestFloatMatch:
|
||||||
# math_match
|
# math_match
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestMathMatch:
|
class TestMathMatch:
|
||||||
def test_returns_zero_without_math_verify(self, monkeypatch):
|
def test_returns_zero_without_math_verify(self, monkeypatch):
|
||||||
"""When math-verify is not importable, math_match should return 0."""
|
"""When math-verify is not importable, math_match should return 0."""
|
||||||
|
|
@ -138,6 +134,7 @@ class TestMathMatch:
|
||||||
# _mathrm_to_text helper
|
# _mathrm_to_text helper
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestMathrmToText:
|
class TestMathrmToText:
|
||||||
def test_replaces_mathrm(self):
|
def test_replaces_mathrm(self):
|
||||||
assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}"
|
assert _mathrm_to_text(r"\mathrm{cm}") == r"\text{cm}"
|
||||||
|
|
@ -154,6 +151,7 @@ class TestMathrmToText:
|
||||||
# cascade_score — without dataset
|
# cascade_score — without dataset
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestCascadeScoreStandalone:
|
class TestCascadeScoreStandalone:
|
||||||
def test_exact_string(self):
|
def test_exact_string(self):
|
||||||
assert cascade_score("42", "42") >= 0.99
|
assert cascade_score("42", "42") >= 0.99
|
||||||
|
|
@ -178,6 +176,7 @@ class TestCascadeScoreStandalone:
|
||||||
# cascade_score — with a real dataset
|
# cascade_score — with a real dataset
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestCascadeScoreWithDataset:
|
class TestCascadeScoreWithDataset:
|
||||||
def test_chain_sum_exact(self):
|
def test_chain_sum_exact(self):
|
||||||
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
|
ds = reasoning_gym.create_dataset("chain_sum", size=5, seed=42)
|
||||||
|
|
@ -205,6 +204,7 @@ class TestCascadeScoreWithDataset:
|
||||||
# ProceduralDataset.score_answer_cascade convenience method
|
# ProceduralDataset.score_answer_cascade convenience method
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestScoreAnswerCascadeMethod:
|
class TestScoreAnswerCascadeMethod:
|
||||||
def test_method_exists(self):
|
def test_method_exists(self):
|
||||||
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
|
ds = reasoning_gym.create_dataset("chain_sum", size=1, seed=0)
|
||||||
|
|
@ -238,11 +238,14 @@ class TestScoreAnswerCascadeMethod:
|
||||||
# Top-level imports
|
# Top-level imports
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestTopLevelImports:
|
class TestTopLevelImports:
|
||||||
def test_cascade_score_importable(self):
|
def test_cascade_score_importable(self):
|
||||||
from reasoning_gym import cascade_score as cs
|
from reasoning_gym import cascade_score as cs
|
||||||
|
|
||||||
assert callable(cs)
|
assert callable(cs)
|
||||||
|
|
||||||
def test_matchers_importable(self):
|
def test_matchers_importable(self):
|
||||||
from reasoning_gym import string_match, float_match, math_match, strip_latex
|
from reasoning_gym import float_match, math_match, string_match, strip_latex
|
||||||
|
|
||||||
assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])
|
assert all(callable(f) for f in [string_match, float_match, math_match, strip_latex])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue