diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py new file mode 100644 index 00000000..df0388a8 --- /dev/null +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -0,0 +1,234 @@ +import random +from dataclasses import dataclass +from fractions import Fraction +from typing import List, Optional + +import sympy + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class IntermediateIntegrationConfig: + problem_types: tuple = ("substitution", "by_parts") + substitution_types: tuple = ( + "linear", # (ax + b)^n + "trigonometric", # sin**2(x)cos(x) + "exponential", # 2xe^x**2 + "radical", # x (3x + 2)^1/2 + ) + + # Integration by parts problem categories + by_parts_types: tuple = ( + "polynomial_exp_trig", # e.g. x^2*e^x + "log_inverse_trig", # e.g. ln(x)/arctan(x) + "cyclic", # e.g. e^x*sinx requiring cyclic integration + "repeated_parts", # Requires multiple integration by parts + ) + seed: Optional[int] = None + size: int = 500 + + linear_lower_bound: int = 1 # coefficient of linear expression + linear_upper_bound: int = 10 + min_linear_degree: int = 2 # degree of linear expression in substitution problem + max_linear_degree: int = 4 + outer_constant_min: int = 1 # multiplicative constant to multiply integrand by + outer_constant_max: int = 3 + min_poly_degree: int = 1 # degree of polynomial in by parts problem + max_poly_degree: int = 3 + symbols: tuple = ("x", "X") + operators: tuple = ( + "+", + "-", + ) + + def validate(self) -> None: + """Validate the configuration parameters of the integral problem""" + assert self.size > 0, "size must be positive" + assert self.linear_lower_bound > 0, "linear_lower_bound must be positive" + assert self.linear_upper_bound >= self.linear_lower_bound, "linear_upper_bound must be >= linear_lower_bound" + assert self.min_linear_degree > 0, "min_linear_degree must be positive" + assert self.max_linear_degree >= self.min_linear_degree, "max_linear_degree must be >= min_linear_degree" + assert self.outer_constant_min > 0, "outer_constant_min must be positive" + assert self.outer_constant_max >= self.outer_constant_min, "outer_constant_max must be >= outer_constant_min" + assert self.min_poly_degree > 0, "min_poly_degree must be positive" + assert self.max_poly_degree >= self.min_poly_degree, "max_poly_degree must be >= min_poly_degree" + assert all(op in ("+", "-") for op in self.operators), "invalid operator specified" + assert all(symbols in ("x", "X") for symbols in self.symbols), "invalid symbol specified" + assert all(t in ("substitution", "by_parts") for t in self.problem_types), "invalid problem type" + assert all( + t in ("linear", "trigonometric", "exponential", "radical") for t in self.substitution_types + ), "invalid substitution type" + assert all( + t in ("polynomial_exp_trig", "log_inverse_trig", "cyclic", "repeated_parts") for t in self.by_parts_types + ), "invalid by_parts type" + + +class IntermediateIntegrationDataset(ProceduralDataset): + """Generates intermediate integration problem - either + by substitution or by parts""" + + """Add multiplicative constant""" + + def __init__(self, config: IntermediateIntegrationConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.prompt_template = [ + "Find the indefinite integral: ∫ {integrand} dx", + "Calculate the antiderivative: ∫ {integrand} dx", + "Evaluate the indefinite integral: ∫ {integrand} dx", + ] + + def _get_outer_constant(self, rng: random.Random) -> int: + """Helper to generate signed outer constant from config""" + value = rng.randint(self.config.outer_constant_min, self.config.outer_constant_max) + return -value if rng.choice(self.config.operators) == "-" else value + + def _generate_linear_substitution_problem(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate a linear substitution problem with outer constant""" + a = rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound) + b = rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound) + + linear_function = a * x + (b if rng.choice(self.config.operators) == "+" else -b) + degree = rng.randint(self.config.min_linear_degree, self.config.max_linear_degree) + + return self._get_outer_constant(rng) * linear_function**degree + + def _generate_exponential_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate exponential substitution problem with outer constant""" + exponent_type = rng.choice(["linear", "quadratic"]) + + # Generate terms with signs + num_terms = 2 if exponent_type == "linear" else 3 + terms = [ + (-1 if rng.choice(self.config.operators) == "-" else 1) + * rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound) + for _ in range(num_terms) + ] + + if exponent_type == "linear": + u = terms[0] * x + terms[1] + du_dx = terms[0] + else: # Quadratic + u = terms[0] * x**2 + terms[1] * x + terms[2] + du_dx = 2 * terms[0] * x + terms[1] + + return self._get_outer_constant(rng) * du_dx * sympy.exp(u) + + def _generate_radical_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate radical substitution problem with outer constant""" + + # Generate linear expression under radical: ax + b with possible negative coefficients + a = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + b = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + + u = a * x + b + derivative = a # du/dx + + integrand = derivative * sympy.sqrt(u) + return self._get_outer_constant(rng) * integrand + + def _generate_trigonometric_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate trigonometric substitution with outer constant""" + trig_func = rng.choice(["sin", "cos"]) + + # Generate signed coefficients + a = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + b = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint( + self.config.linear_lower_bound, self.config.linear_upper_bound + ) + + inner = a * x + b + power = rng.randint(1, 4) + if trig_func == "sin": + integrand = a * sympy.cos(inner) * sympy.sin(inner) ** power + else: + integrand = -a * sympy.sin(inner) * sympy.cos(inner) ** power + return self._get_outer_constant(rng) * integrand + + def _generate_polynomial_exp_trig(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate polynomial × exponential/trigonometric integrand""" + poly_degree = rng.randint(self.config.min_poly_degree, self.config.max_poly_degree) + + func_type = rng.choice(["exp", "sin", "cos"]) + if func_type == "exp": + transcendental = sympy.exp(x) + else: + coefficient = rng.randint(1, 3) + transcendental = sympy.Function(func_type)(coefficient * x) + + polynomial = x**poly_degree + integrand = polynomial * transcendental + return self._get_outer_constant(rng) * integrand + + def _generate_log_inverse_trig(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate logarithmic or inverse trigonometric integrand""" + func_type = rng.choice(["log", "asin", "atan"]) + + if func_type == "log": + log_arg = x if rng.random() < 0.8 else x ** rng.randint(2, 3) + func = sympy.ln(log_arg) + else: + coefficient = rng.randint(1, 3) + func = sympy.Function(func_type)(coefficient * x) + + return self._get_outer_constant(rng) * func + + def _generate_cyclic_integral(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr: + """Generate cyclic integral (e.g., e^x * sinx)""" + func_pair = rng.choice( + [(sympy.exp(x), sympy.sin(x)), (sympy.exp(x), sympy.cos(x)), (sympy.sin(x), sympy.cos(x))] + ) + integrand = func_pair[0] * func_pair[1] + return self._get_outer_constant(rng) * integrand + + def _generate_repeated_parts(self, rng: random.Random, x: sympy.Symbol): + """Generate problem requiring multiple integration by parts""" + poly_degree = rng.randint(3, self.config.max_poly_degree) + transcendental = rng.choice([sympy.sin(x), sympy.cos(x), sympy.exp(x)]) + integrand = x**poly_degree * transcendental + return self._get_outer_constant(rng) * integrand + + def __getitem__(self, index: int): + """Generate either substitution or by-parts problem""" + rng = random.Random(self.seed + index) + problem_type = rng.choice(self.config.problem_types) + x = sympy.Symbol(rng.choice(self.config.symbols)) + + if problem_type == "substitution": + substitution_type = rng.choice(self.config.substitution_types) + if substitution_type == "linear": + integrand = self._generate_linear_substitution_problem(rng, x) + elif substitution_type == "trigonometric": + integrand = self._generate_trigonometric_substitution(rng, x) + elif substitution_type == "exponential": + integrand = self._generate_exponential_substitution(rng, x) + elif substitution_type == "radical": + integrand = self._generate_radical_substitution(rng, x) + else: + parts_type = rng.choice(self.config.by_parts_types) + if parts_type == "polynomial_exp_trig": + integrand = self._generate_polynomial_exp_trig(rng, x) + elif parts_type == "log_inverse_trig": + integrand = self._generate_log_inverse_trig(rng, x) + elif parts_type == "cyclic": + integrand = self._generate_cyclic_integral(rng, x) + elif parts_type == "repeated_parts": + integrand = self._generate_repeated_parts(rng, x) + + answer = sympy.integrate(integrand, x) + return { + "question": rng.choice(self.prompt_template).format(integrand=integrand), + "answer": str(answer) + " + C", + "metadata": { + "integrand": str(integrand), + "problem_type": problem_type, + "variable": str(x), + "type": substitution_type if problem_type == "substitution" else parts_type, + }, + }