convert kernelbench env to use managedserver

This commit is contained in:
teknium 2025-11-14 10:15:01 +00:00
parent c4ecc42139
commit 8e851a5ad4

View file

@ -36,7 +36,6 @@ from atroposlib.envs.base import (
ScoredDataGroup,
)
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Set the start method to 'spawn' for CUDA compatibility
mp.set_start_method("spawn", force=True)
@ -164,12 +163,16 @@ class KernelBenchEnv(BaseEnv):
"""
user_msg = {"role": "user", "content": self.prompt}
chat_completions = await self.server.chat_completion(
messages=[user_msg],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.0,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=[user_msg],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.0,
)
state = managed.get_state()
nodes = state["nodes"]
# Path: runs/<RUN_NAME>/level_1/1/
run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1"
@ -187,6 +190,9 @@ class KernelBenchEnv(BaseEnv):
{
"messages": messages,
"finish_reason": choice.finish_reason,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
@ -197,7 +203,7 @@ class KernelBenchEnv(BaseEnv):
async def score(
self, rollout_group_data: List[Dict]
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup(tokens=[], masks=[], scores=[])
scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[])
# where we will build + compile kernels
build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}")
@ -213,17 +219,18 @@ class KernelBenchEnv(BaseEnv):
results.append(result)
# Wait for all evaluations to complete and process results
for result in results:
for i, result in enumerate(results):
eval_result = result.get() # This will wait for the result
reward = eval_result["reward"]
# Tokenize in the main process since tokenizer isn't pickleable
out_dict = tokenize_for_trainer(
self.tokenizer, eval_result["messages"], eval_result["finish_reason"]
)
# Use tokens, masks, and logprobs from managed_server nodes
tokens = rollout_group_data[i]["tokens"]
masks = rollout_group_data[i]["masks"]
logprobs = rollout_group_data[i]["logprobs"]
scores["tokens"].append(out_dict["tokens"])
scores["masks"].append(out_dict["masks"])
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(reward)
self.reward_buffer.append(max(reward, 0))