mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
decimal refactor
This commit is contained in:
parent
810afd5d05
commit
83e5e92126
2 changed files with 77 additions and 16 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -73,7 +74,7 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, float]:
|
||||
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, Decimal]:
|
||||
"""Generate a single decimal chain sum task
|
||||
|
||||
Args:
|
||||
|
|
@ -85,37 +86,43 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
max_decimal_places: Maximum number of decimal places
|
||||
|
||||
Returns:
|
||||
Tuple of (expression string, result float)
|
||||
Tuple of (expression string, result Decimal)
|
||||
"""
|
||||
|
||||
if self.config.allow_negation:
|
||||
# Allow both positive and negative numbers
|
||||
constants = [rng.randint(-max_value, max_value) for _ in range(num_terms)]
|
||||
else:
|
||||
# Only positive numbers
|
||||
constants = [rng.randint(min_value, max_value) for _ in range(num_terms)]
|
||||
# Convert constants to Decimal
|
||||
constants = [
|
||||
Decimal(
|
||||
str(
|
||||
rng.randint(-max_value, max_value)
|
||||
if self.config.allow_negation
|
||||
else rng.randint(min_value, max_value)
|
||||
)
|
||||
)
|
||||
for _ in range(num_terms)
|
||||
]
|
||||
|
||||
# Generate decimal places for each term
|
||||
decimal_places = [
|
||||
rng.randint(self.config.min_decimal_places, self.config.max_decimal_places) for _ in range(num_terms)
|
||||
]
|
||||
|
||||
# Add decimal parts using Decimal for precise arithmetic
|
||||
for i in range(num_terms):
|
||||
min_val = 0 if decimal_places[i] == 0 else 10 ** (decimal_places[i] - 1)
|
||||
max_val = (10 ** decimal_places[i]) - 1
|
||||
decimal = rng.randint(min_val, max_val)
|
||||
constants[i] += decimal / 10 ** decimal_places[i]
|
||||
decimal_part = Decimal(str(rng.randint(min_val, max_val))) / Decimal(str(10 ** decimal_places[i]))
|
||||
constants[i] += decimal_part
|
||||
|
||||
operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)]
|
||||
|
||||
expression_parts = []
|
||||
result = constants[0]
|
||||
|
||||
expression_parts.append(f"{constants[0]:.{max(decimal_places)}f}")
|
||||
expression_parts.append(f"{constants[0]:.{decimal_places[0]}f}")
|
||||
for i, op in enumerate(operators):
|
||||
c = constants[i + 1]
|
||||
expression_parts.append(op)
|
||||
expression_parts.append(f"{c:.{max(decimal_places)}f}")
|
||||
expression_parts.append(f"{c:.{decimal_places[i+1]}f}")
|
||||
|
||||
if op == "+":
|
||||
result += c
|
||||
|
|
@ -123,5 +130,25 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
result -= c
|
||||
|
||||
expression = " ".join(expression_parts)
|
||||
result = round(result, max(decimal_places))
|
||||
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
|
||||
return expression, result
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
"""Score the answer by comparing decimal values instead of strings.
|
||||
Args:
|
||||
answer: The answer to score
|
||||
entry: The entry containing the oracle answer
|
||||
|
||||
Returns:
|
||||
1.0 for exact numerical match, 0.01 otherwise
|
||||
"""
|
||||
if answer is None or len(answer.strip()) == 0:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
student_answer = Decimal(answer.strip())
|
||||
oracle_answer = Decimal(entry["answer"])
|
||||
|
||||
return 1.0 if student_answer == oracle_answer else 0.01
|
||||
except (ValueError, TypeError, ArithmeticError):
|
||||
return 0.01
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue