mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-03 17:53:26 +00:00
Merge fdb93a3d7d into 49b07130b3
This commit is contained in:
commit
72999eda51
4 changed files with 343 additions and 0 deletions
|
|
@ -23,6 +23,7 @@ from .sudoku import SudokuConfig, SudokuCurriculum, SudokuDataset
|
||||||
from .survo import SurvoConfig, SurvoCurriculum, SurvoDataset
|
from .survo import SurvoConfig, SurvoCurriculum, SurvoDataset
|
||||||
from .tower_of_hanoi import HanoiConfig, HanoiCurriculum, HanoiDataset
|
from .tower_of_hanoi import HanoiConfig, HanoiCurriculum, HanoiDataset
|
||||||
from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset
|
from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset
|
||||||
|
from .wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BoxnetConfig",
|
"BoxnetConfig",
|
||||||
|
|
@ -76,4 +77,7 @@ __all__ = [
|
||||||
"MahjongPuzzleConfig",
|
"MahjongPuzzleConfig",
|
||||||
"MahjongPuzzleDataset",
|
"MahjongPuzzleDataset",
|
||||||
"MahjongPuzzleCurriculum",
|
"MahjongPuzzleCurriculum",
|
||||||
|
"WikiraceConfig",
|
||||||
|
"WikiraceDataset",
|
||||||
|
"WikiraceCurriculum",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
238
reasoning_gym/games/wikirace.py
Normal file
238
reasoning_gym/games/wikirace.py
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
import re
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from random import Random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
except:
|
||||||
|
raise Exception("wikirace requires datasets library. Run `pip install datasets`")
|
||||||
|
|
||||||
|
QUESTION_FORMAT_TEMPLATE = """
|
||||||
|
You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links.
|
||||||
|
|
||||||
|
Answer with just the link number.
|
||||||
|
|
||||||
|
Current article: {current}
|
||||||
|
Target article: {target}
|
||||||
|
Available links (numbered):
|
||||||
|
{formatted_links}
|
||||||
|
|
||||||
|
Your path so far: {formatted_path}
|
||||||
|
|
||||||
|
Think about which link is most likely to lead you toward the target article.
|
||||||
|
First, analyze each link briefly and how it connects to your goal, then select the most promising one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_NAME = "wikirace"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WikiraceConfig:
|
||||||
|
"""Configuration for WikiRace task generation"""
|
||||||
|
|
||||||
|
min_distance: int = 3
|
||||||
|
max_distance: int = 6
|
||||||
|
max_tries: int = 100
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.min_distance > 1, "min_distance must be greater than 1"
|
||||||
|
assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance"
|
||||||
|
|
||||||
|
assert self.max_tries >= 1, "max_tries must be greater than 1"
|
||||||
|
|
||||||
|
|
||||||
|
def load_wiki_graph():
|
||||||
|
dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k")
|
||||||
|
|
||||||
|
graph = defaultdict(set)
|
||||||
|
titles = set()
|
||||||
|
|
||||||
|
# Build the graph
|
||||||
|
for example in dataset["train"]:
|
||||||
|
title = example["article"]
|
||||||
|
links = example["links"]
|
||||||
|
|
||||||
|
titles.add(title)
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
graph[title].add(link)
|
||||||
|
|
||||||
|
# Note: Since titles was a set, and hash are naturally unstable
|
||||||
|
# We want to sort it, so that prng.choice() is stable
|
||||||
|
return graph, sorted(list(titles))
|
||||||
|
|
||||||
|
|
||||||
|
class WikiraceDataset(ProceduralDataset):
|
||||||
|
"""Generates Wikirace Game tasks"""
|
||||||
|
|
||||||
|
def __init__(self, config: WikiraceConfig):
|
||||||
|
self.wikigraph, self.wikititles = load_wiki_graph()
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
# We'll be computing a lot of shortest_path of very similar paths
|
||||||
|
# So cache it
|
||||||
|
@lru_cache(maxsize=128 * 1024)
|
||||||
|
def shortest_path(self, source, target):
|
||||||
|
if source not in self.wikigraph or target not in self.wikigraph:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if source == target:
|
||||||
|
return 1, [source]
|
||||||
|
|
||||||
|
visited = {source}
|
||||||
|
queue = deque([(source, [source], 0)])
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
current_node, path, l = queue.popleft()
|
||||||
|
for neighbor in self.wikigraph[current_node]:
|
||||||
|
if neighbor == target:
|
||||||
|
return 1 + l, (path + [neighbor])
|
||||||
|
if neighbor not in visited:
|
||||||
|
visited.add(neighbor)
|
||||||
|
queue.append((neighbor, path + [neighbor], 1 + l))
|
||||||
|
|
||||||
|
return None # No path found
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single Wikirace Game task
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the task description with a source article, target article, and current chosen path
|
||||||
|
- answer: str, one possible article on the shortest path
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Find a task that suits our min_distance/max_distance
|
||||||
|
# Since some pages might be dead-ends, we might need to try multiple times
|
||||||
|
for _ in range(self.config.max_tries):
|
||||||
|
source = rng.choice(self.wikititles)
|
||||||
|
target = source
|
||||||
|
chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance)
|
||||||
|
path = [source]
|
||||||
|
length = 0
|
||||||
|
while self.shortest_path(source, target)[0] != chosen_distance:
|
||||||
|
possibilities = self.wikigraph[target] - set(path)
|
||||||
|
if not possibilities:
|
||||||
|
break
|
||||||
|
# Since hash() is random, we need to sort the set into a list
|
||||||
|
# for prng stability
|
||||||
|
possibilities = sorted(list(possibilities))
|
||||||
|
target = rng.choice(possibilities)
|
||||||
|
length += 1
|
||||||
|
# Are we lost? Are we looping? Aborting
|
||||||
|
if length > 12:
|
||||||
|
break
|
||||||
|
path.append(target)
|
||||||
|
if self.shortest_path(source, target)[0] == chosen_distance:
|
||||||
|
break
|
||||||
|
# We got lost in a loop or a dead end, try again
|
||||||
|
|
||||||
|
if self.shortest_path(source, target)[0] != chosen_distance:
|
||||||
|
raise Exception(f"After {self.config.max_tries}, we failed to find a suitable wikipedia articles pair")
|
||||||
|
|
||||||
|
_, path = self.shortest_path(source, target)
|
||||||
|
# This is the length of the current path (let's call it state)
|
||||||
|
# 0 mean that we are still at the source of the path we're searching for
|
||||||
|
path_len = rng.randint(0, min(self.config.min_distance, len(path)) - 2)
|
||||||
|
given_path = path[:path_len]
|
||||||
|
given_path = " => ".join(given_path)
|
||||||
|
current = path[path_len]
|
||||||
|
# Stable links
|
||||||
|
links = sorted(list(self.wikigraph[current]))
|
||||||
|
links = list(enumerate(links))
|
||||||
|
question = QUESTION_FORMAT_TEMPLATE.format(
|
||||||
|
current=current,
|
||||||
|
target=target,
|
||||||
|
formatted_links=[f"{x[0]} - {x[1]}\n" for x in links],
|
||||||
|
formatted_path=given_path,
|
||||||
|
)
|
||||||
|
answer = [x[0] for x in links if x[1] == path[path_len + 1]][0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": question,
|
||||||
|
"answer": str(answer),
|
||||||
|
"metadata": {
|
||||||
|
"source_dataset": DATASET_NAME,
|
||||||
|
"source_index": idx,
|
||||||
|
"source": source,
|
||||||
|
"current": current,
|
||||||
|
"target": target,
|
||||||
|
"distance": chosen_distance,
|
||||||
|
"path": given_path,
|
||||||
|
"remaining_path": path[path_len:],
|
||||||
|
"links": links,
|
||||||
|
"difficulty": {
|
||||||
|
"distance": (self.config.min_distance, self.config.max_distance),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
"""Determine if the solution provided solves the problem"""
|
||||||
|
reward = 0.01 # Default reward
|
||||||
|
source = entry["metadata"]["source"]
|
||||||
|
target = entry["metadata"]["target"]
|
||||||
|
current = entry["metadata"]["current"]
|
||||||
|
links = entry["metadata"]["links"]
|
||||||
|
|
||||||
|
if answer is None or not answer.strip():
|
||||||
|
return reward
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = answer.strip()
|
||||||
|
answer = int(answer)
|
||||||
|
if answer < 0:
|
||||||
|
return 0.01
|
||||||
|
link = links[answer][1]
|
||||||
|
new_distance = self.shortest_path(link, target)[0]
|
||||||
|
old_distance = self.shortest_path(current, target)[0]
|
||||||
|
if new_distance < old_distance:
|
||||||
|
# Path is shortet than before, it is following (a) shortest path!
|
||||||
|
return 1.0
|
||||||
|
elif new_distance == old_distance:
|
||||||
|
# Path isn't shorter, but not longer either, that's still something
|
||||||
|
return 0.5
|
||||||
|
else:
|
||||||
|
# At least answer is valid...
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class WikiraceCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(WikiraceCurriculum.__name__, WikiraceConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="distance",
|
||||||
|
levels=[3, 6, 9, 12, 15],
|
||||||
|
description="Number of source numbers",
|
||||||
|
lower_field_name="min_distance",
|
||||||
|
upper_field_name="max_distance",
|
||||||
|
ensure_interval=True,
|
||||||
|
),
|
||||||
|
ScalarAttributeDefinition(
|
||||||
|
name="max_tries",
|
||||||
|
description="Max number of tries to find test cases",
|
||||||
|
field_name="max_tries",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Register the dataset
|
||||||
|
register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum)
|
||||||
1
requirements-optional.txt
Normal file
1
requirements-optional.txt
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
datasets>=3.6.0
|
||||||
100
tests/test_wikirace.py
Normal file
100
tests/test_wikirace.py
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.games.wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_wikirace_game_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = WikiraceConfig(min_distance=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = WikiraceConfig(min_distance=3, max_distance=2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = WikiraceConfig(max_tries=-2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wikirace_game_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config1 = WikiraceConfig(seed=42, size=2)
|
||||||
|
dataset1 = WikiraceDataset(config1)
|
||||||
|
config2 = WikiraceConfig(seed=42, size=2)
|
||||||
|
dataset2 = WikiraceDataset(config2)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_wikirace_game_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = WikiraceConfig(
|
||||||
|
seed=42,
|
||||||
|
size=2,
|
||||||
|
)
|
||||||
|
dataset = WikiraceDataset(config)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata contains required fields
|
||||||
|
assert "source" in item["metadata"]
|
||||||
|
assert "links" in item["metadata"]
|
||||||
|
assert "target" in item["metadata"]
|
||||||
|
assert "current" in item["metadata"]
|
||||||
|
assert "distance" in item["metadata"]
|
||||||
|
|
||||||
|
# Verify number of source numbers is within config range
|
||||||
|
assert config.min_distance <= item["metadata"]["distance"] <= config.max_distance
|
||||||
|
|
||||||
|
# A non-int answer fails
|
||||||
|
assert dataset.score_answer(answer="nope", entry=item) == 0.01
|
||||||
|
|
||||||
|
# A negative answer fails
|
||||||
|
assert dataset.score_answer(answer="-1", entry=item) == 0.01
|
||||||
|
|
||||||
|
# An out of bond answer fails
|
||||||
|
assert dataset.score_answer(answer=str(len(item["metadata"]["links"])), entry=item) == 0.01
|
||||||
|
|
||||||
|
# A parsable answer gives at least 0.1
|
||||||
|
assert dataset.score_answer(answer="0", entry=item) >= 0.1
|
||||||
|
|
||||||
|
# The expected answer gives 1.0
|
||||||
|
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_wikirace_game_single():
|
||||||
|
"""Test a known item"""
|
||||||
|
config = WikiraceConfig(
|
||||||
|
seed=42,
|
||||||
|
size=1,
|
||||||
|
)
|
||||||
|
dataset = WikiraceDataset(config)
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
# If those asserts fails, it probably just means you changed the generation algorithm, which is fine
|
||||||
|
# you'll have have to update this test
|
||||||
|
assert item["metadata"]["source"] == "Vadim Bakatin"
|
||||||
|
assert item["metadata"]["target"] == "Azerbaijan Technological University"
|
||||||
|
assert item["metadata"]["distance"] == 3
|
||||||
|
assert len(item["metadata"]["path"]) == 0
|
||||||
|
|
||||||
|
# If those asserts fails, it is most likely an actual error
|
||||||
|
|
||||||
|
# Only valid answer is 4 - Moscow
|
||||||
|
assert dataset.score_answer(answer="4", entry=item) == 1.0
|
||||||
|
# Selecting 8 - Russians makes you go further away from the target
|
||||||
|
assert dataset.score_answer(answer="2", entry=item) == 0.1
|
||||||
|
# Selecting 0 - Commmunist Party of the Soviet Union doesn't get you further away, but it doesn't get you closer either
|
||||||
|
assert dataset.score_answer(answer="2", entry=item) == 0.1
|
||||||
|
|
||||||
|
# Use this to check the results if you need to update this test
|
||||||
|
# (with pytest -s)
|
||||||
|
# for (i,_) in item['metadata']['links']:
|
||||||
|
# print(i, dataset.score_answer(answer=str(i), entry=item))
|
||||||
Loading…
Add table
Add a link
Reference in a new issue