diff --git a/reasoning_gym/cognition/needle_haystack.py b/reasoning_gym/cognition/needle_haystack.py index 02596104..93cf1cb8 100644 --- a/reasoning_gym/cognition/needle_haystack.py +++ b/reasoning_gym/cognition/needle_haystack.py @@ -18,6 +18,9 @@ class NeedleHaystackConfig: def validate(self) -> None: """Validate configuration parameters""" assert self.num_statements > 0, "num_statements must be greater than 0" + assert self.num_statements < len(NAMES) * len(VERBS) * len( + SUBJECTS + ), f"num_statements must be less than {len(NAMES) * len(VERBS) * len(SUBJECTS)}" def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[str], n: int, rng) -> Dict[str, Any]: @@ -44,8 +47,6 @@ def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[ ValueError: If n exceeds the total number of unique triplets possible. """ total_possible = len(names) * len(verbs) * len(subjects) - if n > total_possible: - raise ValueError("Requested n exceeds the total number of unique combinations.") # Use a range for memory efficiency and sample n unique indices. indices = rng.sample(range(total_possible), n)