Template Abstraction

This commit is contained in:
EduardDurech 2025-02-07 17:06:50 +00:00
parent 227319f1da
commit c2e3dbc826
8 changed files with 325 additions and 387 deletions

View file

@ -1,12 +1,13 @@
from dataclasses import dataclass
"""
Chain arithmetic exercise that evaluates expressions with operator precedence.
"""
from typing import Dict, Any
import operator
import numpy as np
from reasoning_gym.core.base_curriculum import BaseCurriculum
@dataclass
class ChainSumDataset:
"""Dataset generator for chain arithmetic problems."""
class ChainSumExercise:
"""Exercise generator for chain arithmetic problems."""
def __init__(self):
# Define operator mappings
self.pedmas = {
@ -18,8 +19,16 @@ class ChainSumDataset:
}
self.curriculum = None
def generate(self, curriculum: BaseCurriculum) -> Dict[str, Any]:
"""Generate a problem using the curriculum's template system"""
def generate(self, curriculum: Any) -> Dict[str, Any]:
"""
Generate a problem using the curriculum's template system.
Returns:
Dict containing:
- question: str (e.g. "What is 2 + 3 * 4?")
- answer: float (the computed result)
- metadata: dict with parsed expression details
"""
self.curriculum = curriculum
max_attempts = 10
@ -32,35 +41,68 @@ class ChainSumDataset:
continue
raise
def _parse_expression(self, executed_parts: Dict[str, str]) -> tuple[list, list]:
"""Extract values and operators from executed parts"""
values = []
operators = []
def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse the template metadata into structured data.
Args:
metadata: Raw metadata from template evaluation
Returns:
Dictionary containing:
- values: List of numeric values
- operators: List of operators
- structure: Expression structure info
"""
expr_parts = metadata["expression"]["executed_parts"]
parsed = {
"values": [],
"operators": [],
"structure": {
"num_terms": 0,
"notations": []
}
}
# Extract values
i = 0
while f"term_{i}" in executed_parts:
val = executed_parts[f"term_{i}"].lstrip('+')
while f"term_{i}" in expr_parts:
val = expr_parts[f"term_{i}"].lstrip('+')
try:
num = val.lstrip('-')
if num.startswith(('0b', '0x')):
sign = -1 if val.startswith('-') else 1
base = 2 if num.startswith('0b') else 16 if num.startswith('0x') else 10
values.append(sign * float(int(num[2:], base)))
parsed["values"].append(sign * float(int(num[2:], base)))
parsed["structure"]["notations"].append(f"base{base}")
else:
values.append(float(val))
parsed["values"].append(float(val))
parsed["structure"]["notations"].append("scientific" if 'e' in num.lower() else "regular")
except ValueError:
values.append(val)
parsed["values"].append(val)
parsed["structure"]["notations"].append("unknown")
i += 1
parsed["structure"]["num_terms"] = i
# Extract operators
for i in range(len(values) - 1):
if f"op_{i}" in executed_parts:
operators.append(executed_parts[f"op_{i}"])
for i in range(len(parsed["values"]) - 1):
if f"op_{i}" in expr_parts:
parsed["operators"].append(expr_parts[f"op_{i}"])
return values, operators
return parsed
def _evaluate_expression(self, parsed: Dict[str, Any]) -> float:
"""
Evaluate expression respecting operator precedence.
Args:
parsed: Dictionary containing parsed expression data
Returns:
float: The computed result
"""
values = parsed["values"]
operators = parsed["operators"]
def _evaluate_expression(self, values: list, operators: list) -> float:
"""Evaluate expression respecting operator precedence"""
if not operators:
return values[0] if values else 0