diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py index 9e5c9528..5d0b139c 100644 --- a/reasoning_gym/algebra/intermediate_integration.py +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -1,6 +1,6 @@ import random from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional import sympy @@ -221,16 +221,43 @@ class IntermediateIntegrationDataset(ProceduralDataset): integrand = self._generate_repeated_parts(rng, x) answer = sympy.integrate(integrand, x) + answer_str = str(answer) + " + C" + return { "question": rng.choice(self.prompt_template).format(integrand=integrand), - "answer": str(answer) + " + C", + "answer": answer_str, "metadata": { "integrand": str(integrand), "problem_type": problem_type, "variable": str(x), "type": substitution_type if problem_type == "substitution" else parts_type, + "expected_answer_expression": answer, }, } + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + var = metadata["variable"] + x = sympy.Symbol(var) + # Parse answer while allowing integration constant 'C' + user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")}) + # Compute derivative of student's answer + derivative = sympy.diff(user_expr, x) + integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x}) + + # Check mathematical equivalence through simplification + if sympy.simplify(derivative - integrand) == 0: + reward = 1.0 + elif answer.strip(): + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + register_dataset("intermediate_integration", IntermediateIntegrationDataset, IntermediateIntegrationConfig) diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py index e524e3ef..1da32004 100644 --- a/reasoning_gym/algebra/simple_integration.py +++ b/reasoning_gym/algebra/simple_integration.py @@ -1,7 +1,7 @@ import random from dataclasses import dataclass from fractions import Fraction -from typing import Optional +from typing import Any, Dict, Optional import sympy @@ -73,8 +73,36 @@ class SimpleIntegrationDataset(ProceduralDataset): return { "question": rng.choice(self._prompt_templates).format(integrand=derivative), "answer": str(polynomial) + " + C", - "metadata": {"integrand": str(derivative), "variable": str(symbol), "antiderivative": str(polynomial)}, + "metadata": { + "integrand": str(derivative), + "variable": str(symbol), + "expected_answer_expression": polynomial, + }, } + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + var = metadata["variable"] + x = sympy.Symbol(var) + # Parse answer while allowing integration constant 'C' + user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")}) + # Compute derivative of student's answer + derivative = sympy.diff(user_expr, x) + integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x}) + + # Check mathematical equivalence through simplification + if sympy.simplify(derivative - integrand) == 0: + reward = 1.0 + elif answer.strip(): + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + register_dataset("simple_integration", SimpleIntegrationDataset, SimpleIntegrationConfig) diff --git a/tests/test_intermediate_integration.py b/tests/test_intermediate_integration.py index fc35f387..df62ea76 100644 --- a/tests/test_intermediate_integration.py +++ b/tests/test_intermediate_integration.py @@ -95,15 +95,50 @@ def test_intermediate_integration_dataset_items(): assert isinstance(parse_expr(answer), sympy.Expr) -def test_solution_verification(): - """Test for solution verification of each answer""" - config = IntermediateIntegrationConfig(seed=42, size=10) +def test_verify_answer(): + config = IntermediateIntegrationConfig(seed=42) dataset = IntermediateIntegrationDataset(config) + for i in range(len(dataset)): + item = dataset[i] + score = dataset.score_answer(item["answer"], item["metadata"]) + assert score == 1.0 - for item in dataset: - integrand = parse_expr(item["metadata"]["integrand"]) - variable = sympy.Symbol(item["metadata"]["variable"]) - answer = parse_expr(item["answer"].replace(" + C", "")) - # Verify that the derivative of the answer equals the integrand - assert sympy.simplify(sympy.diff(answer, variable) - integrand) == 0 +def test_score_answer_cases(): + """Test various answer scoring scenarios""" + config = IntermediateIntegrationConfig(seed=42) + dataset = IntermediateIntegrationDataset(config) + x = sympy.Symbol("x") + X = sympy.Symbol("X") + + # Test cases: (answer, metadata, expected_score) + test_cases = [ + # Correct answers + ("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0), + ("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0), + ("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Correct without explicit constant + ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), + ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), + # Incorrect but properly formatted + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + # Malformed expressions + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + # Empty answer + ("", {"variable": "x", "integrand": "2*x"}, 0.01), + # Case sensitivity + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + # Alternative constant notation + ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), + ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Simplification required + ("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0), + ("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0), + ] + + for answer, metadata, expected in test_cases: + score = dataset.score_answer(answer, metadata) + assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}" diff --git a/tests/test_simple_integration.py b/tests/test_simple_integration.py index 8d64cc25..0de8ab36 100644 --- a/tests/test_simple_integration.py +++ b/tests/test_simple_integration.py @@ -1,6 +1,3 @@ -import random -from fractions import Fraction - import pytest import sympy from sympy.parsing.sympy_parser import parse_expr @@ -63,7 +60,7 @@ def test_simple_integration_dataset_items(): assert "integrand" in item["metadata"] assert "variable" in item["metadata"] - assert "antiderivative" in item["metadata"] + assert "expected_answer_expression" in item["metadata"] # Verify answer is a mathematical expression answer = item["answer"] @@ -71,15 +68,50 @@ def test_simple_integration_dataset_items(): assert isinstance(parse_expr(answer), sympy.Expr) -def test_simple_integration_solution_verification(): - """Test for solution verification of each answer""" - config = SimpleIntegrationConfig(seed=42, size=10) +def test_verify_answer(): + config = SimpleIntegrationConfig(seed=42) dataset = SimpleIntegrationDataset(config) + for i in range(len(dataset)): + item = dataset[i] + score = dataset.score_answer(item["answer"], item["metadata"]) + assert score == 1.0 - for item in dataset: - integrand = parse_expr(item["metadata"]["integrand"]) - variable = sympy.Symbol(item["metadata"]["variable"]) - answer = parse_expr(item["answer"].replace(" + C", "")) - # Verify that the derivative of the answer equals the integrand - assert sympy.simplify(sympy.diff(answer, variable) - integrand) == 0 +def test_score_answer_cases(): + """Test various answer scoring scenarios""" + config = SimpleIntegrationConfig(seed=42) + dataset = SimpleIntegrationDataset(config) + x = sympy.Symbol("x") + X = sympy.Symbol("X") + + # Test cases: (answer, metadata, expected_score) + test_cases = [ + # Correct answers + ("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0), + ("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0), + ("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Correct without explicit constant + ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), + ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), + # Incorrect but properly formatted + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + # Malformed expressions + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + # Empty answer + ("", {"variable": "x", "integrand": "2*x"}, 0.01), + # Case sensitivity + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + # Alternative constant notation + ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), + ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), + # Simplification required + ("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0), + ("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0), + ] + + for answer, metadata, expected in test_cases: + score = dataset.score_answer(answer, metadata) + assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"