diff --git a/tests/test_intermediate_integration.py b/tests/test_intermediate_integration.py new file mode 100644 index 00000000..fc35f387 --- /dev/null +++ b/tests/test_intermediate_integration.py @@ -0,0 +1,109 @@ +"""Tests for intermediate integration task generation""" + +import pytest +import sympy +from sympy.parsing.sympy_parser import parse_expr + +from reasoning_gym.algebra.intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset + + +def test_intermediate_integration_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(problem_types=["invalid_problem_type"]) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(substitution_types=["invalid_substitution_type"]) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(by_parts_types=["invalid_by_parts_type"]) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(linear_lower_bound=2, linear_upper_bound=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(linear_lower_bound=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_linear_degree=5, max_linear_degree=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_linear_degree=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(outer_constant_min=5, outer_constant_max=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(outer_constant_min=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_poly_degree=5, max_poly_degree=1) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(min_poly_degree=0) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(symbols=("x", "y")) + config.validate() + + with pytest.raises(AssertionError): + config = IntermediateIntegrationConfig(operators=("+", "-", "*", "/")) + config.validate() + + +def test_intermediate_integration_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = IntermediateIntegrationConfig(seed=42, size=10) + dataset1 = IntermediateIntegrationDataset(config) + dataset2 = IntermediateIntegrationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_intermediate_integration_dataset_items(): + """Test that dataset items are valid""" + config = IntermediateIntegrationConfig(seed=42, size=10) + dataset = IntermediateIntegrationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + assert "integrand" in item["metadata"] + assert "problem_type" in item["metadata"] + assert "variable" in item["metadata"] + assert "type" in item["metadata"] + + # verify answer is mathematical expression + answer = item["answer"] + answer = answer.replace(" + C", "") + assert isinstance(parse_expr(answer), sympy.Expr) + + +def test_solution_verification(): + """Test for solution verification of each answer""" + config = IntermediateIntegrationConfig(seed=42, size=10) + dataset = IntermediateIntegrationDataset(config) + + for item in dataset: + integrand = parse_expr(item["metadata"]["integrand"]) + variable = sympy.Symbol(item["metadata"]["variable"]) + answer = parse_expr(item["answer"].replace(" + C", "")) + + # Verify that the derivative of the answer equals the integrand + assert sympy.simplify(sympy.diff(answer, variable) - integrand) == 0