diff --git a/environments/mcp_env.py b/environments/mcp_env.py new file mode 100644 index 00000000..3d2601cc --- /dev/null +++ b/environments/mcp_env.py @@ -0,0 +1,464 @@ +import json +import random +import re +from typing import Dict, List, Optional, Tuple, Union + +import wandb +from datasets import load_dataset +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + Item, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( +system_prompt = ( + "You are an AI assistant capable of using tools to answer requests. " + "When a tool is required, you must generate a single JSON object specifying the tool and its arguments. " + "The JSON format is: {\"tool_name\": \"\", \"arguments\": {}}. " + "Do not output any text before or after this JSON object. " + "You may use tags for your internal reasoning before producing the JSON output." +) + +) + + +class SingleToolCallingEnv(BaseEnv): + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", + group_size=32, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=2000, + batch_size=1024, + steps_per_eval=20, + max_token_length=1024 * 16, + inference_weight=1.0, + wandb_name="toolcall_think", + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9004/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ), + APIServerConfig( + model_name="Qwen/Qwen2.5-1.5B-Instruct", + base_url="http://localhost:9005/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def create_rollout_table(self, wandb_metrics): + + if len(self.rollouts_for_wandb) > 0: + table = wandb.Table(columns=["text", "score", "expected_tool_call"]) + for group in self.rollouts_for_wandb: + for item in group: + table.add_data(item[0], item[1], item[2]) + wandb_metrics["train/rollouts"] = table + + self.rollouts_for_wandb = [] + return wandb_metrics + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """ + Log to wandb with comprehensive metrics. + """ + if wandb_metrics is None: + wandb_metrics = dict() + + # Try to calculate percent_correct, skip if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + await super().wandb_log(wandb_metrics) + + async def setup(self): + # Load the full dataset + full_dataset = load_dataset( + "NousResearch/XLAM-Atropos", + "default", + split="train", + ) + + full_dataset = full_dataset.shuffle(seed=42) + + # Create train/test split on the fly (e.g., 95% train, 5% test) + split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42) + + # Keep the splits as is - no need to reformat + self.train = split_dataset["train"] + self.test = split_dataset["test"] + + self.iter = 0 + + async def rollout_and_score_eval(self, test_item): + # Extract conversations from test item + conversations = test_item["conversations"] + + # Find system message and human message + system_message = next( + (msg for msg in conversations if msg["from"] == "system"), None + ) + human_message = next( + (msg for msg in conversations if msg["from"] == "human"), None + ) + expected_gpt_message = next( + (msg for msg in conversations if msg["from"] == "gpt"), None + ) + + if not human_message or not expected_gpt_message: + return 0 # Skip invalid conversations + + # Create messages for model + messages = [] + if system_message: + messages.append( + { + "role": "system", + "content": system_prompt + "\n\n" + system_message["value"], + } + ) + messages.append({"role": "user", "content": human_message["value"]}) + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get model completion using completion() instead of chat_completion() + completion = await self.server.completion( + prompt=prompt, + n=1, + max_tokens=1024 * 15, + temperature=1.0, + split="eval", + ) + + # Extract the model's response from the completion + model_response = completion.choices[0].text + expected_response = expected_gpt_message["value"] + + # Extract and compare tool calls + score = self._compare_tool_calls(model_response, expected_response) + return score + + def _extract_tool_call_jsons(self, text): + """ + Extract multiple JSONs from within tags + + Args: + text: Text containing tool calls + + Returns: + List of parsed JSON objects or empty list if extraction/parsing fails + """ + # Find all content between tags + matches = re.findall(r"\s*(.*?)\s*", text, re.DOTALL) + tool_calls = [] + + for match in matches: + try: + # Parse the JSON content + json_str = match + tool_call = json.loads(json_str) + tool_calls.append(tool_call) + except json.JSONDecodeError: + # Skip invalid JSON but continue processing other matches + continue + + return tool_calls + + def _compare_tool_calls(self, model_response, expected_response): + """ + Compare multiple tool calls by extracting JSONs from tags and comparing content + + Returns: + 1 if all tool calls match (all required calls are present with correct values), 0 otherwise + """ + # Extract JSONs from tool calls + model_jsons = self._extract_tool_call_jsons(model_response) + expected_jsons = self._extract_tool_call_jsons(expected_response) + + # If we couldn't extract any JSONs or the count doesn't match, return 0 + if not model_jsons or not expected_jsons: + return 0 + + # Copy the expected_jsons to avoid modifying the original + remaining_expected_jsons = expected_jsons.copy() + + # For each model JSON, try to find a matching expected JSON + for model_json in model_jsons: + found_match = False + + for i, expected_json in enumerate(remaining_expected_jsons): + if self._json_objects_match(model_json, expected_json): + # Remove the matched expected JSON + remaining_expected_jsons.pop(i) + found_match = True + break + + # If no match was found for this model JSON, return 0 + if not found_match: + return 0 + + # If we've matched all expected JSONs (none remaining), return 1 + return 1 if not remaining_expected_jsons else 0 + + def _json_objects_match(self, json1, json2): + """ + Check if two JSON objects match, with all fields in json2 existing in json1 + with the same values. + + Args: + json1: First JSON object + json2: Second JSON object (expected values) + + Returns: + True if objects match, False otherwise + """ + try: + # Check if all expected fields are in model response + for key in json2: + if key not in json1: + return False + + # For nested dictionaries (like 'arguments'), check all values + if isinstance(json2[key], dict) and isinstance(json1[key], dict): + for arg_key in json2[key]: + if arg_key not in json1[key]: + return False + if json2[key][arg_key] != json1[key][arg_key]: + return False + # For non-dictionary fields, check direct equality + elif json2[key] != json1[key]: + return False + + # All checks passed + return True + except Exception: + # Any error in comparison counts as failure + return False + + async def evaluate(self, *args, **kwargs): + eval_tasks = [] + for test_item in self.test: + eval_tasks.append(self.rollout_and_score_eval(test_item)) + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:#this one + # Extract messages from the item + messages = [] + for role_dict in item[0]: + messages.append(dict(role_dict)) + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get completions from the model using completion() instead of chat_completion() + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=1024 * 15, + temperature=0.8, # Using temperature to get diverse responses + ) + + to_score = list() + + for i, completion_choice in enumerate(completions.choices): + # Create a copy of the prompt messages + trajectory_messages = [] + for role_dict in item[0]: + trajectory_messages.append(dict(role_dict)) + + # Add the model's response + trajectory_messages.append( + {"role": "assistant", "content": completion_choice.text} + ) + + # Add to scoring queue with expected answer + to_score.append( + ( + tuple(trajectory_messages), + item[1], # The expected tool call JSON + ) + ) + + # Call score to get the scored data + scored_data = await self.score(to_score) + to_backlog = [] + + return scored_data, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + + # Extract the expected JSONs from the answer + expected_jsons = self._extract_tool_call_jsons(rollout_group_data[0][1]) + + # If we can't extract the expected tool call JSONs, skip this item + if not expected_jsons: + return None + + # Shuffle to avoid bias in selection + random.shuffle(rollout_group_data) + + for item in rollout_group_data: + # Extract the model's response + model_response = item[0][-1]["content"] + + # Score 1 if tool calls match, 0 otherwise + reward = 1 if self._compare_tool_calls(model_response, item[1]) else 0 + + # Tokenize the conversation for learning + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Remove examples with insufficient context + if len([1 for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + + # Break once we have enough examples + if len(scores["tokens"]) >= self.config.group_size: + break + + # Record success rate metrics + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + # Apply length penalty if all responses are correct + if all([score == 1.0 for score in scores["scores"]]): + # Calculate token lengths + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # Edge case protection + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + + # Check if all scores are the same (no learning signal) + if all(scores["scores"][0] == score for score in scores["scores"]): + return None + + return scores + +async def get_next_item(self): + next_item_data = self.train[self.iter % len(self.train)] + self.iter += 1 + + user_prompt_content = next_item_data["user_prompt_text"] # Key from your JSON + expected_mcp_call_dict = next_item_data["expected_mcp_call"] # Key from your JSON + + prompt_messages = [] + # Your system prompt is added globally or handled by the tokenizer's chat template + # If you had a per-item system message in your new dataset, you'd add it here. + # For now, we assume the global system_prompt is sufficient. + + prompt_messages.append( + frozenset({"role": "user", "content": user_prompt_content}.items()) + ) + + # The "answer" is now the expected MCP tool call dictionary. + # The score function will expect this to compare against the model's JSON output. + answer = expected_mcp_call_dict # This should be a Python dict + + return (tuple(prompt_messages), answer) + + async def add_rollouts_for_wandb( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + + # save rollout to trajectory + num_keep = self.config.num_rollouts_per_group_for_logging + if num_keep == -1: + num_keep = self.config.group_size + self.rollouts_for_wandb.append( + [ + ( + self.tokenizer.decode(scored_data["tokens"][i]), + scored_data["scores"][i], + item[1], # Just keep the expected tool call JSON + ) + for i in range(num_keep) + ] + ) + if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + + +if __name__ == "__main__": + SingleToolCallingEnv.cli()