This commit is contained in:
Oliver Stanley 2026-03-27 15:35:06 +00:00
parent 678622faec
commit 89cd82c647
2 changed files with 221 additions and 3 deletions

View file

@ -30,6 +30,8 @@ Solve the following task:
@dataclass
class PathStarConfig:
"""Configuration for Path Star dataset generation"""
degree: int = 3
node_range: int = 100_000
min_path_length: int = 3
@ -41,8 +43,13 @@ class PathStarConfig:
seed: Optional[int] = None
def validate(self) -> None:
assert self.degree >= 2 and self.min_path_length >= 1
assert self.node_range > self.degree * self.max_path_length + 1
"""Validate configuration parameters"""
assert self.degree >= 2, "degree must be at least 2"
assert self.min_path_length >= 1, "min_path_length must be at least 1"
assert self.min_path_length <= self.max_path_length, "min_path_length must be <= max_path_length"
assert (
self.node_range > self.degree * self.max_path_length + 1
), "node_range must exceed degree * max_path_length + 1 for unique labels"
class PathStarDataset(ProceduralDataset):
@ -51,6 +58,17 @@ class PathStarDataset(ProceduralDataset):
def __init__(self, config: PathStarConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score an answer. Path is unique in a star graph, so only exact match counts."""
if not isinstance(answer, str) or len(answer.strip()) == 0:
return 0.0
# Normalize: strip, collapse whitespace
answer_normalized = " ".join(answer.strip().split())
oracle_normalized = " ".join(entry["answer"].strip().split())
if answer_normalized == oracle_normalized:
return 1.0
return 0.0
def __getitem__(self, idx: int) -> dict[str, Any]:
rng = random.Random(self.seed + idx)
@ -81,7 +99,10 @@ class PathStarDataset(ProceduralDataset):
rng.shuffle(edges)
edges_str = "".join(f"|{u} {v}" for u, v in edges)
prefix = f"{edges_str}/{center} {goal} = "
if cfg.reversed:
prefix = f"{edges_str}/{goal} {center} = "
else:
prefix = f"{edges_str}/{center} {goal} = "
question = PROMPT_TEMPLATE.format(task=prefix)
# gold path
@ -94,6 +115,8 @@ class PathStarDataset(ProceduralDataset):
"question": question,
"answer": answer,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"center": center,
"goal": goal,
"path_length": path_length,