""" KernelBench Environment Setup Instructions ---------------------------------------- Before running this script, you need to install KernelBench: 1. Install KernelBench from source: pip install git@github.com:ScalingIntelligence/KernelBench.git cd KernelBench pip install -r requirements.txt pip install -e . cd - 2. Set variables at the top of this script: KERNELBENCH_LEVEL: The difficulty level (1-3) KERNELBENCH_PROBLEM_NUMBER: The specific problem number to solve KERNELBENCH_DIR: the absolute path to your KernelBench install These environment variables will be used to configure the evaluation environment. """ import multiprocessing as mp import os from multiprocessing.context import TimeoutError # SECURITY FIX from pathlib import Path from typing import Dict, List, Optional, Tuple, TypedDict, Union from datasets import load_dataset # KernelBench imports from src.eval import eval_kernel_against_ref from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup, ) from atroposlib.type_definitions import Item # Set the start method to 'spawn' for CUDA compatibility mp.set_start_method("spawn", force=True) KERNELBENCH_DIR = Path("/path/to/KernelBench") KERNELBENCH_LEVEL = 1 KERNELBENCH_PROBLEM_NUMBER = 1 os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" # SECURITY FIX: bound how long we wait on a single kernel evaluation EVAL_TIMEOUT_SECONDS = 300 def get_kernelbench_code(level: int, problem_id: int) -> str: """ Return the `code` string for a given KernelBench level/problem combo. Raises ValueError if the problem_id is not found in that level. """ split = f"level_{level}" ds = load_dataset("ScalingIntelligence/KernelBench", split=split) # Keep only rows whose `problem_id` exactly matches the desired one row = ds.filter(lambda x: x["problem_id"] == problem_id) if len(row) == 0: raise ValueError(f"{problem_id=} not found in {split=}") return row[0]["code"] class KBRow(TypedDict): """Single-task record (prompt text plus meta).""" prompt: str # full prompt given to the LLM sample_path: str def evaluate_single_kernel(args): """Helper function to evaluate a single kernel in a process.""" item, build_dir, ref_code = args generated_src = item["messages"][-1]["content"].strip("```python\n").strip("```") # Initialize CUDA in the child process import torch torch.cuda.init() eval_result = eval_kernel_against_ref( original_model_src=ref_code, custom_model_src=generated_src, measure_performance=True, verbose=True, num_correct_trials=1, num_perf_trials=1, build_dir=build_dir, device="cuda:7", ) compiled_flag = bool(getattr(eval_result, "compiled", False)) runtime_val = float(getattr(eval_result, "runtime", -1.0)) reward = 0.3 * (1 if compiled_flag else 0) + runtime_val return { "messages": item["messages"], "finish_reason": item["finish_reason"], "reward": reward, } class KernelBenchEnv(BaseEnv): name = "kernelbench_parallel" @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_cfg = BaseEnvConfig( tokenizer_name="Qwen/Qwen3-4B", group_size=2, max_token_length=2048, batch_size=1, steps_per_eval=1, total_steps=1000, rollout_server_url="http://localhost:8000", use_wandb=False, wandb_name=f"kb_level{KERNELBENCH_LEVEL}_prob{KERNELBENCH_PROBLEM_NUMBER}_parallel", ) server_cfgs = [ APIServerConfig( model_name="Qwen/Qwen3-4B", base_url="http://localhost:9001/v1", api_key="DUMMY_KB_KEY", num_requests_for_eval=64, ) ] return env_cfg, server_cfgs async def setup(self): self.problem_spec = { "level": KERNELBENCH_LEVEL, "problem_id": KERNELBENCH_PROBLEM_NUMBER, "problem_file": f"{KERNELBENCH_PROBLEM_NUMBER}_Square_matrix_multiplication_.py", } self.iter = 0 with open("prompt.txt", "r", encoding="utf-8") as f: self.prompt = f.read() self.ref_code = get_kernelbench_code( KERNELBENCH_LEVEL, KERNELBENCH_PROBLEM_NUMBER ) self.reward_buffer = list() self.pool = mp.Pool(processes=24) async def collect_trajectories( self, item: KBRow ) -> Tuple[ScoredDataGroup, List[Item]]: user_msg = {"role": "user", "content": self.prompt} async with self.server.managed_server(tokenizer=self.tokenizer) as managed: chat_completions = await managed.chat_completion( messages=[user_msg], n=self.config.group_size, max_tokens=self.config.max_token_length, temperature=0.0, ) state = managed.get_state() nodes = state["nodes"] run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1" run_dir.mkdir(parents=True, exist_ok=True) to_score: List[Dict] = [] to_backlog: list() = [] for i, choice in enumerate(chat_completions.choices): kernel_code = choice.message.content sample_path = run_dir / f"sample_{i}.cu" sample_path.write_text(kernel_code, encoding="utf-8") messages = (user_msg, {"role": "assistant", "content": kernel_code}) to_score.append( { "messages": messages, "finish_reason": choice.finish_reason, "tokens": nodes[i].tokens, "masks": nodes[i].masked_tokens, "logprobs": nodes[i].logprobs, } ) to_postprocess = await self.score(to_score) return to_postprocess, to_backlog async def score( self, rollout_group_data: List[Dict] ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[]) build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}") os.makedirs(build_dir, exist_ok=True) eval_args = [(item, build_dir, self.ref_code) for item in rollout_group_data] results = [] for args in eval_args: result = self.pool.apply_async(evaluate_single_kernel, args=(args,)) results.append(result) for i, result in enumerate(results): try: # SECURITY FIX: prevent unbounded blocking / hard DoS eval_result = result.get(timeout=EVAL_TIMEOUT_SECONDS) reward = eval_result["reward"] except TimeoutError: # Treat timeouts as failed evaluations reward = 0.0 tokens = rollout_group_data[i]["tokens"] masks = rollout_group_data[i]["masks"] logprobs = rollout_group_data[i]["logprobs"] scores["tokens"].append(tokens) scores["masks"].append(masks) scores["inference_logprobs"].append(logprobs) scores["scores"].append(reward) self.reward_buffer.append(max(reward, 0)) return scores if scores["tokens"] else None async def get_next_item(self) -> KBRow: return KBRow(prompt=self.prompt, sample_path="") async def evaluate(self, *args, **kwargs): if self.reward_buffer: avg_reward = sum(self.reward_buffer) / len(self.reward_buffer) self.eval_metrics.append(("eval/avg_reward", avg_reward)) self.reward_buffer = list() async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} try: wandb_metrics["train/reward"] = sum(self.reward_buffer) / len( self.reward_buffer ) except ZeroDivisionError: pass self.reward_buffer = list() self.eval_metrics = list() await super().wandb_log(wandb_metrics) async def cleanup(self): self.pool.close() self.pool.join() await super().cleanup() if __name__ == "__main__": KernelBenchEnv.cli()