mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
use native types List->list, Dict->dict, Set->set, Tuple->tuple
This commit is contained in:
parent
5d02064b5a
commit
3e7ff3b084
95 changed files with 754 additions and 760 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -129,14 +129,14 @@ Return the final state of the program.
|
|||
"metadata": {},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the AB task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ class BaseConversionDataset(ProceduralDataset):
|
|||
else:
|
||||
return f"base-{base}"
|
||||
|
||||
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
|
||||
def _generate_conversion(self, rng: Random) -> tuple[int, int, int]:
|
||||
"""Generate random value and source/target bases"""
|
||||
value = rng.randint(self.config.min_value, self.config.max_value)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ https://leetcode.com/problems/01-matrix/description/
|
|||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ class BinaryMatrixDataset(ProceduralDataset):
|
|||
"""Get a string representation of the matrix"""
|
||||
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -211,14 +211,14 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the Cryptarithm task.
|
||||
|
||||
The function awards 1.0 for a correct format and answers for all alphabet pairs.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer already parsed by `extract_answer`
|
||||
answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
|
||||
answer_str (dict[str, Any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import cellpylib as cpl
|
||||
|
||||
|
|
@ -86,14 +86,14 @@ class GameOfLifeDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the GoL task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -209,14 +209,14 @@ Return your solution as a JSON map of vertices to colors. (For example: {{0: 1,
|
|||
"metadata": {"possible_answer": solution, "puzzle": puzzle},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the GraphColor task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import json
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..data import get_data_file_path
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -88,7 +88,7 @@ class GroupAnagramsDataset(ProceduralDataset):
|
|||
anagrams = list(res.values())
|
||||
return self._sort_nested_list(anagrams)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score a single Group Anagrams question"""
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from reasoning_gym.data import read_data_file
|
||||
|
||||
|
|
@ -123,14 +123,14 @@ class LetterJumbleDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves this task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ class NumberFilteringDataset(ProceduralDataset):
|
|||
"""Format a number with specified decimal places"""
|
||||
return f"{num:.{decimals}f}"
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
|
||||
def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]:
|
||||
"""Generate list of numbers and their string representations"""
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
numbers = []
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ Please follow the instruction below:
|
|||
# Reparse to ensure exact decimal representation
|
||||
return f"{float(formatted):.{decimals}f}"
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
|
||||
def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]:
|
||||
"""Generate list of numbers and their string representations"""
|
||||
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
|
||||
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class PalindromeDataset(ProceduralDataset):
|
|||
"""Return the palindrome string from the letter set."""
|
||||
return "".join(letters)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided is a valid palindrome.
|
||||
The answer is expected to be a single string
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import json
|
|||
import string
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ class PalindromePartitioningDataset(ProceduralDataset):
|
|||
_partition(0)
|
||||
return self._sort_list(res)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score a single Palindrome Partitioning question"""
|
||||
if answer is not None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
"""Perform average / max pooling on a matrix"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -95,7 +94,7 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
]
|
||||
)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Score the answer based on the metadata"""
|
||||
|
||||
reward = 0.0
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ https://leetcode.com/problems/ransom-note/description/
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -95,14 +95,14 @@ class RansomNoteDataset(ProceduralDataset):
|
|||
"metadata": {"ransom_note": ransom_note, "magazine": magazine, "solution": answer, "solvable": solvable},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves this task.
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
entry (dict[str, Any]): The original dataset entry containing the correct answer.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class SentenceReorderingDataset(ProceduralDataset):
|
|||
"metadata": {"word_count": word_count},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
expected_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class SpellBackwardDataset(ProceduralDataset):
|
|||
"metadata": {"word": word, "word_len": len(word)},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
expected_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ https://leetcode.com/problems/spiral-matrix/description/
|
|||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ class SpiralMatrixDataset(ProceduralDataset):
|
|||
"metadata": {"matrix": matrix, "solution": answer},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"].strip()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dat
|
|||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ class StringInsertionDataset(ProceduralDataset):
|
|||
i += 1
|
||||
return "".join(output)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..data import get_data_file_path
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -82,7 +82,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
@classmethod
|
||||
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]:
|
||||
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> dict[int, set[str]]:
|
||||
"""Load words from CSV file organized by length"""
|
||||
# Validate length range before processing
|
||||
assert 3 <= min_length <= max_length <= 5, "Word length must be between 3 and 5 inclusive"
|
||||
|
|
@ -117,7 +117,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
|
||||
return word_sets
|
||||
|
||||
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
|
||||
def _get_neighbors(self, word: str, word_set: set[str]) -> set[str]:
|
||||
"""Get neighbors from either precomputed graph or by computing on demand"""
|
||||
# Try precomputed graph first
|
||||
if len(word) in self.word_graphs and word in self.word_graphs[len(word)]:
|
||||
|
|
@ -132,7 +132,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
neighbors.add(neighbor)
|
||||
return neighbors
|
||||
|
||||
def _build_word_graph(self, word_length: int) -> Dict[str, Set[str]]:
|
||||
def _build_word_graph(self, word_length: int) -> dict[str, set[str]]:
|
||||
"""Build graph of word connections for given length, using caching"""
|
||||
# Return cached graph if it exists
|
||||
if word_length in self.word_graphs:
|
||||
|
|
@ -156,7 +156,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
self.word_graphs[word_length] = graph
|
||||
return self.word_graphs[word_length]
|
||||
|
||||
def _find_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
|
||||
def _find_path(self, start: str, end: str, word_set: set[str]) -> Optional[list[str]]:
|
||||
"""Simplified path finding using BFS for shortest paths"""
|
||||
# Early exit if words are direct neighbors
|
||||
if end in self._get_neighbors(start, word_set):
|
||||
|
|
@ -181,7 +181,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
|
||||
return None
|
||||
|
||||
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
|
||||
def _generate_word_pair(self, rng: Random, length: int) -> tuple[str, str, list[str]]:
|
||||
"""Simplified word pair generation"""
|
||||
word_set = self.word_sets[length]
|
||||
words_list = sorted(word_set)
|
||||
|
|
@ -220,7 +220,7 @@ class WordLadderDataset(ProceduralDataset):
|
|||
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
if answer is None:
|
||||
return 0
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from random import Random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..data import read_data_file
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
|
@ -84,7 +84,7 @@ class WordSortingDataset(ProceduralDataset):
|
|||
return "".join(c.upper() if rng.choice([True, False]) else c.lower() for c in word)
|
||||
return word # ORIGINAL case
|
||||
|
||||
def _generate_words(self, rng: Random) -> Tuple[List[str], List[str]]:
|
||||
def _generate_words(self, rng: Random) -> tuple[list[str], list[str]]:
|
||||
"""Generate list of words and their transformed versions"""
|
||||
count = rng.randint(self.config.min_words, self.config.max_words)
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ class WordSortingDataset(ProceduralDataset):
|
|||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
oracle_answer = entry["metadata"]["sorted_words"]
|
||||
if answer is not None and len(answer) > 0:
|
||||
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue