diff --git a/environments/instruction_following_algorithm_environment.py b/environments/instruction_following_algorithm_environment.py index 56144477..69b2b3ec 100644 --- a/environments/instruction_following_algorithm_environment.py +++ b/environments/instruction_following_algorithm_environment.py @@ -644,22 +644,33 @@ class InstructionFollowingEnv(BaseEnv): ) try: - completions = await self.server.completion( - prompt=prompt_str, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=0.8, # Temperature for diverse responses during training rollouts - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completions = await managed.completion( + prompt=prompt_str, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=0.8, # Temperature for diverse responses during training rollouts + ) + + state = managed.get_state() + nodes = state["nodes"] + except Exception as e: print(f"ERROR: Exception during completion generation: {e}") return None, [] to_score_list = [] - for choice in completions.choices: + for i, choice in enumerate(completions.choices): trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy trajectory_messages.append({"role": "assistant", "content": choice.text}) to_score_list.append( - (tuple(trajectory_messages), answer_info) + { + "messages": tuple(trajectory_messages), + "answer_info": answer_info, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } ) # Pass answer_info if not to_score_list: @@ -677,7 +688,9 @@ class InstructionFollowingEnv(BaseEnv): # If scored_data is None, it might be because the group was skipped for being too easy # We need to calculate the scores ourselves to handle the item properly temp_scores = [] - for trajectory_messages, answer_info in to_score_list: + for rollout_item in to_score_list: + trajectory_messages = rollout_item["messages"] + answer_info = rollout_item["answer_info"] model_response_text = trajectory_messages[-1]["content"] func_name = answer_info["func_name"] args_for_verifier = answer_info["args"]