use native types List->list, Dict->dict, Set->set, Tuple->tuple

This commit is contained in:
Andreas Koepf 2025-02-21 15:13:19 +01:00
parent 5d02064b5a
commit 3e7ff3b084
95 changed files with 754 additions and 760 deletions

View file

@ -1,6 +1,6 @@
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -129,14 +129,14 @@ Return the final state of the program.
"metadata": {},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the AB task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -2,7 +2,7 @@
from dataclasses import dataclass
from random import Random
from typing import Optional, Tuple
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
@ -61,7 +61,7 @@ class BaseConversionDataset(ProceduralDataset):
else:
return f"base-{base}"
def _generate_conversion(self, rng: Random) -> Tuple[int, int, int]:
def _generate_conversion(self, rng: Random) -> tuple[int, int, int]:
"""Generate random value and source/target bases"""
value = rng.randint(self.config.min_value, self.config.max_value)

View file

@ -7,7 +7,7 @@ https://leetcode.com/problems/01-matrix/description/
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -112,7 +112,7 @@ class BinaryMatrixDataset(ProceduralDataset):
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:

View file

@ -211,14 +211,14 @@ class CryptarithmDataset(ProceduralDataset):
},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the Cryptarithm task.
The function awards 1.0 for a correct format and answers for all alphabet pairs.
Args:
answer (Optional[str]): The user's answer already parsed by `extract_answer`
answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
answer_str (dict[str, Any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -1,7 +1,7 @@
import json
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
import cellpylib as cpl
@ -86,14 +86,14 @@ class GameOfLifeDataset(ProceduralDataset):
},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the GoL task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -1,7 +1,7 @@
import json
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -209,14 +209,14 @@ Return your solution as a JSON map of vertices to colors. (For example: {{0: 1,
"metadata": {"possible_answer": solution, "puzzle": puzzle},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the GraphColor task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -10,7 +10,7 @@ import json
from collections import defaultdict
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
@ -88,7 +88,7 @@ class GroupAnagramsDataset(ProceduralDataset):
anagrams = list(res.values())
return self._sort_nested_list(anagrams)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score a single Group Anagrams question"""
reward = 0.0
if answer is not None:

View file

@ -3,7 +3,7 @@
import re
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from reasoning_gym.data import read_data_file
@ -123,14 +123,14 @@ class LetterJumbleDataset(ProceduralDataset):
},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves this task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -2,7 +2,7 @@
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Tuple
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
@ -39,7 +39,7 @@ class NumberFilteringDataset(ProceduralDataset):
"""Format a number with specified decimal places"""
return f"{num:.{decimals}f}"
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]:
"""Generate list of numbers and their string representations"""
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
numbers = []

View file

@ -2,7 +2,7 @@
from dataclasses import dataclass
from random import Random
from typing import List, Optional, Tuple
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
@ -46,7 +46,7 @@ Please follow the instruction below:
# Reparse to ensure exact decimal representation
return f"{float(formatted):.{decimals}f}"
def _generate_numbers(self, rng: Random) -> Tuple[List[float], List[str]]:
def _generate_numbers(self, rng: Random) -> tuple[list[float], list[str]]:
"""Generate list of numbers and their string representations"""
count = rng.randint(self.config.min_numbers, self.config.max_numbers)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)

View file

@ -90,7 +90,7 @@ class PalindromeDataset(ProceduralDataset):
"""Return the palindrome string from the letter set."""
return "".join(letters)
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided is a valid palindrome.
The answer is expected to be a single string

View file

@ -8,7 +8,7 @@ import json
import string
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -93,7 +93,7 @@ class PalindromePartitioningDataset(ProceduralDataset):
_partition(0)
return self._sort_list(res)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score a single Palindrome Partitioning question"""
if answer is not None:
try:

View file

@ -1,9 +1,8 @@
"""Perform average / max pooling on a matrix"""
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
import numpy as np
@ -95,7 +94,7 @@ class PoolMatrixDataset(ProceduralDataset):
]
)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Score the answer based on the metadata"""
reward = 0.0

View file

@ -7,7 +7,7 @@ https://leetcode.com/problems/ransom-note/description/
from collections import defaultdict
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -95,14 +95,14 @@ class RansomNoteDataset(ProceduralDataset):
"metadata": {"ransom_note": ransom_note, "magazine": magazine, "solution": answer, "solvable": solvable},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves this task.
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.

View file

@ -92,7 +92,7 @@ class SentenceReorderingDataset(ProceduralDataset):
"metadata": {"word_count": word_count},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:

View file

@ -49,7 +49,7 @@ class SpellBackwardDataset(ProceduralDataset):
"metadata": {"word": word, "word_len": len(word)},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:

View file

@ -6,7 +6,7 @@ https://leetcode.com/problems/spiral-matrix/description/
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -116,7 +116,7 @@ class SpiralMatrixDataset(ProceduralDataset):
"metadata": {"matrix": matrix, "solution": answer},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()

View file

@ -5,7 +5,7 @@ https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dat
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -80,7 +80,7 @@ class StringInsertionDataset(ProceduralDataset):
i += 1
return "".join(output)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:

View file

@ -3,7 +3,7 @@
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Optional
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
@ -82,7 +82,7 @@ class WordLadderDataset(ProceduralDataset):
super().__init__(config=config, seed=config.seed, size=config.size)
@classmethod
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[int, Set[str]]:
def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> dict[int, set[str]]:
"""Load words from CSV file organized by length"""
# Validate length range before processing
assert 3 <= min_length <= max_length <= 5, "Word length must be between 3 and 5 inclusive"
@ -117,7 +117,7 @@ class WordLadderDataset(ProceduralDataset):
return word_sets
def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
def _get_neighbors(self, word: str, word_set: set[str]) -> set[str]:
"""Get neighbors from either precomputed graph or by computing on demand"""
# Try precomputed graph first
if len(word) in self.word_graphs and word in self.word_graphs[len(word)]:
@ -132,7 +132,7 @@ class WordLadderDataset(ProceduralDataset):
neighbors.add(neighbor)
return neighbors
def _build_word_graph(self, word_length: int) -> Dict[str, Set[str]]:
def _build_word_graph(self, word_length: int) -> dict[str, set[str]]:
"""Build graph of word connections for given length, using caching"""
# Return cached graph if it exists
if word_length in self.word_graphs:
@ -156,7 +156,7 @@ class WordLadderDataset(ProceduralDataset):
self.word_graphs[word_length] = graph
return self.word_graphs[word_length]
def _find_path(self, start: str, end: str, word_set: Set[str]) -> Optional[List[str]]:
def _find_path(self, start: str, end: str, word_set: set[str]) -> Optional[list[str]]:
"""Simplified path finding using BFS for shortest paths"""
# Early exit if words are direct neighbors
if end in self._get_neighbors(start, word_set):
@ -181,7 +181,7 @@ class WordLadderDataset(ProceduralDataset):
return None
def _generate_word_pair(self, rng: Random, length: int) -> Tuple[str, str, List[str]]:
def _generate_word_pair(self, rng: Random, length: int) -> tuple[str, str, list[str]]:
"""Simplified word pair generation"""
word_set = self.word_sets[length]
words_list = sorted(word_set)
@ -220,7 +220,7 @@ class WordLadderDataset(ProceduralDataset):
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None:
return 0

View file

@ -4,7 +4,7 @@ import re
from dataclasses import dataclass
from enum import StrEnum
from random import Random
from typing import Dict, List, Optional, Tuple
from typing import Any, Optional
from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset
@ -84,7 +84,7 @@ class WordSortingDataset(ProceduralDataset):
return "".join(c.upper() if rng.choice([True, False]) else c.lower() for c in word)
return word # ORIGINAL case
def _generate_words(self, rng: Random) -> Tuple[List[str], List[str]]:
def _generate_words(self, rng: Random) -> tuple[list[str], list[str]]:
"""Generate list of words and their transformed versions"""
count = rng.randint(self.config.min_words, self.config.max_words)
@ -122,7 +122,7 @@ class WordSortingDataset(ProceduralDataset):
},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["metadata"]["sorted_words"]
if answer is not None and len(answer) > 0:
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]