more linting

This commit is contained in:
Shannon Sands 2025-05-27 13:06:34 +10:00
parent 46892c7bdc
commit bfdf862829
2 changed files with 35 additions and 18 deletions

View file

@ -3,10 +3,9 @@ import random
import re
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,

View file

@ -1,6 +1,4 @@
import json
import random
import re
from typing import Dict, List, Optional, Tuple, Union
import wandb
@ -20,7 +18,7 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are an AI assistant capable of using tools to answer requests. "
"When a tool is required, you must generate a single JSON object specifying the tool and its arguments. "
"The JSON format is: {\"tool_name\": \"<tool_name>\", \"arguments\": {<key_value_args>}}. "
'The JSON format is: {"tool_name": "<tool_name>", "arguments": {<key_value_args>}}. '
"Do not output any text before or after this JSON object. "
"You may use <think></think> tags for your internal reasoning before producing the JSON output."
)
@ -57,8 +55,8 @@ class McpEnv(BaseEnv):
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
dataset_path="my_mcp_dataset.json", # ADDED: Path to your JSON dataset
num_rollouts_per_group_for_logging=4, # Added for logging
num_rollouts_to_keep=4 # Added for logging
num_rollouts_per_group_for_logging=4, # Added for logging
num_rollouts_to_keep=4, # Added for logging
)
server_configs = [
APIServerConfig(
@ -110,7 +108,7 @@ class McpEnv(BaseEnv):
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
wandb_metrics = await self.create_rollout_table(wandb_metrics) # Moved here
wandb_metrics = await self.create_rollout_table(wandb_metrics) # Moved here
await super().wandb_log(wandb_metrics)
async def setup(self):
@ -122,12 +120,16 @@ class McpEnv(BaseEnv):
# streaming=False,
# split="train",
# )
with open(self.config.dataset_path, 'r') as f: # Load dataset from path
full_dataset_list = json.load(f)
with open(self.config.dataset_path, "r") as f: # Load dataset from path
full_dataset_list = json.load(f)
full_dataset = load_dataset("json", data_files={"train": self.config.dataset_path})["train"] # Load JSON directly. This is for using 'datasets' methods
full_dataset = load_dataset(
"json", data_files={"train": self.config.dataset_path}
)[
"train"
] # Load JSON directly. This is for using 'datasets' methods
# full_dataset = full_dataset.shuffle(seed=42) # shuffle here
full_dataset = full_dataset.shuffle(seed=42) # Shuffling datasets object
full_dataset = full_dataset.shuffle(seed=42) # Shuffling datasets object
# Create train/test split on the fly (e.g., 95% train, 5% test)
@ -151,7 +153,11 @@ class McpEnv(BaseEnv):
completion = await self.server.completion(prompt=prompt, n=1, max_tokens=1024)
model_response = completion.choices[0].text
score_value = 1.0 if self._compare_mcp_tool_calls(model_response, expected_mcp_call_dict) else 0.0
score_value = (
1.0
if self._compare_mcp_tool_calls(model_response, expected_mcp_call_dict)
else 0.0
)
return score_value
async def _extract_mcp_tool_call(self, model_response_text: str) -> Optional[Dict]:
@ -162,7 +168,9 @@ class McpEnv(BaseEnv):
# If the model includes <think> tags, strip them first if they are outside the JSON
# For simplicity, assuming the model's primary output after <think> is the JSON
if "</think>" in model_response_text:
model_response_text = model_response_text.split("</think>", 1)[-1].strip()
model_response_text = model_response_text.split("</think>", 1)[
-1
].strip()
return json.loads(model_response_text)
except json.JSONDecodeError:
@ -170,7 +178,9 @@ class McpEnv(BaseEnv):
except Exception: # Other potential errors
return None
async def _compare_mcp_tool_calls(self, model_response_text: str, expected_mcp_call_dict: Dict) -> bool:
async def _compare_mcp_tool_calls(
self, model_response_text: str, expected_mcp_call_dict: Dict
) -> bool:
"""
Compares the model's generated MCP tool call with the expected one.
Returns:
@ -247,7 +257,9 @@ class McpEnv(BaseEnv):
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: # this one
async def collect_trajectories(
self, item
) -> Tuple[ScoredDataGroup, List]: # this one
# Extract messages from the item
messages = []
for role_dict in item[0]:
@ -309,12 +321,18 @@ class McpEnv(BaseEnv):
# print("rollout_group_data is empty, skipping")
# return None
#NEW CODE THAT IS WORKING!!!!
# NEW CODE THAT IS WORKING!!!!
for item in rollout_group_data:
model_response = item[0][-1]["content"]
expected_mcp_call_dict = item[1]
reward = 1.0 if await self._compare_mcp_tool_calls(model_response, expected_mcp_call_dict) else -1.0
reward = (
1.0
if await self._compare_mcp_tool_calls(
model_response, expected_mcp_call_dict
)
else -1.0
)
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]