added intermediate integration (#334)

This commit is contained in:
joesharratt1229 2025-03-11 23:57:51 +01:00 committed by GitHub
parent ede43c58ba
commit 516bca57ab
2 changed files with 151 additions and 62 deletions

View file

@ -9,18 +9,6 @@ from reasoning_gym.algebra.intermediate_integration import IntermediateIntegrati
def test_intermediate_integration_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = IntermediateIntegrationConfig(problem_types=["invalid_problem_type"])
config.validate()
with pytest.raises(AssertionError):
config = IntermediateIntegrationConfig(substitution_types=["invalid_substitution_type"])
config.validate()
with pytest.raises(AssertionError):
config = IntermediateIntegrationConfig(by_parts_types=["invalid_by_parts_type"])
config.validate()
with pytest.raises(AssertionError):
config = IntermediateIntegrationConfig(linear_lower_bound=2, linear_upper_bound=1)
config.validate()
@ -87,8 +75,6 @@ def test_intermediate_integration_dataset_items():
assert "integrand" in item["metadata"]
assert "problem_type" in item["metadata"]
assert "variable" in item["metadata"]
assert "type" in item["metadata"]
# verify answer is mathematical expression
answer = item["answer"]
answer = answer.replace(" + C", "")
@ -143,3 +129,71 @@ def test_score_answer_cases():
dummy_entry = {"metadata": metadata}
score = dataset.score_answer(answer, entry=dummy_entry)
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
def test_intermediate_integration_curriculum():
"""Test the IntermediateIntegrationCurriculum functionality"""
from reasoning_gym.algebra.intermediate_integration import (
IntermediateIntegrationConfig,
IntermediateIntegrationCurriculum,
)
# Create a config for the curriculum
config = IntermediateIntegrationConfig(
size=150, seed=1, problem_type_weights=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]
)
curriculum = IntermediateIntegrationCurriculum()
# Test initial configuration
base_cfg = curriculum.generate_configuration({})
assert base_cfg.problem_type_weights == [1, 0, 0, 0, 0, 0, 0, 0] # Default level 0
# Test incrementing problem_type_weights attribute
curriculum.increment_attr_level("problem_type_weights")
level1_cfg = curriculum.generate_configuration({})
assert level1_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Level 1
# Test incrementing problem_type_weights attribute again
curriculum.increment_attr_level("problem_type_weights")
level2_cfg = curriculum.generate_configuration({})
assert level2_cfg.problem_type_weights == [0, 0, 1, 0, 0, 0, 0, 0] # Level 2
# Test decrementing problem_type_weights attribute
curriculum.decrement_attr_level("problem_type_weights")
back_to_level1_cfg = curriculum.generate_configuration({})
assert back_to_level1_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Back to level 1
# Test global level adjustments
# Reset curriculum
curriculum = IntermediateIntegrationCurriculum()
assert curriculum.get_attr_level("problem_type_weights") == 0
# Increase global level
curriculum.increment_global_level()
assert curriculum.get_attr_level("problem_type_weights") == 1
global_level_cfg = curriculum.generate_configuration({})
assert global_level_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0]
# Increase global level again
curriculum.increment_global_level()
assert curriculum.get_attr_level("problem_type_weights") == 2
global_level_cfg_2 = curriculum.generate_configuration({})
assert global_level_cfg_2.problem_type_weights == [0, 0, 1, 0, 0, 0, 0, 0]
# Decrease global level
curriculum.decrement_global_level()
assert curriculum.get_attr_level("problem_type_weights") == 1
global_level_cfg_3 = curriculum.generate_configuration({})
assert global_level_cfg_3.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0]
# Test upper bound
curriculum = IntermediateIntegrationCurriculum() # Reset curriculum
for _ in range(10): # Try going beyond max level (7)
curriculum.increment_attr_level("problem_type_weights")
max_cfg = curriculum.generate_configuration({})
assert max_cfg.problem_type_weights == [0, 0, 0, 0, 0, 0, 0, 1] # Should be capped at level 7