reasoning-gym/reasoning_gym/dataset.py
Ritvik Rastogi 49b07130b3
feat: add scoring cascade for reducing false negatives (#526)
* feat: add scoring cascade for reducing false negatives in answer verification

* style: fix black and isort formatting

Run black and isort to satisfy pre-commit checks.

Made-with: Cursor

* docs: add scoring cascade example to Quickstart section

Mention the experimental scoring cascade feature at the end of the
Quickstart section with a disclaimer and complete usage examples
showing both the dataset method and standalone function.

Made-with: Cursor

* docs: shorten scoring cascade section in README

Trim to a concise standalone example per review feedback.

Made-with: Cursor

* docs: simplify scoring cascade description in README

Made-with: Cursor

* update readme

---------

Co-authored-by: Zafir Stojanovski <zaf.stojano@gmail.com>
2026-04-17 21:39:15 +02:00

148 lines
5.3 KiB
Python

"""Base class for procedural dataset generators"""
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sized
from copy import deepcopy
from random import Random
from typing import Any, Iterator, Optional, Type, TypeVar
class ProceduralDataset(ABC, Sized, Iterable[dict[str, Any]]):
"""Abstract base class for procedural dataset generators"""
def __init__(self, config: Any, seed: Optional[int] = None, size: int = 500):
"""Initialize the dataset with config, optional seed and size"""
if hasattr(config, "validate") and callable(config.validate):
config.validate()
self.config = config
self.size = size
self.seed = seed if seed is not None else Random().randint(0, 2**32)
@property
def category(self) -> str:
"""Extract category from the module name."""
module_name = self.__class__.__module__
parts = module_name.split(".")
if len(parts) >= 3:
return parts[1] # reasoning_gym.{category}.dataset_name
return "other"
def __len__(self) -> int:
"""Return the virtual size of the dataset"""
return self.size
def __iter__(self):
"""Make the dataset iterable"""
self._current_idx = 0
return self
def __next__(self) -> dict[str, Any]:
"""Get next item in iteration"""
if self._current_idx >= self.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
@abstractmethod
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single dataset item
Args:
idx: Index of the item to generate
Returns:
dict containing at least:
- question: str
- answer: str
- metadata: dict
"""
raise NotImplementedError
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
reward = 0.0
if isinstance(answer, str) and len(answer) > 0:
if answer == oracle_answer:
reward = 1.0
elif oracle_answer in answer:
reward = len(oracle_answer) / len(answer)
return reward
def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score with fallback cascade (LaTeX stripping, string, float, math matching).
Runs this dataset's ``score_answer`` first, then progressively more
lenient matchers. The cascade can only upgrade, never downgrade.
Requires ``pip install reasoning-gym[scoring]`` for the ``math_match``
step (other steps work without extra dependencies).
"""
from .scoring import cascade_score
if answer is None:
return 0.0
return cascade_score(answer, entry.get("answer", ""), dataset=self, entry=entry)
T = TypeVar("T", bound="ProceduralDataset")
class ReseedingDataset(Iterable[dict[str, Any]]):
"""Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end"""
def __init__(self, dataset: T, chunk_size: int = 500):
"""Initialize with dataset instance and chunk size
Args:
dataset: The ProceduralDataset instance to wrap
chunk_size: Size of each generated chunk before reseeding
"""
self.dataset = dataset
self.dataset_cls: Type[T] = type(dataset)
self.chunk_size = chunk_size
# Start with chunk 0
self._current_chunk = 0
self._current_dataset = self._create_chunk(0)
self._current_idx = 0
def _create_chunk(self, chunk_num: int) -> T:
"""Create a new dataset chunk with unique seed"""
# Create new config with modified seed
new_config = deepcopy(self.dataset.config)
if hasattr(new_config, "seed"):
# Derive new seed from chunk number using dataset's seed, wrapping around at 2^32
new_config.seed = (self.dataset.seed + chunk_num) % (2**32)
# Create new dataset instance with chunk config
return self.dataset_cls(new_config)
def __iter__(self) -> Iterator[dict[str, Any]]:
"""Make the dataset iterable"""
self._current_chunk = 0
self._current_dataset = self._create_chunk(0)
self._current_idx = 0
return self
def __next__(self) -> dict[str, Any]:
"""Get next item, creating new chunk if needed"""
if self._current_idx >= self.chunk_size:
# Move to next chunk
self._current_chunk += 1
self._current_dataset = self._create_chunk(self._current_chunk)
self._current_idx = 0
item = self._current_dataset[self._current_idx]
self._current_idx += 1
return item
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Forward scoring to the wrapped dataset's implementation"""
return self.dataset.score_answer(answer, entry)
def score_answer_cascade(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Forward cascade scoring to the wrapped dataset's implementation"""
return self.dataset.score_answer_cascade(answer, entry)