mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
convert kernelbench env to use managedserver
This commit is contained in:
parent
c4ecc42139
commit
8e851a5ad4
1 changed files with 22 additions and 15 deletions
|
|
@ -36,7 +36,6 @@ from atroposlib.envs.base import (
|
|||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Set the start method to 'spawn' for CUDA compatibility
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
|
@ -164,12 +163,16 @@ class KernelBenchEnv(BaseEnv):
|
|||
"""
|
||||
user_msg = {"role": "user", "content": self.prompt}
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[user_msg],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=[user_msg],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Path: runs/<RUN_NAME>/level_1/1/
|
||||
run_dir = KERNELBENCH_DIR / "runs" / "wandb" / "level_1" / "1"
|
||||
|
|
@ -187,6 +190,9 @@ class KernelBenchEnv(BaseEnv):
|
|||
{
|
||||
"messages": messages,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -197,7 +203,7 @@ class KernelBenchEnv(BaseEnv):
|
|||
async def score(
|
||||
self, rollout_group_data: List[Dict]
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
||||
scores = ScoredDataGroup(tokens=[], masks=[], scores=[], inference_logprobs=[])
|
||||
|
||||
# where we will build + compile kernels
|
||||
build_dir = os.path.join("build", "kernelbench", f"{1}", f"{1}")
|
||||
|
|
@ -213,17 +219,18 @@ class KernelBenchEnv(BaseEnv):
|
|||
results.append(result)
|
||||
|
||||
# Wait for all evaluations to complete and process results
|
||||
for result in results:
|
||||
for i, result in enumerate(results):
|
||||
eval_result = result.get() # This will wait for the result
|
||||
reward = eval_result["reward"]
|
||||
|
||||
# Tokenize in the main process since tokenizer isn't pickleable
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, eval_result["messages"], eval_result["finish_reason"]
|
||||
)
|
||||
# Use tokens, masks, and logprobs from managed_server nodes
|
||||
tokens = rollout_group_data[i]["tokens"]
|
||||
masks = rollout_group_data[i]["masks"]
|
||||
logprobs = rollout_group_data[i]["logprobs"]
|
||||
|
||||
scores["tokens"].append(out_dict["tokens"])
|
||||
scores["masks"].append(out_dict["masks"])
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["scores"].append(reward)
|
||||
self.reward_buffer.append(max(reward, 0))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue