normalize answer and partial reward

This commit is contained in:
Zafir Stojanovski 2025-02-09 11:13:23 +01:00
parent 1f9d9d27ab
commit ef2a412c8b
2 changed files with 52 additions and 1 deletions

View file

@ -1,8 +1,9 @@
"""Prime factorization task generator"""
import math
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from typing import Dict, List, Optional
from ..factory import ProceduralDataset, register_dataset
@ -43,6 +44,25 @@ class PrimeFactorizationDataset(ProceduralDataset):
break
return factors
def _normalize_answer(self, answer: str) -> List[int]:
"""Parse and sort factors from a string"""
return sorted([int(factor.strip()) for factor in answer.split("×")])
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None:
oracle_answer_parsed = self._normalize_answer(oracle_answer)
answer_parsed = self._normalize_answer(answer)
if oracle_answer_parsed == answer_parsed:
reward = 1.0
elif math.prod(oracle_answer_parsed) == math.prod(answer_parsed):
reward = 0.5
else:
reward = 0.01
return reward
def __getitem__(self, idx: int) -> dict:
"""Generate a single prime factorization task"""
rng = Random(self.seed + idx)