mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Prevent hangs in kernel evaluation by bounding worker waits
This commit is contained in:
parent
405efa8302
commit
4c4aba108c
1 changed files with 15 additions and 30 deletions
|
|
@ -23,6 +23,7 @@ import multiprocessing as mp
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
from multiprocessing.context import TimeoutError # SECURITY FIX
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
|
@ -46,6 +47,9 @@ KERNELBENCH_PROBLEM_NUMBER = 1
|
|||
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
|
||||
|
||||
# SECURITY FIX: bound how long we wait on a single kernel evaluation
|
||||
EVAL_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
def get_kernelbench_code(level: int, problem_id: int) -> str:
|
||||
"""
|
||||
|
|
@ -64,7 +68,7 @@ def get_kernelbench_code(level: int, problem_id: int) -> str:
|
|||
|
||||
|
||||
class KBRow(TypedDict):
|
||||
"""Single‑task record (prompt text plus meta)."""
|
||||
"""Single-task record (prompt text plus meta)."""
|
||||
|
||||
prompt: str # full prompt given to the LLM
|
||||
sample_path: str
|
||||
|
|
@ -95,8 +99,6 @@ def evaluate_single_kernel(args):
|
|||
runtime_val = float(getattr(eval_result, "runtime", -1.0))
|
||||
reward = 0.3 * (1 if compiled_flag else 0) + runtime_val
|
||||
|
||||
# Note: We can't use the tokenizer here since it's not pickleable
|
||||
# We'll return the raw data and tokenize in the main process
|
||||
return {
|
||||
"messages": item["messages"],
|
||||
"finish_reason": item["finish_reason"],
|
||||
|
|
@ -132,7 +134,6 @@ class KernelBenchEnv(BaseEnv):
|
|||
]
|
||||
return env_cfg, server_cfgs
|
||||
|
||||
# --------------------- Data ------------------------------------------------
|
||||
async def setup(self):
|
||||
self.problem_spec = {
|
||||
"level": KERNELBENCH_LEVEL,
|
||||
|
|
@ -143,24 +144,16 @@ class KernelBenchEnv(BaseEnv):
|
|||
with open("prompt.txt", "r", encoding="utf-8") as f:
|
||||
self.prompt = f.read()
|
||||
|
||||
# Get reference code directly from the dataset
|
||||
self.ref_code = get_kernelbench_code(
|
||||
KERNELBENCH_LEVEL, KERNELBENCH_PROBLEM_NUMBER
|
||||
)
|
||||
self.reward_buffer = list()
|
||||
# Create a process pool for parallel processing
|
||||
self.pool = mp.Pool(processes=24)
|
||||
|
||||
# --------------------- Rollout / scoring ----------------------------------
|
||||
async def collect_trajectories(
|
||||
self, item: KBRow
|
||||
) -> Tuple[ScoredDataGroup, List[Item]]:
|
||||
"""
|
||||
Ask the LLM `group_size` times; each completion should be *only* the
|
||||
CUDA / Triton kernel (per KernelBench docs). We store them to
|
||||
runs/{run_name}/{level}/{id}/sample_<n>.cu so that the official
|
||||
evaluator picks them up.
|
||||
"""
|
||||
|
||||
user_msg = {"role": "user", "content": self.prompt}
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
|
|
@ -174,7 +167,6 @@ class KernelBenchEnv(BaseEnv):
|
|||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Path: runs/<RUN_NAME>/level_1/1/
|
||||
run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -183,7 +175,7 @@ class KernelBenchEnv(BaseEnv):
|
|||
for i, choice in enumerate(chat_completions.choices):
|
||||
kernel_code = choice.message.content
|
||||
sample_path = run_dir / f"sample_{i}.cu"
|
||||
sample_path.write_text(kernel_code, encoding="utf‑8")
|
||||
sample_path.write_text(kernel_code, encoding="utf-8")
|
||||
|
||||
messages = (user_msg, {"role": "assistant", "content": kernel_code})
|
||||
to_score.append(
|
||||
|
|
@ -205,25 +197,25 @@ class KernelBenchEnv(BaseEnv):
|
|||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[])
|
||||
|
||||
# where we will build + compile kernels
|
||||
build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}")
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
|
||||
# Create arguments for parallel evaluation
|
||||
eval_args = [(item, build_dir, self.ref_code) for item in rollout_group_data]
|
||||
|
||||
# Run evaluations in parallel
|
||||
results = []
|
||||
for args in eval_args:
|
||||
result = self.pool.apply_async(evaluate_single_kernel, args=(args,))
|
||||
results.append(result)
|
||||
|
||||
# Wait for all evaluations to complete and process results
|
||||
for i, result in enumerate(results):
|
||||
eval_result = result.get() # This will wait for the result
|
||||
reward = eval_result["reward"]
|
||||
try:
|
||||
# SECURITY FIX: prevent unbounded blocking / hard DoS
|
||||
eval_result = result.get(timeout=EVAL_TIMEOUT_SECONDS)
|
||||
reward = eval_result["reward"]
|
||||
except TimeoutError:
|
||||
# Treat timeouts as failed evaluations
|
||||
reward = 0.0
|
||||
|
||||
# Use tokens, masks, and logprobs from managed_server nodes
|
||||
tokens = rollout_group_data[i]["tokens"]
|
||||
masks = rollout_group_data[i]["masks"]
|
||||
logprobs = rollout_group_data[i]["logprobs"]
|
||||
|
|
@ -237,13 +229,9 @@ class KernelBenchEnv(BaseEnv):
|
|||
return scores if scores["tokens"] else None
|
||||
|
||||
async def get_next_item(self) -> KBRow:
|
||||
"""Return the same single problem every time (env is tiny)."""
|
||||
return KBRow(
|
||||
prompt=self.prompt, sample_path=""
|
||||
) # sample_path is no longer used
|
||||
return KBRow(prompt=self.prompt, sample_path="")
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate the current model on a set of test problems."""
|
||||
if self.reward_buffer:
|
||||
avg_reward = sum(self.reward_buffer) / len(self.reward_buffer)
|
||||
self.eval_metrics.append(("eval/avg_reward", avg_reward))
|
||||
|
|
@ -265,13 +253,10 @@ class KernelBenchEnv(BaseEnv):
|
|||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources when done."""
|
||||
self.pool.close()
|
||||
self.pool.join()
|
||||
await super().cleanup()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
KernelBenchEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue