mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
more linting
This commit is contained in:
parent
46892c7bdc
commit
bfdf862829
2 changed files with 35 additions and 18 deletions
|
|
@ -3,10 +3,9 @@ import random
|
|||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import wandb
|
||||
|
|
@ -20,7 +18,7 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|||
system_prompt = (
|
||||
"You are an AI assistant capable of using tools to answer requests. "
|
||||
"When a tool is required, you must generate a single JSON object specifying the tool and its arguments. "
|
||||
"The JSON format is: {\"tool_name\": \"<tool_name>\", \"arguments\": {<key_value_args>}}. "
|
||||
'The JSON format is: {"tool_name": "<tool_name>", "arguments": {<key_value_args>}}. '
|
||||
"Do not output any text before or after this JSON object. "
|
||||
"You may use <think></think> tags for your internal reasoning before producing the JSON output."
|
||||
)
|
||||
|
|
@ -57,8 +55,8 @@ class McpEnv(BaseEnv):
|
|||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
dataset_path="my_mcp_dataset.json", # ADDED: Path to your JSON dataset
|
||||
num_rollouts_per_group_for_logging=4, # Added for logging
|
||||
num_rollouts_to_keep=4 # Added for logging
|
||||
num_rollouts_per_group_for_logging=4, # Added for logging
|
||||
num_rollouts_to_keep=4, # Added for logging
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
|
|
@ -110,7 +108,7 @@ class McpEnv(BaseEnv):
|
|||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics) # Moved here
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics) # Moved here
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
|
|
@ -122,12 +120,16 @@ class McpEnv(BaseEnv):
|
|||
# streaming=False,
|
||||
# split="train",
|
||||
# )
|
||||
with open(self.config.dataset_path, 'r') as f: # Load dataset from path
|
||||
full_dataset_list = json.load(f)
|
||||
with open(self.config.dataset_path, "r") as f: # Load dataset from path
|
||||
full_dataset_list = json.load(f)
|
||||
|
||||
full_dataset = load_dataset("json", data_files={"train": self.config.dataset_path})["train"] # Load JSON directly. This is for using 'datasets' methods
|
||||
full_dataset = load_dataset(
|
||||
"json", data_files={"train": self.config.dataset_path}
|
||||
)[
|
||||
"train"
|
||||
] # Load JSON directly. This is for using 'datasets' methods
|
||||
# full_dataset = full_dataset.shuffle(seed=42) # shuffle here
|
||||
full_dataset = full_dataset.shuffle(seed=42) # Shuffling datasets object
|
||||
full_dataset = full_dataset.shuffle(seed=42) # Shuffling datasets object
|
||||
|
||||
|
||||
# Create train/test split on the fly (e.g., 95% train, 5% test)
|
||||
|
|
@ -151,7 +153,11 @@ class McpEnv(BaseEnv):
|
|||
completion = await self.server.completion(prompt=prompt, n=1, max_tokens=1024)
|
||||
model_response = completion.choices[0].text
|
||||
|
||||
score_value = 1.0 if self._compare_mcp_tool_calls(model_response, expected_mcp_call_dict) else 0.0
|
||||
score_value = (
|
||||
1.0
|
||||
if self._compare_mcp_tool_calls(model_response, expected_mcp_call_dict)
|
||||
else 0.0
|
||||
)
|
||||
return score_value
|
||||
|
||||
async def _extract_mcp_tool_call(self, model_response_text: str) -> Optional[Dict]:
|
||||
|
|
@ -162,7 +168,9 @@ class McpEnv(BaseEnv):
|
|||
# If the model includes <think> tags, strip them first if they are outside the JSON
|
||||
# For simplicity, assuming the model's primary output after <think> is the JSON
|
||||
if "</think>" in model_response_text:
|
||||
model_response_text = model_response_text.split("</think>", 1)[-1].strip()
|
||||
model_response_text = model_response_text.split("</think>", 1)[
|
||||
-1
|
||||
].strip()
|
||||
|
||||
return json.loads(model_response_text)
|
||||
except json.JSONDecodeError:
|
||||
|
|
@ -170,7 +178,9 @@ class McpEnv(BaseEnv):
|
|||
except Exception: # Other potential errors
|
||||
return None
|
||||
|
||||
async def _compare_mcp_tool_calls(self, model_response_text: str, expected_mcp_call_dict: Dict) -> bool:
|
||||
async def _compare_mcp_tool_calls(
|
||||
self, model_response_text: str, expected_mcp_call_dict: Dict
|
||||
) -> bool:
|
||||
"""
|
||||
Compares the model's generated MCP tool call with the expected one.
|
||||
Returns:
|
||||
|
|
@ -247,7 +257,9 @@ class McpEnv(BaseEnv):
|
|||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: # this one
|
||||
async def collect_trajectories(
|
||||
self, item
|
||||
) -> Tuple[ScoredDataGroup, List]: # this one
|
||||
# Extract messages from the item
|
||||
messages = []
|
||||
for role_dict in item[0]:
|
||||
|
|
@ -309,12 +321,18 @@ class McpEnv(BaseEnv):
|
|||
# print("rollout_group_data is empty, skipping")
|
||||
# return None
|
||||
|
||||
#NEW CODE THAT IS WORKING!!!!
|
||||
# NEW CODE THAT IS WORKING!!!!
|
||||
for item in rollout_group_data:
|
||||
model_response = item[0][-1]["content"]
|
||||
expected_mcp_call_dict = item[1]
|
||||
|
||||
reward = 1.0 if await self._compare_mcp_tool_calls(model_response, expected_mcp_call_dict) else -1.0
|
||||
reward = (
|
||||
1.0
|
||||
if await self._compare_mcp_tool_calls(
|
||||
model_response, expected_mcp_call_dict
|
||||
)
|
||||
else -1.0
|
||||
)
|
||||
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue