reasoning-gym/tests/test_composite.py
2025-02-04 19:17:34 +01:00

106 lines
2.9 KiB
Python

import os
import pytest
import yaml
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
def create_test_config(tmp_path):
"""Create a test YAML config file"""
config = {
"size": 100,
"seed": 42,
"datasets": [
{
"name": "chain_sum",
"weight": 0.3,
"config": {
"min_terms": 2,
"max_terms": 4,
},
},
{
"name": "leg_counting",
"weight": 0.7,
"config": {
"min_animals": 1,
"max_animals": 3,
},
},
],
}
config_path = os.path.join(tmp_path, "test_config.yaml")
print(config_path)
with open(config_path, "w") as f:
yaml.dump(config, f)
return config_path
def test_composite_config_validation():
"""Test configuration validation"""
with pytest.raises(AssertionError):
config = CompositeConfig(size=-1)
config.validate()
with pytest.raises(AssertionError):
config = CompositeConfig(datasets=[])
config.validate()
def test_composite_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset1 = CompositeDataset(config)
dataset2 = CompositeDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_composite_dataset_metadata():
"""Test that metadata includes source dataset information"""
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset = CompositeDataset(config)
item = dataset[0]
assert "source_dataset" in item["metadata"]
assert "source_index" in item["metadata"]
assert item["metadata"]["source_dataset"] == "chain_sum"
assert isinstance(item["metadata"]["source_index"], int)
def test_composite_dataset_weights():
"""Test that dataset weights are properly normalized"""
config = CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
],
)
dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6
def test_yaml_loading(tmp_path):
"""Test loading configuration from YAML"""
config_path = create_test_config(tmp_path)
config = CompositeConfig.from_yaml(config_path)
assert config.size == 100
assert config.seed == 42
assert len(config.datasets) == 2
assert config.datasets[0].name == "chain_sum"
assert config.datasets[1].name == "leg_counting"