mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
added intermediate integration dataset generator
This commit is contained in:
parent
d4706c7128
commit
ed1492ba05
1 changed files with 234 additions and 0 deletions
234
reasoning_gym/algebra/intermediate_integration.py
Normal file
234
reasoning_gym/algebra/intermediate_integration.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import List, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateIntegrationConfig:
|
||||
problem_types: tuple = ("substitution", "by_parts")
|
||||
substitution_types: tuple = (
|
||||
"linear", # (ax + b)^n
|
||||
"trigonometric", # sin**2(x)cos(x)
|
||||
"exponential", # 2xe^x**2
|
||||
"radical", # x (3x + 2)^1/2
|
||||
)
|
||||
|
||||
# Integration by parts problem categories
|
||||
by_parts_types: tuple = (
|
||||
"polynomial_exp_trig", # e.g. x^2*e^x
|
||||
"log_inverse_trig", # e.g. ln(x)/arctan(x)
|
||||
"cyclic", # e.g. e^x*sinx requiring cyclic integration
|
||||
"repeated_parts", # Requires multiple integration by parts
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
linear_lower_bound: int = 1 # coefficient of linear expression
|
||||
linear_upper_bound: int = 10
|
||||
min_linear_degree: int = 2 # degree of linear expression in substitution problem
|
||||
max_linear_degree: int = 4
|
||||
outer_constant_min: int = 1 # multiplicative constant to multiply integrand by
|
||||
outer_constant_max: int = 3
|
||||
min_poly_degree: int = 1 # degree of polynomial in by parts problem
|
||||
max_poly_degree: int = 3
|
||||
symbols: tuple = ("x", "X")
|
||||
operators: tuple = (
|
||||
"+",
|
||||
"-",
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the configuration parameters of the integral problem"""
|
||||
assert self.size > 0, "size must be positive"
|
||||
assert self.linear_lower_bound > 0, "linear_lower_bound must be positive"
|
||||
assert self.linear_upper_bound >= self.linear_lower_bound, "linear_upper_bound must be >= linear_lower_bound"
|
||||
assert self.min_linear_degree > 0, "min_linear_degree must be positive"
|
||||
assert self.max_linear_degree >= self.min_linear_degree, "max_linear_degree must be >= min_linear_degree"
|
||||
assert self.outer_constant_min > 0, "outer_constant_min must be positive"
|
||||
assert self.outer_constant_max >= self.outer_constant_min, "outer_constant_max must be >= outer_constant_min"
|
||||
assert self.min_poly_degree > 0, "min_poly_degree must be positive"
|
||||
assert self.max_poly_degree >= self.min_poly_degree, "max_poly_degree must be >= min_poly_degree"
|
||||
assert all(op in ("+", "-") for op in self.operators), "invalid operator specified"
|
||||
assert all(symbols in ("x", "X") for symbols in self.symbols), "invalid symbol specified"
|
||||
assert all(t in ("substitution", "by_parts") for t in self.problem_types), "invalid problem type"
|
||||
assert all(
|
||||
t in ("linear", "trigonometric", "exponential", "radical") for t in self.substitution_types
|
||||
), "invalid substitution type"
|
||||
assert all(
|
||||
t in ("polynomial_exp_trig", "log_inverse_trig", "cyclic", "repeated_parts") for t in self.by_parts_types
|
||||
), "invalid by_parts type"
|
||||
|
||||
|
||||
class IntermediateIntegrationDataset(ProceduralDataset):
|
||||
"""Generates intermediate integration problem - either
|
||||
by substitution or by parts"""
|
||||
|
||||
"""Add multiplicative constant"""
|
||||
|
||||
def __init__(self, config: IntermediateIntegrationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.prompt_template = [
|
||||
"Find the indefinite integral: ∫ {integrand} dx",
|
||||
"Calculate the antiderivative: ∫ {integrand} dx",
|
||||
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
||||
]
|
||||
|
||||
def _get_outer_constant(self, rng: random.Random) -> int:
|
||||
"""Helper to generate signed outer constant from config"""
|
||||
value = rng.randint(self.config.outer_constant_min, self.config.outer_constant_max)
|
||||
return -value if rng.choice(self.config.operators) == "-" else value
|
||||
|
||||
def _generate_linear_substitution_problem(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate a linear substitution problem with outer constant"""
|
||||
a = rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound)
|
||||
b = rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound)
|
||||
|
||||
linear_function = a * x + (b if rng.choice(self.config.operators) == "+" else -b)
|
||||
degree = rng.randint(self.config.min_linear_degree, self.config.max_linear_degree)
|
||||
|
||||
return self._get_outer_constant(rng) * linear_function**degree
|
||||
|
||||
def _generate_exponential_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate exponential substitution problem with outer constant"""
|
||||
exponent_type = rng.choice(["linear", "quadratic"])
|
||||
|
||||
# Generate terms with signs
|
||||
num_terms = 2 if exponent_type == "linear" else 3
|
||||
terms = [
|
||||
(-1 if rng.choice(self.config.operators) == "-" else 1)
|
||||
* rng.randint(self.config.linear_lower_bound, self.config.linear_upper_bound)
|
||||
for _ in range(num_terms)
|
||||
]
|
||||
|
||||
if exponent_type == "linear":
|
||||
u = terms[0] * x + terms[1]
|
||||
du_dx = terms[0]
|
||||
else: # Quadratic
|
||||
u = terms[0] * x**2 + terms[1] * x + terms[2]
|
||||
du_dx = 2 * terms[0] * x + terms[1]
|
||||
|
||||
return self._get_outer_constant(rng) * du_dx * sympy.exp(u)
|
||||
|
||||
def _generate_radical_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate radical substitution problem with outer constant"""
|
||||
|
||||
# Generate linear expression under radical: ax + b with possible negative coefficients
|
||||
a = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint(
|
||||
self.config.linear_lower_bound, self.config.linear_upper_bound
|
||||
)
|
||||
b = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint(
|
||||
self.config.linear_lower_bound, self.config.linear_upper_bound
|
||||
)
|
||||
|
||||
u = a * x + b
|
||||
derivative = a # du/dx
|
||||
|
||||
integrand = derivative * sympy.sqrt(u)
|
||||
return self._get_outer_constant(rng) * integrand
|
||||
|
||||
def _generate_trigonometric_substitution(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate trigonometric substitution with outer constant"""
|
||||
trig_func = rng.choice(["sin", "cos"])
|
||||
|
||||
# Generate signed coefficients
|
||||
a = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint(
|
||||
self.config.linear_lower_bound, self.config.linear_upper_bound
|
||||
)
|
||||
b = (-1 if rng.choice(self.config.operators) == "-" else 1) * rng.randint(
|
||||
self.config.linear_lower_bound, self.config.linear_upper_bound
|
||||
)
|
||||
|
||||
inner = a * x + b
|
||||
power = rng.randint(1, 4)
|
||||
if trig_func == "sin":
|
||||
integrand = a * sympy.cos(inner) * sympy.sin(inner) ** power
|
||||
else:
|
||||
integrand = -a * sympy.sin(inner) * sympy.cos(inner) ** power
|
||||
return self._get_outer_constant(rng) * integrand
|
||||
|
||||
def _generate_polynomial_exp_trig(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate polynomial × exponential/trigonometric integrand"""
|
||||
poly_degree = rng.randint(self.config.min_poly_degree, self.config.max_poly_degree)
|
||||
|
||||
func_type = rng.choice(["exp", "sin", "cos"])
|
||||
if func_type == "exp":
|
||||
transcendental = sympy.exp(x)
|
||||
else:
|
||||
coefficient = rng.randint(1, 3)
|
||||
transcendental = sympy.Function(func_type)(coefficient * x)
|
||||
|
||||
polynomial = x**poly_degree
|
||||
integrand = polynomial * transcendental
|
||||
return self._get_outer_constant(rng) * integrand
|
||||
|
||||
def _generate_log_inverse_trig(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate logarithmic or inverse trigonometric integrand"""
|
||||
func_type = rng.choice(["log", "asin", "atan"])
|
||||
|
||||
if func_type == "log":
|
||||
log_arg = x if rng.random() < 0.8 else x ** rng.randint(2, 3)
|
||||
func = sympy.ln(log_arg)
|
||||
else:
|
||||
coefficient = rng.randint(1, 3)
|
||||
func = sympy.Function(func_type)(coefficient * x)
|
||||
|
||||
return self._get_outer_constant(rng) * func
|
||||
|
||||
def _generate_cyclic_integral(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||
"""Generate cyclic integral (e.g., e^x * sinx)"""
|
||||
func_pair = rng.choice(
|
||||
[(sympy.exp(x), sympy.sin(x)), (sympy.exp(x), sympy.cos(x)), (sympy.sin(x), sympy.cos(x))]
|
||||
)
|
||||
integrand = func_pair[0] * func_pair[1]
|
||||
return self._get_outer_constant(rng) * integrand
|
||||
|
||||
def _generate_repeated_parts(self, rng: random.Random, x: sympy.Symbol):
|
||||
"""Generate problem requiring multiple integration by parts"""
|
||||
poly_degree = rng.randint(3, self.config.max_poly_degree)
|
||||
transcendental = rng.choice([sympy.sin(x), sympy.cos(x), sympy.exp(x)])
|
||||
integrand = x**poly_degree * transcendental
|
||||
return self._get_outer_constant(rng) * integrand
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
"""Generate either substitution or by-parts problem"""
|
||||
rng = random.Random(self.seed + index)
|
||||
problem_type = rng.choice(self.config.problem_types)
|
||||
x = sympy.Symbol(rng.choice(self.config.symbols))
|
||||
|
||||
if problem_type == "substitution":
|
||||
substitution_type = rng.choice(self.config.substitution_types)
|
||||
if substitution_type == "linear":
|
||||
integrand = self._generate_linear_substitution_problem(rng, x)
|
||||
elif substitution_type == "trigonometric":
|
||||
integrand = self._generate_trigonometric_substitution(rng, x)
|
||||
elif substitution_type == "exponential":
|
||||
integrand = self._generate_exponential_substitution(rng, x)
|
||||
elif substitution_type == "radical":
|
||||
integrand = self._generate_radical_substitution(rng, x)
|
||||
else:
|
||||
parts_type = rng.choice(self.config.by_parts_types)
|
||||
if parts_type == "polynomial_exp_trig":
|
||||
integrand = self._generate_polynomial_exp_trig(rng, x)
|
||||
elif parts_type == "log_inverse_trig":
|
||||
integrand = self._generate_log_inverse_trig(rng, x)
|
||||
elif parts_type == "cyclic":
|
||||
integrand = self._generate_cyclic_integral(rng, x)
|
||||
elif parts_type == "repeated_parts":
|
||||
integrand = self._generate_repeated_parts(rng, x)
|
||||
|
||||
answer = sympy.integrate(integrand, x)
|
||||
return {
|
||||
"question": rng.choice(self.prompt_template).format(integrand=integrand),
|
||||
"answer": str(answer) + " + C",
|
||||
"metadata": {
|
||||
"integrand": str(integrand),
|
||||
"problem_type": problem_type,
|
||||
"variable": str(x),
|
||||
"type": substitution_type if problem_type == "substitution" else parts_type,
|
||||
},
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue