register composite dataset

This commit is contained in:
Andreas Koepf 2025-02-04 19:17:34 +01:00
parent 6ec8f782d7
commit 3e28a14d54
4 changed files with 54 additions and 67 deletions

View file

@ -1,9 +1,9 @@
import os
import pytest
import tempfile
import yaml
from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
def create_test_config(tmp_path):
@ -18,7 +18,7 @@ def create_test_config(tmp_path):
"config": {
"min_terms": 2,
"max_terms": 4,
}
},
},
{
"name": "leg_counting",
@ -26,15 +26,16 @@ def create_test_config(tmp_path):
"config": {
"min_animals": 1,
"max_animals": 3,
}
}
]
},
},
],
}
config_path = os.path.join(tmp_path, "test_config.yaml")
with open(config_path, 'w') as f:
print(config_path)
with open(config_path, "w") as f:
yaml.dump(config, f)
return config_path
@ -43,7 +44,7 @@ def test_composite_config_validation():
with pytest.raises(AssertionError):
config = CompositeConfig(size=-1)
config.validate()
with pytest.raises(AssertionError):
config = CompositeConfig(datasets=[])
config.validate()
@ -52,16 +53,12 @@ def test_composite_config_validation():
def test_composite_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CompositeConfig(
size=10,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
]
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset1 = CompositeDataset(config)
dataset2 = CompositeDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
@ -69,16 +66,12 @@ def test_composite_dataset_deterministic():
def test_composite_dataset_metadata():
"""Test that metadata includes source dataset information"""
config = CompositeConfig(
size=10,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
]
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset = CompositeDataset(config)
item = dataset[0]
assert "source_dataset" in item["metadata"]
assert "source_index" in item["metadata"]
assert item["metadata"]["source_dataset"] == "chain_sum"
@ -93,9 +86,9 @@ def test_composite_dataset_weights():
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
]
],
)
dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6
@ -105,7 +98,7 @@ def test_yaml_loading(tmp_path):
"""Test loading configuration from YAML"""
config_path = create_test_config(tmp_path)
config = CompositeConfig.from_yaml(config_path)
assert config.size == 100
assert config.seed == 42
assert len(config.datasets) == 2