mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
convert instruction following env to use managed server
This commit is contained in:
parent
4738fabd57
commit
6d6a02eb38
1 changed files with 22 additions and 9 deletions
|
|
@ -644,22 +644,33 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
)
|
||||
|
||||
try:
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8, # Temperature for diverse responses during training rollouts
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8, # Temperature for diverse responses during training rollouts
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception during completion generation: {e}")
|
||||
return None, []
|
||||
|
||||
to_score_list = []
|
||||
for choice in completions.choices:
|
||||
for i, choice in enumerate(completions.choices):
|
||||
trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy
|
||||
trajectory_messages.append({"role": "assistant", "content": choice.text})
|
||||
to_score_list.append(
|
||||
(tuple(trajectory_messages), answer_info)
|
||||
{
|
||||
"messages": tuple(trajectory_messages),
|
||||
"answer_info": answer_info,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
) # Pass answer_info
|
||||
|
||||
if not to_score_list:
|
||||
|
|
@ -677,7 +688,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# If scored_data is None, it might be because the group was skipped for being too easy
|
||||
# We need to calculate the scores ourselves to handle the item properly
|
||||
temp_scores = []
|
||||
for trajectory_messages, answer_info in to_score_list:
|
||||
for rollout_item in to_score_list:
|
||||
trajectory_messages = rollout_item["messages"]
|
||||
answer_info = rollout_item["answer_info"]
|
||||
model_response_text = trajectory_messages[-1]["content"]
|
||||
func_name = answer_info["func_name"]
|
||||
args_for_verifier = answer_info["args"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue