mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
normalize answer and partial reward
This commit is contained in:
parent
1f9d9d27ab
commit
ef2a412c8b
2 changed files with 52 additions and 1 deletions
|
|
@ -1,8 +1,9 @@
|
|||
"""Prime factorization task generator"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -43,6 +44,25 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
break
|
||||
return factors
|
||||
|
||||
def _normalize_answer(self, answer: str) -> List[int]:
|
||||
"""Parse and sort factors from a string"""
|
||||
return sorted([int(factor.strip()) for factor in answer.split("×")])
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
oracle_answer = entry["answer"]
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
oracle_answer_parsed = self._normalize_answer(oracle_answer)
|
||||
answer_parsed = self._normalize_answer(answer)
|
||||
if oracle_answer_parsed == answer_parsed:
|
||||
reward = 1.0
|
||||
elif math.prod(oracle_answer_parsed) == math.prod(answer_parsed):
|
||||
reward = 0.5
|
||||
else:
|
||||
reward = 0.01
|
||||
|
||||
return reward
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single prime factorization task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
|
|
|||
|
|
@ -85,6 +85,37 @@ def test_prime_factorization_known_values():
|
|||
assert item["answer"] == "2 × 2 × 3"
|
||||
|
||||
|
||||
def test_prime_factorization_score_answer():
|
||||
"""Test scoring of answers"""
|
||||
config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
|
||||
dataset = PrimeFactorizationDataset(config)
|
||||
item = dataset[0]
|
||||
|
||||
# Perfectly ordered answer
|
||||
answer = "2 × 2 × 3"
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
||||
# No white spaces answer (still correct)
|
||||
answer = "2×2×3"
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
||||
# Shuffled factors (still correct)
|
||||
answer = "2 × 3 × 2"
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
||||
# Partially correct answer (not all numbers are fully factorized)
|
||||
answer = "2 × 6"
|
||||
assert dataset.score_answer(answer, item) == 0.5
|
||||
|
||||
# Incorrect answer
|
||||
answer = "2 × 5"
|
||||
assert dataset.score_answer(answer, item) == 0.01
|
||||
|
||||
# Answer is none
|
||||
answer = None
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
||||
|
||||
def is_prime(n: int) -> bool:
|
||||
"""Helper function to check if a number is prime"""
|
||||
if n < 2:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue