diff --git a/environments/kernelbench_env/kernelbench_env.py b/environments/kernelbench_env/kernelbench_env.py index 364b8854..fe5d5e66 100644 --- a/environments/kernelbench_env/kernelbench_env.py +++ b/environments/kernelbench_env/kernelbench_env.py @@ -36,7 +36,6 @@ from atroposlib.envs.base import ( ScoredDataGroup, ) from atroposlib.type_definitions import Item -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # Set the start method to 'spawn' for CUDA compatibility mp.set_start_method("spawn", force=True) @@ -164,12 +163,16 @@ class KernelBenchEnv(BaseEnv): """ user_msg = {"role": "user", "content": self.prompt} - chat_completions = await self.server.chat_completion( - messages=[user_msg], - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=0.0, - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + chat_completions = await managed.chat_completion( + messages=[user_msg], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=0.0, + ) + + state = managed.get_state() + nodes = state["nodes"] # Path: runs//level_1/1/ run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1" @@ -187,6 +190,9 @@ class KernelBenchEnv(BaseEnv): { "messages": messages, "finish_reason": choice.finish_reason, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, } ) @@ -197,7 +203,7 @@ class KernelBenchEnv(BaseEnv): async def score( self, rollout_group_data: List[Dict] ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: - scores = ScoredDataGroup(tokens=[], masks=[], scores=[]) + scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[]) # where we will build + compile kernels build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}") @@ -213,17 +219,18 @@ class KernelBenchEnv(BaseEnv): results.append(result) # Wait for all evaluations to complete and process results - for result in results: + for i, result in enumerate(results): eval_result = result.get() # This will wait for the result reward = eval_result["reward"] - # Tokenize in the main process since tokenizer isn't pickleable - out_dict = tokenize_for_trainer( - self.tokenizer, eval_result["messages"], eval_result["finish_reason"] - ) + # Use tokens, masks, and logprobs from managed_server nodes + tokens = rollout_group_data[i]["tokens"] + masks = rollout_group_data[i]["masks"] + logprobs = rollout_group_data[i]["logprobs"] - scores["tokens"].append(out_dict["tokens"]) - scores["masks"].append(out_dict["masks"]) + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) scores["scores"].append(reward) self.reward_buffer.append(max(reward, 0))