diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index 62f9df8b..815812d9 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -35,7 +35,7 @@ class APIServerConfig(ServerBaseline): api_key: Optional[str] = Field(default="", description="API key for the server.") base_url: Optional[str] = Field(default="", description="Base URL for the server.") - server_type: Literal["openai"] = Field( + server_type: Literal["openai", "trl"] = Field( default="openai", description="Type of server to use, openai or trl" ) n_kwarg_is_ignored: bool = Field( diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index d81dd338..433c0f19 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -13,6 +13,7 @@ from atroposlib.envs.server_handling.server_baseline import ( ServerBaseline, ) from atroposlib.envs.server_handling.server_harness import ServerHarness +from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer class ServerManagerConfig(BaseModel): @@ -36,11 +37,15 @@ class ServerManager: if isinstance(configs, ServerBaseline): if configs.server_type == "openai": server_class = OpenAIServer + elif configs.server_type == "trl": + server_class = TrlVllmServer else: raise ValueError(f"Invalid server type: {configs.server_type}") else: if configs[0].server_type == "openai": server_class = OpenAIServer + elif configs[0].server_type == "trl": + server_class = TrlVllmServer else: raise ValueError(f"Invalid server type: {configs[0].server_type}") if testing: diff --git a/atroposlib/envs/server_handling/trl_vllm_server.py b/atroposlib/envs/server_handling/trl_vllm_server.py new file mode 100644 index 00000000..2dc18dd6 --- /dev/null +++ b/atroposlib/envs/server_handling/trl_vllm_server.py @@ -0,0 +1,281 @@ +""" +This is a server that interfaces with trl's vLLM server." + +Developed with much help from @winglian when they worked on integrating Atropos into Axolotl. +""" + +import asyncio +import time +import uuid +from typing import Optional + +import aiohttp +import numpy as np +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.completion import Completion +from tenacity import retry, stop_after_attempt, wait_random_exponential +from transformers import AutoTokenizer + +from atroposlib.envs.server_handling.openai_server import AsyncSemWithAdaptiveWeight +from atroposlib.envs.server_handling.server_baseline import APIServerConfig + + +class TrlVllmServer: + """ + A server that interfaces with trl's vLLM server. + """ + + def __init__(self, config: APIServerConfig): + self.config = config + self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once) + self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval) + self.server_healthy = True + self.attempts_list = [] + self.request_timings = [] + # in case eval is much different, we should keep different buffers + self.eval_attempts_list = [] + self.eval_request_timings = [] + self.check_task = None + self.initialized = False + self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) + + async def update_weight(self, weight: float) -> None: + # need to update sems + self.sem.update_weight(weight) + self.eval_sem.update_weight(weight) + + async def check_server_status_task(self): + # TODO: Implement server health check for trl's vLLM server + self.server_healthy = True + + async def wandb_metrics( + self, metrics_dict: Optional[dict], server_name: Optional[str] + ): + if server_name is None: + server_name = "server" + if len(self.request_timings) > 0: + metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean( + self.request_timings + ) + metrics_dict[f"server/{server_name}_request_time_std"] = np.std( + self.request_timings + ) + metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile( + self.request_timings, 99 + ) + if len(self.eval_request_timings) > 0: + metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean( + self.eval_request_timings + ) + metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std( + self.eval_request_timings + ) + metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile( + self.eval_request_timings, 99 + ) + if len(self.attempts_list) > 0: + metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean( + self.attempts_list + ) + if len(self.eval_attempts_list) > 0: + metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean( + self.eval_attempts_list + ) + return metrics_dict + + async def _chat_handle(self, **kwargs) -> ChatCompletion: + url = f"{self.config.base_url}/generate/" + prompt = kwargs.get("messages", []) + prompt = self.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True + ) + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json={ + "prompts": [prompt], + "n": kwargs.get("n", 1), + "repetition_penalty": kwargs.get("repetition_penalty", 1.0), + "temperature": kwargs.get("temperature", 1.0), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", -1), + "min_p": kwargs.get("min_p", 0.0), + "max_tokens": kwargs.get("max_tokens", 1024), + }, + ) as response: + completions = await response.json() + completions = ChatCompletion( + id=str(uuid.uuid4()), + object="chat.completion", + created=int(time.time()), + model=self.config.model_name, + choices=[ + Choice( + finish_reason=( + "stop" + if self.tokenizer.eos_token_id in completion + else "length" + ), + index=i, + message=ChatCompletionMessage( + content=self.tokenizer.decode(completion), + role="assistant", + ), + ) + for i, completion in enumerate(completions["completion_ids"]) + ], + ) + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + print(kwargs) + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._chat_handle(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._chat_handle(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def chat_completion(self, **kwargs) -> ChatCompletion: + if not self.initialized: + if ( + self.config.base_url is not None + ): # skip health check if using OpenAI API + self.check_task = asyncio.create_task(self.check_server_status_task()) + else: + self.server_healthy = True + self.initialized = True + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {} + stat_dict["attempts"] = 0 + if split == "train": + ret_data = await self._chat_comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + # Give separate eval workers, if desired, gotta go fast for those evals + ret_data = await self._chat_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return ret_data + + async def _comp_handle(self, **kwargs) -> ChatCompletion: + url = f"{self.config.base_url}/generate/" + prompt = kwargs.get("prompt", "") + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json={ + "prompts": [prompt], + "n": kwargs.get("n", 1), + "repetition_penalty": kwargs.get("repetition_penalty", 1.0), + "temperature": kwargs.get("temperature", 1.0), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", -1), + "min_p": kwargs.get("min_p", 0.0), + "max_tokens": kwargs.get("max_tokens", 1024), + }, + ) as response: + completions = await response.json() + completions = ChatCompletion( + id=str(uuid.uuid4()), + object="chat.completion", + created=int(time.time()), + model=self.config.model_name, + choices=[ + Choice( + finish_reason=( + "stop" + if self.tokenizer.eos_token_id in completion + else "length" + ), + index=i, + message=ChatCompletionMessage( + content=self.tokenizer.decode(completion), + role="assistant", + ), + ) + for i, completion in enumerate(completions["completion_ids"]) + ], + ) + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _comp(self, stat_dict, **kwargs) -> Completion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._comp_handle(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _comp_eval(self, stat_dict, **kwargs) -> Completion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._comp_handle(**kwargs) + stat_dict["end"] = time.time() + return completions + + async def completion(self, **kwargs) -> Completion: + if not self.initialized: + if ( + self.config.base_url is not None + ): # skip health check if using OpenAI API + self.check_task = asyncio.create_task(self.check_server_status_task()) + else: + self.server_healthy = True + self.initialized = True + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {} + stat_dict["attempts"] = 0 + if split == "train": + ret_data = await self._comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + # Give separate eval workers, if desired, gotta go fast for those evals + ret_data = await self._comp_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return ret_data diff --git a/environments/gsm8k_server axolotl.py b/environments/gsm8k_server axolotl.py new file mode 100644 index 00000000..244861f8 --- /dev/null +++ b/environments/gsm8k_server axolotl.py @@ -0,0 +1,300 @@ +import random +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.type_definitions import Item, number +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \\boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \\boxed{your answer here}""" + + +class GSM8kRow(TypedDict): + question: str + answer: str + + +class GSM8kEnv(BaseEnv): + + name = "gsm8k" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="Qwen/Qwen3-1.7B", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k", + ) + server_configs = [ + APIServerConfig( + base_url="http://localhost:9001", + api_key="x", + num_requests_for_eval=256, + model_name="Qwen/Qwen3-1.7B", + server_type="trl", + ), + ] + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42) + test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42) + self.test = list() + for item in test_data: + self.test.append( + { + "question": item["question"], + "gold_answer": item["answer"] + .split("#")[-1] + .strip() + .replace(",", ""), + } + ) + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, question: str, answer: str) -> number: + completion = await self.server.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + split="eval", + ) + gold_parsed = parse( + "\\boxed{" + answer + "}", + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + answer_parsed = parse( + completion.choices[0].message.content.split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + score = 1 if verify(answer_parsed, gold_parsed) else 0 + return score + + async def evaluate(self, *args, **kwargs): + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval(item["question"], item["gold_answer"]) + ) + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def collect_trajectories( + self, item: GSM8kRow + ) -> Tuple[ScoredDataGroup, list[Item]]: + user_message = {"role": "user", "content": item["question"]} + gold_answer = ( + "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + ) + + chat_completions = await self.server.chat_completion( + messages=[{"role": "system", "content": system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + ) + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "system", "content": system_prompt}, + user_message, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + "gold_answer": gold_answer, + "finish_reason": chat_completion.finish_reason, + } + ) + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + gold_parsed = parse( + rollout_group_data[0]["gold_answer"], + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + random.shuffle(rollout_group_data) + for item in rollout_group_data: + # print(item[0][-1]["content"]) + answer_parsed = parse( + item["messages"][-1]["content"].split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = verify(answer_parsed, gold_parsed) + # print( + # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}" + # ) + out_dict = tokenize_for_trainer( + self.tokenizer, item["messages"], item["finish_reason"] + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + if len(scores["tokens"]) >= self.config.group_size: + break + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + # check if all the same + # print(scores['scores']) + if all([score == 1 for score in scores["scores"]]): + # Do length penalty :) + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # What? But don't want to crash a run so just in case... + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + if all([scores["scores"][0] == score for score in scores["scores"]]): + return None # If all the same, we return None + return scores + else: + # If the gold solution is not parseable, we return None + return None + + async def get_next_item(self) -> GSM8kRow: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GSM8kEnv.cli()