""" KernelBench Environment Setup Instructions ---------------------------------------- Before running this script, you need to install KernelBench: 1. Install KernelBench from source: pip install git@github.com:ScalingIntelligence/KernelBench.git cd KernelBench pip install -r requirements.txt pip install -e . cd - 2. Set variables at the top of this script: KERNELBENCH_LEVEL: The difficulty level (1-3) KERNELBENCH_PROBLEM_NUMBER: The specific problem number to solve KERNELBENCH_DIR: the absolute path to your KernelBench install These environment variables will be used to configure the evaluation environment. """ import multiprocessing as mp import os from pathlib import Path from typing import Dict, List, Optional, Tuple, TypedDict, Union from multiprocessing.context import TimeoutError # SECURITY FIX from datasets import load_dataset # KernelBench imports from src.eval import eval_kernel_against_ref from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup, ) from atroposlib.type_definitions import Item # Set the start method to 'spawn' for CUDA compatibility mp.set_start_method("spawn", force=True) KERNELBENCH_DIR = Path("/path/to/KernelBench") KERNELBENCH_LEVEL = 1 KERNELBENCH_PROBLEM_NUMBER = 1 os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" # SECURITY FIX: bound how long we wait on a single kernel evaluation EVAL_TIMEOUT_SECONDS = 300 def get_kernelbench_code(level: int, problem_id: int) -> str: """ Return the `code` string for a given KernelBench level/problem combo. Raises ValueError if the problem_id is not found in that level. """ split = f"level_{level}" ds = load_dataset("ScalingIntelligence/KernelBench", split=split) # Keep only rows whose `problem_id` exactly matches the desired one row = ds.filter(lambda x: x["problem_id"] == problem_id) if len(row) == 0: raise ValueError(f"{problem_id=} not found in {split=}") return row[0]["code"] class KBRow(TypedDict): """Single-task record (prompt text plus meta).""" prompt: str # full prompt given to the LLM sample_path: str def evaluate_single_kernel(args): """Helper function to evaluate a single kernel in a process.""" item, build_dir, ref_code = args generated_src = item["messages"][-1]["content"].strip("```python\n").strip("```") # Initialize CUDA in the child process import torch torch.cuda.init() eval_result = eval_kernel_against_ref( original_model_src=ref_code, custom_model_src=generated_src, measure_performance=True, verbose=True, num_correct_trials=1, num_perf_trials=1, build_dir=build_dir, device="cuda:7", ) compiled_flag = bool(getattr(eval_result, "compiled", False)) runtime_val = float(getattr(eval_result, "runtime", -1.0)) reward = 0.3 * (1 if compiled_flag else 0) + runtime_val return { "messages": item["messages"], "finish_reason": item["finish_reason"], "reward": reward, } class KernelBenchEnv(BaseEnv): name = "kernelbench_parallel" @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_cfg = BaseEnvConfig( tokenizer_name="Qwen/Qwen3-4B", group_size=2, max_token_length=2048, batch_size=1, steps_per_eval=1, total_steps=1000, rollout_server_url="http://localhost:8000", use_wandb=False, wandb_name=f"kb_level{KERNELBENCH_LEVEL}_prob{KERNELBENCH_PROBLEM_NUMBER}_parallel", ) server_cfgs = [ APIServerConfig( model_name="Qwen/Qwen3-4B", base_url="http://localhost:9001/v1", api_key="DUMMY_KB_KEY", num_requests_for_eval=64, ) ] return env_cfg, server_cfgs async def setup(self): self.problem_spec = { "level": KERNELBENCH_LEVEL, "problem_id": KERNELBENCH_PROBLEM_NUMBER, "problem_file": f"{KERNELBENCH_PROBLEM_NUMBER}_Square_matrix_multiplication_.py", } self.iter = 0 with open("prompt.txt", "r", encoding="utf-8") as f: self.prompt = f.read() self.ref_code = get_kernelbench_code( KERNELBENCH_LEVEL, KERNELBENCH_PROBLEM_NUMBER ) self.reward_buffer = list() self.pool = mp.Pool(processes=24) async def collect_trajectories( self, item: KBRow ) -> Tuple[ScoredDataGroup, List[Item]]: user_msg = {"role": "user", "content": self.prompt} async with self.server.managed_server(tokenizer=self.tokenizer) as managed: chat_completions = await managed.chat_completion( messages=[user_msg], n=self.config.group_size, max_tokens=self.config.max_token_length, temperature=0.0, ) state = managed.get_state() nodes = state["nodes"] run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1" run_dir.mkdir(parents=True, exist_ok=True) to_score: List[Dict] = [] to_backlog: list() = [] for i, choice in enumerate(chat_completions.choices): kernel_code = choice.message.content sample_path = run_dir / f"sample_{i}.cu" sample_path.write_text(kernel_code, encoding="utf-8") messages = (user_msg, {"role": "assistant", "content": kernel_code}) to_score.append( { "messages": messages, "finish_reason": choice.finish_reason, "tokens": nodes[i].tokens, "masks": nodes[i].masked_tokens, "logprobs": nodes[i].logprobs, } ) to_postprocess = await self.score(to_score) return to_postprocess, to_backlog async def score( self, rollout_group_data: List[Dict] ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[]) build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}") os.makedirs(build_dir, exist_ok=True) eval_args = [(item, build_dir, self.ref_code) for item in rollout_group_data] results = [] for args in eval_args: result = self.pool.apply_async(evaluate_single_kernel, args=(args,)) results.append(result) for i, result in enumerate(results): try: # SECURITY FIX: prevent unbounded blocking / hard DoS eval_result = result.get(timeout=EVAL_TIMEOUT_SECONDS) reward = eval_result["reward"] except TimeoutError: # Treat timeouts as failed evaluations reward = 0.0 tokens = rollout_group_data[i]["tokens"] masks = rollout_group_data[i]["masks"] logprobs = rollout_group_data[i]["logprobs"] scores["tokens"].append(tokens) scores["masks"].append(masks) scores["inference_logprobs"].append(logprobs) scores["scores"].append(reward) self.reward_buffer.append(max(reward, 0)) return scores if scores["tokens"] else None async def get_next_item(self) -> KBRow: return KBRow(prompt=self.prompt, sample_path="") async def evaluate(self, *args, **kwargs): if self.reward_buffer: avg_reward = sum(self.reward_buffer) / len(self.reward_buffer) self.eval_metrics.append(("eval/avg_reward", avg_reward)) self.reward_buffer = list() async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} try: wandb_metrics["train/reward"] = sum(self.reward_buffer) / len( self.reward_buffer ) except ZeroDivisionError: pass self.reward_buffer = list() self.eval_metrics = list() await super().wandb_log(wandb_metrics) async def cleanup(self): self.pool.close() self.pool.join() await super().cleanup() if __name__ == "__main__": KernelBenchEnv.cli()