mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
normalize answer and partial reward
This commit is contained in:
parent
1f9d9d27ab
commit
ef2a412c8b
2 changed files with 52 additions and 1 deletions
|
|
@ -1,8 +1,9 @@
|
|||
"""Prime factorization task generator"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -43,6 +44,25 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
break
|
||||
return factors
|
||||
|
||||
def _normalize_answer(self, answer: str) -> List[int]:
|
||||
"""Parse and sort factors from a string"""
|
||||
return sorted([int(factor.strip()) for factor in answer.split("×")])
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
oracle_answer = entry["answer"]
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
oracle_answer_parsed = self._normalize_answer(oracle_answer)
|
||||
answer_parsed = self._normalize_answer(answer)
|
||||
if oracle_answer_parsed == answer_parsed:
|
||||
reward = 1.0
|
||||
elif math.prod(oracle_answer_parsed) == math.prod(answer_parsed):
|
||||
reward = 0.5
|
||||
else:
|
||||
reward = 0.01
|
||||
|
||||
return reward
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single prime factorization task"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue