diff --git a/README.md b/README.md index 2107b030..a7301f5f 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `LegCountingDataset`: Generate animal leg counting word problems with various animals - `PrimeFactorizationDataset`: Generate prime factorization tasks with configurable number ranges - `TimeIntervalsDataset`: Generate time interval calculation tasks with various formats (time, date, datetime) and complexities +- `ComplexArithmeticDataset`: Generate complex arithmetic problems with configurable number of integers ### Algorithmic Tasks diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index fc7a867a..68a6ad20 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -2,7 +2,7 @@ from .intermediate_integration import IntermediateIntegrationConfig, Intermediat from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset from .simple_integration import SimpleIntegrationConfig, SimpleIntegrationDataset - +from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset __all__ = [ "IntermediateIntegrationConfig", "IntermediateIntegrationDataset", @@ -12,4 +12,6 @@ __all__ = [ "SimpleEquationsConfig", "SimpleIntegrationConfig", "SimpleIntegrationDataset", + "ComplexArithmeticConfig", + "ComplexArithmeticDataset", ] diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py new file mode 100644 index 00000000..b8b17cb1 --- /dev/null +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -0,0 +1,137 @@ +import random +from dataclasses import dataclass +from typing import Optional, Tuple +import cmath + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class ComplexArithmeticConfig: + min_real: int = -10 + max_real: int = 10 + min_imag: int = -10 + max_imag: int = 10 + operations: Tuple[str, ...] = ("+", "-", "*", "/") + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.max_real >= self.min_real, "max_real must be >= min_real" + assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag" + assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator" + + +class ComplexArithmeticDataset(ProceduralDataset): + """Generates complex number arithmetic problems.""" + + def __init__(self, config: ComplexArithmeticConfig): + self._prompt_templates = { + "+": "Add the complex numbers: ({a}) + ({b})", + "-": "Subtract the complex numbers: ({a}) - ({b})", + "*": "Multiply the complex numbers: ({a}) × ({b})", + "/": "Divide the complex numbers: ({a}) ÷ ({b})", + } + super().__init__(config=config, seed=config.seed, size=config.size) + + def _generate_complex(self, rng: random.Random) -> complex: + """Generate a random complex number.""" + real = rng.randint(self.config.min_real, self.config.max_real) + imag = rng.randint(self.config.min_imag, self.config.max_imag) + return complex(real, imag) + + def _format_complex(self, z: complex) -> str: + """Format complex number for display.""" + real, imag = z.real, z.imag + if imag == 0: + return f"{real:.0f}" + elif real == 0: + return f"{imag:.0f}i" + else: + sign = "+" if imag >= 0 else "-" + return f"{real:.0f} {sign} {abs(imag):.0f}i" + + def __getitem__(self, idx: int) -> dict: + rng = random.Random(self.seed + idx) + + # Generate two random complex numbers + a = self._generate_complex(rng) + b = self._generate_complex(rng) + + # For division, ensure denominator is not zero + while b == 0: + b = self._generate_complex(rng) + + # Choose random operation + op = rng.choice(self.config.operations) + + # Calculate result + if op == "+": + result = a + b + elif op == "-": + result = a - b + elif op == "*": + result = a * b + else: # op == "/" + result = a / b + + question = self._prompt_templates[op].format( + a=self._format_complex(a), + b=self._format_complex(b) + ) + + return { + "question": question, + "answer": self._format_complex(result), + "metadata": { + "num1": (a.real, a.imag), + "num2": (b.real, b.imag), + "operation": op, + "result": (result.real, result.imag), + }, + } + + def score_answer(self, answer: str, metadata: dict) -> float: + """Score the answer, allowing for minor formatting differences.""" + if answer is None: + return 0.0 + + try: + # Convert the expected result from metadata + expected_result = complex(*metadata["result"]) + + # Parse student answer + # Remove spaces and convert to lowercase + answer = answer.replace(" ", "").lower() + + # Handle different forms of writing complex numbers + if "i" not in answer and "j" not in answer: + # Pure real number + return abs(complex(float(answer)) - expected_result) < 1e-10 + + # Replace 'i' with 'j' for Python's complex number notation + answer = answer.replace('i', 'j') + + # Handle cases like "2j" (add plus sign) + if answer[0] == 'j': + answer = '1' + answer + elif answer[-1] == 'j' and not any(c in answer[:-1] for c in '+-'): + answer = answer.replace('j', '+1j') + + # Add missing real or imaginary parts + if 'j' not in answer: + answer += '+0j' + + # Parse the answer string into a complex number + student_result = complex(answer) + + # Check if the results are close enough (allowing for minor floating-point differences) + return float(abs(student_result - expected_result) < 1e-10) + + except (ValueError, TypeError): + # If there's any error in parsing the answer + return 0.0 + + +register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) \ No newline at end of file diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py new file mode 100644 index 00000000..94ef9490 --- /dev/null +++ b/tests/test_complex_arithmetic.py @@ -0,0 +1,89 @@ +import pytest +from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset + + +def test_complex_arithmetic_basic(): + """Test basic functionality of complex arithmetic dataset.""" + config = ComplexArithmeticConfig( + min_real=-5, + max_real=5, + min_imag=-5, + max_imag=5, + operations=("+", "-", "*", "/"), + seed=42, + size=10 + ) + dataset = ComplexArithmeticDataset(config) + + print(dataset) + + # Test dataset size + assert len(dataset) == 10 + + # Test a specific item + item = dataset[0] + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Add more detailed assertions + assert isinstance(item["question"], str) + assert isinstance(item["answer"], str) + assert isinstance(item["metadata"], dict) + + # Check metadata structure + assert "num1" in item["metadata"] + assert "num2" in item["metadata"] + assert "operation" in item["metadata"] + assert "result" in item["metadata"] + + # Check data types in metadata + assert isinstance(item["metadata"]["num1"], tuple) + assert isinstance(item["metadata"]["num2"], tuple) + assert len(item["metadata"]["num1"]) == 2 # Real and imaginary parts + assert len(item["metadata"]["num2"]) == 2 + assert isinstance(item["metadata"]["operation"], str) + assert isinstance(item["metadata"]["result"], tuple) + + # dump dataset into a text file + with open("complex_arithmetic_dataset.txt", "w") as f: + for item in dataset: + f.write(str(item) + "\n") + + +def test_complex_arithmetic_scoring(): + """Test scoring function with various answer formats.""" + config = ComplexArithmeticConfig(seed=42) + dataset = ComplexArithmeticDataset(config) + + # Create a test case with known answer + metadata = { + "result": (3.0, 2.0) # represents 3 + 2i + } + + # Test various correct answer formats + assert dataset.score_answer("3 + 2i", metadata) == 1.0 + assert dataset.score_answer("3+2i", metadata) == 1.0 + assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0 + + # Test incorrect answers + assert dataset.score_answer("2 + 3i", metadata) == 0.0 + assert dataset.score_answer("3", metadata) == 0.0 + assert dataset.score_answer("inf + 2i", metadata) == 0.0 + assert dataset.score_answer("2i", metadata) == 0.0 + assert dataset.score_answer("invalid", metadata) == 0.0 + + +def test_complex_arithmetic_division_by_zero(): + """Test that division by zero is handled properly.""" + config = ComplexArithmeticConfig( + operations=("/",), # Only test division + seed=42 + ) + dataset = ComplexArithmeticDataset(config) + + # Check multiple items to ensure no division by zero + for i in range(10): + item = dataset[i] + num2 = complex(*item["metadata"]["num2"]) + assert num2 != 0 \ No newline at end of file