mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
convert bootcamp to use managedserver
This commit is contained in:
parent
8e851a5ad4
commit
ae101ea8e4
1 changed files with 44 additions and 29 deletions
|
|
@ -17,7 +17,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
from .bootcamp_registry import create_bootcamp, get_available_bootcamps
|
||||
|
||||
|
|
@ -168,14 +167,18 @@ class InternBootcampEnv(BaseEnv):
|
|||
messages, metadata = item
|
||||
logger.info(f"Collecting trajectories for item: {messages}")
|
||||
|
||||
# Get completions from the model using chat_completion
|
||||
completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Get completions from the model using chat_completion
|
||||
completions = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
to_score = []
|
||||
|
||||
|
|
@ -187,7 +190,16 @@ class InternBootcampEnv(BaseEnv):
|
|||
{"role": "assistant", "content": model_response}
|
||||
]
|
||||
|
||||
to_score.append((full_messages, metadata, model_response))
|
||||
to_score.append(
|
||||
{
|
||||
"messages": full_messages,
|
||||
"metadata": metadata,
|
||||
"model_response": model_response,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
# Score the trajectories immediately and return a ScoredDataGroup
|
||||
scored_data = await self.score(to_score)
|
||||
|
|
@ -202,8 +214,16 @@ class InternBootcampEnv(BaseEnv):
|
|||
scored_data["masks"] = []
|
||||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
scored_data["inference_logprobs"] = []
|
||||
|
||||
for rollout_item in rollout_group_data:
|
||||
messages = rollout_item["messages"]
|
||||
metadata = rollout_item["metadata"]
|
||||
model_response = rollout_item["model_response"]
|
||||
tokens = rollout_item["tokens"]
|
||||
masks = rollout_item["masks"]
|
||||
logprobs = rollout_item["logprobs"]
|
||||
|
||||
for messages, metadata, model_response in rollout_group_data:
|
||||
# Verify the response using the bootcamp
|
||||
identity = metadata["identity"]
|
||||
|
||||
|
|
@ -240,15 +260,9 @@ class InternBootcampEnv(BaseEnv):
|
|||
f"identity={identity}"
|
||||
)
|
||||
|
||||
# Tokenize for trainer
|
||||
tokens_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages,
|
||||
None,
|
||||
)
|
||||
|
||||
scored_data["tokens"].append(tokens_dict["tokens"])
|
||||
scored_data["masks"].append(tokens_dict["masks"])
|
||||
scored_data["tokens"].append(tokens)
|
||||
scored_data["masks"].append(masks)
|
||||
scored_data["inference_logprobs"].append(logprobs)
|
||||
scored_data["scores"].append(final_score)
|
||||
scored_data["messages"].append(messages)
|
||||
|
||||
|
|
@ -304,15 +318,16 @@ class InternBootcampEnv(BaseEnv):
|
|||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Get model response using chat_completion
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0, # Deterministic for evaluation
|
||||
top_p=1.0,
|
||||
split="eval",
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Get model response using chat_completion
|
||||
completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0, # Deterministic for evaluation
|
||||
top_p=1.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
model_response = completion.choices[0].message.content
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue