This commit is contained in:
Oliver Stanley 2026-03-27 15:35:06 +00:00
parent 678622faec
commit 89cd82c647
2 changed files with 221 additions and 3 deletions

View file

@ -30,6 +30,8 @@ Solve the following task:
@dataclass
class PathStarConfig:
"""Configuration for Path Star dataset generation"""
degree: int = 3
node_range: int = 100_000
min_path_length: int = 3
@ -41,8 +43,13 @@ class PathStarConfig:
seed: Optional[int] = None
def validate(self) -> None:
assert self.degree >= 2 and self.min_path_length >= 1
assert self.node_range > self.degree * self.max_path_length + 1
"""Validate configuration parameters"""
assert self.degree >= 2, "degree must be at least 2"
assert self.min_path_length >= 1, "min_path_length must be at least 1"
assert self.min_path_length <= self.max_path_length, "min_path_length must be <= max_path_length"
assert (
self.node_range > self.degree * self.max_path_length + 1
), "node_range must exceed degree * max_path_length + 1 for unique labels"
class PathStarDataset(ProceduralDataset):
@ -51,6 +58,17 @@ class PathStarDataset(ProceduralDataset):
def __init__(self, config: PathStarConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score an answer. Path is unique in a star graph, so only exact match counts."""
if not isinstance(answer, str) or len(answer.strip()) == 0:
return 0.0
# Normalize: strip, collapse whitespace
answer_normalized = " ".join(answer.strip().split())
oracle_normalized = " ".join(entry["answer"].strip().split())
if answer_normalized == oracle_normalized:
return 1.0
return 0.0
def __getitem__(self, idx: int) -> dict[str, Any]:
rng = random.Random(self.seed + idx)
@ -81,7 +99,10 @@ class PathStarDataset(ProceduralDataset):
rng.shuffle(edges)
edges_str = "".join(f"|{u} {v}" for u, v in edges)
prefix = f"{edges_str}/{center} {goal} = "
if cfg.reversed:
prefix = f"{edges_str}/{goal} {center} = "
else:
prefix = f"{edges_str}/{center} {goal} = "
question = PROMPT_TEMPLATE.format(task=prefix)
# gold path
@ -94,6 +115,8 @@ class PathStarDataset(ProceduralDataset):
"question": question,
"answer": answer,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"center": center,
"goal": goal,
"path_length": path_length,

195
tests/test_path_star.py Normal file
View file

@ -0,0 +1,195 @@
"""Tests for Path Star graph problem generation"""
import pytest
from reasoning_gym.graphs.path_star import PathStarConfig, PathStarCurriculum, PathStarDataset
def test_path_star_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = PathStarConfig(degree=1) # Must be >= 2
config.validate()
with pytest.raises(AssertionError):
config = PathStarConfig(min_path_length=0) # Must be >= 1
config.validate()
with pytest.raises(AssertionError):
config = PathStarConfig(min_path_length=5, max_path_length=3) # min > max
config.validate()
with pytest.raises(AssertionError):
config = PathStarConfig(degree=3, max_path_length=5, node_range=16) # node_range too small (need > 3*5+1=16)
config.validate()
def test_path_star_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = PathStarConfig(seed=42, size=10)
dataset1 = PathStarDataset(config)
dataset2 = PathStarDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_path_star_dataset_items():
"""Test basic properties of generated items"""
config = PathStarConfig(min_path_length=3, max_path_length=5, size=10, seed=42)
dataset = PathStarDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata fields
assert item["metadata"]["source_dataset"] == "path_star"
assert item["metadata"]["source_index"] == i
assert "center" in item["metadata"]
assert "goal" in item["metadata"]
assert "path_length" in item["metadata"]
assert "goal_path" in item["metadata"]
assert "difficulty" in item["metadata"]
# Verify answer format: space-separated integers
answer_parts = item["answer"].split()
assert all(part.isdigit() for part in answer_parts)
# First node should be center, last should be goal
center = item["metadata"]["center"]
goal = item["metadata"]["goal"]
assert int(answer_parts[0]) == center
assert int(answer_parts[-1]) == goal
# Path length should match: center + path_length nodes
path_length = item["metadata"]["path_length"]
assert len(answer_parts) == path_length + 1
# Path length within configured range
assert config.min_path_length <= path_length <= config.max_path_length
def test_path_star_dataset_iteration():
"""Test that iteration respects dataset size"""
config = PathStarConfig(size=5, seed=42)
dataset = PathStarDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_path_star_answer_correctness():
"""Test that generated paths are valid by checking edge connectivity"""
config = PathStarConfig(size=20, seed=123)
dataset = PathStarDataset(config)
for i in range(len(dataset)):
item = dataset[i]
question = item["question"]
answer_parts = [int(x) for x in item["answer"].split()]
# Parse edges from the question
# Format: ...edges_str/start goal = ...
# Extract the task part between "Solve the following task:\n" and the end
task_line = question.split("Solve the following task:\n")[1].strip()
edge_part, _ = task_line.split("/")
edges = set()
for edge_str in edge_part.split("|"):
edge_str = edge_str.strip()
if edge_str:
u, v = edge_str.split()
edges.add((int(u), int(v)))
# Verify consecutive nodes in the answer are connected by edges
for j in range(len(answer_parts) - 1):
u, v = answer_parts[j], answer_parts[j + 1]
assert (u, v) in edges, f"Edge ({u}, {v}) not found in edges for item {i}"
def test_path_star_score_answer():
"""Test the score_answer method"""
config = PathStarConfig(seed=42, size=5)
dataset = PathStarDataset(config)
item = dataset[0]
oracle = item["answer"]
# Exact match
assert dataset.score_answer(oracle, item) == 1.0
# Match with extra whitespace
assert dataset.score_answer(f" {oracle} ", item) == 1.0
# Match with extra internal whitespace
spaced = oracle.replace(" ", " ")
assert dataset.score_answer(spaced, item) == 1.0
# Wrong answer
assert dataset.score_answer("0 1 2 3", item) == 0.0
# None
assert dataset.score_answer(None, item) == 0.0
# Empty string
assert dataset.score_answer("", item) == 0.0
def test_path_star_reversed():
"""Test that reversed=True produces correct answer and task format"""
config_fwd = PathStarConfig(seed=42, size=5, reversed=False)
config_rev = PathStarConfig(seed=42, size=5, reversed=True)
dataset_fwd = PathStarDataset(config_fwd)
dataset_rev = PathStarDataset(config_rev)
for i in range(len(dataset_fwd)):
item_fwd = dataset_fwd[i]
item_rev = dataset_rev[i]
# Reversed answer should be the forward answer reversed
fwd_parts = item_fwd["answer"].split()
rev_parts = item_rev["answer"].split()
assert rev_parts == list(reversed(fwd_parts))
# Task format should swap start/goal
center = item_fwd["metadata"]["center"]
goal = item_fwd["metadata"]["goal"]
assert f"/{center} {goal} = " in item_fwd["question"]
assert f"/{goal} {center} = " in item_rev["question"]
def test_path_star_curriculum():
"""Test curriculum creates valid configs at various levels"""
curriculum = PathStarCurriculum()
base_value = {"size": 150, "seed": 1}
# Level 0 (base)
base_cfg: PathStarConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.degree == 2
assert base_cfg.node_range == 10_000
assert base_cfg.min_path_length == 3 and base_cfg.max_path_length == 3
# Increment attributes
curriculum.increment_attr_level("degree")
curriculum.increment_attr_level("node_range")
curriculum.increment_attr_level("path_length")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.degree == 3
assert increased_cfg.node_range == 50_000
assert increased_cfg.min_path_length == 3 and increased_cfg.max_path_length == 5
# Decrement degree back
curriculum.decrement_attr_level("degree")
partial_cfg = curriculum.generate_configuration(base_value)
assert partial_cfg.degree == 2
assert partial_cfg.node_range == 50_000
assert partial_cfg.min_path_length == 3 and partial_cfg.max_path_length == 5