mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
register composite dataset
This commit is contained in:
parent
6ec8f782d7
commit
3e28a14d54
4 changed files with 54 additions and 67 deletions
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import yaml
|
||||
|
||||
from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec
|
||||
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
||||
|
||||
|
||||
def create_test_config(tmp_path):
|
||||
|
|
@ -18,7 +18,7 @@ def create_test_config(tmp_path):
|
|||
"config": {
|
||||
"min_terms": 2,
|
||||
"max_terms": 4,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leg_counting",
|
||||
|
|
@ -26,15 +26,16 @@ def create_test_config(tmp_path):
|
|||
"config": {
|
||||
"min_animals": 1,
|
||||
"max_animals": 3,
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
config_path = os.path.join(tmp_path, "test_config.yaml")
|
||||
with open(config_path, 'w') as f:
|
||||
print(config_path)
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
|
||||
return config_path
|
||||
|
||||
|
||||
|
|
@ -43,7 +44,7 @@ def test_composite_config_validation():
|
|||
with pytest.raises(AssertionError):
|
||||
config = CompositeConfig(size=-1)
|
||||
config.validate()
|
||||
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CompositeConfig(datasets=[])
|
||||
config.validate()
|
||||
|
|
@ -52,16 +53,12 @@ def test_composite_config_validation():
|
|||
def test_composite_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CompositeConfig(
|
||||
size=10,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
|
||||
]
|
||||
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
||||
)
|
||||
|
||||
|
||||
dataset1 = CompositeDataset(config)
|
||||
dataset2 = CompositeDataset(config)
|
||||
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
|
@ -69,16 +66,12 @@ def test_composite_dataset_deterministic():
|
|||
def test_composite_dataset_metadata():
|
||||
"""Test that metadata includes source dataset information"""
|
||||
config = CompositeConfig(
|
||||
size=10,
|
||||
seed=42,
|
||||
datasets=[
|
||||
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
|
||||
]
|
||||
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
||||
)
|
||||
|
||||
|
||||
dataset = CompositeDataset(config)
|
||||
item = dataset[0]
|
||||
|
||||
|
||||
assert "source_dataset" in item["metadata"]
|
||||
assert "source_index" in item["metadata"]
|
||||
assert item["metadata"]["source_dataset"] == "chain_sum"
|
||||
|
|
@ -93,9 +86,9 @@ def test_composite_dataset_weights():
|
|||
datasets=[
|
||||
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
|
||||
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
dataset = CompositeDataset(config)
|
||||
assert abs(dataset.weights[0] - 0.4) < 1e-6
|
||||
assert abs(dataset.weights[1] - 0.6) < 1e-6
|
||||
|
|
@ -105,7 +98,7 @@ def test_yaml_loading(tmp_path):
|
|||
"""Test loading configuration from YAML"""
|
||||
config_path = create_test_config(tmp_path)
|
||||
config = CompositeConfig.from_yaml(config_path)
|
||||
|
||||
|
||||
assert config.size == 100
|
||||
assert config.seed == 42
|
||||
assert len(config.datasets) == 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue