From 09e30a23081ad3daf955836bfc0fe12f984e0748 Mon Sep 17 00:00:00 2001 From: Pierre-Hugues Husson Date: Tue, 10 Jun 2025 23:58:11 +0200 Subject: [PATCH] Add a wikirace task --- reasoning_gym/games/__init__.py | 4 + reasoning_gym/games/wikirace.py | 223 ++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 reasoning_gym/games/wikirace.py diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 91e5ca73..32bdeedd 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -23,6 +23,7 @@ from .sudoku import SudokuConfig, SudokuCurriculum, SudokuDataset from .survo import SurvoConfig, SurvoCurriculum, SurvoDataset from .tower_of_hanoi import HanoiConfig, HanoiCurriculum, HanoiDataset from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset +from .wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset __all__ = [ "BoxnetConfig", @@ -76,4 +77,7 @@ __all__ = [ "MahjongPuzzleConfig", "MahjongPuzzleDataset", "MahjongPuzzleCurriculum", + "WikiraceConfig", + "WikiraceDataset", + "WikiraceCurriculum", ] diff --git a/reasoning_gym/games/wikirace.py b/reasoning_gym/games/wikirace.py new file mode 100644 index 00000000..357a6b18 --- /dev/null +++ b/reasoning_gym/games/wikirace.py @@ -0,0 +1,223 @@ +import re +from collections import deque, defaultdict +from dataclasses import dataclass +from functools import lru_cache +from random import Random +from typing import Any, Optional +from datasets import load_dataset + +import sympy +from sympy import Symbol, symbols +from sympy.parsing.sympy_parser import parse_expr + +from ..coaching import BaseCurriculum, RangeAttributeDefinition +from ..factory import ProceduralDataset, register_dataset + +QUESTION_FORMAT_TEMPLATE = """ +You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links. + +Answer with just the link number. + +Current article: {current} +Target article: {target} +Available links (numbered): +{formatted_links} + +Your path so far: {formatted_path} + +Think about which link is most likely to lead you toward the target article. +First, analyze each link briefly and how it connects to your goal, then select the most promising one. +""" + + +DATASET_NAME = "wikirace" + + +@dataclass +class WikiraceConfig: + """Configuration for WikiRace task generation""" + + min_distance: int = 3 + max_distance: int = 6 + seed: Optional[int] = None + size: int = 500 + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.min_distance > 1, "min_distance must be greater than 1" + assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance" + +def load_wiki_graph(): + dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k") + + graph = defaultdict(set) + titles = set() + + # Build the graph + for example in dataset['train']: + title = example['article'] + links = example['links'] + + titles.add(title) + + for link in links: + graph[title].add(link) + + # Note: Since titles was a set, and hash are naturally unstable + # We want to sort it, so that prng.choice() is stable + return graph, sorted(list(titles)) + +class WikiraceDataset(ProceduralDataset): + """Generates Wikirace Game tasks""" + + def __init__(self, config: WikiraceConfig): + self.wikigraph, self.wikititles = load_wiki_graph() + super().__init__(config=config, seed=config.seed, size=config.size) + + # We'll be computing a lot of shortest_path of very similar paths + # So cache it + @lru_cache(maxsize=128*1024) + def shortest_path(self, source, target): + if source not in self.wikigraph or target not in self.wikigraph: + return None + + if source == target: + return 1, [source] + + visited = {source} + queue = deque([(source, [source], 0)]) + + while queue: + current_node, path, l = queue.popleft() + for neighbor in self.wikigraph[current_node]: + if neighbor == target: + return 1 + l, (path + [neighbor]) + if neighbor not in visited: + visited.add(neighbor) + queue.append((neighbor, path + [neighbor], 1+l)) + + return None # No path found + + def __getitem__(self, idx: int) -> dict: + """Generate a single Wikirace Game task + + Returns: + dict with keys: + - question: str, the task description with numbers and target + - answer: str, one possible solution expression + - metadata: dict with generation parameters + """ + rng = Random(self.seed + idx) + + # Find a task that suits our min_distance/max_distance + # Since some pages might be dead-ends, we might need to try multiple times + while True: + source = rng.choice(self.wikititles) + target = source + chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance) + path = [source] + length = 0 + while self.shortest_path(source, target)[0] != chosen_distance: + possibilities = self.wikigraph[target] - set(path) + if not possibilities: + break + # Since hash() is random, we need to sort the set into a list + # for prng stability + possibilities = sorted(list(possibilities)) + target = rng.choice(possibilities) + length += 1 + # Are we lost? Are we looping? Aborting + if length > 12: + break + path.append(target) + if self.shortest_path(source, target)[0] == chosen_distance: + break + else: + # We got lost in a loop or a dead end, try again + pass + _, path = self.shortest_path(source, target) + # This is the length of the current path (let's call it state) + # 0 mean that we are still at the source of the path we're searching for + path_len = rng.randint(0, min(self.config.min_distance, len(path)) - 2) + given_path = path[:path_len] + given_path = " => ".join(given_path) + current = path[path_len] + # Stable links + links = sorted(list(self.wikigraph[current])) + links = list(enumerate(links)) + question = QUESTION_FORMAT_TEMPLATE.format( + current = current, + target = target, + formatted_links = [f"{x[0]} - {x[1]}\n" for x in links], + formatted_path = given_path, + ) + answer = [x[0] for x in links if x[1] == path[path_len+1]][0] + + v= { + "question": question, + "answer": answer, + "metadata": { + "source_dataset": DATASET_NAME, + "source_index": idx, + "source": source, + "current": current, + "target": target, + "path": given_path, + "remaining_path": path[path_len:], + "links": links, + "difficulty": { + "distance": (self.config.min_distance, self.config.max_distance), + }, + }, + } + return v + + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.01 # Default reward + source = entry['metadata']['source'] + target = entry['metadata']['target'] + current = entry['metadata']['current'] + links = entry['metadata']['links'] + + if answer is None or not answer.strip(): + return reward + + try: + answer = answer.strip() + answer = int(answer) + link = links[answer][1] + new_distance = self.shortest_path(link, target)[1] + old_distance = self.shortest_path(current, target)[1] + if new_distance < old_distance: + # Path is shortet than before, it is following (a) shortest path! + return 1.0 + elif new_distance == old_distance: + # Path isn't shorter, but not longer either, that's still something + return 0.5 + else: + # At least answer is valid... + return 0.1 + + except Exception: + return 0.01 + + +class WikiraceCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(WikiraceCurriculum.__name__, WikiraceConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="distance", + levels=[3, 6, 9, 12, 15], + description="Number of source numbers", + lower_field_name="min_distance", + upper_field_name="max_distance", + ensure_interval=True, + ), + ) + +# Register the dataset +register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum)