more native type hints

This commit is contained in:
Andreas Koepf 2025-02-21 21:23:14 +01:00
parent ae26704d05
commit 74f590e24f
19 changed files with 90 additions and 92 deletions

View file

@ -1,7 +1,6 @@
import re
from dataclasses import dataclass
from random import Random
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from ..factory import ProceduralDataset, register_dataset
@ -20,7 +19,7 @@ class NeedleHaystackConfig:
assert self.num_statements < 168387000, f"num_statements must be less than {168387000}"
def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[str], n: int, rng) -> Dict[str, Any]:
def generate_unique_triplets(names: list[str], verbs: list[str], subjects: list[str], n: int, rng) -> dict[str, Any]:
"""
Generate n unique random triplets (name, verb, subject) without generating the full Cartesian product in memory.
@ -29,14 +28,14 @@ def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[
randomly chosen as the 'needle'.
Args:
names (List[str]): List of names.
verbs (List[str]): List of verbs.
subjects (List[str]): List of subjects.
names (list[str]): List of names.
verbs (list[str]): List of verbs.
subjects (list[str]): List of subjects.
n (int): Number of unique triplets to generate.
rng (random.Random): A pre-seeded random number generator.
Returns:
Dict[str, Any]: A dictionary with:
dict[str, Any]: A dictionary with:
- "triplets": a list of n unique triplets (tuples of (name, verb, subject)),
- "needle": one triplet randomly chosen from the list.
@ -47,7 +46,7 @@ def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[
# Use a range for memory efficiency and sample n unique indices.
indices = rng.sample(range(total_possible), n)
triplets: List[Tuple[str, str, str]] = []
triplets: list[tuple[str, str, str]] = []
num_verbs = len(verbs)
num_subjects = len(subjects)
@ -101,12 +100,12 @@ class NeedleHaystackDataset(ProceduralDataset):
"metadata": {"question": question},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the task.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
entry (dict[str, Any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.