mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
3348e31a29
commit
d932d9c03b
3 changed files with 333 additions and 129 deletions
|
|
@ -237,4 +237,3 @@ The code execution endpoint (`lcb_modal_endpoint.py`) implements extensive safet
|
|||
## 📄 License
|
||||
|
||||
This environment is part of the Atropos training framework. See the main repository for license information.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,29 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
import numpy as np
|
||||
import time
|
||||
import modal
|
||||
import math
|
||||
from pydantic import Field
|
||||
from collections import deque
|
||||
import json
|
||||
import wandb
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import modal
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from pydantic import Field
|
||||
from rich import print as rprint
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
EvalHandlingEnum,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from rich import print as rprint
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
|
|
@ -32,6 +34,7 @@ FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin, solve the problem
|
|||
|
||||
async_semaphore = asyncio.Semaphore(100)
|
||||
|
||||
|
||||
def get_prompt(question, problem_type, starter_code=None):
|
||||
prompt = ""
|
||||
prompt += f"Question: {question}\n\n"
|
||||
|
|
@ -46,30 +49,49 @@ def get_prompt(question, problem_type, starter_code=None):
|
|||
prompt += f"### Answer: (use the provided format with backticks)\n\n"
|
||||
return prompt
|
||||
|
||||
|
||||
run_test = modal.Function.from_name("joeli-lcb", "run_test")
|
||||
|
||||
|
||||
class CodeConfig(BaseEnvConfig):
|
||||
dataset_name: str = Field("normal", description="dataset name, should be normal or deepmind")
|
||||
dataset_name: str = Field(
|
||||
"normal", description="dataset name, should be normal or deepmind"
|
||||
)
|
||||
temperature: float = Field(0.6, description="model temperature")
|
||||
eval_temperature: float = Field(0.6, description="model temperature during evaluation")
|
||||
eval_temperature: float = Field(
|
||||
0.6, description="model temperature during evaluation"
|
||||
)
|
||||
top_p: float = Field(0.95, description="top p")
|
||||
eval_top_p: float = Field(0.95, description="eval top p")
|
||||
start_idx: int = Field(0, description="start index in training set")
|
||||
max_eval_token_length: int = Field(40960, description="max sequence length during evaluation")
|
||||
max_eval_token_length: int = Field(
|
||||
40960, description="max sequence length during evaluation"
|
||||
)
|
||||
|
||||
|
||||
class CodingEnv(BaseEnv):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
self.id = 0
|
||||
self.complete = []
|
||||
self.total = []
|
||||
self.complement = []
|
||||
self.eval_metrics = []
|
||||
self.train_metrics = {"rewards": [], "overlong_ratio": [], "pass_gs": [], "completion_lengths": [], "correct_completion_lengths": [], "incorrect_completion_lengths": []}
|
||||
self.train_metrics = {
|
||||
"rewards": [],
|
||||
"overlong_ratio": [],
|
||||
"pass_gs": [],
|
||||
"completion_lengths": [],
|
||||
"correct_completion_lengths": [],
|
||||
"incorrect_completion_lengths": [],
|
||||
}
|
||||
self.cur_time = datetime.now().strftime("%Y-%m-%d %H.%M.%S")
|
||||
|
||||
self.temp_metrics = {"indices": [], "num_correct": [], }
|
||||
self.temp_metrics = {
|
||||
"indices": [],
|
||||
"num_correct": [],
|
||||
}
|
||||
|
||||
self.blacklist = set()
|
||||
self.deq = deque()
|
||||
|
|
@ -85,7 +107,7 @@ class CodingEnv(BaseEnv):
|
|||
batch_size=-1,
|
||||
steps_per_eval=10,
|
||||
max_batches_offpolicy=2,
|
||||
max_eval_workers=128, # max workers per node * num gpus
|
||||
max_eval_workers=128, # max workers per node * num gpus
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
eval_limit_ratio=0.25,
|
||||
max_token_length=32768,
|
||||
|
|
@ -112,7 +134,6 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
# item looks like {problem, tests, problem_type, idx}
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
|
|
@ -120,25 +141,38 @@ class CodingEnv(BaseEnv):
|
|||
rprint("COLLECTING TRAJECTORIES")
|
||||
train_or_valid = item["split"]
|
||||
system_msg = {"role": "system", "content": system_prompt}
|
||||
user_msg = {"role": "user", "content": get_prompt(item["problem"], item["problem_type"], item.get("starter_code", None))}
|
||||
user_msg = {
|
||||
"role": "user",
|
||||
"content": get_prompt(
|
||||
item["problem"], item["problem_type"], item.get("starter_code", None)
|
||||
),
|
||||
}
|
||||
|
||||
prompt_tokens = tokenize_for_trainer(self.tokenizer, chat=[system_msg, user_msg])
|
||||
prompt_tokens = tokenize_for_trainer(
|
||||
self.tokenizer, chat=[system_msg, user_msg]
|
||||
)
|
||||
buffer = 32
|
||||
|
||||
async def generate_and_score(index: int) -> Tuple[List[int], List[int], float]:
|
||||
# Step 1: Generate single completion
|
||||
|
||||
max_tokens = self.config.max_token_length - len(prompt_tokens["tokens"]) - buffer
|
||||
max_tokens = (
|
||||
self.config.max_token_length - len(prompt_tokens["tokens"]) - buffer
|
||||
)
|
||||
temp = self.config.temperature
|
||||
top_p = self.config.top_p
|
||||
if train_or_valid == "test":
|
||||
max_tokens = self.config.max_eval_token_length - len(prompt_tokens["tokens"]) - buffer
|
||||
max_tokens = (
|
||||
self.config.max_eval_token_length
|
||||
- len(prompt_tokens["tokens"])
|
||||
- buffer
|
||||
)
|
||||
top_p = self.config.eval_top_p
|
||||
temp = self.config.eval_temperature
|
||||
|
||||
chat_completion = await self.server.chat_completion(
|
||||
messages=[system_msg, user_msg],
|
||||
n=1, ### CHANGE
|
||||
n=1, ### CHANGE
|
||||
max_tokens=max_tokens,
|
||||
temperature=temp,
|
||||
top_p=top_p,
|
||||
|
|
@ -156,8 +190,10 @@ class CodingEnv(BaseEnv):
|
|||
if isinstance(tests, str):
|
||||
tests = json.loads(tests)
|
||||
tests["fn_name"] = fn_name
|
||||
|
||||
score_input = [(messages, tests, item["idx"], chat_completion.choices[0].finish_reason)]
|
||||
|
||||
score_input = [
|
||||
(messages, tests, item["idx"], chat_completion.choices[0].finish_reason)
|
||||
]
|
||||
scored_group, tup = await self.score(score_input)
|
||||
return (
|
||||
scored_group["tokens"][0],
|
||||
|
|
@ -169,10 +205,10 @@ class CodingEnv(BaseEnv):
|
|||
tup[2],
|
||||
assistant_msg,
|
||||
)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
if train_or_valid == "train":
|
||||
self.total.append(item["idx"]) # LOCK
|
||||
self.total.append(item["idx"]) # LOCK
|
||||
self.complement = sorted(list(set(self.total) - set(self.complete)))
|
||||
print("TOTAL: ", self.complement)
|
||||
if train_or_valid == "train":
|
||||
|
|
@ -182,7 +218,7 @@ class CodingEnv(BaseEnv):
|
|||
results = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
dt = end_time-start_time
|
||||
dt = end_time - start_time
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = [r[0] for r in results]
|
||||
scored_data["masks"] = [r[1] for r in results]
|
||||
|
|
@ -195,7 +231,7 @@ class CodingEnv(BaseEnv):
|
|||
asst_messages = [r[7] for r in results]
|
||||
cur_id = item["idx"]
|
||||
if train_or_valid == "train":
|
||||
self.complete.append(cur_id) # LOCK
|
||||
self.complete.append(cur_id) # LOCK
|
||||
self.complement = sorted(list(set(self.total) - set(self.complete)))
|
||||
|
||||
rprint(train_or_valid)
|
||||
|
|
@ -204,10 +240,10 @@ class CodingEnv(BaseEnv):
|
|||
print("SCORES ", scored_data["scores"])
|
||||
print("MISSING ", self.complement)
|
||||
print("CURRENT TIME", time.time())
|
||||
#for error, code in zip(errors, codes):
|
||||
# for error, code in zip(errors, codes):
|
||||
# print("ERROR:", error)
|
||||
#if 'output' in error and error['output'] == '':
|
||||
# print(code)
|
||||
# if 'output' in error and error['output'] == '':
|
||||
# print(code)
|
||||
print(f"Elapsed time: {dt:.2f} seconds")
|
||||
|
||||
async with self.lock:
|
||||
|
|
@ -215,15 +251,44 @@ class CodingEnv(BaseEnv):
|
|||
for x in scored_data["scores"]:
|
||||
if math.isclose(1.0, x):
|
||||
heap_sum += 1
|
||||
num_override = sum([x["set_advantage_to_zero"] for x in scored_data["overrides"]])
|
||||
num_override = sum(
|
||||
[x["set_advantage_to_zero"] for x in scored_data["overrides"]]
|
||||
)
|
||||
|
||||
if train_or_valid == "train":
|
||||
self.train_metrics["rewards"].append(1.0 * heap_sum / self.config.group_size)
|
||||
self.train_metrics["overlong_ratio"].append(1.0 * num_override / self.config.group_size)
|
||||
self.train_metrics["rewards"].append(
|
||||
1.0 * heap_sum / self.config.group_size
|
||||
)
|
||||
self.train_metrics["overlong_ratio"].append(
|
||||
1.0 * num_override / self.config.group_size
|
||||
)
|
||||
self.train_metrics["pass_gs"].append(1.0 * (heap_sum > 0))
|
||||
self.train_metrics["completion_lengths"].append(sum([len(x) for x in scored_data["tokens"]]) / len(scored_data["tokens"]))
|
||||
self.train_metrics["correct_completion_lengths"].append(sum([len(x) for x, y in zip(scored_data["tokens"], scored_data["scores"]) if math.isclose(y, 1.0)]))
|
||||
self.train_metrics["incorrect_completion_lengths"].append(sum([len(x) for x, y in zip(scored_data["tokens"], scored_data["scores"]) if math.isclose(y, -1.0)]))
|
||||
self.train_metrics["completion_lengths"].append(
|
||||
sum([len(x) for x in scored_data["tokens"]])
|
||||
/ len(scored_data["tokens"])
|
||||
)
|
||||
self.train_metrics["correct_completion_lengths"].append(
|
||||
sum(
|
||||
[
|
||||
len(x)
|
||||
for x, y in zip(
|
||||
scored_data["tokens"], scored_data["scores"]
|
||||
)
|
||||
if math.isclose(y, 1.0)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.train_metrics["incorrect_completion_lengths"].append(
|
||||
sum(
|
||||
[
|
||||
len(x)
|
||||
for x, y in zip(
|
||||
scored_data["tokens"], scored_data["scores"]
|
||||
)
|
||||
if math.isclose(y, -1.0)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.temp_metrics["indices"].append(cur_id)
|
||||
self.temp_metrics["num_correct"].append(heap_sum)
|
||||
|
||||
|
|
@ -239,13 +304,32 @@ class CodingEnv(BaseEnv):
|
|||
script_dir = os.path.join(script_dir, "train_logs")
|
||||
else:
|
||||
script_dir = os.path.join(script_dir, "eval_logs")
|
||||
file_path = os.path.join(script_dir, "qwen_data_dump_"+self.cur_time + ".txt")
|
||||
file_path_long = os.path.join(script_dir, "qwen_data_dump_long"+self.cur_time+".txt")
|
||||
file_path = os.path.join(
|
||||
script_dir, "qwen_data_dump_" + self.cur_time + ".txt"
|
||||
)
|
||||
file_path_long = os.path.join(
|
||||
script_dir, "qwen_data_dump_long" + self.cur_time + ".txt"
|
||||
)
|
||||
with open(file_path, "a") as f:
|
||||
mp = {"cur_id": cur_id, "num_correct": heap_sum, "total": self.config.group_size, "scores": scored_data["scores"], "lengths": lengths}
|
||||
mp = {
|
||||
"cur_id": cur_id,
|
||||
"num_correct": heap_sum,
|
||||
"total": self.config.group_size,
|
||||
"scores": scored_data["scores"],
|
||||
"lengths": lengths,
|
||||
}
|
||||
f.write(json.dumps(mp) + "\n")
|
||||
with open(file_path_long, "a") as f:
|
||||
mp = {"cur_id": cur_id, "num_correct": heap_sum, "total": self.config.group_size, "scores": scored_data["scores"], "lengths": lengths, "errors": errors, "codes": codes, "gen": asst_messages[0]}
|
||||
mp = {
|
||||
"cur_id": cur_id,
|
||||
"num_correct": heap_sum,
|
||||
"total": self.config.group_size,
|
||||
"scores": scored_data["scores"],
|
||||
"lengths": lengths,
|
||||
"errors": errors,
|
||||
"codes": codes,
|
||||
"gen": asst_messages[0],
|
||||
}
|
||||
f.write(json.dumps(mp) + "\n")
|
||||
|
||||
return scored_data, []
|
||||
|
|
@ -287,15 +371,36 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
scores = scored_data["scores"]
|
||||
num_correct = sum(1 for x in scores if math.isclose(x, 1.0))
|
||||
num_overlong = sum(x["set_advantage_to_zero"] for x in scored_data["overrides"])
|
||||
num_overlong = sum(
|
||||
x["set_advantage_to_zero"] for x in scored_data["overrides"]
|
||||
)
|
||||
|
||||
async with self.lock2:
|
||||
overlong_ratio.append(num_overlong / len(scores))
|
||||
pass_gs.append(num_correct > 0)
|
||||
temp_completion_lengths.append(sum([len(x) for x in scored_data["tokens"]]) / len(scored_data["tokens"]))
|
||||
temp_completion_lengths.append(
|
||||
sum([len(x) for x in scored_data["tokens"]])
|
||||
/ len(scored_data["tokens"])
|
||||
)
|
||||
|
||||
correct_completion_lengths.extend([len(x) for x, y in zip(scored_data["tokens"], scored_data["scores"]) if math.isclose(y, 1.0)])
|
||||
incorrect_completion_lengths.extend([len(x) for x, y in zip(scored_data["tokens"], scored_data["scores"]) if math.isclose(y, -1.0)])
|
||||
correct_completion_lengths.extend(
|
||||
[
|
||||
len(x)
|
||||
for x, y in zip(
|
||||
scored_data["tokens"], scored_data["scores"]
|
||||
)
|
||||
if math.isclose(y, 1.0)
|
||||
]
|
||||
)
|
||||
incorrect_completion_lengths.extend(
|
||||
[
|
||||
len(x)
|
||||
for x, y in zip(
|
||||
scored_data["tokens"], scored_data["scores"]
|
||||
)
|
||||
if math.isclose(y, -1.0)
|
||||
]
|
||||
)
|
||||
all_total.append(len(scores))
|
||||
all_correct.append(num_correct)
|
||||
if item["difficulty"] == "easy":
|
||||
|
|
@ -316,14 +421,26 @@ class CodingEnv(BaseEnv):
|
|||
if n - c < k:
|
||||
return 1.0
|
||||
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||
|
||||
all_pass_1 = sum([estimator(n, c, 1) for n, c in zip(all_total, all_correct)]) / len(all_total)
|
||||
easy_pass_1 = sum([estimator(n, c, 1) for n, c in zip(easy_total, easy_correct)]) / (len(easy_total) + 1e-6)
|
||||
medium_pass_1 = sum([estimator(n, c, 1) for n, c in zip(medium_total, medium_correct)]) / (len(medium_total) + 1e-6)
|
||||
hard_pass_1 = sum([estimator(n, c, 1) for n, c in zip(hard_total, hard_correct)]) / (len(hard_total) + 1e-6)
|
||||
|
||||
self.eval_metrics.append(("eval/overlong_ratio", 1.0 * sum(overlong_ratio) / len(overlong_ratio)))
|
||||
self.eval_metrics.append(("eval/pass@group_size", 1.0 * sum(pass_gs) / len(pass_gs)))
|
||||
all_pass_1 = sum(
|
||||
[estimator(n, c, 1) for n, c in zip(all_total, all_correct)]
|
||||
) / len(all_total)
|
||||
easy_pass_1 = sum(
|
||||
[estimator(n, c, 1) for n, c in zip(easy_total, easy_correct)]
|
||||
) / (len(easy_total) + 1e-6)
|
||||
medium_pass_1 = sum(
|
||||
[estimator(n, c, 1) for n, c in zip(medium_total, medium_correct)]
|
||||
) / (len(medium_total) + 1e-6)
|
||||
hard_pass_1 = sum(
|
||||
[estimator(n, c, 1) for n, c in zip(hard_total, hard_correct)]
|
||||
) / (len(hard_total) + 1e-6)
|
||||
|
||||
self.eval_metrics.append(
|
||||
("eval/overlong_ratio", 1.0 * sum(overlong_ratio) / len(overlong_ratio))
|
||||
)
|
||||
self.eval_metrics.append(
|
||||
("eval/pass@group_size", 1.0 * sum(pass_gs) / len(pass_gs))
|
||||
)
|
||||
|
||||
avg_comp_len = sum(temp_completion_lengths) / len(temp_completion_lengths)
|
||||
|
||||
|
|
@ -333,12 +450,22 @@ class CodingEnv(BaseEnv):
|
|||
self.eval_metrics.append(("eval/hard_pass_1", hard_pass_1))
|
||||
self.eval_metrics.append(("eval/completion_length", avg_comp_len))
|
||||
|
||||
self.eval_metrics.append(("eval/correct_completion_length", sum(correct_completion_lengths) / len(correct_completion_lengths)))
|
||||
self.eval_metrics.append(("eval/incorrect_completion_length", sum(incorrect_completion_lengths) / len(incorrect_completion_lengths)))
|
||||
self.eval_metrics.append(
|
||||
(
|
||||
"eval/correct_completion_length",
|
||||
sum(correct_completion_lengths) / len(correct_completion_lengths),
|
||||
)
|
||||
)
|
||||
self.eval_metrics.append(
|
||||
(
|
||||
"eval/incorrect_completion_length",
|
||||
sum(incorrect_completion_lengths) / len(incorrect_completion_lengths),
|
||||
)
|
||||
)
|
||||
print("STATS", self.eval_metrics)
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def offline_filter(self):
|
||||
rprint("OFFLINE FILTERING")
|
||||
train_data = self.train
|
||||
|
|
@ -364,10 +491,9 @@ class CodingEnv(BaseEnv):
|
|||
f.write(json.dumps(curr_step) + "\n")
|
||||
|
||||
tasks = [asyncio.create_task(filter_temp(i)) for i in range(len(train_data))]
|
||||
await asyncio.gather(*tasks)
|
||||
await asyncio.gather(*tasks)
|
||||
print("DONE")
|
||||
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
|
@ -378,8 +504,18 @@ class CodingEnv(BaseEnv):
|
|||
async with self.lock:
|
||||
if self.wandb_step == 1 and self.curr_step > 5:
|
||||
self.wandb_step = self.curr_step
|
||||
cnt = sum((i != 0 and i != self.config.group_size) for i in self.temp_metrics["num_correct"])
|
||||
print("TEMP METRICS INDICES: ", self.temp_metrics["indices"], "\n", self.temp_metrics["num_correct"], cnt, len(self.temp_metrics["num_correct"]))
|
||||
cnt = sum(
|
||||
(i != 0 and i != self.config.group_size)
|
||||
for i in self.temp_metrics["num_correct"]
|
||||
)
|
||||
print(
|
||||
"TEMP METRICS INDICES: ",
|
||||
self.temp_metrics["indices"],
|
||||
"\n",
|
||||
self.temp_metrics["num_correct"],
|
||||
cnt,
|
||||
len(self.temp_metrics["num_correct"]),
|
||||
)
|
||||
for k, v in wandb_metrics.items():
|
||||
print(k, v)
|
||||
|
||||
|
|
@ -390,30 +526,56 @@ class CodingEnv(BaseEnv):
|
|||
old_completion_idx = 0
|
||||
new_completion_idx = 0
|
||||
|
||||
#assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
|
||||
print("TEMP METRICS LENGTHS", len(self.temp_metrics["num_correct"]), "COMPLETION LENGTHS", len(self.completion_lengths))
|
||||
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
|
||||
print(
|
||||
"TEMP METRICS LENGTHS",
|
||||
len(self.temp_metrics["num_correct"]),
|
||||
"COMPLETION LENGTHS",
|
||||
len(self.completion_lengths),
|
||||
)
|
||||
print("BATCH SZ", self.config.batch_size)
|
||||
|
||||
while idx < len(self.temp_metrics["num_correct"]):
|
||||
if self.temp_metrics["num_correct"][idx] != 0 and self.temp_metrics["num_correct"][idx] != self.config.group_size:
|
||||
if (
|
||||
self.temp_metrics["num_correct"][idx] != 0
|
||||
and self.temp_metrics["num_correct"][idx] != self.config.group_size
|
||||
):
|
||||
good_updates += 1
|
||||
idx += 1
|
||||
if good_updates == self.config.batch_size // self.config.group_size:
|
||||
wandb_metrics["train_rewards/rewards"] = sum(self.train_metrics["rewards"][old_idx:idx]) / len(self.train_metrics["rewards"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/overlong_ratio"] = sum(self.train_metrics["overlong_ratio"][old_idx:idx]) / len(self.train_metrics["overlong_ratio"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/pass@group_size"] = sum(self.train_metrics["pass_gs"][old_idx:idx]) / len(self.train_metrics["pass_gs"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/completion_lengths"] = sum(self.train_metrics["completion_lengths"][old_idx:idx]) / len(self.train_metrics["completion_lengths"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/correct_completion_lengths"] = sum(self.train_metrics["correct_completion_lengths"][old_idx:idx]) / sum(self.temp_metrics["num_correct"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/incorrect_completion_lengths"] = sum(self.train_metrics["incorrect_completion_lengths"][old_idx:idx]) / (self.config.group_size * (idx - old_idx) - sum(self.temp_metrics["num_correct"][old_idx:idx]))
|
||||
|
||||
wandb_metrics["train_rewards/rewards"] = sum(
|
||||
self.train_metrics["rewards"][old_idx:idx]
|
||||
) / len(self.train_metrics["rewards"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/overlong_ratio"] = sum(
|
||||
self.train_metrics["overlong_ratio"][old_idx:idx]
|
||||
) / len(self.train_metrics["overlong_ratio"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/pass@group_size"] = sum(
|
||||
self.train_metrics["pass_gs"][old_idx:idx]
|
||||
) / len(self.train_metrics["pass_gs"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/completion_lengths"] = sum(
|
||||
self.train_metrics["completion_lengths"][old_idx:idx]
|
||||
) / len(self.train_metrics["completion_lengths"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/correct_completion_lengths"] = sum(
|
||||
self.train_metrics["correct_completion_lengths"][old_idx:idx]
|
||||
) / sum(self.temp_metrics["num_correct"][old_idx:idx])
|
||||
wandb_metrics["train_rewards/incorrect_completion_lengths"] = sum(
|
||||
self.train_metrics["incorrect_completion_lengths"][old_idx:idx]
|
||||
) / (
|
||||
self.config.group_size * (idx - old_idx)
|
||||
- sum(self.temp_metrics["num_correct"][old_idx:idx])
|
||||
)
|
||||
|
||||
new_completion_idx += self.config.batch_size
|
||||
|
||||
assert old_completion_idx <= len(self.completion_lengths), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
|
||||
|
||||
assert old_completion_idx <= len(
|
||||
self.completion_lengths
|
||||
), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
|
||||
|
||||
wandb_metrics["train/completion_lengths"] = sum(
|
||||
self.completion_lengths[old_completion_idx:new_completion_idx]
|
||||
) / len(self.completion_lengths[old_completion_idx:new_completion_idx])
|
||||
) / len(
|
||||
self.completion_lengths[old_completion_idx:new_completion_idx]
|
||||
)
|
||||
wandb_metrics["train/completion_lengths_std"] = np.std(
|
||||
self.completion_lengths[old_completion_idx:new_completion_idx]
|
||||
)
|
||||
|
|
@ -424,12 +586,18 @@ class CodingEnv(BaseEnv):
|
|||
self.completion_lengths[old_completion_idx:new_completion_idx]
|
||||
)
|
||||
wandb_metrics["train/completion_lengths_p95"] = (
|
||||
np.array(self.completion_lengths[old_completion_idx:new_completion_idx]) > (0.95 * self.max_token_len)
|
||||
np.array(
|
||||
self.completion_lengths[
|
||||
old_completion_idx:new_completion_idx
|
||||
]
|
||||
)
|
||||
> (0.95 * self.max_token_len)
|
||||
).mean()
|
||||
|
||||
|
||||
if self.wandb_prepend is not None:
|
||||
wandb_metrics = {
|
||||
f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items()
|
||||
f"{self.wandb_prepend}_{k}": v
|
||||
for k, v in wandb_metrics.items()
|
||||
}
|
||||
print("WANDB LOG")
|
||||
wandb.log(wandb_metrics, step=self.wandb_step, commit=True)
|
||||
|
|
@ -440,16 +608,25 @@ class CodingEnv(BaseEnv):
|
|||
old_idx = idx
|
||||
old_completion_idx = new_completion_idx
|
||||
|
||||
|
||||
self.completion_lengths = self.completion_lengths[old_completion_idx:]
|
||||
self.temp_metrics["indices"] = self.temp_metrics["indices"][old_idx:]
|
||||
self.temp_metrics["num_correct"] = self.temp_metrics["num_correct"][old_idx:]
|
||||
self.temp_metrics["num_correct"] = self.temp_metrics["num_correct"][
|
||||
old_idx:
|
||||
]
|
||||
self.train_metrics["rewards"] = self.train_metrics["rewards"][old_idx:]
|
||||
self.train_metrics["overlong_ratio"] = self.train_metrics["overlong_ratio"][old_idx:]
|
||||
self.train_metrics["overlong_ratio"] = self.train_metrics["overlong_ratio"][
|
||||
old_idx:
|
||||
]
|
||||
self.train_metrics["pass_gs"] = self.train_metrics["pass_gs"][old_idx:]
|
||||
self.train_metrics["completion_lengths"] = self.train_metrics["completion_lengths"][old_idx:]
|
||||
self.train_metrics["correct_completion_lengths"] = self.train_metrics["correct_completion_lengths"][old_idx:]
|
||||
self.train_metrics["incorrect_completion_lengths"] = self.train_metrics["incorrect_completion_lengths"][old_idx:]
|
||||
self.train_metrics["completion_lengths"] = self.train_metrics[
|
||||
"completion_lengths"
|
||||
][old_idx:]
|
||||
self.train_metrics["correct_completion_lengths"] = self.train_metrics[
|
||||
"correct_completion_lengths"
|
||||
][old_idx:]
|
||||
self.train_metrics["incorrect_completion_lengths"] = self.train_metrics[
|
||||
"incorrect_completion_lengths"
|
||||
][old_idx:]
|
||||
|
||||
print("WANDB STEPS vs STATUS STEP", self.wandb_step, self.curr_step)
|
||||
|
||||
|
|
@ -474,23 +651,26 @@ class CodingEnv(BaseEnv):
|
|||
async def setup(self):
|
||||
"""Setup the environment"""
|
||||
if self.config.dataset_name == "deepmind":
|
||||
self.train = load_dataset("deepmind/code_contests", split="train") ### CHANGE
|
||||
self.train = load_dataset(
|
||||
"deepmind/code_contests", split="train"
|
||||
) ### CHANGE
|
||||
else:
|
||||
self.train = load_dataset("NousResearch/RLVR_Coding_Problems", split="train")
|
||||
test = load_dataset("NousResearch/lcb_test",split="test")
|
||||
self.train = load_dataset(
|
||||
"NousResearch/RLVR_Coding_Problems", split="train"
|
||||
)
|
||||
test = load_dataset("NousResearch/lcb_test", split="test")
|
||||
self.test = []
|
||||
for problem in test:
|
||||
self.test.append(problem)
|
||||
self.test[-1]["idx"] = len(self.test) - 1
|
||||
self.test[-1]["split"] = "test"
|
||||
|
||||
self.iter = 0 ### CHANGE
|
||||
self.iter = 0 ### CHANGE
|
||||
self.lock = asyncio.Lock()
|
||||
self.lock2 = asyncio.Lock()
|
||||
|
||||
self.wandb_step = 1
|
||||
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(script_dir, "perfect_indices.txt")
|
||||
with open(file_path, "r") as f:
|
||||
|
|
@ -506,7 +686,7 @@ class CodingEnv(BaseEnv):
|
|||
arr_short.append(i)
|
||||
print("ARR SHORT LENGTH", len(arr_short))
|
||||
temp_arr = arr_short * 100
|
||||
self.deq = deque(temp_arr[self.config.start_idx:])
|
||||
self.deq = deque(temp_arr[self.config.start_idx :])
|
||||
else:
|
||||
self.good_indices = [i for i in range(10000)]
|
||||
for i in range(10000):
|
||||
|
|
@ -520,14 +700,14 @@ class CodingEnv(BaseEnv):
|
|||
await self.offline_filter()
|
||||
rprint("FINISH blacklisting")
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
st = time.time()
|
||||
await self.evaluate() ### CHANGE
|
||||
ed = time.time()
|
||||
rprint("ELAPSED TIME FOR EVALUATION: ", ed-st)
|
||||
"""
|
||||
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Get the next items to be rolled out
|
||||
|
|
@ -542,7 +722,12 @@ class CodingEnv(BaseEnv):
|
|||
next_item["split"] = "train"
|
||||
if self.config.dataset_name == "deepmind":
|
||||
next_item["problem"] = next_item["description"]
|
||||
next_item["tests"] = {"input": next_item["private_tests"]["input"] + next_item["generated_tests"]["input"], "output": next_item["private_tests"]["output"] + next_item["generated_tests"]["output"]}
|
||||
next_item["tests"] = {
|
||||
"input": next_item["private_tests"]["input"]
|
||||
+ next_item["generated_tests"]["input"],
|
||||
"output": next_item["private_tests"]["output"]
|
||||
+ next_item["generated_tests"]["output"],
|
||||
}
|
||||
next_item["problem_type"] = "stdin_stdout"
|
||||
return next_item
|
||||
|
||||
|
|
@ -554,7 +739,7 @@ class CodingEnv(BaseEnv):
|
|||
return python_blocks
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
|
||||
|
||||
assert len(rollout_group_data) == 1
|
||||
cur_id = rollout_group_data[0][2]
|
||||
|
||||
|
|
@ -577,7 +762,7 @@ class CodingEnv(BaseEnv):
|
|||
else:
|
||||
if item[3] != "stop":
|
||||
rprint("FINISH REASON", item[3])
|
||||
#assert item[3] == "stop"
|
||||
# assert item[3] == "stop"
|
||||
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0], item[3])
|
||||
tokens = out_dict["tokens"]
|
||||
|
|
@ -593,7 +778,6 @@ class CodingEnv(BaseEnv):
|
|||
code = None
|
||||
test_cases = {"tests": item[1]}
|
||||
|
||||
|
||||
codes.append(code)
|
||||
|
||||
if code == None:
|
||||
|
|
@ -617,7 +801,7 @@ class CodingEnv(BaseEnv):
|
|||
rprint(test_cases["tests"]["fn_name"])
|
||||
rprint("except:", code)
|
||||
res = [False]
|
||||
metadata= {"error": "segmentation fault"}
|
||||
metadata = {"error": "segmentation fault"}
|
||||
rprint("NON SEGFAULT")
|
||||
rprint("FAULT", f)
|
||||
for x in res:
|
||||
|
|
@ -633,9 +817,9 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
errors.append(metadata)
|
||||
|
||||
|
||||
return scores, (codes[0], errors[0], lengths[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(system_prompt)
|
||||
CodingEnv.cli()
|
||||
CodingEnv.cli()
|
||||
|
|
|
|||
|
|
@ -1,37 +1,34 @@
|
|||
import ast
|
||||
import json
|
||||
import sys
|
||||
import faulthandler
|
||||
import json
|
||||
import platform
|
||||
import modal
|
||||
|
||||
# used for debugging to time steps
|
||||
from datetime import datetime
|
||||
|
||||
# to run the solution files we're using a timing based approach
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# used for debugging to time steps
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
|
||||
# used for testing the code that reads from input
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
# from pyext import RuntimeModule
|
||||
from types import ModuleType
|
||||
|
||||
from enum import Enum
|
||||
from decimal import Decimal
|
||||
import time
|
||||
|
||||
# used for testing the code that reads from input
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import modal
|
||||
import numpy as np
|
||||
|
||||
import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
|
||||
|
||||
image = modal.Image.debian_slim(python_version="3.10").pip_install("numpy")
|
||||
app = modal.App("joeli-lcb", image=image)
|
||||
|
||||
|
||||
def truncatefn(s, length=300):
|
||||
if isinstance(s, str):
|
||||
pass
|
||||
|
|
@ -117,7 +114,10 @@ def clean_if_name(code: str) -> str:
|
|||
last_block = astree.body[-1]
|
||||
if isinstance(last_block, ast.If):
|
||||
condition = last_block.test
|
||||
if ast.unparse(condition).strip() == "__name__ == '__main__'" or ast.unparse(condition).strip() == "__name__ == \"__main__\"":
|
||||
if (
|
||||
ast.unparse(condition).strip() == "__name__ == '__main__'"
|
||||
or ast.unparse(condition).strip() == '__name__ == "__main__"'
|
||||
):
|
||||
code = (
|
||||
ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
|
||||
)
|
||||
|
|
@ -181,10 +181,10 @@ def call_method(method, inputs):
|
|||
@patch("sys.stdin.read", lambda *args: inputs)
|
||||
# @patch('sys.stdout.write', print)
|
||||
def _inner_call_method(_method):
|
||||
if 'stdin' in _method.__globals__:
|
||||
_method.__globals__['stdin'] = mock_stdin
|
||||
if 'stdout' in _method.__globals__:
|
||||
_method.__globals__['stdout'] = sys.stdout
|
||||
if "stdin" in _method.__globals__:
|
||||
_method.__globals__["stdin"] = mock_stdin
|
||||
if "stdout" in _method.__globals__:
|
||||
_method.__globals__["stdout"] = sys.stdout
|
||||
try:
|
||||
return _method()
|
||||
except SystemExit as e:
|
||||
|
|
@ -259,11 +259,21 @@ def grade_call_based(
|
|||
if method is None:
|
||||
return
|
||||
|
||||
all_inputs = [[json.loads(line) for line in inputs.split("\n")] if isinstance(inputs, str) else inputs for inputs in all_inputs]
|
||||
all_outputs = [json.loads(output) if isinstance(output, str) else output[0] for output in all_outputs]
|
||||
all_inputs = [
|
||||
(
|
||||
[json.loads(line) for line in inputs.split("\n")]
|
||||
if isinstance(inputs, str)
|
||||
else inputs
|
||||
)
|
||||
for inputs in all_inputs
|
||||
]
|
||||
all_outputs = [
|
||||
json.loads(output) if isinstance(output, str) else output[0]
|
||||
for output in all_outputs
|
||||
]
|
||||
|
||||
#all_inputs = [[json.loads(line) for line in inputs.split("\n")] for inputs in all_inputs]
|
||||
#all_outputs = [json.loads(output) for output in all_outputs]
|
||||
# all_inputs = [[json.loads(line) for line in inputs.split("\n")] for inputs in all_inputs]
|
||||
# all_outputs = [json.loads(output) for output in all_outputs]
|
||||
|
||||
total_execution = 0
|
||||
all_results = []
|
||||
|
|
@ -327,6 +337,7 @@ def grade_call_based(
|
|||
|
||||
return all_results, {"execution time": total_execution}
|
||||
|
||||
|
||||
def decimals_are_close(a: Decimal, b: Decimal, tol: Decimal = Decimal("1e-4")) -> bool:
|
||||
return abs(a - b) <= tol
|
||||
|
||||
|
|
@ -450,7 +461,10 @@ def grade_stdio(
|
|||
if decimal_prediction_line == decimal_gtout_line:
|
||||
continue
|
||||
|
||||
if len(decimal_prediction_line) == len(decimal_gtout_line) and all(decimals_are_close(a, b) for a, b in zip(decimal_prediction_line, decimal_gtout_line)):
|
||||
if len(decimal_prediction_line) == len(decimal_gtout_line) and all(
|
||||
decimals_are_close(a, b)
|
||||
for a, b in zip(decimal_prediction_line, decimal_gtout_line)
|
||||
):
|
||||
continue
|
||||
|
||||
all_results.append(-2)
|
||||
|
|
@ -460,14 +474,22 @@ def grade_stdio(
|
|||
return all_results, {"execution time": total_execution_time}
|
||||
|
||||
|
||||
@app.function(max_containers=100,cpu=1.0, restrict_modal_access=False, max_inputs=1, timeout=600, block_network=False, retries=3)
|
||||
@app.function(
|
||||
max_containers=100,
|
||||
cpu=1.0,
|
||||
restrict_modal_access=False,
|
||||
max_inputs=1,
|
||||
timeout=600,
|
||||
block_network=False,
|
||||
retries=3,
|
||||
)
|
||||
def run_test(sample, test=None, debug=False, timeout=15):
|
||||
"""
|
||||
if test(generated_code) is not None it'll try to run the code.
|
||||
otherwise it'll just return an input and output pair.
|
||||
"""
|
||||
print("VALID, went into function")
|
||||
#test = json.loads(test)
|
||||
# test = json.loads(test)
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
|
||||
# Disable functionalities that can make destructive changes to the test.
|
||||
|
|
@ -478,7 +500,7 @@ def run_test(sample, test=None, debug=False, timeout=15):
|
|||
print(f"start = {datetime.now().time()}")
|
||||
|
||||
try:
|
||||
#in_outs = json.loads(sample["input_output"])
|
||||
# in_outs = json.loads(sample["input_output"])
|
||||
in_outs = sample["tests"]
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
|
@ -630,4 +652,3 @@ def reliability_guard(maximum_memory_bytes=None):
|
|||
sys.modules["resource"] = None
|
||||
sys.modules["psutil"] = None
|
||||
sys.modules["tkinter"] = None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue