mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
convert swe_rl to managedserver
This commit is contained in:
parent
9d3dbd1a73
commit
be74c759e5
1 changed files with 53 additions and 48 deletions
|
|
@ -30,7 +30,6 @@ from atroposlib.envs.base import (
|
|||
EvalHandlingEnum,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Prompt Constants
|
||||
THINKING_SYSTEM_PROMPT_CONTENT = "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem." # noqa: E501
|
||||
|
|
@ -590,13 +589,18 @@ class SWERLEnv(BaseEnv):
|
|||
)
|
||||
|
||||
try:
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_for_llm,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8,
|
||||
stop=stop_tokens,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.completion(
|
||||
prompt=prompt_for_llm,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8,
|
||||
stop=stop_tokens,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
self.logger.error(
|
||||
f"HTTP client error during completion request for item {item_id}: {type(e).__name__}: {e}"
|
||||
|
|
@ -629,16 +633,23 @@ class SWERLEnv(BaseEnv):
|
|||
)
|
||||
|
||||
# Prepare to collect all conversations and their potential scores for this item
|
||||
# This list will hold tuples of (conversation_messages, oracle_patch, finish_reason)
|
||||
# This list will hold dicts with conversation_messages, oracle_patch, finish_reason, tokens, masks, logprobs
|
||||
# which is the input format expected by the self.score method.
|
||||
raw_rollouts_for_scoring = []
|
||||
|
||||
for choice in completions.choices:
|
||||
for i, choice in enumerate(completions.choices):
|
||||
current_trajectory_messages = messages_for_llm_prompt + [
|
||||
{"role": "assistant", "content": choice.text.strip()}
|
||||
]
|
||||
raw_rollouts_for_scoring.append(
|
||||
(current_trajectory_messages, oracle_patch, choice.finish_reason)
|
||||
{
|
||||
"messages": current_trajectory_messages,
|
||||
"oracle_patch": oracle_patch,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
if not raw_rollouts_for_scoring:
|
||||
|
|
@ -654,8 +665,8 @@ class SWERLEnv(BaseEnv):
|
|||
|
||||
num_scored_rollouts = len(scored_data.get("scores", []))
|
||||
for i in range(num_scored_rollouts):
|
||||
# raw_rollouts_for_scoring[i][0] is the list of message dicts for the i-th rollout
|
||||
conversation_messages = raw_rollouts_for_scoring[i][0]
|
||||
# raw_rollouts_for_scoring[i]["messages"] is the list of message dicts for the i-th rollout
|
||||
conversation_messages = raw_rollouts_for_scoring[i]["messages"]
|
||||
score_for_rollout = scored_data["scores"][i]
|
||||
rollouts_with_scores_to_save.append(
|
||||
{
|
||||
|
|
@ -812,7 +823,7 @@ class SWERLEnv(BaseEnv):
|
|||
return "\n".join(full_patch_parts)
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data: List[Tuple[List[Dict[str, str]], str, str]]
|
||||
self, rollout_group_data: List[Dict[str, any]]
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = []
|
||||
|
|
@ -820,6 +831,7 @@ class SWERLEnv(BaseEnv):
|
|||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
scored_data["overrides"] = []
|
||||
scored_data["inference_logprobs"] = []
|
||||
|
||||
patch_format_correct_count_batch = 0
|
||||
similarity_scores_batch_temp = []
|
||||
|
|
@ -842,7 +854,14 @@ class SWERLEnv(BaseEnv):
|
|||
# Collect all failed responses for immediate saving
|
||||
failed_responses_this_group = []
|
||||
|
||||
for trajectory_messages, oracle_patch_str, finish_reason in rollout_group_data:
|
||||
for rollout_item in rollout_group_data:
|
||||
trajectory_messages = rollout_item["messages"]
|
||||
oracle_patch_str = rollout_item["oracle_patch"]
|
||||
finish_reason = rollout_item["finish_reason"]
|
||||
tokens = rollout_item["tokens"]
|
||||
masks = rollout_item["masks"]
|
||||
logprobs = rollout_item["logprobs"]
|
||||
|
||||
assistant_response = ""
|
||||
if (
|
||||
trajectory_messages
|
||||
|
|
@ -955,28 +974,15 @@ class SWERLEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
|
||||
try:
|
||||
tokenized_output = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=trajectory_messages,
|
||||
include_messages=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Tokenization failed: {e}")
|
||||
continue
|
||||
if (
|
||||
not tokenized_output
|
||||
or not tokenized_output.get("tokens")
|
||||
or not tokenized_output["tokens"][0]
|
||||
):
|
||||
# Remove examples with insufficient context
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scored_data["tokens"].append(tokenized_output["tokens"])
|
||||
scored_data["masks"].append(tokenized_output["masks"])
|
||||
scored_data["tokens"].append(tokens)
|
||||
scored_data["masks"].append(masks)
|
||||
scored_data["inference_logprobs"].append(logprobs)
|
||||
scored_data["scores"].append(reward)
|
||||
scored_data["messages"].append(
|
||||
tokenized_output.get("messages", trajectory_messages)
|
||||
)
|
||||
scored_data["messages"].append(trajectory_messages)
|
||||
scored_data["overrides"].append(override_dict)
|
||||
if len(scored_data["scores"]) >= self.config.group_size:
|
||||
break
|
||||
|
|
@ -1156,17 +1162,15 @@ class SWERLEnv(BaseEnv):
|
|||
failed_rollouts_with_scores_to_save = []
|
||||
|
||||
# Build the failed rollouts data structure
|
||||
for i, (trajectory_messages, oracle_patch, finish_reason) in enumerate(
|
||||
rollout_group_data
|
||||
):
|
||||
for i, rollout_item in enumerate(rollout_group_data):
|
||||
if i < len(scored_data["scores"]):
|
||||
score_for_rollout = scored_data["scores"][i]
|
||||
failed_rollouts_with_scores_to_save.append(
|
||||
{
|
||||
"conversation": trajectory_messages, # Full conversation history
|
||||
"conversation": rollout_item["messages"], # Full conversation history
|
||||
"score": score_for_rollout,
|
||||
"oracle_patch": oracle_patch,
|
||||
"finish_reason": finish_reason,
|
||||
"oracle_patch": rollout_item["oracle_patch"],
|
||||
"finish_reason": rollout_item["finish_reason"],
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -1301,14 +1305,15 @@ class SWERLEnv(BaseEnv):
|
|||
)
|
||||
|
||||
try:
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_for_llm,
|
||||
n=self.config.eval_n_samples,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.2,
|
||||
stop=stop_tokens,
|
||||
split="eval",
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.completion(
|
||||
prompt=prompt_for_llm,
|
||||
n=self.config.eval_n_samples,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.2,
|
||||
stop=stop_tokens,
|
||||
split="eval",
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
self.logger.error(
|
||||
f"HTTP client error during eval completion request for item {item_id}: {type(e).__name__}: {e}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue