This commit is contained in:
GarmashAlex 2026-03-28 03:15:10 +00:00 committed by GitHub
commit 7e6f292540
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,5 +1,6 @@
import asyncio
import json
import logging
import math
import os
import time
@ -25,6 +26,8 @@ from atroposlib.envs.base import (
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
logger = logging.getLogger(__name__)
system_prompt = ""
system_prompt += (
@ -86,9 +89,6 @@ class CodingEnv(BaseEnv):
super().__init__(*args, **kwargs)
self.id = 0
self.complete = []
self.total = []
self.complement = []
self.eval_metrics = []
self.train_metrics = {
"rewards": [],
@ -219,10 +219,6 @@ class CodingEnv(BaseEnv):
)
start_time = time.time()
if train_or_valid == "train":
self.total.append(item["idx"]) # LOCK
self.complement = sorted(list(set(self.total) - set(self.complete)))
print("TOTAL: ", self.complement)
if train_or_valid == "train":
tasks = [generate_and_score(i) for i in range(self.config.group_size)]
else:
@ -242,21 +238,21 @@ class CodingEnv(BaseEnv):
lengths = [r[6] for r in results]
asst_messages = [r[7] for r in results]
cur_id = item["idx"]
if train_or_valid == "train":
self.complete.append(cur_id) # LOCK
self.complement = sorted(list(set(self.total) - set(self.complete)))
rprint(train_or_valid)
rprint("CURRENT ID: ", cur_id, item["problem_type"])
print("GEN LENGTHS: ", lengths)
print("SCORES ", scored_data["scores"])
print("MISSING ", self.complement)
print("CURRENT TIME", time.time())
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"collect_trajectories split=%s idx=%s type=%s lengths=%s scores=%s elapsed=%.2fs",
train_or_valid,
cur_id,
item["problem_type"],
lengths,
scored_data["scores"],
dt,
)
# for error, code in zip(errors, codes):
# print("ERROR:", error)
# if 'output' in error and error['output'] == '':
# print(code)
print(f"Elapsed time: {dt:.2f} seconds")
async with self.lock:
heap_sum = 0
@ -806,16 +802,29 @@ class CodingEnv(BaseEnv):
rprint(test_cases["tests"]["fn_name"])
metadata = {"error": "segmentation fault"}
rprint("FAULT")
rprint("code:\n", code)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"run_test modal remote error idx=%s fn=%s code_len=%s",
cur_id,
test_cases["tests"]["fn_name"],
len(code),
)
rprint(metadata)
except Exception as f:
rprint("index", cur_id)
rprint(test_cases["tests"]["fn_name"])
rprint("except:", code)
res = [False]
metadata = {"error": "segmentation fault"}
rprint("NON SEGFAULT")
rprint("FAULT", f)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"run_test error idx=%s fn=%s code_len=%s exception=%r",
cur_id,
test_cases["tests"]["fn_name"],
len(code),
f,
)
for x in res:
if not isinstance(x, (int, bool)):
rprint("WARNING")