mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
added intermediate integration (#334)
This commit is contained in:
parent
ede43c58ba
commit
516bca57ab
2 changed files with 151 additions and 62 deletions
|
|
@ -1,28 +1,27 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IntermediateIntegrationConfig:
|
class IntermediateIntegrationConfig:
|
||||||
problem_types: tuple = ("substitution", "by_parts")
|
problem_types: tuple = (
|
||||||
substitution_types: tuple = (
|
"linear",
|
||||||
"linear", # (ax + b)^n
|
"radical",
|
||||||
"trigonometric", # sin**2(x)cos(x)
|
"log_inverse_trig",
|
||||||
"exponential", # 2xe^x**2
|
"trigonometric",
|
||||||
"radical", # x (3x + 2)^1/2
|
"polynomial_exp_trig",
|
||||||
|
"exponential",
|
||||||
|
"cyclic",
|
||||||
|
"repeated_parts",
|
||||||
)
|
)
|
||||||
|
problem_type_weights: list[float] = field(
|
||||||
# Integration by parts problem categories
|
default_factory=lambda: [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]
|
||||||
by_parts_types: tuple = (
|
|
||||||
"polynomial_exp_trig", # e.g. x^2*e^x
|
|
||||||
"log_inverse_trig", # e.g. ln(x)/arctan(x)
|
|
||||||
"cyclic", # e.g. e^x*sinx requiring cyclic integration
|
|
||||||
"repeated_parts", # Requires multiple integration by parts
|
|
||||||
)
|
)
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
@ -35,7 +34,7 @@ class IntermediateIntegrationConfig:
|
||||||
outer_constant_max: int = 3
|
outer_constant_max: int = 3
|
||||||
min_poly_degree: int = 1 # degree of polynomial in by parts problem
|
min_poly_degree: int = 1 # degree of polynomial in by parts problem
|
||||||
max_poly_degree: int = 3
|
max_poly_degree: int = 3
|
||||||
symbols: tuple = ("x", "X")
|
symbols: tuple = "x"
|
||||||
operators: tuple = (
|
operators: tuple = (
|
||||||
"+",
|
"+",
|
||||||
"-",
|
"-",
|
||||||
|
|
@ -43,6 +42,9 @@ class IntermediateIntegrationConfig:
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
"""Validate the configuration parameters of the integral problem"""
|
"""Validate the configuration parameters of the integral problem"""
|
||||||
|
assert len(self.problem_types) == len(
|
||||||
|
self.problem_type_weights
|
||||||
|
), "problem_types and problem_type_weights must have the same length"
|
||||||
assert self.size > 0, "size must be positive"
|
assert self.size > 0, "size must be positive"
|
||||||
assert self.linear_lower_bound > 0, "linear_lower_bound must be positive"
|
assert self.linear_lower_bound > 0, "linear_lower_bound must be positive"
|
||||||
assert self.linear_upper_bound >= self.linear_lower_bound, "linear_upper_bound must be >= linear_lower_bound"
|
assert self.linear_upper_bound >= self.linear_lower_bound, "linear_upper_bound must be >= linear_lower_bound"
|
||||||
|
|
@ -54,13 +56,6 @@ class IntermediateIntegrationConfig:
|
||||||
assert self.max_poly_degree >= self.min_poly_degree, "max_poly_degree must be >= min_poly_degree"
|
assert self.max_poly_degree >= self.min_poly_degree, "max_poly_degree must be >= min_poly_degree"
|
||||||
assert all(op in ("+", "-") for op in self.operators), "invalid operator specified"
|
assert all(op in ("+", "-") for op in self.operators), "invalid operator specified"
|
||||||
assert all(symbols in ("x", "X") for symbols in self.symbols), "invalid symbol specified"
|
assert all(symbols in ("x", "X") for symbols in self.symbols), "invalid symbol specified"
|
||||||
assert all(t in ("substitution", "by_parts") for t in self.problem_types), "invalid problem type"
|
|
||||||
assert all(
|
|
||||||
t in ("linear", "trigonometric", "exponential", "radical") for t in self.substitution_types
|
|
||||||
), "invalid substitution type"
|
|
||||||
assert all(
|
|
||||||
t in ("polynomial_exp_trig", "log_inverse_trig", "cyclic", "repeated_parts") for t in self.by_parts_types
|
|
||||||
), "invalid by_parts type"
|
|
||||||
|
|
||||||
|
|
||||||
class IntermediateIntegrationDataset(ProceduralDataset):
|
class IntermediateIntegrationDataset(ProceduralDataset):
|
||||||
|
|
@ -78,6 +73,7 @@ class IntermediateIntegrationDataset(ProceduralDataset):
|
||||||
]
|
]
|
||||||
self.added_instruction = """
|
self.added_instruction = """
|
||||||
When performing calculations, please follow these guidelines:
|
When performing calculations, please follow these guidelines:
|
||||||
|
Use same variable symbols as given in the question
|
||||||
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
|
1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2.
|
||||||
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
|
2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`.
|
||||||
3. Use `exp(x)` or `E**(x)` for the exponential function (i.e. use capital E for Euler's number).
|
3. Use `exp(x)` or `E**(x)` for the exponential function (i.e. use capital E for Euler's number).
|
||||||
|
|
@ -175,13 +171,22 @@ When performing calculations, please follow these guidelines:
|
||||||
"""Generate logarithmic or inverse trigonometric integrand"""
|
"""Generate logarithmic or inverse trigonometric integrand"""
|
||||||
func_type = rng.choice(["log", "asin", "atan"])
|
func_type = rng.choice(["log", "asin", "atan"])
|
||||||
|
|
||||||
|
coefficient = rng.randint(1, 3)
|
||||||
if func_type == "log":
|
if func_type == "log":
|
||||||
log_arg = x if rng.random() < 0.8 else x ** rng.randint(2, 3)
|
log_arg = x if rng.random() < 0.8 else x ** rng.randint(2, 3)
|
||||||
func = sympy.ln(log_arg)
|
func = sympy.ln(log_arg)
|
||||||
else:
|
elif func_type == "asin":
|
||||||
coefficient = rng.randint(1, 3)
|
# For asin(ax), the integral is:
|
||||||
func = sympy.Function(func_type)(coefficient * x)
|
# x*asin(ax) + (1/a)*sqrt(1-(ax)^2)
|
||||||
|
inner_coef = coefficient
|
||||||
|
func = sympy.asin(inner_coef * x)
|
||||||
|
elif func_type == "atan":
|
||||||
|
# For atan(ax), the integral is:
|
||||||
|
# x*atan(ax) - (1/2a)*ln(1+(ax)^2)
|
||||||
|
inner_coef = coefficient
|
||||||
|
func = sympy.atan(inner_coef * x)
|
||||||
|
|
||||||
|
# The sympy.integrate will correctly handle all these cases
|
||||||
return self._get_outer_constant(rng) * func
|
return self._get_outer_constant(rng) * func
|
||||||
|
|
||||||
def _generate_cyclic_integral(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
def _generate_cyclic_integral(self, rng: random.Random, x: sympy.Symbol) -> sympy.Expr:
|
||||||
|
|
@ -202,28 +207,24 @@ When performing calculations, please follow these guidelines:
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
"""Generate either substitution or by-parts problem"""
|
"""Generate either substitution or by-parts problem"""
|
||||||
rng = random.Random(self.seed + index)
|
rng = random.Random(self.seed + index)
|
||||||
problem_type = rng.choice(self.config.problem_types)
|
problem_type = rng.choices(self.config.problem_types, weights=self.config.problem_type_weights, k=1)[0]
|
||||||
x = sympy.Symbol(rng.choice(self.config.symbols))
|
x = sympy.Symbol(rng.choice(self.config.symbols))
|
||||||
|
|
||||||
if problem_type == "substitution":
|
if problem_type == "linear":
|
||||||
substitution_type = rng.choice(self.config.substitution_types)
|
|
||||||
if substitution_type == "linear":
|
|
||||||
integrand = self._generate_linear_substitution_problem(rng, x)
|
integrand = self._generate_linear_substitution_problem(rng, x)
|
||||||
elif substitution_type == "trigonometric":
|
elif problem_type == "trigonometric":
|
||||||
integrand = self._generate_trigonometric_substitution(rng, x)
|
integrand = self._generate_trigonometric_substitution(rng, x)
|
||||||
elif substitution_type == "exponential":
|
elif problem_type == "exponential":
|
||||||
integrand = self._generate_exponential_substitution(rng, x)
|
integrand = self._generate_exponential_substitution(rng, x)
|
||||||
elif substitution_type == "radical":
|
elif problem_type == "radical":
|
||||||
integrand = self._generate_radical_substitution(rng, x)
|
integrand = self._generate_radical_substitution(rng, x)
|
||||||
else:
|
elif problem_type == "log_inverse_trig":
|
||||||
parts_type = rng.choice(self.config.by_parts_types)
|
|
||||||
if parts_type == "polynomial_exp_trig":
|
|
||||||
integrand = self._generate_polynomial_exp_trig(rng, x)
|
|
||||||
elif parts_type == "log_inverse_trig":
|
|
||||||
integrand = self._generate_log_inverse_trig(rng, x)
|
integrand = self._generate_log_inverse_trig(rng, x)
|
||||||
elif parts_type == "cyclic":
|
elif problem_type == "polynomial_exp_trig":
|
||||||
|
integrand = self._generate_polynomial_exp_trig(rng, x)
|
||||||
|
elif problem_type == "cyclic":
|
||||||
integrand = self._generate_cyclic_integral(rng, x)
|
integrand = self._generate_cyclic_integral(rng, x)
|
||||||
elif parts_type == "repeated_parts":
|
elif problem_type == "repeated_parts":
|
||||||
integrand = self._generate_repeated_parts(rng, x)
|
integrand = self._generate_repeated_parts(rng, x)
|
||||||
|
|
||||||
answer = sympy.integrate(integrand, x)
|
answer = sympy.integrate(integrand, x)
|
||||||
|
|
@ -237,8 +238,10 @@ When performing calculations, please follow these guidelines:
|
||||||
"integrand": str(integrand),
|
"integrand": str(integrand),
|
||||||
"problem_type": problem_type,
|
"problem_type": problem_type,
|
||||||
"variable": str(x),
|
"variable": str(x),
|
||||||
"type": substitution_type if problem_type == "substitution" else parts_type,
|
|
||||||
"expected_answer_expression": answer,
|
"expected_answer_expression": answer,
|
||||||
|
"difficulty": {
|
||||||
|
"problem_type_weights": self.config.problem_type_weights,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -264,4 +267,36 @@ When performing calculations, please follow these guidelines:
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("intermediate_integration", IntermediateIntegrationDataset, IntermediateIntegrationConfig)
|
class IntermediateIntegrationCurriculum(BaseCurriculum):
|
||||||
|
"""Curriculum for intermediate integration problems"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(IntermediateIntegrationCurriculum.__name__, IntermediateIntegrationConfig)
|
||||||
|
self._define_attributes(
|
||||||
|
ScalarAttributeDefinition(
|
||||||
|
name="problem_type_weights",
|
||||||
|
field_name="problem_type_weights",
|
||||||
|
levels=[
|
||||||
|
[1, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 1, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
default_level=0,
|
||||||
|
description="The weights of the problem types",
|
||||||
|
attr_type=AttributeType.STATIC,
|
||||||
|
min_value=[1, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
"intermediate_integration",
|
||||||
|
IntermediateIntegrationDataset,
|
||||||
|
IntermediateIntegrationConfig,
|
||||||
|
IntermediateIntegrationCurriculum,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,18 +9,6 @@ from reasoning_gym.algebra.intermediate_integration import IntermediateIntegrati
|
||||||
|
|
||||||
def test_intermediate_integration_config_validation():
|
def test_intermediate_integration_config_validation():
|
||||||
"""Test that invalid configs raise appropriate errors"""
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = IntermediateIntegrationConfig(problem_types=["invalid_problem_type"])
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = IntermediateIntegrationConfig(substitution_types=["invalid_substitution_type"])
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = IntermediateIntegrationConfig(by_parts_types=["invalid_by_parts_type"])
|
|
||||||
config.validate()
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = IntermediateIntegrationConfig(linear_lower_bound=2, linear_upper_bound=1)
|
config = IntermediateIntegrationConfig(linear_lower_bound=2, linear_upper_bound=1)
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -87,8 +75,6 @@ def test_intermediate_integration_dataset_items():
|
||||||
assert "integrand" in item["metadata"]
|
assert "integrand" in item["metadata"]
|
||||||
assert "problem_type" in item["metadata"]
|
assert "problem_type" in item["metadata"]
|
||||||
assert "variable" in item["metadata"]
|
assert "variable" in item["metadata"]
|
||||||
assert "type" in item["metadata"]
|
|
||||||
|
|
||||||
# verify answer is mathematical expression
|
# verify answer is mathematical expression
|
||||||
answer = item["answer"]
|
answer = item["answer"]
|
||||||
answer = answer.replace(" + C", "")
|
answer = answer.replace(" + C", "")
|
||||||
|
|
@ -143,3 +129,71 @@ def test_score_answer_cases():
|
||||||
dummy_entry = {"metadata": metadata}
|
dummy_entry = {"metadata": metadata}
|
||||||
score = dataset.score_answer(answer, entry=dummy_entry)
|
score = dataset.score_answer(answer, entry=dummy_entry)
|
||||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_intermediate_integration_curriculum():
|
||||||
|
"""Test the IntermediateIntegrationCurriculum functionality"""
|
||||||
|
from reasoning_gym.algebra.intermediate_integration import (
|
||||||
|
IntermediateIntegrationConfig,
|
||||||
|
IntermediateIntegrationCurriculum,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a config for the curriculum
|
||||||
|
config = IntermediateIntegrationConfig(
|
||||||
|
size=150, seed=1, problem_type_weights=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]
|
||||||
|
)
|
||||||
|
|
||||||
|
curriculum = IntermediateIntegrationCurriculum()
|
||||||
|
|
||||||
|
# Test initial configuration
|
||||||
|
base_cfg = curriculum.generate_configuration({})
|
||||||
|
assert base_cfg.problem_type_weights == [1, 0, 0, 0, 0, 0, 0, 0] # Default level 0
|
||||||
|
|
||||||
|
# Test incrementing problem_type_weights attribute
|
||||||
|
curriculum.increment_attr_level("problem_type_weights")
|
||||||
|
level1_cfg = curriculum.generate_configuration({})
|
||||||
|
assert level1_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Level 1
|
||||||
|
|
||||||
|
# Test incrementing problem_type_weights attribute again
|
||||||
|
curriculum.increment_attr_level("problem_type_weights")
|
||||||
|
level2_cfg = curriculum.generate_configuration({})
|
||||||
|
assert level2_cfg.problem_type_weights == [0, 0, 1, 0, 0, 0, 0, 0] # Level 2
|
||||||
|
|
||||||
|
# Test decrementing problem_type_weights attribute
|
||||||
|
curriculum.decrement_attr_level("problem_type_weights")
|
||||||
|
back_to_level1_cfg = curriculum.generate_configuration({})
|
||||||
|
assert back_to_level1_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Back to level 1
|
||||||
|
|
||||||
|
# Test global level adjustments
|
||||||
|
# Reset curriculum
|
||||||
|
curriculum = IntermediateIntegrationCurriculum()
|
||||||
|
assert curriculum.get_attr_level("problem_type_weights") == 0
|
||||||
|
|
||||||
|
# Increase global level
|
||||||
|
curriculum.increment_global_level()
|
||||||
|
assert curriculum.get_attr_level("problem_type_weights") == 1
|
||||||
|
|
||||||
|
global_level_cfg = curriculum.generate_configuration({})
|
||||||
|
assert global_level_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
# Increase global level again
|
||||||
|
curriculum.increment_global_level()
|
||||||
|
assert curriculum.get_attr_level("problem_type_weights") == 2
|
||||||
|
|
||||||
|
global_level_cfg_2 = curriculum.generate_configuration({})
|
||||||
|
assert global_level_cfg_2.problem_type_weights == [0, 0, 1, 0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
# Decrease global level
|
||||||
|
curriculum.decrement_global_level()
|
||||||
|
assert curriculum.get_attr_level("problem_type_weights") == 1
|
||||||
|
|
||||||
|
global_level_cfg_3 = curriculum.generate_configuration({})
|
||||||
|
assert global_level_cfg_3.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
# Test upper bound
|
||||||
|
curriculum = IntermediateIntegrationCurriculum() # Reset curriculum
|
||||||
|
for _ in range(10): # Try going beyond max level (7)
|
||||||
|
curriculum.increment_attr_level("problem_type_weights")
|
||||||
|
|
||||||
|
max_cfg = curriculum.generate_configuration({})
|
||||||
|
assert max_cfg.problem_type_weights == [0, 0, 0, 0, 0, 0, 0, 1] # Should be capped at level 7
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue