mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Template Abstraction
This commit is contained in:
parent
227319f1da
commit
c2e3dbc826
8 changed files with 325 additions and 387 deletions
|
|
@ -1,12 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
"""
|
||||
Chain arithmetic exercise that evaluates expressions with operator precedence.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import operator
|
||||
import numpy as np
|
||||
from reasoning_gym.core.base_curriculum import BaseCurriculum
|
||||
|
||||
@dataclass
|
||||
class ChainSumDataset:
|
||||
"""Dataset generator for chain arithmetic problems."""
|
||||
class ChainSumExercise:
|
||||
"""Exercise generator for chain arithmetic problems."""
|
||||
def __init__(self):
|
||||
# Define operator mappings
|
||||
self.pedmas = {
|
||||
|
|
@ -18,8 +19,16 @@ class ChainSumDataset:
|
|||
}
|
||||
self.curriculum = None
|
||||
|
||||
def generate(self, curriculum: BaseCurriculum) -> Dict[str, Any]:
|
||||
"""Generate a problem using the curriculum's template system"""
|
||||
def generate(self, curriculum: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a problem using the curriculum's template system.
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- question: str (e.g. "What is 2 + 3 * 4?")
|
||||
- answer: float (the computed result)
|
||||
- metadata: dict with parsed expression details
|
||||
"""
|
||||
self.curriculum = curriculum
|
||||
max_attempts = 10
|
||||
|
||||
|
|
@ -32,35 +41,68 @@ class ChainSumDataset:
|
|||
continue
|
||||
raise
|
||||
|
||||
def _parse_expression(self, executed_parts: Dict[str, str]) -> tuple[list, list]:
|
||||
"""Extract values and operators from executed parts"""
|
||||
values = []
|
||||
operators = []
|
||||
def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse the template metadata into structured data.
|
||||
|
||||
Args:
|
||||
metadata: Raw metadata from template evaluation
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- values: List of numeric values
|
||||
- operators: List of operators
|
||||
- structure: Expression structure info
|
||||
"""
|
||||
expr_parts = metadata["expression"]["executed_parts"]
|
||||
parsed = {
|
||||
"values": [],
|
||||
"operators": [],
|
||||
"structure": {
|
||||
"num_terms": 0,
|
||||
"notations": []
|
||||
}
|
||||
}
|
||||
|
||||
# Extract values
|
||||
i = 0
|
||||
while f"term_{i}" in executed_parts:
|
||||
val = executed_parts[f"term_{i}"].lstrip('+')
|
||||
while f"term_{i}" in expr_parts:
|
||||
val = expr_parts[f"term_{i}"].lstrip('+')
|
||||
try:
|
||||
num = val.lstrip('-')
|
||||
if num.startswith(('0b', '0x')):
|
||||
sign = -1 if val.startswith('-') else 1
|
||||
base = 2 if num.startswith('0b') else 16 if num.startswith('0x') else 10
|
||||
values.append(sign * float(int(num[2:], base)))
|
||||
parsed["values"].append(sign * float(int(num[2:], base)))
|
||||
parsed["structure"]["notations"].append(f"base{base}")
|
||||
else:
|
||||
values.append(float(val))
|
||||
parsed["values"].append(float(val))
|
||||
parsed["structure"]["notations"].append("scientific" if 'e' in num.lower() else "regular")
|
||||
except ValueError:
|
||||
values.append(val)
|
||||
parsed["values"].append(val)
|
||||
parsed["structure"]["notations"].append("unknown")
|
||||
i += 1
|
||||
|
||||
parsed["structure"]["num_terms"] = i
|
||||
|
||||
# Extract operators
|
||||
for i in range(len(values) - 1):
|
||||
if f"op_{i}" in executed_parts:
|
||||
operators.append(executed_parts[f"op_{i}"])
|
||||
for i in range(len(parsed["values"]) - 1):
|
||||
if f"op_{i}" in expr_parts:
|
||||
parsed["operators"].append(expr_parts[f"op_{i}"])
|
||||
|
||||
return values, operators
|
||||
return parsed
|
||||
|
||||
def _evaluate_expression(self, parsed: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Evaluate expression respecting operator precedence.
|
||||
|
||||
Args:
|
||||
parsed: Dictionary containing parsed expression data
|
||||
Returns:
|
||||
float: The computed result
|
||||
"""
|
||||
values = parsed["values"]
|
||||
operators = parsed["operators"]
|
||||
|
||||
def _evaluate_expression(self, values: list, operators: list) -> float:
|
||||
"""Evaluate expression respecting operator precedence"""
|
||||
if not operators:
|
||||
return values[0] if values else 0
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue