mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
added intermediate integration (#334)
This commit is contained in:
parent
ede43c58ba
commit
516bca57ab
2 changed files with 151 additions and 62 deletions
|
|
@ -9,18 +9,6 @@ from reasoning_gym.algebra.intermediate_integration import IntermediateIntegrati
|
|||
|
||||
def test_intermediate_integration_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = IntermediateIntegrationConfig(problem_types=["invalid_problem_type"])
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = IntermediateIntegrationConfig(substitution_types=["invalid_substitution_type"])
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = IntermediateIntegrationConfig(by_parts_types=["invalid_by_parts_type"])
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = IntermediateIntegrationConfig(linear_lower_bound=2, linear_upper_bound=1)
|
||||
config.validate()
|
||||
|
|
@ -87,8 +75,6 @@ def test_intermediate_integration_dataset_items():
|
|||
assert "integrand" in item["metadata"]
|
||||
assert "problem_type" in item["metadata"]
|
||||
assert "variable" in item["metadata"]
|
||||
assert "type" in item["metadata"]
|
||||
|
||||
# verify answer is mathematical expression
|
||||
answer = item["answer"]
|
||||
answer = answer.replace(" + C", "")
|
||||
|
|
@ -143,3 +129,71 @@ def test_score_answer_cases():
|
|||
dummy_entry = {"metadata": metadata}
|
||||
score = dataset.score_answer(answer, entry=dummy_entry)
|
||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||
|
||||
|
||||
def test_intermediate_integration_curriculum():
|
||||
"""Test the IntermediateIntegrationCurriculum functionality"""
|
||||
from reasoning_gym.algebra.intermediate_integration import (
|
||||
IntermediateIntegrationConfig,
|
||||
IntermediateIntegrationCurriculum,
|
||||
)
|
||||
|
||||
# Create a config for the curriculum
|
||||
config = IntermediateIntegrationConfig(
|
||||
size=150, seed=1, problem_type_weights=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]
|
||||
)
|
||||
|
||||
curriculum = IntermediateIntegrationCurriculum()
|
||||
|
||||
# Test initial configuration
|
||||
base_cfg = curriculum.generate_configuration({})
|
||||
assert base_cfg.problem_type_weights == [1, 0, 0, 0, 0, 0, 0, 0] # Default level 0
|
||||
|
||||
# Test incrementing problem_type_weights attribute
|
||||
curriculum.increment_attr_level("problem_type_weights")
|
||||
level1_cfg = curriculum.generate_configuration({})
|
||||
assert level1_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Level 1
|
||||
|
||||
# Test incrementing problem_type_weights attribute again
|
||||
curriculum.increment_attr_level("problem_type_weights")
|
||||
level2_cfg = curriculum.generate_configuration({})
|
||||
assert level2_cfg.problem_type_weights == [0, 0, 1, 0, 0, 0, 0, 0] # Level 2
|
||||
|
||||
# Test decrementing problem_type_weights attribute
|
||||
curriculum.decrement_attr_level("problem_type_weights")
|
||||
back_to_level1_cfg = curriculum.generate_configuration({})
|
||||
assert back_to_level1_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0] # Back to level 1
|
||||
|
||||
# Test global level adjustments
|
||||
# Reset curriculum
|
||||
curriculum = IntermediateIntegrationCurriculum()
|
||||
assert curriculum.get_attr_level("problem_type_weights") == 0
|
||||
|
||||
# Increase global level
|
||||
curriculum.increment_global_level()
|
||||
assert curriculum.get_attr_level("problem_type_weights") == 1
|
||||
|
||||
global_level_cfg = curriculum.generate_configuration({})
|
||||
assert global_level_cfg.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0]
|
||||
|
||||
# Increase global level again
|
||||
curriculum.increment_global_level()
|
||||
assert curriculum.get_attr_level("problem_type_weights") == 2
|
||||
|
||||
global_level_cfg_2 = curriculum.generate_configuration({})
|
||||
assert global_level_cfg_2.problem_type_weights == [0, 0, 1, 0, 0, 0, 0, 0]
|
||||
|
||||
# Decrease global level
|
||||
curriculum.decrement_global_level()
|
||||
assert curriculum.get_attr_level("problem_type_weights") == 1
|
||||
|
||||
global_level_cfg_3 = curriculum.generate_configuration({})
|
||||
assert global_level_cfg_3.problem_type_weights == [0, 1, 0, 0, 0, 0, 0, 0]
|
||||
|
||||
# Test upper bound
|
||||
curriculum = IntermediateIntegrationCurriculum() # Reset curriculum
|
||||
for _ in range(10): # Try going beyond max level (7)
|
||||
curriculum.increment_attr_level("problem_type_weights")
|
||||
|
||||
max_cfg = curriculum.generate_configuration({})
|
||||
assert max_cfg.problem_type_weights == [0, 0, 0, 0, 0, 0, 0, 1] # Should be capped at level 7
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue