mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix duplicate code + add safety checks
This commit is contained in:
parent
7da681ec46
commit
6b92ee16ec
3 changed files with 4 additions and 6 deletions
|
|
@ -36,9 +36,6 @@ def select_best_index(
|
|||
raise ValueError("Primary and secondary score lists must have the same length.")
|
||||
|
||||
num_items = len(primary_scores)
|
||||
if num_items == 0: # Should be caught by the first check, but as a safeguard.
|
||||
raise ValueError("Input score lists cannot be empty.")
|
||||
|
||||
best_index = 0
|
||||
|
||||
for i in range(1, num_items):
|
||||
|
|
|
|||
|
|
@ -552,7 +552,10 @@ class MathEnv(BaseEnv):
|
|||
i
|
||||
for i, score in enumerate(to_postprocess["scores"])
|
||||
if score == 1.0
|
||||
][0]
|
||||
]
|
||||
if len(pos_idx) == 0:
|
||||
return None, to_backlog
|
||||
pos_idx = pos_idx[0]
|
||||
neg_idx = [
|
||||
i
|
||||
for i, score in enumerate(to_postprocess["scores"])
|
||||
|
|
|
|||
|
|
@ -251,8 +251,6 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
if DEBUG:
|
||||
print(f"[DEBUG setup] kept {len(subset)} rows from Dataset")
|
||||
|
||||
split = full.train_test_split(test_size=0.02, seed=42)
|
||||
|
||||
split = full.train_test_split(test_size=0.02, seed=42)
|
||||
self.train, self.test = split["train"], split["test"]
|
||||
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue