Feat: expose score_answer function without needing to instantiate a dataset (#422)

* feat: get `score_answer` for a given dataset

* fix: `self` error

* add test
This commit is contained in:
rasdani 2025-04-18 10:36:44 +02:00 committed by GitHub
parent 169d8c3aec
commit 72e45e9401
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 95 additions and 2 deletions

View file

@ -3,7 +3,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin
"""
from . import algebra, algorithmic, arc, arithmetic, code, cognition, data, games, geometry, graphs, induction, logic
from .factory import create_dataset, register_dataset
from .factory import create_dataset, get_score_answer_fn, register_dataset
__version__ = "0.1.19"
__all__ = [
@ -21,4 +21,5 @@ __all__ = [
"induction",
"create_dataset",
"register_dataset",
"get_score_answer_fn",
]

View file

@ -1,5 +1,5 @@
from dataclasses import is_dataclass
from typing import Optional, Type, TypeVar
from typing import Callable, Optional, Type, TypeVar
from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT
@ -96,3 +96,24 @@ def create_curriculum(name: str) -> BaseCurriculum:
def has_curriculum(name: str) -> bool:
return name in CURRICULA
def get_score_answer_fn(name: str) -> Callable[[], float]:
"""
Get the score answer function for the named dataset.
Args:
name: Registered dataset name
Returns:
Score function for the dataset
Raises:
ValueError: If dataset not found
"""
if name not in DATASETS:
raise ValueError(f"Dataset '{name}' not registered")
dataset_cls, config_cls = DATASETS[name]
return dataset_cls(config=config_cls()).score_answer

View file

@ -0,0 +1,71 @@
"""
Tests for the get_score_answer_fn helper with hard-coded sample cases.
"""
import pytest
from reasoning_gym import get_score_answer_fn
TEST_CASES = [
{
"dataset": "letter_jumble",
"entry": {"answer": "second opportunity to receive"},
"model_answer": "second opportunity to receive",
"expected": 1.0,
"id": "rg_4806-correct",
},
{
"dataset": "word_sorting",
"entry": {
"answer": "arrive, burdens, computers, federal, louder, paragraphs, side, specified, virus",
"metadata": {
"sorted_words": [
"arrive",
"burdens",
"computers",
"federal",
"louder",
"paragraphs",
"side",
"specified",
"virus",
]
},
},
"model_answer": "arrive, burdens, computers, federal, louder, paragraphs, side, specified, virus",
"expected": 1.0,
"id": "rg_16004-word_sorting-correct",
},
{
"dataset": "spell_backward",
"entry": {"answer": "ssiknu"},
"model_answer": "ssiknu",
"expected": 1.0,
"id": "rg_14211-correct",
},
{
"dataset": "letter_jumble",
"entry": {"answer": "second opportunity to receive"},
"model_answer": "completely wrong answer here",
"expected": 0.0,
"id": "rg_4806-incorrect",
},
{
"dataset": "spell_backward",
"entry": {"answer": "ssiknu"},
"model_answer": "unkiss",
"expected": 0.0,
"id": "rg_14211-incorrect",
},
]
@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["id"])
def test_get_score_answer_fn_hardcoded(case):
"""
Ensure the dataset-specific scorer returns the expected value
for the given model answer and entry.
"""
scorer = get_score_answer_fn(case["dataset"])
returned = scorer(case["model_answer"], case["entry"])
assert returned == pytest.approx(case["expected"], abs=1e-8)