convert swe_rl to managedserver

This commit is contained in:
teknium 2025-11-14 06:13:02 +00:00
parent 9d3dbd1a73
commit be74c759e5

View file

@ -30,7 +30,6 @@ from atroposlib.envs.base import (
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Prompt Constants
THINKING_SYSTEM_PROMPT_CONTENT = "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem." # noqa: E501
@ -590,13 +589,18 @@ class SWERLEnv(BaseEnv):
)
try:
completions = await self.server.completion(
prompt=prompt_for_llm,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.8,
stop=stop_tokens,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.completion(
prompt=prompt_for_llm,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.8,
stop=stop_tokens,
)
state = managed.get_state()
nodes = state["nodes"]
except aiohttp.ClientError as e:
self.logger.error(
f"HTTP client error during completion request for item {item_id}: {type(e).__name__}: {e}"
@ -629,16 +633,23 @@ class SWERLEnv(BaseEnv):
)
# Prepare to collect all conversations and their potential scores for this item
# This list will hold tuples of (conversation_messages, oracle_patch, finish_reason)
# This list will hold dicts with conversation_messages, oracle_patch, finish_reason, tokens, masks, logprobs
# which is the input format expected by the self.score method.
raw_rollouts_for_scoring = []
for choice in completions.choices:
for i, choice in enumerate(completions.choices):
current_trajectory_messages = messages_for_llm_prompt + [
{"role": "assistant", "content": choice.text.strip()}
]
raw_rollouts_for_scoring.append(
(current_trajectory_messages, oracle_patch, choice.finish_reason)
{
"messages": current_trajectory_messages,
"oracle_patch": oracle_patch,
"finish_reason": choice.finish_reason,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
if not raw_rollouts_for_scoring:
@ -654,8 +665,8 @@ class SWERLEnv(BaseEnv):
num_scored_rollouts = len(scored_data.get("scores", []))
for i in range(num_scored_rollouts):
# raw_rollouts_for_scoring[i][0] is the list of message dicts for the i-th rollout
conversation_messages = raw_rollouts_for_scoring[i][0]
# raw_rollouts_for_scoring[i]["messages"] is the list of message dicts for the i-th rollout
conversation_messages = raw_rollouts_for_scoring[i]["messages"]
score_for_rollout = scored_data["scores"][i]
rollouts_with_scores_to_save.append(
{
@ -812,7 +823,7 @@ class SWERLEnv(BaseEnv):
return "\n".join(full_patch_parts)
async def score(
self, rollout_group_data: List[Tuple[List[Dict[str, str]], str, str]]
self, rollout_group_data: List[Dict[str, any]]
) -> Optional[ScoredDataGroup]:
scored_data = ScoredDataGroup()
scored_data["tokens"] = []
@ -820,6 +831,7 @@ class SWERLEnv(BaseEnv):
scored_data["scores"] = []
scored_data["messages"] = []
scored_data["overrides"] = []
scored_data["inference_logprobs"] = []
patch_format_correct_count_batch = 0
similarity_scores_batch_temp = []
@ -842,7 +854,14 @@ class SWERLEnv(BaseEnv):
# Collect all failed responses for immediate saving
failed_responses_this_group = []
for trajectory_messages, oracle_patch_str, finish_reason in rollout_group_data:
for rollout_item in rollout_group_data:
trajectory_messages = rollout_item["messages"]
oracle_patch_str = rollout_item["oracle_patch"]
finish_reason = rollout_item["finish_reason"]
tokens = rollout_item["tokens"]
masks = rollout_item["masks"]
logprobs = rollout_item["logprobs"]
assistant_response = ""
if (
trajectory_messages
@ -955,28 +974,15 @@ class SWERLEnv(BaseEnv):
}
)
try:
tokenized_output = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=trajectory_messages,
include_messages=True,
)
except Exception as e:
self.logger.error(f"Tokenization failed: {e}")
continue
if (
not tokenized_output
or not tokenized_output.get("tokens")
or not tokenized_output["tokens"][0]
):
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
continue
scored_data["tokens"].append(tokenized_output["tokens"])
scored_data["masks"].append(tokenized_output["masks"])
scored_data["tokens"].append(tokens)
scored_data["masks"].append(masks)
scored_data["inference_logprobs"].append(logprobs)
scored_data["scores"].append(reward)
scored_data["messages"].append(
tokenized_output.get("messages", trajectory_messages)
)
scored_data["messages"].append(trajectory_messages)
scored_data["overrides"].append(override_dict)
if len(scored_data["scores"]) >= self.config.group_size:
break
@ -1156,17 +1162,15 @@ class SWERLEnv(BaseEnv):
failed_rollouts_with_scores_to_save = []
# Build the failed rollouts data structure
for i, (trajectory_messages, oracle_patch, finish_reason) in enumerate(
rollout_group_data
):
for i, rollout_item in enumerate(rollout_group_data):
if i < len(scored_data["scores"]):
score_for_rollout = scored_data["scores"][i]
failed_rollouts_with_scores_to_save.append(
{
"conversation": trajectory_messages, # Full conversation history
"conversation": rollout_item["messages"], # Full conversation history
"score": score_for_rollout,
"oracle_patch": oracle_patch,
"finish_reason": finish_reason,
"oracle_patch": rollout_item["oracle_patch"],
"finish_reason": rollout_item["finish_reason"],
}
)
@ -1301,14 +1305,15 @@ class SWERLEnv(BaseEnv):
)
try:
completions = await self.server.completion(
prompt=prompt_for_llm,
n=self.config.eval_n_samples,
max_tokens=self.config.max_token_length,
temperature=0.2,
stop=stop_tokens,
split="eval",
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.completion(
prompt=prompt_for_llm,
n=self.config.eval_n_samples,
max_tokens=self.config.max_token_length,
temperature=0.2,
stop=stop_tokens,
split="eval",
)
except aiohttp.ClientError as e:
self.logger.error(
f"HTTP client error during eval completion request for item {item_id}: {type(e).__name__}: {e}"