mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix duplicate code + add safety checks
This commit is contained in:
parent
7da681ec46
commit
6b92ee16ec
3 changed files with 4 additions and 6 deletions
|
|
@ -36,9 +36,6 @@ def select_best_index(
|
||||||
raise ValueError("Primary and secondary score lists must have the same length.")
|
raise ValueError("Primary and secondary score lists must have the same length.")
|
||||||
|
|
||||||
num_items = len(primary_scores)
|
num_items = len(primary_scores)
|
||||||
if num_items == 0: # Should be caught by the first check, but as a safeguard.
|
|
||||||
raise ValueError("Input score lists cannot be empty.")
|
|
||||||
|
|
||||||
best_index = 0
|
best_index = 0
|
||||||
|
|
||||||
for i in range(1, num_items):
|
for i in range(1, num_items):
|
||||||
|
|
|
||||||
|
|
@ -552,7 +552,10 @@ class MathEnv(BaseEnv):
|
||||||
i
|
i
|
||||||
for i, score in enumerate(to_postprocess["scores"])
|
for i, score in enumerate(to_postprocess["scores"])
|
||||||
if score == 1.0
|
if score == 1.0
|
||||||
][0]
|
]
|
||||||
|
if len(pos_idx) == 0:
|
||||||
|
return None, to_backlog
|
||||||
|
pos_idx = pos_idx[0]
|
||||||
neg_idx = [
|
neg_idx = [
|
||||||
i
|
i
|
||||||
for i, score in enumerate(to_postprocess["scores"])
|
for i, score in enumerate(to_postprocess["scores"])
|
||||||
|
|
|
||||||
|
|
@ -251,8 +251,6 @@ class InterleavedInlineEnv(BaseEnv):
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
print(f"[DEBUG setup] kept {len(subset)} rows from Dataset")
|
print(f"[DEBUG setup] kept {len(subset)} rows from Dataset")
|
||||||
|
|
||||||
split = full.train_test_split(test_size=0.02, seed=42)
|
|
||||||
|
|
||||||
split = full.train_test_split(test_size=0.02, seed=42)
|
split = full.train_test_split(test_size=0.02, seed=42)
|
||||||
self.train, self.test = split["train"], split["test"]
|
self.train, self.test = split["train"], split["test"]
|
||||||
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
|
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue