convert instruction following env to use managed server

This commit is contained in:
teknium 2025-11-14 09:49:04 +00:00
parent 4738fabd57
commit 6d6a02eb38

View file

@ -644,22 +644,33 @@ class InstructionFollowingEnv(BaseEnv):
)
try:
completions = await self.server.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
state = managed.get_state()
nodes = state["nodes"]
except Exception as e:
print(f"ERROR: Exception during completion generation: {e}")
return None, []
to_score_list = []
for choice in completions.choices:
for i, choice in enumerate(completions.choices):
trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy
trajectory_messages.append({"role": "assistant", "content": choice.text})
to_score_list.append(
(tuple(trajectory_messages), answer_info)
{
"messages": tuple(trajectory_messages),
"answer_info": answer_info,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
) # Pass answer_info
if not to_score_list:
@ -677,7 +688,9 @@ class InstructionFollowingEnv(BaseEnv):
# If scored_data is None, it might be because the group was skipped for being too easy
# We need to calculate the scores ourselves to handle the item properly
temp_scores = []
for trajectory_messages, answer_info in to_score_list:
for rollout_item in to_score_list:
trajectory_messages = rollout_item["messages"]
answer_info = rollout_item["answer_info"]
model_response_text = trajectory_messages[-1]["content"]
func_name = answer_info["func_name"]
args_for_verifier = answer_info["args"]