normalize answer and partial reward

This commit is contained in:
Zafir Stojanovski 2025-02-09 11:13:23 +01:00
parent 1f9d9d27ab
commit ef2a412c8b
2 changed files with 52 additions and 1 deletions

View file

@ -1,8 +1,9 @@
"""Prime factorization task generator"""
import math
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from typing import Dict, List, Optional
from ..factory import ProceduralDataset, register_dataset
@ -43,6 +44,25 @@ class PrimeFactorizationDataset(ProceduralDataset):
break
return factors
def _normalize_answer(self, answer: str) -> List[int]:
"""Parse and sort factors from a string"""
return sorted([int(factor.strip()) for factor in answer.split("×")])
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None:
oracle_answer_parsed = self._normalize_answer(oracle_answer)
answer_parsed = self._normalize_answer(answer)
if oracle_answer_parsed == answer_parsed:
reward = 1.0
elif math.prod(oracle_answer_parsed) == math.prod(answer_parsed):
reward = 0.5
else:
reward = 0.01
return reward
def __getitem__(self, idx: int) -> dict:
"""Generate a single prime factorization task"""
rng = Random(self.seed + idx)

View file

@ -85,6 +85,37 @@ def test_prime_factorization_known_values():
assert item["answer"] == "2 × 2 × 3"
def test_prime_factorization_score_answer():
"""Test scoring of answers"""
config = PrimeFactorizationConfig(min_value=12, max_value=12, size=1, seed=42) # Force specific number
dataset = PrimeFactorizationDataset(config)
item = dataset[0]
# Perfectly ordered answer
answer = "2 × 2 × 3"
assert dataset.score_answer(answer, item) == 1.0
# No white spaces answer (still correct)
answer = "2×2×3"
assert dataset.score_answer(answer, item) == 1.0
# Shuffled factors (still correct)
answer = "2 × 3 × 2"
assert dataset.score_answer(answer, item) == 1.0
# Partially correct answer (not all numbers are fully factorized)
answer = "2 × 6"
assert dataset.score_answer(answer, item) == 0.5
# Incorrect answer
answer = "2 × 5"
assert dataset.score_answer(answer, item) == 0.01
# Answer is none
answer = None
assert dataset.score_answer(answer, item) == 0.0
def is_prime(n: int) -> bool:
"""Helper function to check if a number is prime"""
if n < 2: