mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge bac04c6343 into c20c85256e
This commit is contained in:
commit
7e6f292540
1 changed files with 28 additions and 19 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
|
@ -25,6 +26,8 @@ from atroposlib.envs.base import (
|
|||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
system_prompt += (
|
||||
|
|
@ -86,9 +89,6 @@ class CodingEnv(BaseEnv):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.id = 0
|
||||
self.complete = []
|
||||
self.total = []
|
||||
self.complement = []
|
||||
self.eval_metrics = []
|
||||
self.train_metrics = {
|
||||
"rewards": [],
|
||||
|
|
@ -219,10 +219,6 @@ class CodingEnv(BaseEnv):
|
|||
)
|
||||
|
||||
start_time = time.time()
|
||||
if train_or_valid == "train":
|
||||
self.total.append(item["idx"]) # LOCK
|
||||
self.complement = sorted(list(set(self.total) - set(self.complete)))
|
||||
print("TOTAL: ", self.complement)
|
||||
if train_or_valid == "train":
|
||||
tasks = [generate_and_score(i) for i in range(self.config.group_size)]
|
||||
else:
|
||||
|
|
@ -242,21 +238,21 @@ class CodingEnv(BaseEnv):
|
|||
lengths = [r[6] for r in results]
|
||||
asst_messages = [r[7] for r in results]
|
||||
cur_id = item["idx"]
|
||||
if train_or_valid == "train":
|
||||
self.complete.append(cur_id) # LOCK
|
||||
self.complement = sorted(list(set(self.total) - set(self.complete)))
|
||||
|
||||
rprint(train_or_valid)
|
||||
rprint("CURRENT ID: ", cur_id, item["problem_type"])
|
||||
print("GEN LENGTHS: ", lengths)
|
||||
print("SCORES ", scored_data["scores"])
|
||||
print("MISSING ", self.complement)
|
||||
print("CURRENT TIME", time.time())
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"collect_trajectories split=%s idx=%s type=%s lengths=%s scores=%s elapsed=%.2fs",
|
||||
train_or_valid,
|
||||
cur_id,
|
||||
item["problem_type"],
|
||||
lengths,
|
||||
scored_data["scores"],
|
||||
dt,
|
||||
)
|
||||
# for error, code in zip(errors, codes):
|
||||
# print("ERROR:", error)
|
||||
# if 'output' in error and error['output'] == '':
|
||||
# print(code)
|
||||
print(f"Elapsed time: {dt:.2f} seconds")
|
||||
|
||||
async with self.lock:
|
||||
heap_sum = 0
|
||||
|
|
@ -806,16 +802,29 @@ class CodingEnv(BaseEnv):
|
|||
rprint(test_cases["tests"]["fn_name"])
|
||||
metadata = {"error": "segmentation fault"}
|
||||
rprint("FAULT")
|
||||
rprint("code:\n", code)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"run_test modal remote error idx=%s fn=%s code_len=%s",
|
||||
cur_id,
|
||||
test_cases["tests"]["fn_name"],
|
||||
len(code),
|
||||
)
|
||||
rprint(metadata)
|
||||
except Exception as f:
|
||||
rprint("index", cur_id)
|
||||
rprint(test_cases["tests"]["fn_name"])
|
||||
rprint("except:", code)
|
||||
res = [False]
|
||||
metadata = {"error": "segmentation fault"}
|
||||
rprint("NON SEGFAULT")
|
||||
rprint("FAULT", f)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"run_test error idx=%s fn=%s code_len=%s exception=%r",
|
||||
cur_id,
|
||||
test_cases["tests"]["fn_name"],
|
||||
len(code),
|
||||
f,
|
||||
)
|
||||
for x in res:
|
||||
if not isinstance(x, (int, bool)):
|
||||
rprint("WARNING")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue