formatting, cleanup

This commit is contained in:
Andreas Koepf 2025-01-24 17:12:42 +01:00
parent b767e58e48
commit 3dc80be7d2
12 changed files with 189 additions and 376 deletions

View file

@ -1,10 +1,10 @@
import random
import string
from dataclasses import dataclass
from typing import Optional, Tuple
import string
import sympy
from sympy import Symbol, solve, Eq
from sympy import Eq, Symbol, solve
from ..dataset import ProceduralDataset
@ -12,11 +12,12 @@ from ..dataset import ProceduralDataset
@dataclass
class SimpleEquationsConfig:
"""Configuration for simple equation task generation"""
min_terms: int = 2 # Minimum number of terms in expression
max_terms: int = 4 # Maximum number of terms
min_value: int = 1 # Minimum value for constants
max_value: int = 100 # Maximum value for constants
operators: tuple = ('+', '-', '*') # Allowed operators
operators: tuple = ("+", "-", "*") # Allowed operators
seed: Optional[int] = None
size: int = 500
@ -44,7 +45,7 @@ class SimpleEquationsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict:
"""Generate a single equation task
Returns:
dict with keys:
- question: str, the equation to solve (e.g. "3 * x = 12")
@ -52,18 +53,18 @@ class SimpleEquationsDataset(ProceduralDataset):
- metadata: dict with generation parameters
"""
rng = random.Random(self.seed + idx)
# Get variable and generate equation
variable = self._get_variable(rng)
equation, solution = self._generate_equation(rng, variable)
return {
"question": rng.choice(self._prompt_templates).format(variable=variable, equation=equation),
"answer": str(solution),
"metadata": {
"equation": equation,
"variable": variable,
}
},
}
def _get_variable(self, rng: random.Random) -> str:
@ -72,60 +73,60 @@ class SimpleEquationsDataset(ProceduralDataset):
def _generate_equation(self, rng: random.Random, variable: str) -> Tuple[str, int]:
"""Generate an equation and its solution
Args:
rng: Random number generator
variable: Variable symbol to use in equation
Returns:
Tuple of (equation string, solution integer)
"""
x = Symbol(variable)
# Generate terms for left side
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
terms = []
# Generate all constant terms first
for _ in range(num_terms):
value = rng.randint(self.config.min_value, self.config.max_value)
terms.append(value)
# Replace one random term with the variable term
var_pos = rng.randint(0, num_terms - 1)
coef = rng.randint(self.config.min_value, self.config.max_value)
terms[var_pos] = coef * x
# Apply operators between terms
expr = terms[0]
for i in range(1, num_terms):
op = rng.choice(self.config.operators)
if op == '+':
if op == "+":
expr = expr + terms[i]
elif op == '-':
elif op == "-":
expr = expr - terms[i]
else: # '*'
expr = expr * terms[i]
left_side = expr
# Generate right side
right_side = rng.randint(self.config.min_value, self.config.max_value)
# Create equation
equation = Eq(left_side, right_side)
solutions = solve(equation, x)
# Check if we found any solutions
if not solutions:
return self._generate_equation(rng, variable)
solution = solutions[0]
# Only return if solution is a positive integer
if not (isinstance(solution, sympy.Integer) and solution > 0):
return self._generate_equation(rng, variable)
return f"{left_side} = {right_side}", int(solution)
@ -134,7 +135,7 @@ def simple_equations_dataset(
max_terms: int = 5,
min_value: int = 1,
max_value: int = 100,
operators: tuple = ('+', '-', '*'),
operators: tuple = ("+", "-", "*"),
seed: Optional[int] = None,
size: int = 500,
) -> SimpleEquationsDataset: