mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
106 lines
2.9 KiB
Python
106 lines
2.9 KiB
Python
import os
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
|
|
|
|
|
def create_test_config(tmp_path):
|
|
"""Create a test YAML config file"""
|
|
config = {
|
|
"size": 100,
|
|
"seed": 42,
|
|
"datasets": [
|
|
{
|
|
"name": "chain_sum",
|
|
"weight": 0.3,
|
|
"config": {
|
|
"min_terms": 2,
|
|
"max_terms": 4,
|
|
},
|
|
},
|
|
{
|
|
"name": "leg_counting",
|
|
"weight": 0.7,
|
|
"config": {
|
|
"min_animals": 1,
|
|
"max_animals": 3,
|
|
},
|
|
},
|
|
],
|
|
}
|
|
|
|
config_path = os.path.join(tmp_path, "test_config.yaml")
|
|
print(config_path)
|
|
with open(config_path, "w") as f:
|
|
yaml.dump(config, f)
|
|
|
|
return config_path
|
|
|
|
|
|
def test_composite_config_validation():
|
|
"""Test configuration validation"""
|
|
with pytest.raises(AssertionError):
|
|
config = CompositeConfig(size=-1)
|
|
config.validate()
|
|
|
|
with pytest.raises(AssertionError):
|
|
config = CompositeConfig(datasets=[])
|
|
config.validate()
|
|
|
|
|
|
def test_composite_dataset_deterministic():
|
|
"""Test that dataset generates same items with same seed"""
|
|
config = CompositeConfig(
|
|
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
|
)
|
|
|
|
dataset1 = CompositeDataset(config)
|
|
dataset2 = CompositeDataset(config)
|
|
|
|
for i in range(len(dataset1)):
|
|
assert dataset1[i] == dataset2[i]
|
|
|
|
|
|
def test_composite_dataset_metadata():
|
|
"""Test that metadata includes source dataset information"""
|
|
config = CompositeConfig(
|
|
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
|
)
|
|
|
|
dataset = CompositeDataset(config)
|
|
item = dataset[0]
|
|
|
|
assert "source_dataset" in item["metadata"]
|
|
assert "source_index" in item["metadata"]
|
|
assert item["metadata"]["source_dataset"] == "chain_sum"
|
|
assert isinstance(item["metadata"]["source_index"], int)
|
|
|
|
|
|
def test_composite_dataset_weights():
|
|
"""Test that dataset weights are properly normalized"""
|
|
config = CompositeConfig(
|
|
size=1000,
|
|
seed=42,
|
|
datasets=[
|
|
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
|
|
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
|
|
],
|
|
)
|
|
|
|
dataset = CompositeDataset(config)
|
|
assert abs(dataset.weights[0] - 0.4) < 1e-6
|
|
assert abs(dataset.weights[1] - 0.6) < 1e-6
|
|
|
|
|
|
def test_yaml_loading(tmp_path):
|
|
"""Test loading configuration from YAML"""
|
|
config_path = create_test_config(tmp_path)
|
|
config = CompositeConfig.from_yaml(config_path)
|
|
|
|
assert config.size == 100
|
|
assert config.seed == 42
|
|
assert len(config.datasets) == 2
|
|
assert config.datasets[0].name == "chain_sum"
|
|
assert config.datasets[1].name == "leg_counting"
|