diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py index 903cdb99..1175d92f 100644 --- a/reasoning_gym/algebra/complex_arithmetic.py +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -1,9 +1,10 @@ import cmath import math import random -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -14,6 +15,7 @@ class ComplexArithmeticConfig: min_imag: int = -10 max_imag: int = 10 operations: tuple[str, ...] = ("+", "-", "*", "/") + operations_weights: list[float] = field(default_factory=lambda: [0.4, 0.4, 0.1, 0.1]) seed: Optional[int] = None size: int = 500 @@ -22,6 +24,7 @@ class ComplexArithmeticConfig: assert self.max_real >= self.min_real, "max_real must be >= min_real" assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag" assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator" + assert round(sum(self.operations_weights), 1) == 1.0, "operations_weights must sum to 1.0" class ComplexArithmeticDataset(ProceduralDataset): @@ -57,7 +60,7 @@ class ComplexArithmeticDataset(ProceduralDataset): rng = random.Random(self.seed + idx) # Choose random operation - op = rng.choice(self.config.operations) + op = rng.choices(self.config.operations, weights=self.config.operations_weights, k=1)[0] if op == "/": # For division, first generate the quotient (a) and divisor (b) @@ -91,6 +94,13 @@ class ComplexArithmeticDataset(ProceduralDataset): "num2": (b.real, b.imag), "operation": op, "result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers + "difficulty": { + "min_real": self.config.min_real, + "max_real": self.config.max_real, + "min_imag": self.config.min_imag, + "max_imag": self.config.max_imag, + "operations_weights": self.config.operations_weights, + }, }, } @@ -169,4 +179,60 @@ class ComplexArithmeticDataset(ProceduralDataset): return 0.0 -register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) +class ComplexArithmeticCurriculum(BaseCurriculum): + """Curriculum for complex number arithmetic problems.""" + + def __init__(self): + super().__init__(ComplexArithmeticCurriculum.__name__, ComplexArithmeticConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="min_real", + field_name="min_real", + levels=[-10, -100, -10000, -100000000], + default_level=0, + description="Minimum real part for complex numbers", + attr_type=AttributeType.STATIC, + min_value=-10, + ), + ScalarAttributeDefinition( + name="max_real", + field_name="max_real", + levels=[10, 100, 10000, 100000000], + default_level=0, + description="Maximum real part for complex numbers", + attr_type=AttributeType.STATIC, + min_value=10, + ), + ScalarAttributeDefinition( + name="min_imag", + field_name="min_imag", + levels=[-10, -100, -10000, -100000000], + default_level=0, + description="Minimum imaginary part for complex numbers", + attr_type=AttributeType.STATIC, + min_value=-10, + ), + ScalarAttributeDefinition( + name="max_imag", + field_name="max_imag", + levels=[10, 100, 10000, 100000000], + default_level=0, + description="Maximum imaginary part for complex numbers", + attr_type=AttributeType.STATIC, + min_value=10, + ), + ScalarAttributeDefinition( + name="operations_weights", + field_name="operations_weights", + levels=[[0.4, 0.4, 0.1, 0.1], [0.25, 0.25, 0.25, 0.25], [0.2, 0.2, 0.3, 0.3], [0.1, 0.1, 0.4, 0.4]], + default_level=0, + description="Operations weights to sample operation to use for each complex arithmetic problem", + attr_type=AttributeType.STATIC, + min_value=[0.4, 0.4, 0.1, 0.1], + ), + ) + + +register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum) diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py index 14f67bfe..a3e3638a 100644 --- a/tests/test_complex_arithmetic.py +++ b/tests/test_complex_arithmetic.py @@ -1,6 +1,10 @@ import pytest -from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset +from reasoning_gym.algebra.complex_arithmetic import ( + ComplexArithmeticConfig, + ComplexArithmeticCurriculum, + ComplexArithmeticDataset, +) def test_complex_arithmetic_basic(): @@ -81,7 +85,9 @@ def test_complex_arithmetic_scoring(): def test_complex_arithmetic_division_by_zero(): """Test that division by zero is handled properly.""" - config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division + config = ComplexArithmeticConfig( + operations=("+", "-", "*", "/"), operations_weights=[0.0, 0.0, 0.0, 1.0], seed=42 + ) # Only test division dataset = ComplexArithmeticDataset(config) # Check multiple items to ensure no division by zero @@ -131,3 +137,65 @@ def test_parse_string_to_complex(): assert dataset.parse_string_to_complex("invalid") is None assert dataset.parse_string_to_complex("3 + i + 2") is None assert dataset.parse_string_to_complex("3 + 2x") is None + + +def test_complex_arithmetic_curriculum(): + """Test the curriculum for complex arithmetic.""" + curriculum = ComplexArithmeticCurriculum() + base_value = {"size": 150, "seed": 1} + + base_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value) + + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_real == base_cfg.min_imag == -10 + assert base_cfg.max_real == base_cfg.max_imag == 10 + assert base_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1] + + # Increase and validate increase in level + curriculum.increment_attr_level("min_real") + curriculum.increment_attr_level("min_imag") + curriculum.increment_attr_level("max_real") + curriculum.increment_attr_level("max_imag") + curriculum.increment_attr_level("operations_weights") + + increased_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value) + assert increased_cfg.min_real == increased_cfg.min_imag == -100 + assert increased_cfg.max_real == increased_cfg.max_imag == 100 + assert increased_cfg.operations_weights == [0.25, 0.25, 0.25, 0.25] + + # Decrease and validate decrease in level + curriculum.decrement_attr_level("min_real") + curriculum.decrement_attr_level("min_imag") + curriculum.decrement_attr_level("max_real") + curriculum.decrement_attr_level("max_imag") + curriculum.decrement_attr_level("operations_weights") + + decreased_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value) + assert decreased_cfg.min_real == decreased_cfg.min_imag == -10 + assert decreased_cfg.max_real == decreased_cfg.max_imag == 10 + assert decreased_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1] + + # Test upper bound boundary condition + for _ in range(10): + curriculum.increment_attr_level("min_real") + curriculum.increment_attr_level("min_imag") + curriculum.increment_attr_level("max_real") + curriculum.increment_attr_level("max_imag") + curriculum.increment_attr_level("operations_weights") + upper_bound_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value) + assert upper_bound_cfg.min_real == upper_bound_cfg.min_imag == -100000000 + assert upper_bound_cfg.max_real == upper_bound_cfg.max_imag == 100000000 + assert upper_bound_cfg.operations_weights == [0.1, 0.1, 0.4, 0.4] + + # Test lower bound boundary condition + for _ in range(10): + curriculum.decrement_attr_level("min_real") + curriculum.decrement_attr_level("min_imag") + curriculum.decrement_attr_level("max_real") + curriculum.decrement_attr_level("max_imag") + curriculum.decrement_attr_level("operations_weights") + lower_bound_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value) + assert lower_bound_cfg.min_real == lower_bound_cfg.min_imag == -10 + assert lower_bound_cfg.max_real == lower_bound_cfg.max_imag == 10 + assert lower_bound_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]