This commit is contained in:
Pierre-Hugues HUSSON 2026-04-18 19:54:31 +05:30 committed by GitHub
commit 72999eda51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 343 additions and 0 deletions

View file

@ -23,6 +23,7 @@ from .sudoku import SudokuConfig, SudokuCurriculum, SudokuDataset
from .survo import SurvoConfig, SurvoCurriculum, SurvoDataset
from .tower_of_hanoi import HanoiConfig, HanoiCurriculum, HanoiDataset
from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset
from .wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset
__all__ = [
"BoxnetConfig",
@ -76,4 +77,7 @@ __all__ = [
"MahjongPuzzleConfig",
"MahjongPuzzleDataset",
"MahjongPuzzleCurriculum",
"WikiraceConfig",
"WikiraceDataset",
"WikiraceCurriculum",
]

View file

@ -0,0 +1,238 @@
import re
from collections import defaultdict, deque
from dataclasses import dataclass
from functools import lru_cache
from random import Random
from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
try:
from datasets import load_dataset
except:
raise Exception("wikirace requires datasets library. Run `pip install datasets`")
QUESTION_FORMAT_TEMPLATE = """
You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links.
Answer with just the link number.
Current article: {current}
Target article: {target}
Available links (numbered):
{formatted_links}
Your path so far: {formatted_path}
Think about which link is most likely to lead you toward the target article.
First, analyze each link briefly and how it connects to your goal, then select the most promising one.
"""
DATASET_NAME = "wikirace"
@dataclass
class WikiraceConfig:
"""Configuration for WikiRace task generation"""
min_distance: int = 3
max_distance: int = 6
max_tries: int = 100
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_distance > 1, "min_distance must be greater than 1"
assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance"
assert self.max_tries >= 1, "max_tries must be greater than 1"
def load_wiki_graph():
dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k")
graph = defaultdict(set)
titles = set()
# Build the graph
for example in dataset["train"]:
title = example["article"]
links = example["links"]
titles.add(title)
for link in links:
graph[title].add(link)
# Note: Since titles was a set, and hash are naturally unstable
# We want to sort it, so that prng.choice() is stable
return graph, sorted(list(titles))
class WikiraceDataset(ProceduralDataset):
"""Generates Wikirace Game tasks"""
def __init__(self, config: WikiraceConfig):
self.wikigraph, self.wikititles = load_wiki_graph()
super().__init__(config=config, seed=config.seed, size=config.size)
# We'll be computing a lot of shortest_path of very similar paths
# So cache it
@lru_cache(maxsize=128 * 1024)
def shortest_path(self, source, target):
if source not in self.wikigraph or target not in self.wikigraph:
return None
if source == target:
return 1, [source]
visited = {source}
queue = deque([(source, [source], 0)])
while queue:
current_node, path, l = queue.popleft()
for neighbor in self.wikigraph[current_node]:
if neighbor == target:
return 1 + l, (path + [neighbor])
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [neighbor], 1 + l))
return None # No path found
def __getitem__(self, idx: int) -> dict:
"""Generate a single Wikirace Game task
Returns:
dict with keys:
- question: str, the task description with a source article, target article, and current chosen path
- answer: str, one possible article on the shortest path
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
# Find a task that suits our min_distance/max_distance
# Since some pages might be dead-ends, we might need to try multiple times
for _ in range(self.config.max_tries):
source = rng.choice(self.wikititles)
target = source
chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance)
path = [source]
length = 0
while self.shortest_path(source, target)[0] != chosen_distance:
possibilities = self.wikigraph[target] - set(path)
if not possibilities:
break
# Since hash() is random, we need to sort the set into a list
# for prng stability
possibilities = sorted(list(possibilities))
target = rng.choice(possibilities)
length += 1
# Are we lost? Are we looping? Aborting
if length > 12:
break
path.append(target)
if self.shortest_path(source, target)[0] == chosen_distance:
break
# We got lost in a loop or a dead end, try again
if self.shortest_path(source, target)[0] != chosen_distance:
raise Exception(f"After {self.config.max_tries}, we failed to find a suitable wikipedia articles pair")
_, path = self.shortest_path(source, target)
# This is the length of the current path (let's call it state)
# 0 mean that we are still at the source of the path we're searching for
path_len = rng.randint(0, min(self.config.min_distance, len(path)) - 2)
given_path = path[:path_len]
given_path = " => ".join(given_path)
current = path[path_len]
# Stable links
links = sorted(list(self.wikigraph[current]))
links = list(enumerate(links))
question = QUESTION_FORMAT_TEMPLATE.format(
current=current,
target=target,
formatted_links=[f"{x[0]} - {x[1]}\n" for x in links],
formatted_path=given_path,
)
answer = [x[0] for x in links if x[1] == path[path_len + 1]][0]
return {
"question": question,
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"source": source,
"current": current,
"target": target,
"distance": chosen_distance,
"path": given_path,
"remaining_path": path[path_len:],
"links": links,
"difficulty": {
"distance": (self.config.min_distance, self.config.max_distance),
},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.01 # Default reward
source = entry["metadata"]["source"]
target = entry["metadata"]["target"]
current = entry["metadata"]["current"]
links = entry["metadata"]["links"]
if answer is None or not answer.strip():
return reward
try:
answer = answer.strip()
answer = int(answer)
if answer < 0:
return 0.01
link = links[answer][1]
new_distance = self.shortest_path(link, target)[0]
old_distance = self.shortest_path(current, target)[0]
if new_distance < old_distance:
# Path is shortet than before, it is following (a) shortest path!
return 1.0
elif new_distance == old_distance:
# Path isn't shorter, but not longer either, that's still something
return 0.5
else:
# At least answer is valid...
return 0.1
except Exception:
return 0.01
class WikiraceCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(WikiraceCurriculum.__name__, WikiraceConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="distance",
levels=[3, 6, 9, 12, 15],
description="Number of source numbers",
lower_field_name="min_distance",
upper_field_name="max_distance",
ensure_interval=True,
),
ScalarAttributeDefinition(
name="max_tries",
description="Max number of tries to find test cases",
field_name="max_tries",
),
)
# Register the dataset
register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum)

View file

@ -0,0 +1 @@
datasets>=3.6.0

100
tests/test_wikirace.py Normal file
View file

@ -0,0 +1,100 @@
import pytest
from reasoning_gym.games.wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset
def test_wikirace_game_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = WikiraceConfig(min_distance=0)
config.validate()
with pytest.raises(AssertionError):
config = WikiraceConfig(min_distance=3, max_distance=2)
config.validate()
with pytest.raises(AssertionError):
config = WikiraceConfig(max_tries=-2)
config.validate()
def test_wikirace_game_deterministic():
"""Test that dataset generates same items with same seed"""
config1 = WikiraceConfig(seed=42, size=2)
dataset1 = WikiraceDataset(config1)
config2 = WikiraceConfig(seed=42, size=2)
dataset2 = WikiraceDataset(config2)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_wikirace_game_items():
"""Test basic properties of generated items"""
config = WikiraceConfig(
seed=42,
size=2,
)
dataset = WikiraceDataset(config)
for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata contains required fields
assert "source" in item["metadata"]
assert "links" in item["metadata"]
assert "target" in item["metadata"]
assert "current" in item["metadata"]
assert "distance" in item["metadata"]
# Verify number of source numbers is within config range
assert config.min_distance <= item["metadata"]["distance"] <= config.max_distance
# A non-int answer fails
assert dataset.score_answer(answer="nope", entry=item) == 0.01
# A negative answer fails
assert dataset.score_answer(answer="-1", entry=item) == 0.01
# An out of bond answer fails
assert dataset.score_answer(answer=str(len(item["metadata"]["links"])), entry=item) == 0.01
# A parsable answer gives at least 0.1
assert dataset.score_answer(answer="0", entry=item) >= 0.1
# The expected answer gives 1.0
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
def test_wikirace_game_single():
"""Test a known item"""
config = WikiraceConfig(
seed=42,
size=1,
)
dataset = WikiraceDataset(config)
item = dataset[0]
# If those asserts fails, it probably just means you changed the generation algorithm, which is fine
# you'll have have to update this test
assert item["metadata"]["source"] == "Vadim Bakatin"
assert item["metadata"]["target"] == "Azerbaijan Technological University"
assert item["metadata"]["distance"] == 3
assert len(item["metadata"]["path"]) == 0
# If those asserts fails, it is most likely an actual error
# Only valid answer is 4 - Moscow
assert dataset.score_answer(answer="4", entry=item) == 1.0
# Selecting 8 - Russians makes you go further away from the target
assert dataset.score_answer(answer="2", entry=item) == 0.1
# Selecting 0 - Commmunist Party of the Soviet Union doesn't get you further away, but it doesn't get you closer either
assert dataset.score_answer(answer="2", entry=item) == 0.1
# Use this to check the results if you need to update this test
# (with pytest -s)
# for (i,_) in item['metadata']['links']:
# print(i, dataset.score_answer(answer=str(i), entry=item))