diff --git a/reasoning_gym/games/wikirace.py b/reasoning_gym/games/wikirace.py index 357a6b18..45c787ec 100644 --- a/reasoning_gym/games/wikirace.py +++ b/reasoning_gym/games/wikirace.py @@ -1,18 +1,18 @@ import re -from collections import deque, defaultdict +from collections import defaultdict, deque from dataclasses import dataclass from functools import lru_cache from random import Random from typing import Any, Optional -from datasets import load_dataset - -import sympy -from sympy import Symbol, symbols -from sympy.parsing.sympy_parser import parse_expr from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset +try: + from datasets import load_dataset +except: + raise Exception("wikirace requires datasets library. Run `pip install datasets`") + QUESTION_FORMAT_TEMPLATE = """ You are playing WikiRace, trying to navigate from one Wikipedia article to another using only links. @@ -39,6 +39,7 @@ class WikiraceConfig: min_distance: int = 3 max_distance: int = 6 + max_tries: int = 100 seed: Optional[int] = None size: int = 500 @@ -47,6 +48,9 @@ class WikiraceConfig: assert self.min_distance > 1, "min_distance must be greater than 1" assert self.max_distance >= self.min_distance, "max_distance must be >= min_distance" + assert self.max_tries >= 1, "max_tries must be greater than 1" + + def load_wiki_graph(): dataset = load_dataset("HuggingFaceTB/simplewiki-pruned-350k") @@ -54,9 +58,9 @@ def load_wiki_graph(): titles = set() # Build the graph - for example in dataset['train']: - title = example['article'] - links = example['links'] + for example in dataset["train"]: + title = example["article"] + links = example["links"] titles.add(title) @@ -67,6 +71,7 @@ def load_wiki_graph(): # We want to sort it, so that prng.choice() is stable return graph, sorted(list(titles)) + class WikiraceDataset(ProceduralDataset): """Generates Wikirace Game tasks""" @@ -76,7 +81,7 @@ class WikiraceDataset(ProceduralDataset): # We'll be computing a lot of shortest_path of very similar paths # So cache it - @lru_cache(maxsize=128*1024) + @lru_cache(maxsize=128 * 1024) def shortest_path(self, source, target): if source not in self.wikigraph or target not in self.wikigraph: return None @@ -94,7 +99,7 @@ class WikiraceDataset(ProceduralDataset): return 1 + l, (path + [neighbor]) if neighbor not in visited: visited.add(neighbor) - queue.append((neighbor, path + [neighbor], 1+l)) + queue.append((neighbor, path + [neighbor], 1 + l)) return None # No path found @@ -103,15 +108,15 @@ class WikiraceDataset(ProceduralDataset): Returns: dict with keys: - - question: str, the task description with numbers and target - - answer: str, one possible solution expression + - question: str, the task description with a source article, target article, and current chosen path + - answer: str, one possible article on the shortest path - metadata: dict with generation parameters """ rng = Random(self.seed + idx) # Find a task that suits our min_distance/max_distance # Since some pages might be dead-ends, we might need to try multiple times - while True: + for _ in range(self.config.max_tries): source = rng.choice(self.wikititles) target = source chosen_distance = rng.randint(self.config.min_distance, self.config.max_distance) @@ -132,9 +137,11 @@ class WikiraceDataset(ProceduralDataset): path.append(target) if self.shortest_path(source, target)[0] == chosen_distance: break - else: - # We got lost in a loop or a dead end, try again - pass + # We got lost in a loop or a dead end, try again + + if self.shortest_path(source, target)[0] != chosen_distance: + raise Exception(f"After {self.config.max_tries}, we failed to find a suitable wikipedia articles pair") + _, path = self.shortest_path(source, target) # This is the length of the current path (let's call it state) # 0 mean that we are still at the source of the path we're searching for @@ -146,22 +153,23 @@ class WikiraceDataset(ProceduralDataset): links = sorted(list(self.wikigraph[current])) links = list(enumerate(links)) question = QUESTION_FORMAT_TEMPLATE.format( - current = current, - target = target, - formatted_links = [f"{x[0]} - {x[1]}\n" for x in links], - formatted_path = given_path, + current=current, + target=target, + formatted_links=[f"{x[0]} - {x[1]}\n" for x in links], + formatted_path=given_path, ) - answer = [x[0] for x in links if x[1] == path[path_len+1]][0] + answer = [x[0] for x in links if x[1] == path[path_len + 1]][0] - v= { + return { "question": question, - "answer": answer, + "answer": str(answer), "metadata": { "source_dataset": DATASET_NAME, "source_index": idx, "source": source, "current": current, "target": target, + "distance": chosen_distance, "path": given_path, "remaining_path": path[path_len:], "links": links, @@ -170,15 +178,14 @@ class WikiraceDataset(ProceduralDataset): }, }, } - return v def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Determine if the solution provided solves the problem""" reward = 0.01 # Default reward - source = entry['metadata']['source'] - target = entry['metadata']['target'] - current = entry['metadata']['current'] - links = entry['metadata']['links'] + source = entry["metadata"]["source"] + target = entry["metadata"]["target"] + current = entry["metadata"]["current"] + links = entry["metadata"]["links"] if answer is None or not answer.strip(): return reward @@ -186,9 +193,11 @@ class WikiraceDataset(ProceduralDataset): try: answer = answer.strip() answer = int(answer) + if answer < 0: + return 0.01 link = links[answer][1] - new_distance = self.shortest_path(link, target)[1] - old_distance = self.shortest_path(current, target)[1] + new_distance = self.shortest_path(link, target)[0] + old_distance = self.shortest_path(current, target)[0] if new_distance < old_distance: # Path is shortet than before, it is following (a) shortest path! return 1.0 @@ -217,7 +226,13 @@ class WikiraceCurriculum(BaseCurriculum): upper_field_name="max_distance", ensure_interval=True, ), + ScalarAttributeDefinition( + name="max_tries", + description="Max number of tries to find test cases", + field_name="max_tries", + ), ) + # Register the dataset register_dataset(DATASET_NAME, WikiraceDataset, WikiraceConfig, WikiraceCurriculum) diff --git a/requirements-optional.txt b/requirements-optional.txt new file mode 100644 index 00000000..18354c69 --- /dev/null +++ b/requirements-optional.txt @@ -0,0 +1 @@ +datasets>=3.6.0 diff --git a/tests/test_wikirace.py b/tests/test_wikirace.py new file mode 100644 index 00000000..8ca58979 --- /dev/null +++ b/tests/test_wikirace.py @@ -0,0 +1,100 @@ +import pytest + +from reasoning_gym.games.wikirace import WikiraceConfig, WikiraceCurriculum, WikiraceDataset + + +def test_wikirace_game_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = WikiraceConfig(min_distance=0) + config.validate() + + with pytest.raises(AssertionError): + config = WikiraceConfig(min_distance=3, max_distance=2) + config.validate() + + with pytest.raises(AssertionError): + config = WikiraceConfig(max_tries=-2) + config.validate() + + +def test_wikirace_game_deterministic(): + """Test that dataset generates same items with same seed""" + config1 = WikiraceConfig(seed=42, size=2) + dataset1 = WikiraceDataset(config1) + config2 = WikiraceConfig(seed=42, size=2) + dataset2 = WikiraceDataset(config2) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_wikirace_game_items(): + """Test basic properties of generated items""" + config = WikiraceConfig( + seed=42, + size=2, + ) + dataset = WikiraceDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contains required fields + assert "source" in item["metadata"] + assert "links" in item["metadata"] + assert "target" in item["metadata"] + assert "current" in item["metadata"] + assert "distance" in item["metadata"] + + # Verify number of source numbers is within config range + assert config.min_distance <= item["metadata"]["distance"] <= config.max_distance + + # A non-int answer fails + assert dataset.score_answer(answer="nope", entry=item) == 0.01 + + # A negative answer fails + assert dataset.score_answer(answer="-1", entry=item) == 0.01 + + # An out of bond answer fails + assert dataset.score_answer(answer=str(len(item["metadata"]["links"])), entry=item) == 0.01 + + # A parsable answer gives at least 0.1 + assert dataset.score_answer(answer="0", entry=item) >= 0.1 + + # The expected answer gives 1.0 + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + + +def test_wikirace_game_single(): + """Test a known item""" + config = WikiraceConfig( + seed=42, + size=1, + ) + dataset = WikiraceDataset(config) + item = dataset[0] + + # If those asserts fails, it probably just means you changed the generation algorithm, which is fine + # you'll have have to update this test + assert item["metadata"]["source"] == "Vadim Bakatin" + assert item["metadata"]["target"] == "Azerbaijan Technological University" + assert item["metadata"]["distance"] == 3 + assert len(item["metadata"]["path"]) == 0 + + # If those asserts fails, it is most likely an actual error + + # Only valid answer is 4 - Moscow + assert dataset.score_answer(answer="4", entry=item) == 1.0 + # Selecting 8 - Russians makes you go further away from the target + assert dataset.score_answer(answer="2", entry=item) == 0.1 + # Selecting 0 - Commmunist Party of the Soviet Union doesn't get you further away, but it doesn't get you closer either + assert dataset.score_answer(answer="2", entry=item) == 0.1 + + # Use this to check the results if you need to update this test + # (with pytest -s) + # for (i,_) in item['metadata']['links']: + # print(i, dataset.score_answer(answer=str(i), entry=item))