wiggle imports

This commit is contained in:
Rich Jones 2025-02-20 16:23:40 +01:00
parent 39b739917e
commit 2188c53308

View file

@ -4,7 +4,6 @@ from random import Random
from typing import Any, Dict, List, Optional
from ..factory import ProceduralDataset, register_dataset
from .needle_data import NAMES, SUBJECTS, VERBS
@dataclass
@ -18,9 +17,7 @@ class NeedleHaystackConfig:
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.num_statements > 0, "num_statements must be greater than 0"
assert self.num_statements < len(NAMES) * len(VERBS) * len(
SUBJECTS
), f"num_statements must be less than {len(NAMES) * len(VERBS) * len(SUBJECTS)}"
assert self.num_statements < 168387000, f"num_statements must be less than {168387000}"
def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[str], n: int, rng) -> Dict[str, Any]:
@ -85,6 +82,8 @@ class NeedleHaystackDataset(ProceduralDataset):
- answer: None, indicating to use the dynamic evaluator
- metadata: dict with generation parameters and example solution
"""
from .needle_data import NAMES, SUBJECTS, VERBS
rng = Random(self.seed + idx)
stack = generate_unique_triplets(NAMES, VERBS, SUBJECTS, self.config.num_statements, rng)