convert bootcamp to use managedserver

This commit is contained in:
teknium 2025-11-14 10:17:48 +00:00
parent 8e851a5ad4
commit ae101ea8e4

View file

@ -17,7 +17,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .bootcamp_registry import create_bootcamp, get_available_bootcamps
@ -168,14 +167,18 @@ class InternBootcampEnv(BaseEnv):
messages, metadata = item
logger.info(f"Collecting trajectories for item: {messages}")
# Get completions from the model using chat_completion
completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get completions from the model using chat_completion
completions = await managed.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
state = managed.get_state()
nodes = state["nodes"]
to_score = []
@ -187,7 +190,16 @@ class InternBootcampEnv(BaseEnv):
{"role": "assistant", "content": model_response}
]
to_score.append((full_messages, metadata, model_response))
to_score.append(
{
"messages": full_messages,
"metadata": metadata,
"model_response": model_response,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
# Score the trajectories immediately and return a ScoredDataGroup
scored_data = await self.score(to_score)
@ -202,8 +214,16 @@ class InternBootcampEnv(BaseEnv):
scored_data["masks"] = []
scored_data["scores"] = []
scored_data["messages"] = []
scored_data["inference_logprobs"] = []
for rollout_item in rollout_group_data:
messages = rollout_item["messages"]
metadata = rollout_item["metadata"]
model_response = rollout_item["model_response"]
tokens = rollout_item["tokens"]
masks = rollout_item["masks"]
logprobs = rollout_item["logprobs"]
for messages, metadata, model_response in rollout_group_data:
# Verify the response using the bootcamp
identity = metadata["identity"]
@ -240,15 +260,9 @@ class InternBootcampEnv(BaseEnv):
f"identity={identity}"
)
# Tokenize for trainer
tokens_dict = tokenize_for_trainer(
self.tokenizer,
messages,
None,
)
scored_data["tokens"].append(tokens_dict["tokens"])
scored_data["masks"].append(tokens_dict["masks"])
scored_data["tokens"].append(tokens)
scored_data["masks"].append(masks)
scored_data["inference_logprobs"].append(logprobs)
scored_data["scores"].append(final_score)
scored_data["messages"].append(messages)
@ -304,15 +318,16 @@ class InternBootcampEnv(BaseEnv):
{"role": "user", "content": prompt},
]
# Get model response using chat_completion
completion = await self.server.chat_completion(
messages=messages,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0, # Deterministic for evaluation
top_p=1.0,
split="eval",
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get model response using chat_completion
completion = await managed.chat_completion(
messages=messages,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0, # Deterministic for evaluation
top_p=1.0,
split="eval",
)
model_response = completion.choices[0].message.content