mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
Minor question template & score_answer improvements (#261)
* math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test
This commit is contained in:
parent
061282e373
commit
5d7fbac0ad
106 changed files with 403 additions and 507 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -129,7 +129,11 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
result -= c
|
||||
|
||||
expression = " ".join(expression_parts)
|
||||
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
|
||||
try:
|
||||
q = Decimal(f"0.{'0' * max(decimal_places)}")
|
||||
result = result.quantize(q)
|
||||
except InvalidOperation:
|
||||
pass
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
@ -141,16 +145,19 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
Returns:
|
||||
1.0 for exact numerical match, 0.01 otherwise
|
||||
"""
|
||||
if answer is None or len(answer.strip()) == 0:
|
||||
if not isinstance(answer, str) or len(answer.strip()) == 0:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
student_answer = Decimal(answer.strip())
|
||||
oracle_answer = Decimal(entry["answer"])
|
||||
|
||||
return 1.0 if student_answer == oracle_answer else 0.01
|
||||
except (ValueError, TypeError, ArithmeticError):
|
||||
return 0.01
|
||||
if student_answer == oracle_answer:
|
||||
return 1.0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue