mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Feat: expose score_answer function without needing to instantiate a dataset (#422)
* feat: get `score_answer` for a given dataset * fix: `self` error * add test
This commit is contained in:
parent
169d8c3aec
commit
72e45e9401
3 changed files with 95 additions and 2 deletions
|
|
@ -3,7 +3,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin
|
|||
"""
|
||||
|
||||
from . import algebra, algorithmic, arc, arithmetic, code, cognition, data, games, geometry, graphs, induction, logic
|
||||
from .factory import create_dataset, register_dataset
|
||||
from .factory import create_dataset, get_score_answer_fn, register_dataset
|
||||
|
||||
__version__ = "0.1.19"
|
||||
__all__ = [
|
||||
|
|
@ -21,4 +21,5 @@ __all__ = [
|
|||
"induction",
|
||||
"create_dataset",
|
||||
"register_dataset",
|
||||
"get_score_answer_fn",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import is_dataclass
|
||||
from typing import Optional, Type, TypeVar
|
||||
from typing import Callable, Optional, Type, TypeVar
|
||||
|
||||
from reasoning_gym.coaching.base_curriculum import BaseCurriculum, ConfigT
|
||||
|
||||
|
|
@ -96,3 +96,24 @@ def create_curriculum(name: str) -> BaseCurriculum:
|
|||
|
||||
def has_curriculum(name: str) -> bool:
|
||||
return name in CURRICULA
|
||||
|
||||
|
||||
def get_score_answer_fn(name: str) -> Callable[[], float]:
|
||||
"""
|
||||
Get the score answer function for the named dataset.
|
||||
|
||||
Args:
|
||||
name: Registered dataset name
|
||||
|
||||
Returns:
|
||||
Score function for the dataset
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset not found
|
||||
"""
|
||||
if name not in DATASETS:
|
||||
raise ValueError(f"Dataset '{name}' not registered")
|
||||
|
||||
dataset_cls, config_cls = DATASETS[name]
|
||||
|
||||
return dataset_cls(config=config_cls()).score_answer
|
||||
|
|
|
|||
71
tests/test_get_score_answer_fn.py
Normal file
71
tests/test_get_score_answer_fn.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Tests for the get_score_answer_fn helper with hard-coded sample cases.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym import get_score_answer_fn
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"dataset": "letter_jumble",
|
||||
"entry": {"answer": "second opportunity to receive"},
|
||||
"model_answer": "second opportunity to receive",
|
||||
"expected": 1.0,
|
||||
"id": "rg_4806-correct",
|
||||
},
|
||||
{
|
||||
"dataset": "word_sorting",
|
||||
"entry": {
|
||||
"answer": "arrive, burdens, computers, federal, louder, paragraphs, side, specified, virus",
|
||||
"metadata": {
|
||||
"sorted_words": [
|
||||
"arrive",
|
||||
"burdens",
|
||||
"computers",
|
||||
"federal",
|
||||
"louder",
|
||||
"paragraphs",
|
||||
"side",
|
||||
"specified",
|
||||
"virus",
|
||||
]
|
||||
},
|
||||
},
|
||||
"model_answer": "arrive, burdens, computers, federal, louder, paragraphs, side, specified, virus",
|
||||
"expected": 1.0,
|
||||
"id": "rg_16004-word_sorting-correct",
|
||||
},
|
||||
{
|
||||
"dataset": "spell_backward",
|
||||
"entry": {"answer": "ssiknu"},
|
||||
"model_answer": "ssiknu",
|
||||
"expected": 1.0,
|
||||
"id": "rg_14211-correct",
|
||||
},
|
||||
{
|
||||
"dataset": "letter_jumble",
|
||||
"entry": {"answer": "second opportunity to receive"},
|
||||
"model_answer": "completely wrong answer here",
|
||||
"expected": 0.0,
|
||||
"id": "rg_4806-incorrect",
|
||||
},
|
||||
{
|
||||
"dataset": "spell_backward",
|
||||
"entry": {"answer": "ssiknu"},
|
||||
"model_answer": "unkiss",
|
||||
"expected": 0.0,
|
||||
"id": "rg_14211-incorrect",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["id"])
|
||||
def test_get_score_answer_fn_hardcoded(case):
|
||||
"""
|
||||
Ensure the dataset-specific scorer returns the expected value
|
||||
for the given model answer and entry.
|
||||
"""
|
||||
scorer = get_score_answer_fn(case["dataset"])
|
||||
returned = scorer(case["model_answer"], case["entry"])
|
||||
assert returned == pytest.approx(case["expected"], abs=1e-8)
|
||||
Loading…
Add table
Add a link
Reference in a new issue