Add complex arithmetic curriculum (#310)

* Add complex arithmetic curriculum
This commit is contained in:
Adefioye 2025-03-09 18:28:51 -05:00 committed by GitHub
parent 9bd4f03dbd
commit 841663cc5a
2 changed files with 139 additions and 5 deletions

View file

@ -1,9 +1,10 @@
import cmath import cmath
import math import math
import random import random
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -14,6 +15,7 @@ class ComplexArithmeticConfig:
min_imag: int = -10 min_imag: int = -10
max_imag: int = 10 max_imag: int = 10
operations: tuple[str, ...] = ("+", "-", "*", "/") operations: tuple[str, ...] = ("+", "-", "*", "/")
operations_weights: list[float] = field(default_factory=lambda: [0.4, 0.4, 0.1, 0.1])
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
@ -22,6 +24,7 @@ class ComplexArithmeticConfig:
assert self.max_real >= self.min_real, "max_real must be >= min_real" assert self.max_real >= self.min_real, "max_real must be >= min_real"
assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag" assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag"
assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator" assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator"
assert round(sum(self.operations_weights), 1) == 1.0, "operations_weights must sum to 1.0"
class ComplexArithmeticDataset(ProceduralDataset): class ComplexArithmeticDataset(ProceduralDataset):
@ -57,7 +60,7 @@ class ComplexArithmeticDataset(ProceduralDataset):
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
# Choose random operation # Choose random operation
op = rng.choice(self.config.operations) op = rng.choices(self.config.operations, weights=self.config.operations_weights, k=1)[0]
if op == "/": if op == "/":
# For division, first generate the quotient (a) and divisor (b) # For division, first generate the quotient (a) and divisor (b)
@ -91,6 +94,13 @@ class ComplexArithmeticDataset(ProceduralDataset):
"num2": (b.real, b.imag), "num2": (b.real, b.imag),
"operation": op, "operation": op,
"result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers "result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers
"difficulty": {
"min_real": self.config.min_real,
"max_real": self.config.max_real,
"min_imag": self.config.min_imag,
"max_imag": self.config.max_imag,
"operations_weights": self.config.operations_weights,
},
}, },
} }
@ -169,4 +179,60 @@ class ComplexArithmeticDataset(ProceduralDataset):
return 0.0 return 0.0
register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig) class ComplexArithmeticCurriculum(BaseCurriculum):
"""Curriculum for complex number arithmetic problems."""
def __init__(self):
super().__init__(ComplexArithmeticCurriculum.__name__, ComplexArithmeticConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="min_real",
field_name="min_real",
levels=[-10, -100, -10000, -100000000],
default_level=0,
description="Minimum real part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=-10,
),
ScalarAttributeDefinition(
name="max_real",
field_name="max_real",
levels=[10, 100, 10000, 100000000],
default_level=0,
description="Maximum real part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=10,
),
ScalarAttributeDefinition(
name="min_imag",
field_name="min_imag",
levels=[-10, -100, -10000, -100000000],
default_level=0,
description="Minimum imaginary part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=-10,
),
ScalarAttributeDefinition(
name="max_imag",
field_name="max_imag",
levels=[10, 100, 10000, 100000000],
default_level=0,
description="Maximum imaginary part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=10,
),
ScalarAttributeDefinition(
name="operations_weights",
field_name="operations_weights",
levels=[[0.4, 0.4, 0.1, 0.1], [0.25, 0.25, 0.25, 0.25], [0.2, 0.2, 0.3, 0.3], [0.1, 0.1, 0.4, 0.4]],
default_level=0,
description="Operations weights to sample operation to use for each complex arithmetic problem",
attr_type=AttributeType.STATIC,
min_value=[0.4, 0.4, 0.1, 0.1],
),
)
register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum)

View file

@ -1,6 +1,10 @@
import pytest import pytest
from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset from reasoning_gym.algebra.complex_arithmetic import (
ComplexArithmeticConfig,
ComplexArithmeticCurriculum,
ComplexArithmeticDataset,
)
def test_complex_arithmetic_basic(): def test_complex_arithmetic_basic():
@ -81,7 +85,9 @@ def test_complex_arithmetic_scoring():
def test_complex_arithmetic_division_by_zero(): def test_complex_arithmetic_division_by_zero():
"""Test that division by zero is handled properly.""" """Test that division by zero is handled properly."""
config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division config = ComplexArithmeticConfig(
operations=("+", "-", "*", "/"), operations_weights=[0.0, 0.0, 0.0, 1.0], seed=42
) # Only test division
dataset = ComplexArithmeticDataset(config) dataset = ComplexArithmeticDataset(config)
# Check multiple items to ensure no division by zero # Check multiple items to ensure no division by zero
@ -131,3 +137,65 @@ def test_parse_string_to_complex():
assert dataset.parse_string_to_complex("invalid") is None assert dataset.parse_string_to_complex("invalid") is None
assert dataset.parse_string_to_complex("3 + i + 2") is None assert dataset.parse_string_to_complex("3 + i + 2") is None
assert dataset.parse_string_to_complex("3 + 2x") is None assert dataset.parse_string_to_complex("3 + 2x") is None
def test_complex_arithmetic_curriculum():
"""Test the curriculum for complex arithmetic."""
curriculum = ComplexArithmeticCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_real == base_cfg.min_imag == -10
assert base_cfg.max_real == base_cfg.max_imag == 10
assert base_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]
# Increase and validate increase in level
curriculum.increment_attr_level("min_real")
curriculum.increment_attr_level("min_imag")
curriculum.increment_attr_level("max_real")
curriculum.increment_attr_level("max_imag")
curriculum.increment_attr_level("operations_weights")
increased_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert increased_cfg.min_real == increased_cfg.min_imag == -100
assert increased_cfg.max_real == increased_cfg.max_imag == 100
assert increased_cfg.operations_weights == [0.25, 0.25, 0.25, 0.25]
# Decrease and validate decrease in level
curriculum.decrement_attr_level("min_real")
curriculum.decrement_attr_level("min_imag")
curriculum.decrement_attr_level("max_real")
curriculum.decrement_attr_level("max_imag")
curriculum.decrement_attr_level("operations_weights")
decreased_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_real == decreased_cfg.min_imag == -10
assert decreased_cfg.max_real == decreased_cfg.max_imag == 10
assert decreased_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]
# Test upper bound boundary condition
for _ in range(10):
curriculum.increment_attr_level("min_real")
curriculum.increment_attr_level("min_imag")
curriculum.increment_attr_level("max_real")
curriculum.increment_attr_level("max_imag")
curriculum.increment_attr_level("operations_weights")
upper_bound_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.min_real == upper_bound_cfg.min_imag == -100000000
assert upper_bound_cfg.max_real == upper_bound_cfg.max_imag == 100000000
assert upper_bound_cfg.operations_weights == [0.1, 0.1, 0.4, 0.4]
# Test lower bound boundary condition
for _ in range(10):
curriculum.decrement_attr_level("min_real")
curriculum.decrement_attr_level("min_imag")
curriculum.decrement_attr_level("max_real")
curriculum.decrement_attr_level("max_imag")
curriculum.decrement_attr_level("operations_weights")
lower_bound_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.min_real == lower_bound_cfg.min_imag == -10
assert lower_bound_cfg.max_real == lower_bound_cfg.max_imag == 10
assert lower_bound_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]