diff --git a/environments/swe_rl_env.py b/environments/swe_rl_env.py index d202756f..e34dd505 100644 --- a/environments/swe_rl_env.py +++ b/environments/swe_rl_env.py @@ -30,7 +30,6 @@ from atroposlib.envs.base import ( EvalHandlingEnum, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # Prompt Constants THINKING_SYSTEM_PROMPT_CONTENT = "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem." # noqa: E501 @@ -590,13 +589,18 @@ class SWERLEnv(BaseEnv): ) try: - completions = await self.server.completion( - prompt=prompt_for_llm, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=0.8, - stop=stop_tokens, - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completions = await managed.completion( + prompt=prompt_for_llm, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=0.8, + stop=stop_tokens, + ) + + state = managed.get_state() + nodes = state["nodes"] + except aiohttp.ClientError as e: self.logger.error( f"HTTP client error during completion request for item {item_id}: {type(e).__name__}: {e}" @@ -629,16 +633,23 @@ class SWERLEnv(BaseEnv): ) # Prepare to collect all conversations and their potential scores for this item - # This list will hold tuples of (conversation_messages, oracle_patch, finish_reason) + # This list will hold dicts with conversation_messages, oracle_patch, finish_reason, tokens, masks, logprobs # which is the input format expected by the self.score method. raw_rollouts_for_scoring = [] - for choice in completions.choices: + for i, choice in enumerate(completions.choices): current_trajectory_messages = messages_for_llm_prompt + [ {"role": "assistant", "content": choice.text.strip()} ] raw_rollouts_for_scoring.append( - (current_trajectory_messages, oracle_patch, choice.finish_reason) + { + "messages": current_trajectory_messages, + "oracle_patch": oracle_patch, + "finish_reason": choice.finish_reason, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } ) if not raw_rollouts_for_scoring: @@ -654,8 +665,8 @@ class SWERLEnv(BaseEnv): num_scored_rollouts = len(scored_data.get("scores", [])) for i in range(num_scored_rollouts): - # raw_rollouts_for_scoring[i][0] is the list of message dicts for the i-th rollout - conversation_messages = raw_rollouts_for_scoring[i][0] + # raw_rollouts_for_scoring[i]["messages"] is the list of message dicts for the i-th rollout + conversation_messages = raw_rollouts_for_scoring[i]["messages"] score_for_rollout = scored_data["scores"][i] rollouts_with_scores_to_save.append( { @@ -812,7 +823,7 @@ class SWERLEnv(BaseEnv): return "\n".join(full_patch_parts) async def score( - self, rollout_group_data: List[Tuple[List[Dict[str, str]], str, str]] + self, rollout_group_data: List[Dict[str, any]] ) -> Optional[ScoredDataGroup]: scored_data = ScoredDataGroup() scored_data["tokens"] = [] @@ -820,6 +831,7 @@ class SWERLEnv(BaseEnv): scored_data["scores"] = [] scored_data["messages"] = [] scored_data["overrides"] = [] + scored_data["inference_logprobs"] = [] patch_format_correct_count_batch = 0 similarity_scores_batch_temp = [] @@ -842,7 +854,14 @@ class SWERLEnv(BaseEnv): # Collect all failed responses for immediate saving failed_responses_this_group = [] - for trajectory_messages, oracle_patch_str, finish_reason in rollout_group_data: + for rollout_item in rollout_group_data: + trajectory_messages = rollout_item["messages"] + oracle_patch_str = rollout_item["oracle_patch"] + finish_reason = rollout_item["finish_reason"] + tokens = rollout_item["tokens"] + masks = rollout_item["masks"] + logprobs = rollout_item["logprobs"] + assistant_response = "" if ( trajectory_messages @@ -955,28 +974,15 @@ class SWERLEnv(BaseEnv): } ) - try: - tokenized_output = tokenize_for_trainer( - tokenizer=self.tokenizer, - chat=trajectory_messages, - include_messages=True, - ) - except Exception as e: - self.logger.error(f"Tokenization failed: {e}") - continue - if ( - not tokenized_output - or not tokenized_output.get("tokens") - or not tokenized_output["tokens"][0] - ): + # Remove examples with insufficient context + if len([1 for i in masks if i != -100]) < 10: continue - scored_data["tokens"].append(tokenized_output["tokens"]) - scored_data["masks"].append(tokenized_output["masks"]) + scored_data["tokens"].append(tokens) + scored_data["masks"].append(masks) + scored_data["inference_logprobs"].append(logprobs) scored_data["scores"].append(reward) - scored_data["messages"].append( - tokenized_output.get("messages", trajectory_messages) - ) + scored_data["messages"].append(trajectory_messages) scored_data["overrides"].append(override_dict) if len(scored_data["scores"]) >= self.config.group_size: break @@ -1156,17 +1162,15 @@ class SWERLEnv(BaseEnv): failed_rollouts_with_scores_to_save = [] # Build the failed rollouts data structure - for i, (trajectory_messages, oracle_patch, finish_reason) in enumerate( - rollout_group_data - ): + for i, rollout_item in enumerate(rollout_group_data): if i < len(scored_data["scores"]): score_for_rollout = scored_data["scores"][i] failed_rollouts_with_scores_to_save.append( { - "conversation": trajectory_messages, # Full conversation history + "conversation": rollout_item["messages"], # Full conversation history "score": score_for_rollout, - "oracle_patch": oracle_patch, - "finish_reason": finish_reason, + "oracle_patch": rollout_item["oracle_patch"], + "finish_reason": rollout_item["finish_reason"], } ) @@ -1301,14 +1305,15 @@ class SWERLEnv(BaseEnv): ) try: - completions = await self.server.completion( - prompt=prompt_for_llm, - n=self.config.eval_n_samples, - max_tokens=self.config.max_token_length, - temperature=0.2, - stop=stop_tokens, - split="eval", - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completions = await managed.completion( + prompt=prompt_for_llm, + n=self.config.eval_n_samples, + max_tokens=self.config.max_token_length, + temperature=0.2, + stop=stop_tokens, + split="eval", + ) except aiohttp.ClientError as e: self.logger.error( f"HTTP client error during eval completion request for item {item_id}: {type(e).__name__}: {e}"