diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index 68a6ad20..6f40b106 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,8 +1,9 @@ +from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset -from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset + __all__ = [ "IntermediateIntegrationConfig", "IntermediateIntegrationDataset", diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py index b8b17cb1..a90aaba4 100644 --- a/reasoning_gym/algebra/complex_arithmetic.py +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -1,7 +1,7 @@ +import cmath import random from dataclasses import dataclass from typing import Optional, Tuple -import cmath from ..factory import ProceduralDataset, register_dataset @@ -54,18 +54,18 @@ class ComplexArithmeticDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: rng = random.Random(self.seed + idx) - + # Generate two random complex numbers a = self._generate_complex(rng) b = self._generate_complex(rng) - + # For division, ensure denominator is not zero while b == 0: b = self._generate_complex(rng) # Choose random operation op = rng.choice(self.config.operations) - + # Calculate result if op == "+": result = a + b @@ -76,10 +76,7 @@ class ComplexArithmeticDataset(ProceduralDataset): else: # op == "/" result = a / b - question = self._prompt_templates[op].format( - a=self._format_complex(a), - b=self._format_complex(b) - ) + question = self._prompt_templates[op].format(a=self._format_complex(a), b=self._format_complex(b)) return { "question": question, @@ -100,32 +97,32 @@ class ComplexArithmeticDataset(ProceduralDataset): try: # Convert the expected result from metadata expected_result = complex(*metadata["result"]) - + # Parse student answer # Remove spaces and convert to lowercase answer = answer.replace(" ", "").lower() - + # Handle different forms of writing complex numbers if "i" not in answer and "j" not in answer: # Pure real number return abs(complex(float(answer)) - expected_result) < 1e-10 # Replace 'i' with 'j' for Python's complex number notation - answer = answer.replace('i', 'j') - + answer = answer.replace("i", "j") + # Handle cases like "2j" (add plus sign) - if answer[0] == 'j': - answer = '1' + answer - elif answer[-1] == 'j' and not any(c in answer[:-1] for c in '+-'): - answer = answer.replace('j', '+1j') - + if answer[0] == "j": + answer = "1" + answer + elif answer[-1] == "j" and not any(c in answer[:-1] for c in "+-"): + answer = answer.replace("j", "+1j") + # Add missing real or imaginary parts - if 'j' not in answer: - answer += '+0j' - + if "j" not in answer: + answer += "+0j" + # Parse the answer string into a complex number student_result = complex(answer) - + # Check if the results are close enough (allowing for minor floating-point differences) return float(abs(student_result - expected_result) < 1e-10) @@ -134,4 +131,4 @@ class ComplexArithmeticDataset(ProceduralDataset): return 0.0 -register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) \ No newline at end of file +register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py index 94ef9490..2e562c26 100644 --- a/tests/test_complex_arithmetic.py +++ b/tests/test_complex_arithmetic.py @@ -1,25 +1,20 @@ import pytest + from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset def test_complex_arithmetic_basic(): """Test basic functionality of complex arithmetic dataset.""" config = ComplexArithmeticConfig( - min_real=-5, - max_real=5, - min_imag=-5, - max_imag=5, - operations=("+", "-", "*", "/"), - seed=42, - size=10 + min_real=-5, max_real=5, min_imag=-5, max_imag=5, operations=("+", "-", "*", "/"), seed=42, size=10 ) dataset = ComplexArithmeticDataset(config) - + print(dataset) # Test dataset size assert len(dataset) == 10 - + # Test a specific item item = dataset[0] assert "question" in item @@ -30,13 +25,13 @@ def test_complex_arithmetic_basic(): assert isinstance(item["question"], str) assert isinstance(item["answer"], str) assert isinstance(item["metadata"], dict) - + # Check metadata structure assert "num1" in item["metadata"] assert "num2" in item["metadata"] assert "operation" in item["metadata"] assert "result" in item["metadata"] - + # Check data types in metadata assert isinstance(item["metadata"]["num1"], tuple) assert isinstance(item["metadata"]["num2"], tuple) @@ -45,7 +40,7 @@ def test_complex_arithmetic_basic(): assert isinstance(item["metadata"]["operation"], str) assert isinstance(item["metadata"]["result"], tuple) - # dump dataset into a text file + # dump dataset into a text file with open("complex_arithmetic_dataset.txt", "w") as f: for item in dataset: f.write(str(item) + "\n") @@ -55,17 +50,15 @@ def test_complex_arithmetic_scoring(): """Test scoring function with various answer formats.""" config = ComplexArithmeticConfig(seed=42) dataset = ComplexArithmeticDataset(config) - + # Create a test case with known answer - metadata = { - "result": (3.0, 2.0) # represents 3 + 2i - } - + metadata = {"result": (3.0, 2.0)} # represents 3 + 2i + # Test various correct answer formats assert dataset.score_answer("3 + 2i", metadata) == 1.0 assert dataset.score_answer("3+2i", metadata) == 1.0 assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0 - + # Test incorrect answers assert dataset.score_answer("2 + 3i", metadata) == 0.0 assert dataset.score_answer("3", metadata) == 0.0 @@ -76,14 +69,11 @@ def test_complex_arithmetic_scoring(): def test_complex_arithmetic_division_by_zero(): """Test that division by zero is handled properly.""" - config = ComplexArithmeticConfig( - operations=("/",), # Only test division - seed=42 - ) + config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division dataset = ComplexArithmeticDataset(config) - + # Check multiple items to ensure no division by zero for i in range(10): item = dataset[i] num2 = complex(*item["metadata"]["num2"]) - assert num2 != 0 \ No newline at end of file + assert num2 != 0