Add complex arithmetic curriculum (#310)

* Add complex arithmetic curriculum
This commit is contained in:
Adefioye 2025-03-09 18:28:51 -05:00 committed by GitHub
parent 9bd4f03dbd
commit 841663cc5a
2 changed files with 139 additions and 5 deletions

View file

@ -1,6 +1,10 @@
import pytest
from reasoning_gym.algebra.complex_arithmetic import ComplexArithmeticConfig, ComplexArithmeticDataset
from reasoning_gym.algebra.complex_arithmetic import (
ComplexArithmeticConfig,
ComplexArithmeticCurriculum,
ComplexArithmeticDataset,
)
def test_complex_arithmetic_basic():
@ -81,7 +85,9 @@ def test_complex_arithmetic_scoring():
def test_complex_arithmetic_division_by_zero():
"""Test that division by zero is handled properly."""
config = ComplexArithmeticConfig(operations=("/",), seed=42) # Only test division
config = ComplexArithmeticConfig(
operations=("+", "-", "*", "/"), operations_weights=[0.0, 0.0, 0.0, 1.0], seed=42
) # Only test division
dataset = ComplexArithmeticDataset(config)
# Check multiple items to ensure no division by zero
@ -131,3 +137,65 @@ def test_parse_string_to_complex():
assert dataset.parse_string_to_complex("invalid") is None
assert dataset.parse_string_to_complex("3 + i + 2") is None
assert dataset.parse_string_to_complex("3 + 2x") is None
def test_complex_arithmetic_curriculum():
"""Test the curriculum for complex arithmetic."""
curriculum = ComplexArithmeticCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_real == base_cfg.min_imag == -10
assert base_cfg.max_real == base_cfg.max_imag == 10
assert base_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]
# Increase and validate increase in level
curriculum.increment_attr_level("min_real")
curriculum.increment_attr_level("min_imag")
curriculum.increment_attr_level("max_real")
curriculum.increment_attr_level("max_imag")
curriculum.increment_attr_level("operations_weights")
increased_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert increased_cfg.min_real == increased_cfg.min_imag == -100
assert increased_cfg.max_real == increased_cfg.max_imag == 100
assert increased_cfg.operations_weights == [0.25, 0.25, 0.25, 0.25]
# Decrease and validate decrease in level
curriculum.decrement_attr_level("min_real")
curriculum.decrement_attr_level("min_imag")
curriculum.decrement_attr_level("max_real")
curriculum.decrement_attr_level("max_imag")
curriculum.decrement_attr_level("operations_weights")
decreased_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_real == decreased_cfg.min_imag == -10
assert decreased_cfg.max_real == decreased_cfg.max_imag == 10
assert decreased_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]
# Test upper bound boundary condition
for _ in range(10):
curriculum.increment_attr_level("min_real")
curriculum.increment_attr_level("min_imag")
curriculum.increment_attr_level("max_real")
curriculum.increment_attr_level("max_imag")
curriculum.increment_attr_level("operations_weights")
upper_bound_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.min_real == upper_bound_cfg.min_imag == -100000000
assert upper_bound_cfg.max_real == upper_bound_cfg.max_imag == 100000000
assert upper_bound_cfg.operations_weights == [0.1, 0.1, 0.4, 0.4]
# Test lower bound boundary condition
for _ in range(10):
curriculum.decrement_attr_level("min_real")
curriculum.decrement_attr_level("min_imag")
curriculum.decrement_attr_level("max_real")
curriculum.decrement_attr_level("max_imag")
curriculum.decrement_attr_level("operations_weights")
lower_bound_cfg: ComplexArithmeticCurriculum = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.min_real == lower_bound_cfg.min_imag == -10
assert lower_bound_cfg.max_real == lower_bound_cfg.max_imag == 10
assert lower_bound_cfg.operations_weights == [0.4, 0.4, 0.1, 0.1]