mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
Based on the implementation and requirements, here's a concise commit message:
feat: Add CompositeDataset for weighted multi-dataset sampling
This commit is contained in:
parent
0561844779
commit
f07b6b7f61
3 changed files with 220 additions and 0 deletions
113
tests/test_composite.py
Normal file
113
tests/test_composite.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec
|
||||
|
||||
|
||||
def create_test_config(tmp_path):
|
||||
"""Create a test YAML config file"""
|
||||
config = {
|
||||
"size": 100,
|
||||
"seed": 42,
|
||||
"datasets": [
|
||||
{
|
||||
"name": "chain_sum",
|
||||
"weight": 0.3,
|
||||
"config": {
|
||||
"min_terms": 2,
|
||||
"max_terms": 4,
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "leg_counting",
|
||||
"weight": 0.7,
|
||||
"config": {
|
||||
"min_animals": 1,
|
||||
"max_animals": 3,
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config_path = os.path.join(tmp_path, "test_config.yaml")
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
return config_path
|
||||
|
||||
|
||||
def test_composite_config_validation():
|
||||
"""Test configuration validation"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = CompositeConfig(size=-1)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CompositeConfig(datasets=[])
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_composite_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CompositeConfig(
|
||||
size=10,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
|
||||
]
|
||||
)
|
||||
|
||||
dataset1 = CompositeDataset(config)
|
||||
dataset2 = CompositeDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_composite_dataset_metadata():
|
||||
"""Test that metadata includes source dataset information"""
|
||||
config = CompositeConfig(
|
||||
size=10,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
|
||||
]
|
||||
)
|
||||
|
||||
dataset = CompositeDataset(config)
|
||||
item = dataset[0]
|
||||
|
||||
assert "source_dataset" in item["metadata"]
|
||||
assert "source_index" in item["metadata"]
|
||||
assert item["metadata"]["source_dataset"] == "chain_sum"
|
||||
assert isinstance(item["metadata"]["source_index"], int)
|
||||
|
||||
|
||||
def test_composite_dataset_weights():
|
||||
"""Test that dataset weights are properly normalized"""
|
||||
config = CompositeConfig(
|
||||
size=1000,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
|
||||
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = CompositeDataset(config)
|
||||
assert abs(dataset.weights[0] - 0.4) < 1e-6
|
||||
assert abs(dataset.weights[1] - 0.6) < 1e-6
|
||||
|
||||
|
||||
def test_yaml_loading(tmp_path):
|
||||
"""Test loading configuration from YAML"""
|
||||
config_path = create_test_config(tmp_path)
|
||||
config = CompositeConfig.from_yaml(config_path)
|
||||
|
||||
assert config.size == 100
|
||||
assert config.seed == 42
|
||||
assert len(config.datasets) == 2
|
||||
assert config.datasets[0].name == "chain_sum"
|
||||
assert config.datasets[1].name == "leg_counting"
|
||||
Loading…
Add table
Add a link
Reference in a new issue