diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py new file mode 100644 index 00000000..616e9404 --- /dev/null +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -0,0 +1,143 @@ +import cmath +import random +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class ComplexArithmeticConfig: + min_real: int = -10 + max_real: int = 10 + min_imag: int = -10 + max_imag: int = 10 + operations: Tuple[str, ...] = ("+", "-", "*", "/") + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.max_real >= self.min_real, "max_real must be >= min_real" + assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag" + assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator" + + +class ComplexArithmeticDataset(ProceduralDataset): + """Generates complex number arithmetic problems.""" + + def __init__(self, config: ComplexArithmeticConfig): + self._prompt_templates = { + "+": "Add the complex numbers: ({a}) + ({b})", + "-": "Subtract the complex numbers: ({a}) - ({b})", + "*": "Multiply the complex numbers: ({a}) × ({b})", + "/": "Divide the complex numbers: ({a}) ÷ ({b})", + } + super().__init__(config=config, seed=config.seed, size=config.size) + + def _generate_complex(self, rng: random.Random) -> complex: + """Generate a random complex number.""" + real = rng.randint(self.config.min_real, self.config.max_real) + imag = rng.randint(self.config.min_imag, self.config.max_imag) + return complex(real, imag) + + def _format_complex(self, z: complex) -> str: + """Format complex number with 2 decimal places.""" + real, imag = z.real, z.imag + if abs(imag) < 1e-10: + return f"{real:.2f}" + elif abs(real) < 1e-10: + return f"{imag:.2f}i" + else: + sign = "+" if imag >= 0 else "-" + return f"{real:.2f} {sign} {abs(imag):.2f}i" + + def __getitem__(self, idx: int) -> dict: + rng = random.Random(self.seed + idx) + + # Generate two random complex numbers + a = self._generate_complex(rng) + b = self._generate_complex(rng) + + # For division, ensure denominator is not zero + while b == 0: + b = self._generate_complex(rng) + + # Choose random operation + op = rng.choice(self.config.operations) + + # Calculate result + if op == "+": + result = a + b + elif op == "-": + result = a - b + elif op == "*": + result = a * b + else: # op == "/" + result = a / b + + question = self._prompt_templates[op].format(a=self._format_complex(a), b=self._format_complex(b)) + + return { + "question": question, + "answer": self._format_complex(result), + "metadata": { + "num1": (a.real, a.imag), + "num2": (b.real, b.imag), + "operation": op, + "result": (result.real, result.imag), + }, + } + + @staticmethod + def parse_string_to_complex(answer: str) -> complex: + try: + # Normalize the answer string by removing spaces and converting to lowercase + answer = answer.replace(" ", "").lower() + # Convert mathematical notation 'i' to Python's 'j' for complex numbers + answer = answer.replace("i", "j") + + # Handle real numbers (no imaginary part) + if "j" not in answer: + student_result = complex(float(answer)) + else: + # Handle cases like "j" or "2j" (implicit coefficient) + if answer[0] == "j": + # Convert "j" to "1j", "2j" remains unchanged + answer = "1" + answer + # Handle cases like "3j" where there's no explicit + or - before j + elif answer[-1] == "j" and not any(c in answer[:-1] for c in "+-"): + # Convert "3j" to "3+1j" + answer = answer.replace("j", "+1j") + + # Ensure the string has an imaginary part, even if zero + if "j" not in answer: + answer += "+0j" + + # Parse the normalized string into a complex number + student_result = complex(answer) + + except ValueError: + return None + + return student_result + + def score_answer(self, answer: str, metadata: dict) -> float: + """Score the answer using exponential distance-based scoring.""" + if answer is None: + return 0.0 + + try: + student_result = self.parse_string_to_complex(answer) + expected_result = complex(*metadata["result"]) + # Calculate distance-based score using exponential decay + distance = abs(student_result - expected_result) + score = min(1.0, math.exp(-distance)) # Add 'import math' at the top + return score + + except (ValueError, TypeError): + return 0.0 + + +register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py new file mode 100644 index 00000000..0317b93e --- /dev/null +++ b/tests/test_complex_arithmetic.py @@ -0,0 +1,94 @@ +import pytest + +from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset + + +def test_complex_arithmetic_basic(): + """Test basic functionality of complex arithmetic dataset.""" + config = ComplexArithmeticConfig( + min_real=-5, max_real=5, min_imag=-5, max_imag=5, operations=("+", "-", "*", "/"), seed=42, size=10 + ) + dataset = ComplexArithmeticDataset(config) + + print(dataset) + + # Test dataset size + assert len(dataset) == 10 + + # Test a specific item + item = dataset[0] + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Add more detailed assertions + assert isinstance(item["question"], str) + assert isinstance(item["answer"], str) + assert isinstance(item["metadata"], dict) + + # Check metadata structure + assert "num1" in item["metadata"] + assert "num2" in item["metadata"] + assert "operation" in item["metadata"] + assert "result" in item["metadata"] + + # Check data types in metadata + assert isinstance(item["metadata"]["num1"], tuple) + assert isinstance(item["metadata"]["num2"], tuple) + assert len(item["metadata"]["num1"]) == 2 # Real and imaginary parts + assert len(item["metadata"]["num2"]) == 2 + assert isinstance(item["metadata"]["operation"], str) + assert isinstance(item["metadata"]["result"], tuple) + + # Make sure answer matches the result in metadata + # results is a tuple of two floats (real, imag) and answer is a string + # answer is formatted as "real + imagi" + assert ComplexArithmeticDataset.parse_string_to_complex(item["answer"]) == complex(*item["metadata"]["result"]) + + with open("complex_arithmetic_dataset.txt", "w") as f: + for item in dataset: + f.write(str(item) + "\n") + + +def test_complex_arithmetic_scoring(): + """Test scoring function with various answer formats and accuracies.""" + config = ComplexArithmeticConfig(seed=42) + dataset = ComplexArithmeticDataset(config) + + # Test case with answer 3 + 2i + metadata = {"result": (3.0, 2.0)} + + # Test exact matches (should get score of 1.0) + assert dataset.score_answer("3 + 2i", metadata) == 1.0 + assert dataset.score_answer("3+2i", metadata) == 1.0 + assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0 + + # Test answers with small errors (should get high but < 1.0 scores) + print(dataset.score_answer("3.1 + 2i", metadata)) + assert 0.9 < dataset.score_answer("3.1 + 2i", metadata) < 1.0 + assert 0.9 < dataset.score_answer("3 + 2.1i", metadata) < 1.0 + assert 0.7 < dataset.score_answer("3.1 + 2.1i", metadata) < 0.95 + + # Test answers with moderate errors (should get medium scores) + assert 0.3 < dataset.score_answer("4 + 2i", metadata) < 0.4 + assert 0.3 < dataset.score_answer("3 + 3i", metadata) < 0.4 + + # Test answers with large errors (should get very low scores) + assert dataset.score_answer("10 + 10i", metadata) < 0.01 + + # Test invalid answers (should get 0.0) + assert dataset.score_answer("invalid", metadata) == 0.0 + assert dataset.score_answer(None, metadata) == 0.0 + assert dataset.score_answer("inf + 2i", metadata) == 0.0 + + +def test_complex_arithmetic_division_by_zero(): + """Test that division by zero is handled properly.""" + config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division + dataset = ComplexArithmeticDataset(config) + + # Check multiple items to ensure no division by zero + for i in range(10): + item = dataset[i] + num2 = complex(*item["metadata"]["num2"]) + assert num2 != 0