diff --git a/environments/kernelbench_env.py b/environments/kernelbench_env.py new file mode 100644 index 00000000..c84e8d0f --- /dev/null +++ b/environments/kernelbench_env.py @@ -0,0 +1,205 @@ +# kernelbench_env.py +import os +import json +import subprocess +from pathlib import Path +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, APIServerConfig, ScoredDataGroup +from atroposlib.type_definitions import Item, number +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +# KernelBench imports + +from src.eval import eval_kernel_against_ref # <- new import + + +KERNELBENCH_DIR = Path("/path/to/KernelBench") # ← point here to your clone + +class KBRow(TypedDict): + """Single‑task record (prompt text plus meta).""" + prompt: str # full prompt given to the LLM + sample_path: str + + +class KBEnv(BaseEnv): + """ + A stripped‑down Atropos environment that only handles Level‑1 / problem‑1 + (square matrix multiplication). It generates one kernel per rollout + group, writes the kernel to the expected `runs/{run_name}` layout, then + invokes KernelBench's evaluation script to obtain a scalar reward. + """ + + name = "kernelbench" + + # ---------- Static config helpers ---------------------------------------- + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_cfg = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=4, # 4 candidate kernels per step + max_token_length=2048, + batch_size=4, + steps_per_eval=50, + total_steps=1000, + rollout_server_url="http://localhost:8000", + use_wandb=False, # flip on if you want logging + wandb_name="kb_level1_prob1", + ) + + server_cfgs = [ + APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="DUMMY_KB_KEY", # fill in if proxy requires it + num_requests_for_eval=64, + ) + ] + return env_cfg, server_cfgs + + # --------------------- Data ------------------------------------------------ + async def setup(self): + """ + Nothing to load from disk – we construct the single prompt on‑the‑fly + with KernelBench's PromptConstructor so that it exactly matches their + evaluation format. + """ + # Hard‑code the HF dataset identifier for problem 1 + self.problem_spec = { + "level": 1, + "problem_id": 1, + "problem_file": "1_Square_matrix_multiplication_.py", + } + self.iter = 0 + with open("prompt.txt", "r", encoding="utf-8") as f: + self.prompt = f.read() + + self.sample_path="./sample.py" + self.reward_buffer = list() + + # --------------------- Rollout / scoring ---------------------------------- + async def collect_trajectories( + self, item: KBRow + ) -> Tuple[ScoredDataGroup, List[Item]]: + """ + Ask the LLM `group_size` times; each completion should be *only* the + CUDA / Triton kernel (per KernelBench docs). We store them to + runs/{run_name}/{level}/{id}/sample_.cu so that the official + evaluator picks them up. + """ + user_msg = {"role": "user", "content": self.prompt} + + chat_completions = await self.server.chat_completion( + messages=[user_msg], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=0.0, + ) + + # Path: runs//level_1/1/ + run_dir = run_dir = KERNELBENCH_DIR / "runs" / self.config.wandb_name / "level_1" / "1" + run_dir.mkdir(parents=True, exist_ok=True) + + to_score: List[Dict] = [] + to_backlog: list() + for i, choice in enumerate(chat_completions.choices): + kernel_code = choice.message.content + sample_path = run_dir / f"sample_{i}.cu" + sample_path.write_text(kernel_code, encoding="utf‑8") + + messages = (user_msg, {"role": "assistant", "content": kernel_code}) + to_score.append( + { + "messages": messages, + "sample_path": str(sample_path), + "finish_reason": choice.finish_reason, + } + ) + + to_postprocess = await self.score(to_score) + + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data: List[Dict] + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + + scores = ScoredDataGroup(tokens=[], masks=[], scores=[]) + + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + + # where we will build + compile kernels + build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}") + os.makedirs(build_dir, exist_ok=True) + + for item in rollout_group_data: + generated_src = item["prompt"] + custom_model_src = Path(item["sample_path"]).read_text() + + eval_result = eval_kernel_against_ref( + ref_arch_src=generated_src, # blank per instructions + custom_model_src=custom_model_src, + measure_performance=True, + verbose=True, + num_correct_trials=1, + num_perf_trials=1, + build_dir=build_dir, + ) + + compiled_flag = bool(getattr(eval_result, "compiled", False)) + runtime_val = float(getattr(eval_result, "runtime", -1.0)) + + reward = 0.3 * (1 if compiled_flag else 0) + runtime_val + + out_dict = tokenize_for_trainer(self.tokenizer, item["messages"], item["finish_reason"]) + + scores["tokens"].append(out_dict["tokens"]) + scores["masks"].append(out_dict["masks"]) + scores["scores"].append(reward) + + for score in scores["scores"]: + self.reward_buffer.append(max(score, 0)) + + return scores if scores["tokens"] else None + + async def get_next_item(self) -> KBRow: + """Return the same single problem every time (env is tiny).""" + return KBRow(prompt=self.prompt, sample_path=self.sample_path) + + async def evaluate(self, *args, **kwargs): + """Evaluate the current model on a set of test problems.""" + # For now, we'll just log the average reward from the reward buffer + if self.reward_buffer: + avg_reward = sum(self.reward_buffer) / len(self.reward_buffer) + self.eval_metrics.append(("eval/avg_reward", avg_reward)) + self.reward_buffer = list() + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/reward"] = sum( + self.reward_buffer + ) / len(self.reward_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.reward_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + +# ----------------------------------------------------------------------------- + + +if __name__ == "__main__": + KBEnv.cli() +