Refactor SimpleEquations

This commit is contained in:
EduardDurech 2025-02-08 20:51:18 +00:00
parent b3e61988c1
commit 7dce30324b
6 changed files with 789 additions and 201 deletions

View file

@ -1,7 +1,7 @@
from .polynomial_equations import PolynomialEquationsExercise
# from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset
from .simple_equations import SimpleEquationsExercise
__all__ = [
# "SimpleEquationsDataset",
"SimpleEquationsExercise",
"PolynomialEquationsExercise",
]

View file

@ -1,119 +1,125 @@
import random
import string
from dataclasses import dataclass
from typing import Optional, Tuple
"""
Simple equations exercise that generates and solves linear equations with one variable.
"""
import sympy
from sympy import Eq, Symbol, solve
from typing import Dict, Any
from sympy import Symbol, solve, parse_expr, Eq
from ..factory import ProceduralDataset, register_dataset
class SimpleEquationsExercise:
"""Exercise generator for simple equations with one variable."""
def __init__(self):
self.curriculum = None
@dataclass
class SimpleEquationsConfig:
"""Configuration for simple equation task generation"""
min_terms: int = 2 # Minimum number of terms in expression
max_terms: int = 4 # Maximum number of terms
min_value: int = 1 # Minimum value for constants
max_value: int = 100 # Maximum value for constants
operators: tuple = ("+", "-", "*") # Allowed operators
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
assert self.min_value > 0, "min_value must be positive"
assert self.max_value >= self.min_value, "max_value must be >= min_value"
assert len(self.operators) > 0, "must specify at least one operator"
assert all(op in ("+", "-", "*") for op in self.operators), "invalid operator specified"
class SimpleEquationsDataset(ProceduralDataset):
"""Generates simple equations with one variable to solve"""
def __init__(self, config: SimpleEquationsConfig):
self._prompt_templates = [
"Find the value of {variable} in the equation: {equation}",
"Solve for {variable}: {equation}",
"Determine the value of {variable} that satisfies: {equation}",
]
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single equation task
def generate(self, curriculum: Any) -> Dict[str, Any]:
"""
Generate a simple equation problem using the curriculum.
Returns:
dict with keys:
- question: str, the equation to solve (e.g. "3 * x = 12")
- answer: str, the solution value (e.g. "4")
- metadata: dict with generation parameters
Dict containing:
- question: str (e.g. "Find the value of x in the equation: 3*x + 2 = 4*x - 1")
- answer: str (the solution value, e.g. "3")
- metadata: dict with details (equation, variable, etc.)
"""
rng = random.Random(self.seed + idx)
self.curriculum = curriculum
template = curriculum.get_template(curriculum.rng)
return template.eval(self, curriculum.rng)
# Get variable and generate equation
variable = self._get_variable(rng)
equation, solution = self._generate_equation(rng, variable)
def _parse_expression(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse the template metadata into structured data.
return {
"question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation),
"answer": str(solution),
"metadata": {
"equation": equation,
"variable": variable,
The metadata structure is expected to be:
{
"lhs": {
"term_0": {
"sign": str, # "" or "-"
"coeff": str, # coefficient value with "*" if needed
"variable": str # variable name or empty
},
"term_1": {...}, # Same structure as term_0
...,
"op_0": str, # "+" or "-" between terms
"op_1": str, # More operators if needed
...
},
"rhs": { # Same structure as lhs
...
},
"variable": {
"var": str # The variable name used in the equation
}
}
def _get_variable(self, rng: random.Random) -> str:
"""Get a random lowercase variable name"""
return rng.choice(string.ascii_lowercase)
Args:
metadata: Raw metadata from template evaluation
Returns:
Dictionary containing:
- lhs_terms: List[str] of formatted term strings for left side
- rhs_terms: List[str] of formatted term strings for right side
- lhs_operators: List[str] of operators between left terms
- rhs_operators: List[str] of operators between right terms
- variable: str, the variable name used
"""
def parse_side(side_parts: Dict[str, Any]) -> tuple[list, list]:
"""Helper to parse one side of the equation."""
terms = []
operators = []
i = 0
while f"term_{i}" in side_parts:
term_dict = side_parts[f"term_{i}"]
terms.append("".join(term_dict[k] for k in ("sign", "coeff", "variable")))
if f"op_{i}" in side_parts:
operators.append(side_parts[f"op_{i}"])
i += 1
return terms, operators
def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]:
"""Generate an equation and its solution
# Parse both sides of the equation
lhs_terms, lhs_operators = parse_side(metadata["lhs"])
rhs_terms, rhs_operators = parse_side(metadata["rhs"])
return {
"lhs_terms": lhs_terms,
"rhs_terms": rhs_terms,
"lhs_operators": lhs_operators,
"rhs_operators": rhs_operators,
"variable": metadata["variable"]["var"]
}
def _evaluate_expression(self, parsed: Dict[str, Any]) -> str:
"""
Evaluate the equation and find its solution.
Args:
rng: Random number generator
variable: Variable symbol to use in equation
parsed: Dictionary containing parsed expression data
Returns:
Tuple of (equation string, solution integer)
String representation of the solution
"""
x = Symbol(variable)
# Create sympy symbol from parsed variable
var = Symbol(parsed["variable"])
# Generate terms for left side
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
terms = []
# Build left and right expressions
def build_expr(terms: list, operators: list) -> str:
"""Helper to build expression string from terms and operators."""
expr = terms[0]
for i, op in enumerate(operators):
expr = f"{expr} {op} {terms[i + 1]}"
return expr
# Generate all constant terms first
for _ in range(num_terms):
value = rng.randint(self.config.min_value, self.config.max_value)
terms.append(value)
lhs_expr = build_expr(parsed["lhs_terms"], parsed["lhs_operators"])
rhs_expr = build_expr(parsed["rhs_terms"], parsed["rhs_operators"])
# Replace one random term with the variable term
var_pos = rng.randint(0, num_terms - 1)
coef = rng.randint(self.config.min_value, self.config.max_value)
if "*" in self.config.operators:
terms[var_pos] = coef * x
else:
terms[var_pos] = x
try:
# Parse both sides into sympy expressions
lhs = parse_expr(lhs_expr, local_dict={parsed["variable"]: var})
rhs = parse_expr(rhs_expr, local_dict={parsed["variable"]: var})
# Apply operators between terms
expr = terms[0]
for i in range(1, num_terms):
op = rng.choice(self.config.operators)
if op == "+":
expr = expr + terms[i]
elif op == "-":
expr = expr - terms[i]
else: # '*'
expr = expr * terms[i]
# Solve the equation
solution = solve(Eq(lhs, rhs), var)
left_side = expr
solution_value = rng.randint(self.config.min_value, self.config.max_value)
right_side = left_side.subs(x, solution_value)
return f"{left_side} = {right_side}", solution_value
register_dataset("simple_equations", SimpleEquationsDataset, SimpleEquationsConfig)
# Convert to float and return as string
if solution:
return str(float(solution[0]))
return ""
except Exception as e:
return f"Error solving equation: {lhs_expr} = {rhs_expr}\nError: {str(e)}"