diff --git a/README.md b/README.md index a7301f5f..2107b030 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,6 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `LegCountingDataset`: Generate animal leg counting word problems with various animals - `PrimeFactorizationDataset`: Generate prime factorization tasks with configurable number ranges - `TimeIntervalsDataset`: Generate time interval calculation tasks with various formats (time, date, datetime) and complexities -- `ComplexArithmeticDataset`: Generate complex arithmetic problems with configurable number of integers ### Algorithmic Tasks diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index 6f40b106..fc7a867a 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,4 +1,3 @@ -from .complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset from .intermediate_integration import IntermediateIntegrationConfig, IntermediateIntegrationDataset from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset @@ -13,6 +12,4 @@ __all__ = [ "SimpleEquationsConfig", "SimpleIntegrationConfig", "SimpleIntegrationDataset", - "ComplexArithmeticConfig", - "ComplexArithmeticDataset", ] diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py deleted file mode 100644 index a90aaba4..00000000 --- a/reasoning_gym/algebra/complex_arithmetic.py +++ /dev/null @@ -1,134 +0,0 @@ -import cmath -import random -from dataclasses import dataclass -from typing import Optional, Tuple - -from ..factory import ProceduralDataset, register_dataset - - -@dataclass -class ComplexArithmeticConfig: - min_real: int = -10 - max_real: int = 10 - min_imag: int = -10 - max_imag: int = 10 - operations: Tuple[str, ...] = ("+", "-", "*", "/") - seed: Optional[int] = None - size: int = 500 - - def validate(self) -> None: - """Validate configuration parameters.""" - assert self.max_real >= self.min_real, "max_real must be >= min_real" - assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag" - assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator" - - -class ComplexArithmeticDataset(ProceduralDataset): - """Generates complex number arithmetic problems.""" - - def __init__(self, config: ComplexArithmeticConfig): - self._prompt_templates = { - "+": "Add the complex numbers: ({a}) + ({b})", - "-": "Subtract the complex numbers: ({a}) - ({b})", - "*": "Multiply the complex numbers: ({a}) × ({b})", - "/": "Divide the complex numbers: ({a}) ÷ ({b})", - } - super().__init__(config=config, seed=config.seed, size=config.size) - - def _generate_complex(self, rng: random.Random) -> complex: - """Generate a random complex number.""" - real = rng.randint(self.config.min_real, self.config.max_real) - imag = rng.randint(self.config.min_imag, self.config.max_imag) - return complex(real, imag) - - def _format_complex(self, z: complex) -> str: - """Format complex number for display.""" - real, imag = z.real, z.imag - if imag == 0: - return f"{real:.0f}" - elif real == 0: - return f"{imag:.0f}i" - else: - sign = "+" if imag >= 0 else "-" - return f"{real:.0f} {sign} {abs(imag):.0f}i" - - def __getitem__(self, idx: int) -> dict: - rng = random.Random(self.seed + idx) - - # Generate two random complex numbers - a = self._generate_complex(rng) - b = self._generate_complex(rng) - - # For division, ensure denominator is not zero - while b == 0: - b = self._generate_complex(rng) - - # Choose random operation - op = rng.choice(self.config.operations) - - # Calculate result - if op == "+": - result = a + b - elif op == "-": - result = a - b - elif op == "*": - result = a * b - else: # op == "/" - result = a / b - - question = self._prompt_templates[op].format(a=self._format_complex(a), b=self._format_complex(b)) - - return { - "question": question, - "answer": self._format_complex(result), - "metadata": { - "num1": (a.real, a.imag), - "num2": (b.real, b.imag), - "operation": op, - "result": (result.real, result.imag), - }, - } - - def score_answer(self, answer: str, metadata: dict) -> float: - """Score the answer, allowing for minor formatting differences.""" - if answer is None: - return 0.0 - - try: - # Convert the expected result from metadata - expected_result = complex(*metadata["result"]) - - # Parse student answer - # Remove spaces and convert to lowercase - answer = answer.replace(" ", "").lower() - - # Handle different forms of writing complex numbers - if "i" not in answer and "j" not in answer: - # Pure real number - return abs(complex(float(answer)) - expected_result) < 1e-10 - - # Replace 'i' with 'j' for Python's complex number notation - answer = answer.replace("i", "j") - - # Handle cases like "2j" (add plus sign) - if answer[0] == "j": - answer = "1" + answer - elif answer[-1] == "j" and not any(c in answer[:-1] for c in "+-"): - answer = answer.replace("j", "+1j") - - # Add missing real or imaginary parts - if "j" not in answer: - answer += "+0j" - - # Parse the answer string into a complex number - student_result = complex(answer) - - # Check if the results are close enough (allowing for minor floating-point differences) - return float(abs(student_result - expected_result) < 1e-10) - - except (ValueError, TypeError): - # If there's any error in parsing the answer - return 0.0 - - -register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py deleted file mode 100644 index 2e562c26..00000000 --- a/tests/test_complex_arithmetic.py +++ /dev/null @@ -1,79 +0,0 @@ -import pytest - -from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset - - -def test_complex_arithmetic_basic(): - """Test basic functionality of complex arithmetic dataset.""" - config = ComplexArithmeticConfig( - min_real=-5, max_real=5, min_imag=-5, max_imag=5, operations=("+", "-", "*", "/"), seed=42, size=10 - ) - dataset = ComplexArithmeticDataset(config) - - print(dataset) - - # Test dataset size - assert len(dataset) == 10 - - # Test a specific item - item = dataset[0] - assert "question" in item - assert "answer" in item - assert "metadata" in item - - # Add more detailed assertions - assert isinstance(item["question"], str) - assert isinstance(item["answer"], str) - assert isinstance(item["metadata"], dict) - - # Check metadata structure - assert "num1" in item["metadata"] - assert "num2" in item["metadata"] - assert "operation" in item["metadata"] - assert "result" in item["metadata"] - - # Check data types in metadata - assert isinstance(item["metadata"]["num1"], tuple) - assert isinstance(item["metadata"]["num2"], tuple) - assert len(item["metadata"]["num1"]) == 2 # Real and imaginary parts - assert len(item["metadata"]["num2"]) == 2 - assert isinstance(item["metadata"]["operation"], str) - assert isinstance(item["metadata"]["result"], tuple) - - # dump dataset into a text file - with open("complex_arithmetic_dataset.txt", "w") as f: - for item in dataset: - f.write(str(item) + "\n") - - -def test_complex_arithmetic_scoring(): - """Test scoring function with various answer formats.""" - config = ComplexArithmeticConfig(seed=42) - dataset = ComplexArithmeticDataset(config) - - # Create a test case with known answer - metadata = {"result": (3.0, 2.0)} # represents 3 + 2i - - # Test various correct answer formats - assert dataset.score_answer("3 + 2i", metadata) == 1.0 - assert dataset.score_answer("3+2i", metadata) == 1.0 - assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0 - - # Test incorrect answers - assert dataset.score_answer("2 + 3i", metadata) == 0.0 - assert dataset.score_answer("3", metadata) == 0.0 - assert dataset.score_answer("inf + 2i", metadata) == 0.0 - assert dataset.score_answer("2i", metadata) == 0.0 - assert dataset.score_answer("invalid", metadata) == 0.0 - - -def test_complex_arithmetic_division_by_zero(): - """Test that division by zero is handled properly.""" - config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division - dataset = ComplexArithmeticDataset(config) - - # Check multiple items to ensure no division by zero - for i in range(10): - item = dataset[i] - num2 = complex(*item["metadata"]["num2"]) - assert num2 != 0