mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
move assert to config
This commit is contained in:
parent
621c20d8d8
commit
39b739917e
1 changed files with 3 additions and 2 deletions
|
|
@ -18,6 +18,9 @@ class NeedleHaystackConfig:
|
|||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.num_statements > 0, "num_statements must be greater than 0"
|
||||
assert self.num_statements < len(NAMES) * len(VERBS) * len(
|
||||
SUBJECTS
|
||||
), f"num_statements must be less than {len(NAMES) * len(VERBS) * len(SUBJECTS)}"
|
||||
|
||||
|
||||
def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[str], n: int, rng) -> Dict[str, Any]:
|
||||
|
|
@ -44,8 +47,6 @@ def generate_unique_triplets(names: List[str], verbs: List[str], subjects: List[
|
|||
ValueError: If n exceeds the total number of unique triplets possible.
|
||||
"""
|
||||
total_possible = len(names) * len(verbs) * len(subjects)
|
||||
if n > total_possible:
|
||||
raise ValueError("Requested n exceeds the total number of unique combinations.")
|
||||
|
||||
# Use a range for memory efficiency and sample n unique indices.
|
||||
indices = rng.sample(range(total_possible), n)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue