mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge fdb93a3d7d into 49b07130b3
This commit is contained in:
commit
72999eda51
4 changed files with 343 additions and 0 deletions
|
|
@ -23,6 +23,7 @@ from .sudoku import SudokuConfig, SudokuCurriculum, SudokuDataset
|
|||
from .survo import SurvoConfig, SurvoCurriculum, SurvoDataset
|
||||
from .tower_of_hanoi import HanoiConfig, HanoiCurriculum, HanoiDataset
|
||||
from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset
|
||||
from .wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset
|
||||
|
||||
__all__ = [
|
||||
"BoxnetConfig",
|
||||
|
|
@ -76,4 +77,7 @@ __all__ = [
|
|||
"MahjongPuzzleConfig",
|
||||
"MahjongPuzzleDataset",
|
||||
"MahjongPuzzleCurriculum",
|
||||
"WikiraceConfig",
|
||||
"WikiraceDataset",
|
||||
"WikiraceCurriculum",
|
||||
]
|
||||
|
|
|
|||
238
reasoning_gym/games/wikirace.py
Normal file
238
reasoning_gym/games/wikirace.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
import re
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except:
|
||||
raise Exception("wikirace requires datasets library. Run `pip install datasets`")
|
||||
|
||||
QUESTION_FORMAT_TEMPLATE = """
|
||||
You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links.
|
||||
|
||||
Answer with just the link number.
|
||||
|
||||
Current article: {current}
|
||||
Target article: {target}
|
||||
Available links (numbered):
|
||||
{formatted_links}
|
||||
|
||||
Your path so far: {formatted_path}
|
||||
|
||||
Think about which link is most likely to lead you toward the target article.
|
||||
First, analyze each link briefly and how it connects to your goal, then select the most promising one.
|
||||
"""
|
||||
|
||||
|
||||
DATASET_NAME = "wikirace"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WikiraceConfig:
|
||||
"""Configuration for WikiRace task generation"""
|
||||
|
||||
min_distance: int = 3
|
||||
max_distance: int = 6
|
||||
max_tries: int = 100
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_distance > 1, "min_distance must be greater than 1"
|
||||
assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance"
|
||||
|
||||
assert self.max_tries >= 1, "max_tries must be greater than 1"
|
||||
|
||||
|
||||
def load_wiki_graph():
|
||||
dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k")
|
||||
|
||||
graph = defaultdict(set)
|
||||
titles = set()
|
||||
|
||||
# Build the graph
|
||||
for example in dataset["train"]:
|
||||
title = example["article"]
|
||||
links = example["links"]
|
||||
|
||||
titles.add(title)
|
||||
|
||||
for link in links:
|
||||
graph[title].add(link)
|
||||
|
||||
# Note: Since titles was a set, and hash are naturally unstable
|
||||
# We want to sort it, so that prng.choice() is stable
|
||||
return graph, sorted(list(titles))
|
||||
|
||||
|
||||
class WikiraceDataset(ProceduralDataset):
|
||||
"""Generates Wikirace Game tasks"""
|
||||
|
||||
def __init__(self, config: WikiraceConfig):
|
||||
self.wikigraph, self.wikititles = load_wiki_graph()
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# We'll be computing a lot of shortest_path of very similar paths
|
||||
# So cache it
|
||||
@lru_cache(maxsize=128 * 1024)
|
||||
def shortest_path(self, source, target):
|
||||
if source not in self.wikigraph or target not in self.wikigraph:
|
||||
return None
|
||||
|
||||
if source == target:
|
||||
return 1, [source]
|
||||
|
||||
visited = {source}
|
||||
queue = deque([(source, [source], 0)])
|
||||
|
||||
while queue:
|
||||
current_node, path, l = queue.popleft()
|
||||
for neighbor in self.wikigraph[current_node]:
|
||||
if neighbor == target:
|
||||
return 1 + l, (path + [neighbor])
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor], 1 + l))
|
||||
|
||||
return None # No path found
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Wikirace Game task
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the task description with a source article, target article, and current chosen path
|
||||
- answer: str, one possible article on the shortest path
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
# Find a task that suits our min_distance/max_distance
|
||||
# Since some pages might be dead-ends, we might need to try multiple times
|
||||
for _ in range(self.config.max_tries):
|
||||
source = rng.choice(self.wikititles)
|
||||
target = source
|
||||
chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance)
|
||||
path = [source]
|
||||
length = 0
|
||||
while self.shortest_path(source, target)[0] != chosen_distance:
|
||||
possibilities = self.wikigraph[target] - set(path)
|
||||
if not possibilities:
|
||||
break
|
||||
# Since hash() is random, we need to sort the set into a list
|
||||
# for prng stability
|
||||
possibilities = sorted(list(possibilities))
|
||||
target = rng.choice(possibilities)
|
||||
length += 1
|
||||
# Are we lost? Are we looping? Aborting
|
||||
if length > 12:
|
||||
break
|
||||
path.append(target)
|
||||
if self.shortest_path(source, target)[0] == chosen_distance:
|
||||
break
|
||||
# We got lost in a loop or a dead end, try again
|
||||
|
||||
if self.shortest_path(source, target)[0] != chosen_distance:
|
||||
raise Exception(f"After {self.config.max_tries}, we failed to find a suitable wikipedia articles pair")
|
||||
|
||||
_, path = self.shortest_path(source, target)
|
||||
# This is the length of the current path (let's call it state)
|
||||
# 0 mean that we are still at the source of the path we're searching for
|
||||
path_len = rng.randint(0, min(self.config.min_distance, len(path)) - 2)
|
||||
given_path = path[:path_len]
|
||||
given_path = " => ".join(given_path)
|
||||
current = path[path_len]
|
||||
# Stable links
|
||||
links = sorted(list(self.wikigraph[current]))
|
||||
links = list(enumerate(links))
|
||||
question = QUESTION_FORMAT_TEMPLATE.format(
|
||||
current=current,
|
||||
target=target,
|
||||
formatted_links=[f"{x[0]} - {x[1]}\n" for x in links],
|
||||
formatted_path=given_path,
|
||||
)
|
||||
answer = [x[0] for x in links if x[1] == path[path_len + 1]][0]
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"source": source,
|
||||
"current": current,
|
||||
"target": target,
|
||||
"distance": chosen_distance,
|
||||
"path": given_path,
|
||||
"remaining_path": path[path_len:],
|
||||
"links": links,
|
||||
"difficulty": {
|
||||
"distance": (self.config.min_distance, self.config.max_distance),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the problem"""
|
||||
reward = 0.01 # Default reward
|
||||
source = entry["metadata"]["source"]
|
||||
target = entry["metadata"]["target"]
|
||||
current = entry["metadata"]["current"]
|
||||
links = entry["metadata"]["links"]
|
||||
|
||||
if answer is None or not answer.strip():
|
||||
return reward
|
||||
|
||||
try:
|
||||
answer = answer.strip()
|
||||
answer = int(answer)
|
||||
if answer < 0:
|
||||
return 0.01
|
||||
link = links[answer][1]
|
||||
new_distance = self.shortest_path(link, target)[0]
|
||||
old_distance = self.shortest_path(current, target)[0]
|
||||
if new_distance < old_distance:
|
||||
# Path is shortet than before, it is following (a) shortest path!
|
||||
return 1.0
|
||||
elif new_distance == old_distance:
|
||||
# Path isn't shorter, but not longer either, that's still something
|
||||
return 0.5
|
||||
else:
|
||||
# At least answer is valid...
|
||||
return 0.1
|
||||
|
||||
except Exception:
|
||||
return 0.01
|
||||
|
||||
|
||||
class WikiraceCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(WikiraceCurriculum.__name__, WikiraceConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
RangeAttributeDefinition(
|
||||
name="distance",
|
||||
levels=[3, 6, 9, 12, 15],
|
||||
description="Number of source numbers",
|
||||
lower_field_name="min_distance",
|
||||
upper_field_name="max_distance",
|
||||
ensure_interval=True,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_tries",
|
||||
description="Max number of tries to find test cases",
|
||||
field_name="max_tries",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum)
|
||||
1
requirements-optional.txt
Normal file
1
requirements-optional.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
datasets>=3.6.0
|
||||
100
tests/test_wikirace.py
Normal file
100
tests/test_wikirace.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.games.wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset
|
||||
|
||||
|
||||
def test_wikirace_game_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = WikiraceConfig(min_distance=0)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = WikiraceConfig(min_distance=3, max_distance=2)
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = WikiraceConfig(max_tries=-2)
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_wikirace_game_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config1 = WikiraceConfig(seed=42, size=2)
|
||||
dataset1 = WikiraceDataset(config1)
|
||||
config2 = WikiraceConfig(seed=42, size=2)
|
||||
dataset2 = WikiraceDataset(config2)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_wikirace_game_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = WikiraceConfig(
|
||||
seed=42,
|
||||
size=2,
|
||||
)
|
||||
dataset = WikiraceDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata contains required fields
|
||||
assert "source" in item["metadata"]
|
||||
assert "links" in item["metadata"]
|
||||
assert "target" in item["metadata"]
|
||||
assert "current" in item["metadata"]
|
||||
assert "distance" in item["metadata"]
|
||||
|
||||
# Verify number of source numbers is within config range
|
||||
assert config.min_distance <= item["metadata"]["distance"] <= config.max_distance
|
||||
|
||||
# A non-int answer fails
|
||||
assert dataset.score_answer(answer="nope", entry=item) == 0.01
|
||||
|
||||
# A negative answer fails
|
||||
assert dataset.score_answer(answer="-1", entry=item) == 0.01
|
||||
|
||||
# An out of bond answer fails
|
||||
assert dataset.score_answer(answer=str(len(item["metadata"]["links"])), entry=item) == 0.01
|
||||
|
||||
# A parsable answer gives at least 0.1
|
||||
assert dataset.score_answer(answer="0", entry=item) >= 0.1
|
||||
|
||||
# The expected answer gives 1.0
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
|
||||
|
||||
def test_wikirace_game_single():
|
||||
"""Test a known item"""
|
||||
config = WikiraceConfig(
|
||||
seed=42,
|
||||
size=1,
|
||||
)
|
||||
dataset = WikiraceDataset(config)
|
||||
item = dataset[0]
|
||||
|
||||
# If those asserts fails, it probably just means you changed the generation algorithm, which is fine
|
||||
# you'll have have to update this test
|
||||
assert item["metadata"]["source"] == "Vadim Bakatin"
|
||||
assert item["metadata"]["target"] == "Azerbaijan Technological University"
|
||||
assert item["metadata"]["distance"] == 3
|
||||
assert len(item["metadata"]["path"]) == 0
|
||||
|
||||
# If those asserts fails, it is most likely an actual error
|
||||
|
||||
# Only valid answer is 4 - Moscow
|
||||
assert dataset.score_answer(answer="4", entry=item) == 1.0
|
||||
# Selecting 8 - Russians makes you go further away from the target
|
||||
assert dataset.score_answer(answer="2", entry=item) == 0.1
|
||||
# Selecting 0 - Commmunist Party of the Soviet Union doesn't get you further away, but it doesn't get you closer either
|
||||
assert dataset.score_answer(answer="2", entry=item) == 0.1
|
||||
|
||||
# Use this to check the results if you need to update this test
|
||||
# (with pytest -s)
|
||||
# for (i,_) in item['metadata']['links']:
|
||||
# print(i, dataset.score_answer(answer=str(i), entry=item))
|
||||
Loading…
Add table
Add a link
Reference in a new issue