mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
fixes
This commit is contained in:
parent
678622faec
commit
89cd82c647
2 changed files with 221 additions and 3 deletions
|
|
@ -30,6 +30,8 @@ Solve the following task:
|
|||
|
||||
@dataclass
|
||||
class PathStarConfig:
|
||||
"""Configuration for Path Star dataset generation"""
|
||||
|
||||
degree: int = 3
|
||||
node_range: int = 100_000
|
||||
min_path_length: int = 3
|
||||
|
|
@ -41,8 +43,13 @@ class PathStarConfig:
|
|||
seed: Optional[int] = None
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self.degree >= 2 and self.min_path_length >= 1
|
||||
assert self.node_range > self.degree * self.max_path_length + 1
|
||||
"""Validate configuration parameters"""
|
||||
assert self.degree >= 2, "degree must be at least 2"
|
||||
assert self.min_path_length >= 1, "min_path_length must be at least 1"
|
||||
assert self.min_path_length <= self.max_path_length, "min_path_length must be <= max_path_length"
|
||||
assert (
|
||||
self.node_range > self.degree * self.max_path_length + 1
|
||||
), "node_range must exceed degree * max_path_length + 1 for unique labels"
|
||||
|
||||
|
||||
class PathStarDataset(ProceduralDataset):
|
||||
|
|
@ -51,6 +58,17 @@ class PathStarDataset(ProceduralDataset):
|
|||
def __init__(self, config: PathStarConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score an answer. Path is unique in a star graph, so only exact match counts."""
|
||||
if not isinstance(answer, str) or len(answer.strip()) == 0:
|
||||
return 0.0
|
||||
# Normalize: strip, collapse whitespace
|
||||
answer_normalized = " ".join(answer.strip().split())
|
||||
oracle_normalized = " ".join(entry["answer"].strip().split())
|
||||
if answer_normalized == oracle_normalized:
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
|
|
@ -81,7 +99,10 @@ class PathStarDataset(ProceduralDataset):
|
|||
rng.shuffle(edges)
|
||||
|
||||
edges_str = "".join(f"|{u} {v}" for u, v in edges)
|
||||
prefix = f"{edges_str}/{center} {goal} = "
|
||||
if cfg.reversed:
|
||||
prefix = f"{edges_str}/{goal} {center} = "
|
||||
else:
|
||||
prefix = f"{edges_str}/{center} {goal} = "
|
||||
question = PROMPT_TEMPLATE.format(task=prefix)
|
||||
|
||||
# gold path
|
||||
|
|
@ -94,6 +115,8 @@ class PathStarDataset(ProceduralDataset):
|
|||
"question": question,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"center": center,
|
||||
"goal": goal,
|
||||
"path_length": path_length,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue