Minor question template & score_answer improvements (#261)

* math prompt improvements
* ignore brackets in complex_arithmetic results
* improve additional instruction in prompt of polynomial_equations
* more strict tests for score_answer in polynomial_equations
* simplify special reward handling
* fix test_intermediate_integration
* fix sokoban dataset
* add common dataset score_answer consistency test
This commit is contained in:
Andreas Köpf 2025-03-04 21:55:09 +01:00 committed by GitHub
parent 061282e373
commit 5d7fbac0ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
106 changed files with 403 additions and 507 deletions

View file

@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from decimal import Decimal
from decimal import Decimal, InvalidOperation
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
result -= c
expression = " ".join(expression_parts)
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
try:
q = Decimal(f"0.{'0' * max(decimal_places)}")
result = result.quantize(q)
except InvalidOperation:
pass
return expression, result
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
Returns:
1.0 for exact numerical match, 0.01 otherwise
"""
if answer is None or len(answer.strip()) == 0:
if not isinstance(answer, str) or len(answer.strip()) == 0:
return 0.0
try:
student_answer = Decimal(answer.strip())
oracle_answer = Decimal(entry["answer"])
return 1.0 if student_answer == oracle_answer else 0.01
except (ValueError, TypeError, ArithmeticError):
return 0.01
if student_answer == oracle_answer:
return 1.0
except Exception:
pass
return 0.0
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)