decimal refactor

This commit is contained in:
vncntt 2025-02-19 14:46:27 -08:00
parent 810afd5d05
commit 83e5e92126
2 changed files with 77 additions and 16 deletions

View file

@ -1,6 +1,7 @@
import random
from dataclasses import dataclass
from typing import Optional
from decimal import Decimal
from typing import Any, Dict, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -73,7 +74,7 @@ class DecimalChainSumDataset(ProceduralDataset):
},
}
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, float]:
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, Decimal]:
"""Generate a single decimal chain sum task
Args:
@ -85,37 +86,43 @@ class DecimalChainSumDataset(ProceduralDataset):
max_decimal_places: Maximum number of decimal places
Returns:
Tuple of (expression string, result float)
Tuple of (expression string, result Decimal)
"""
if self.config.allow_negation:
# Allow both positive and negative numbers
constants = [rng.randint(-max_value, max_value) for _ in range(num_terms)]
else:
# Only positive numbers
constants = [rng.randint(min_value, max_value) for _ in range(num_terms)]
# Convert constants to Decimal
constants = [
Decimal(
str(
rng.randint(-max_value, max_value)
if self.config.allow_negation
else rng.randint(min_value, max_value)
)
)
for _ in range(num_terms)
]
# Generate decimal places for each term
decimal_places = [
rng.randint(self.config.min_decimal_places, self.config.max_decimal_places) for _ in range(num_terms)
]
# Add decimal parts using Decimal for precise arithmetic
for i in range(num_terms):
min_val = 0 if decimal_places[i] == 0 else 10 ** (decimal_places[i] - 1)
max_val = (10 ** decimal_places[i]) - 1
decimal = rng.randint(min_val, max_val)
constants[i] += decimal / 10 ** decimal_places[i]
decimal_part = Decimal(str(rng.randint(min_val, max_val))) / Decimal(str(10 ** decimal_places[i]))
constants[i] += decimal_part
operators = [rng.choice(["+", "-"]) for _ in range(num_terms - 1)]
expression_parts = []
result = constants[0]
expression_parts.append(f"{constants[0]:.{max(decimal_places)}f}")
expression_parts.append(f"{constants[0]:.{decimal_places[0]}f}")
for i, op in enumerate(operators):
c = constants[i + 1]
expression_parts.append(op)
expression_parts.append(f"{c:.{max(decimal_places)}f}")
expression_parts.append(f"{c:.{decimal_places[i+1]}f}")
if op == "+":
result += c
@ -123,5 +130,25 @@ class DecimalChainSumDataset(ProceduralDataset):
result -= c
expression = " ".join(expression_parts)
result = round(result, max(decimal_places))
result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}"))
return expression, result
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Score the answer by comparing decimal values instead of strings.
Args:
answer: The answer to score
entry: The entry containing the oracle answer
Returns:
1.0 for exact numerical match, 0.01 otherwise
"""
if answer is None or len(answer.strip()) == 0:
return 0.0
try:
student_answer = Decimal(answer.strip())
oracle_answer = Decimal(entry["answer"])
return 1.0 if student_answer == oracle_answer else 0.01
except (ValueError, TypeError, ArithmeticError):
return 0.01