mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
convert tool_calling_server to managedserver
This commit is contained in:
parent
0a3c15c7ad
commit
d8c68a93e3
1 changed files with 42 additions and 27 deletions
|
|
@ -15,7 +15,6 @@ from atroposlib.envs.base import (
|
||||||
Item,
|
Item,
|
||||||
ScoredDataGroup,
|
ScoredDataGroup,
|
||||||
)
|
)
|
||||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
||||||
|
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||||
|
|
@ -157,14 +156,15 @@ class SingleToolCallingEnv(BaseEnv):
|
||||||
messages, add_generation_prompt=True, tokenize=False
|
messages, add_generation_prompt=True, tokenize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get model completion using completion() instead of chat_completion()
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||||
completion = await self.server.completion(
|
# Get model completion using completion() instead of chat_completion()
|
||||||
prompt=prompt,
|
completion = await managed.completion(
|
||||||
n=1,
|
prompt=prompt,
|
||||||
max_tokens=1024 * 15,
|
n=1,
|
||||||
temperature=1.0,
|
max_tokens=1024 * 15,
|
||||||
split="eval",
|
temperature=1.0,
|
||||||
)
|
split="eval",
|
||||||
|
)
|
||||||
|
|
||||||
# Extract the model's response from the completion
|
# Extract the model's response from the completion
|
||||||
model_response = completion.choices[0].text
|
model_response = completion.choices[0].text
|
||||||
|
|
@ -289,13 +289,18 @@ class SingleToolCallingEnv(BaseEnv):
|
||||||
messages, add_generation_prompt=True, tokenize=False
|
messages, add_generation_prompt=True, tokenize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get completions from the model using completion() instead of chat_completion()
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||||
completions = await self.server.completion(
|
# Get completions from the model using completion() instead of chat_completion()
|
||||||
prompt=prompt,
|
completions = await managed.completion(
|
||||||
n=self.config.group_size,
|
prompt=prompt,
|
||||||
max_tokens=1024 * 15,
|
n=self.config.group_size,
|
||||||
temperature=0.8, # Using temperature to get diverse responses
|
max_tokens=1024 * 15,
|
||||||
)
|
temperature=0.8, # Using temperature to get diverse responses
|
||||||
|
)
|
||||||
|
|
||||||
|
state = managed.get_state()
|
||||||
|
nodes = state["nodes"]
|
||||||
|
|
||||||
to_score = list()
|
to_score = list()
|
||||||
|
|
||||||
for i, completion_choice in enumerate(completions.choices):
|
for i, completion_choice in enumerate(completions.choices):
|
||||||
|
|
@ -311,10 +316,13 @@ class SingleToolCallingEnv(BaseEnv):
|
||||||
|
|
||||||
# Add to scoring queue with expected answer
|
# Add to scoring queue with expected answer
|
||||||
to_score.append(
|
to_score.append(
|
||||||
(
|
{
|
||||||
tuple(trajectory_messages),
|
"messages": tuple(trajectory_messages),
|
||||||
item[1], # The expected tool call JSON
|
"expected_tool_call": item[1], # The expected tool call JSON
|
||||||
)
|
"tokens": nodes[i].tokens,
|
||||||
|
"masks": nodes[i].masked_tokens,
|
||||||
|
"logprobs": nodes[i].logprobs,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call score to get the scored data
|
# Call score to get the scored data
|
||||||
|
|
@ -330,9 +338,12 @@ class SingleToolCallingEnv(BaseEnv):
|
||||||
scores["tokens"] = list()
|
scores["tokens"] = list()
|
||||||
scores["masks"] = list()
|
scores["masks"] = list()
|
||||||
scores["scores"] = list()
|
scores["scores"] = list()
|
||||||
|
scores["inference_logprobs"] = list()
|
||||||
|
|
||||||
# Extract the expected JSONs from the answer
|
# Extract the expected JSONs from the answer
|
||||||
expected_jsons = self._extract_tool_call_jsons(rollout_group_data[0][1])
|
expected_jsons = self._extract_tool_call_jsons(
|
||||||
|
rollout_group_data[0]["expected_tool_call"]
|
||||||
|
)
|
||||||
|
|
||||||
# If we can't extract the expected tool call JSONs, skip this item
|
# If we can't extract the expected tool call JSONs, skip this item
|
||||||
if not expected_jsons:
|
if not expected_jsons:
|
||||||
|
|
@ -343,15 +354,18 @@ class SingleToolCallingEnv(BaseEnv):
|
||||||
|
|
||||||
for item in rollout_group_data:
|
for item in rollout_group_data:
|
||||||
# Extract the model's response
|
# Extract the model's response
|
||||||
model_response = item[0][-1]["content"]
|
model_response = item["messages"][-1]["content"]
|
||||||
|
|
||||||
# Score 1 if tool calls match, 0 otherwise
|
# Score 1 if tool calls match, 0 otherwise
|
||||||
reward = 1 if self._compare_tool_calls(model_response, item[1]) else 0
|
reward = (
|
||||||
|
1
|
||||||
|
if self._compare_tool_calls(model_response, item["expected_tool_call"])
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
# Tokenize the conversation for learning
|
tokens = item["tokens"]
|
||||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
masks = item["masks"]
|
||||||
tokens = out_dict["tokens"]
|
logprobs = item["logprobs"]
|
||||||
masks = out_dict["masks"]
|
|
||||||
|
|
||||||
# Remove examples with insufficient context
|
# Remove examples with insufficient context
|
||||||
if len([1 for i in masks if i != -100]) < 10:
|
if len([1 for i in masks if i != -100]) < 10:
|
||||||
|
|
@ -359,6 +373,7 @@ class SingleToolCallingEnv(BaseEnv):
|
||||||
|
|
||||||
scores["tokens"].append(tokens)
|
scores["tokens"].append(tokens)
|
||||||
scores["masks"].append(masks)
|
scores["masks"].append(masks)
|
||||||
|
scores["inference_logprobs"].append(logprobs)
|
||||||
scores["scores"].append(1.0 if reward else -1.0)
|
scores["scores"].append(1.0 if reward else -1.0)
|
||||||
|
|
||||||
# Break once we have enough examples
|
# Break once we have enough examples
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue