mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
pre-commit. add KB path as variable
This commit is contained in:
parent
fbfe06771b
commit
9875f5dc06
1 changed files with 42 additions and 30 deletions
|
|
@ -14,36 +14,40 @@ Before running this script, you need to install KernelBench:
|
|||
2. Set variables at the top of this script:
|
||||
KERNELBENCH_LEVEL: The difficulty level (1-3)
|
||||
KERNELBENCH_PROBLEM_NUMBER: The specific problem number to solve
|
||||
KERNELBENCH_DIR: the absolute path to your KernelBench install
|
||||
|
||||
These environment variables will be used to configure the evaluation environment.
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, APIServerConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from datasets import load_dataset
|
||||
|
||||
# KernelBench imports
|
||||
from src.eval import eval_kernel_against_ref
|
||||
|
||||
# Set the start method to 'spawn' for CUDA compatibility
|
||||
mp.set_start_method('spawn', force=True)
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
KERNELBENCH_DIR = Path("/home/artem_nous/KernelBench")
|
||||
# Set the start method to 'spawn' for CUDA compatibility
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
KERNELBENCH_DIR = Path("/path/to/KernelBench")
|
||||
KERNELBENCH_LEVEL = 1
|
||||
KERNELBENCH_PROBLEM_NUMBER = 1
|
||||
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
|
||||
|
||||
|
||||
def get_kernelbench_code(level: int, problem_id: int) -> str:
|
||||
"""
|
||||
Return the `code` string for a given KernelBench level/problem combo.
|
||||
|
|
@ -59,11 +63,14 @@ def get_kernelbench_code(level: int, problem_id: int) -> str:
|
|||
|
||||
return row[0]["code"]
|
||||
|
||||
|
||||
class KBRow(TypedDict):
|
||||
"""Single‑task record (prompt text plus meta)."""
|
||||
prompt: str # full prompt given to the LLM
|
||||
|
||||
prompt: str # full prompt given to the LLM
|
||||
sample_path: str
|
||||
|
||||
|
||||
def evaluate_single_kernel(args):
|
||||
"""Helper function to evaluate a single kernel in a process."""
|
||||
item, build_dir, ref_code = args
|
||||
|
|
@ -71,6 +78,7 @@ def evaluate_single_kernel(args):
|
|||
|
||||
# Initialize CUDA in the child process
|
||||
import torch
|
||||
|
||||
torch.cuda.init()
|
||||
|
||||
eval_result = eval_kernel_against_ref(
|
||||
|
|
@ -81,7 +89,7 @@ def evaluate_single_kernel(args):
|
|||
num_correct_trials=1,
|
||||
num_perf_trials=1,
|
||||
build_dir=build_dir,
|
||||
device="cuda:7"
|
||||
device="cuda:7",
|
||||
)
|
||||
|
||||
compiled_flag = bool(getattr(eval_result, "compiled", False))
|
||||
|
|
@ -93,9 +101,10 @@ def evaluate_single_kernel(args):
|
|||
return {
|
||||
"messages": item["messages"],
|
||||
"finish_reason": item["finish_reason"],
|
||||
"reward": reward
|
||||
"reward": reward,
|
||||
}
|
||||
|
||||
|
||||
class KernelBenchEnv(BaseEnv):
|
||||
|
||||
name = "kernelbench_parallel"
|
||||
|
|
@ -104,13 +113,13 @@ class KernelBenchEnv(BaseEnv):
|
|||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_cfg = BaseEnvConfig(
|
||||
tokenizer_name="Qwen/Qwen3-4B",
|
||||
group_size=2,
|
||||
group_size=2,
|
||||
max_token_length=2048,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1000,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
use_wandb=False,
|
||||
use_wandb=False,
|
||||
wandb_name=f"kb_level{KERNELBENCH_LEVEL}_prob{KERNELBENCH_PROBLEM_NUMBER}_parallel",
|
||||
)
|
||||
|
||||
|
|
@ -118,7 +127,7 @@ class KernelBenchEnv(BaseEnv):
|
|||
APIServerConfig(
|
||||
model_name="Qwen/Qwen3-4B",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="DUMMY_KB_KEY",
|
||||
api_key="DUMMY_KB_KEY",
|
||||
num_requests_for_eval=64,
|
||||
)
|
||||
]
|
||||
|
|
@ -136,7 +145,9 @@ class KernelBenchEnv(BaseEnv):
|
|||
self.prompt = f.read()
|
||||
|
||||
# Get reference code directly from the dataset
|
||||
self.ref_code = get_kernelbench_code(KERNELBENCH_LEVEL, KERNELBENCH_PROBLEM_NUMBER)
|
||||
self.ref_code = get_kernelbench_code(
|
||||
KERNELBENCH_LEVEL, KERNELBENCH_PROBLEM_NUMBER
|
||||
)
|
||||
self.reward_buffer = list()
|
||||
# Create a process pool for parallel processing
|
||||
self.pool = mp.Pool(processes=24)
|
||||
|
|
@ -161,7 +172,7 @@ class KernelBenchEnv(BaseEnv):
|
|||
)
|
||||
|
||||
# Path: runs/<RUN_NAME>/level_1/1/
|
||||
run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1"
|
||||
run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
to_score: List[Dict] = []
|
||||
|
|
@ -178,7 +189,7 @@ class KernelBenchEnv(BaseEnv):
|
|||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
||||
return to_postprocess, to_backlog
|
||||
|
|
@ -205,14 +216,12 @@ class KernelBenchEnv(BaseEnv):
|
|||
for result in results:
|
||||
eval_result = result.get() # This will wait for the result
|
||||
reward = eval_result["reward"]
|
||||
|
||||
|
||||
# Tokenize in the main process since tokenizer isn't pickleable
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
eval_result["messages"],
|
||||
eval_result["finish_reason"]
|
||||
self.tokenizer, eval_result["messages"], eval_result["finish_reason"]
|
||||
)
|
||||
|
||||
|
||||
scores["tokens"].append(out_dict["tokens"])
|
||||
scores["masks"].append(out_dict["masks"])
|
||||
scores["scores"].append(reward)
|
||||
|
|
@ -222,7 +231,9 @@ class KernelBenchEnv(BaseEnv):
|
|||
|
||||
async def get_next_item(self) -> KBRow:
|
||||
"""Return the same single problem every time (env is tiny)."""
|
||||
return KBRow(prompt=self.prompt, sample_path="") # sample_path is no longer used
|
||||
return KBRow(
|
||||
prompt=self.prompt, sample_path=""
|
||||
) # sample_path is no longer used
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate the current model on a set of test problems."""
|
||||
|
|
@ -236,9 +247,9 @@ class KernelBenchEnv(BaseEnv):
|
|||
wandb_metrics = {}
|
||||
|
||||
try:
|
||||
wandb_metrics["train/reward"] = sum(
|
||||
wandb_metrics["train/reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
) / len(self.reward_buffer)
|
||||
)
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
|
|
@ -252,7 +263,8 @@ class KernelBenchEnv(BaseEnv):
|
|||
self.pool.join()
|
||||
await super().cleanup()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
KernelBenchEnv.cli()
|
||||
KernelBenchEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue