mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
262 lines
8.2 KiB
Python
262 lines
8.2 KiB
Python
"""
|
|
KernelBench Environment Setup Instructions
|
|
----------------------------------------
|
|
|
|
Before running this script, you need to install KernelBench:
|
|
|
|
1. Install KernelBench from source:
|
|
pip install git@github.com:ScalingIntelligence/KernelBench.git
|
|
cd KernelBench
|
|
pip install -r requirements.txt
|
|
pip install -e .
|
|
cd -
|
|
|
|
2. Set variables at the top of this script:
|
|
KERNELBENCH_LEVEL: The difficulty level (1-3)
|
|
KERNELBENCH_PROBLEM_NUMBER: The specific problem number to solve
|
|
KERNELBENCH_DIR: the absolute path to your KernelBench install
|
|
|
|
These environment variables will be used to configure the evaluation environment.
|
|
"""
|
|
|
|
import multiprocessing as mp
|
|
import os
|
|
from multiprocessing.context import TimeoutError # SECURITY FIX
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
|
|
|
from datasets import load_dataset
|
|
|
|
# KernelBench imports
|
|
from src.eval import eval_kernel_against_ref
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
|
|
# Set the start method to 'spawn' for CUDA compatibility
|
|
mp.set_start_method("spawn", force=True)
|
|
|
|
KERNELBENCH_DIR = Path("/path/to/KernelBench")
|
|
KERNELBENCH_LEVEL = 1
|
|
KERNELBENCH_PROBLEM_NUMBER = 1
|
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
|
|
|
|
# SECURITY FIX: bound how long we wait on a single kernel evaluation
|
|
EVAL_TIMEOUT_SECONDS = 300
|
|
|
|
|
|
def get_kernelbench_code(level: int, problem_id: int) -> str:
|
|
"""
|
|
Return the `code` string for a given KernelBench level/problem combo.
|
|
Raises ValueError if the problem_id is not found in that level.
|
|
"""
|
|
split = f"level_{level}"
|
|
ds = load_dataset("ScalingIntelligence/KernelBench", split=split)
|
|
|
|
# Keep only rows whose `problem_id` exactly matches the desired one
|
|
row = ds.filter(lambda x: x["problem_id"] == problem_id)
|
|
if len(row) == 0:
|
|
raise ValueError(f"{problem_id=} not found in {split=}")
|
|
|
|
return row[0]["code"]
|
|
|
|
|
|
class KBRow(TypedDict):
|
|
"""Single-task record (prompt text plus meta)."""
|
|
|
|
prompt: str # full prompt given to the LLM
|
|
sample_path: str
|
|
|
|
|
|
def evaluate_single_kernel(args):
|
|
"""Helper function to evaluate a single kernel in a process."""
|
|
item, build_dir, ref_code = args
|
|
generated_src = item["messages"][-1]["content"].strip("```python\n").strip("```")
|
|
|
|
# Initialize CUDA in the child process
|
|
import torch
|
|
|
|
torch.cuda.init()
|
|
|
|
eval_result = eval_kernel_against_ref(
|
|
original_model_src=ref_code,
|
|
custom_model_src=generated_src,
|
|
measure_performance=True,
|
|
verbose=True,
|
|
num_correct_trials=1,
|
|
num_perf_trials=1,
|
|
build_dir=build_dir,
|
|
device="cuda:7",
|
|
)
|
|
|
|
compiled_flag = bool(getattr(eval_result, "compiled", False))
|
|
runtime_val = float(getattr(eval_result, "runtime", -1.0))
|
|
reward = 0.3 * (1 if compiled_flag else 0) + runtime_val
|
|
|
|
return {
|
|
"messages": item["messages"],
|
|
"finish_reason": item["finish_reason"],
|
|
"reward": reward,
|
|
}
|
|
|
|
|
|
class KernelBenchEnv(BaseEnv):
|
|
|
|
name = "kernelbench_parallel"
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
|
env_cfg = BaseEnvConfig(
|
|
tokenizer_name="Qwen/Qwen3-4B",
|
|
group_size=2,
|
|
max_token_length=2048,
|
|
batch_size=1,
|
|
steps_per_eval=1,
|
|
total_steps=1000,
|
|
rollout_server_url="http://localhost:8000",
|
|
use_wandb=False,
|
|
wandb_name=f"kb_level{KERNELBENCH_LEVEL}_prob{KERNELBENCH_PROBLEM_NUMBER}_parallel",
|
|
)
|
|
|
|
server_cfgs = [
|
|
APIServerConfig(
|
|
model_name="Qwen/Qwen3-4B",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="DUMMY_KB_KEY",
|
|
num_requests_for_eval=64,
|
|
)
|
|
]
|
|
return env_cfg, server_cfgs
|
|
|
|
async def setup(self):
|
|
self.problem_spec = {
|
|
"level": KERNELBENCH_LEVEL,
|
|
"problem_id": KERNELBENCH_PROBLEM_NUMBER,
|
|
"problem_file": f"{KERNELBENCH_PROBLEM_NUMBER}_Square_matrix_multiplication_.py",
|
|
}
|
|
self.iter = 0
|
|
with open("prompt.txt", "r", encoding="utf-8") as f:
|
|
self.prompt = f.read()
|
|
|
|
self.ref_code = get_kernelbench_code(
|
|
KERNELBENCH_LEVEL, KERNELBENCH_PROBLEM_NUMBER
|
|
)
|
|
self.reward_buffer = list()
|
|
self.pool = mp.Pool(processes=24)
|
|
|
|
async def collect_trajectories(
|
|
self, item: KBRow
|
|
) -> Tuple[ScoredDataGroup, List[Item]]:
|
|
|
|
user_msg = {"role": "user", "content": self.prompt}
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
chat_completions = await managed.chat_completion(
|
|
messages=[user_msg],
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.0,
|
|
)
|
|
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1"
|
|
run_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
to_score: List[Dict] = []
|
|
to_backlog: list() = []
|
|
for i, choice in enumerate(chat_completions.choices):
|
|
kernel_code = choice.message.content
|
|
sample_path = run_dir / f"sample_{i}.cu"
|
|
sample_path.write_text(kernel_code, encoding="utf-8")
|
|
|
|
messages = (user_msg, {"role": "assistant", "content": kernel_code})
|
|
to_score.append(
|
|
{
|
|
"messages": messages,
|
|
"finish_reason": choice.finish_reason,
|
|
"tokens": nodes[i].tokens,
|
|
"masks": nodes[i].masked_tokens,
|
|
"logprobs": nodes[i].logprobs,
|
|
}
|
|
)
|
|
|
|
to_postprocess = await self.score(to_score)
|
|
|
|
return to_postprocess, to_backlog
|
|
|
|
async def score(
|
|
self, rollout_group_data: List[Dict]
|
|
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
|
scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[])
|
|
|
|
build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}")
|
|
os.makedirs(build_dir, exist_ok=True)
|
|
|
|
eval_args = [(item, build_dir, self.ref_code) for item in rollout_group_data]
|
|
|
|
results = []
|
|
for args in eval_args:
|
|
result = self.pool.apply_async(evaluate_single_kernel, args=(args,))
|
|
results.append(result)
|
|
|
|
for i, result in enumerate(results):
|
|
try:
|
|
# SECURITY FIX: prevent unbounded blocking / hard DoS
|
|
eval_result = result.get(timeout=EVAL_TIMEOUT_SECONDS)
|
|
reward = eval_result["reward"]
|
|
except TimeoutError:
|
|
# Treat timeouts as failed evaluations
|
|
reward = 0.0
|
|
|
|
tokens = rollout_group_data[i]["tokens"]
|
|
masks = rollout_group_data[i]["masks"]
|
|
logprobs = rollout_group_data[i]["logprobs"]
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["inference_logprobs"].append(logprobs)
|
|
scores["scores"].append(reward)
|
|
self.reward_buffer.append(max(reward, 0))
|
|
|
|
return scores if scores["tokens"] else None
|
|
|
|
async def get_next_item(self) -> KBRow:
|
|
return KBRow(prompt=self.prompt, sample_path="")
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
if self.reward_buffer:
|
|
avg_reward = sum(self.reward_buffer) / len(self.reward_buffer)
|
|
self.eval_metrics.append(("eval/avg_reward", avg_reward))
|
|
self.reward_buffer = list()
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
try:
|
|
wandb_metrics["train/reward"] = sum(self.reward_buffer) / len(
|
|
self.reward_buffer
|
|
)
|
|
except ZeroDivisionError:
|
|
pass
|
|
|
|
self.reward_buffer = list()
|
|
self.eval_metrics = list()
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def cleanup(self):
|
|
self.pool.close()
|
|
self.pool.join()
|
|
await super().cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
KernelBenchEnv.cli()
|