diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index e02011df..1cc690a4 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -506,17 +506,33 @@ class ServerManager: # 1. Attempt to pin to requested base_url if base_url: - for server in self.servers: - if server.server_healthy and server.config.base_url == base_url: - selected_server = server + # We add a small retry loop for pinning health checks to avoid "flapping" fallbacks + # especially during high load where a background health check might briefly fail. + for attempt in range(3): + # Optimization: Check if target server is healthy immediately + for server in self.servers: + if server.config.base_url == base_url: + if server.server_healthy: + selected_server = server + break + # If we found it but it's not healthy, we'll try again after sleep + break + + if selected_server: break + + # Only sleep if we actually need to wait for a health loop update + if attempt < 2: + await asyncio.sleep(0.1) # Reduce sleep to 100ms + if selected_server is None: warnings.warn( - f"Requested pinned base_url '{base_url}' is not healthy or not found. " - "Falling back to most available server." + f"Requested pinned base_url '{base_url}' is not healthy or not found " + "after 3 attempts. Falling back to most available server." ) + # 2. Fallback to most available if no pin or pin failed if selected_server is None: most_available_server = 0 diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index 63201b3e..c57664f3 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -40,29 +40,20 @@ class SGLangServer(APIServer): super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): + # Use a lightweight HTTP GET for health instead of a full inference call. + health_url = f"{self.config.base_url.replace('/v1', '')}/health" while True: try: - if chat_completion: - await self.openai.chat.completions.create( - model=self.config.model_name, - messages=[{"role": "user", "content": "hi"}], - max_tokens=1, - ) - else: - await self.openai.completions.create( - model=self.config.model_name, - prompt="hi", - max_tokens=1, - ) - self.server_healthy = True - except ( - aiohttp.ClientError, - openai.OpenAIError, - openai.APITimeoutError, - Exception, - ): + async with aiohttp.ClientSession() as session: + async with session.get(health_url, timeout=5) as response: + if response.status == 200: + self.server_healthy = True + else: + self.server_healthy = False + except Exception: self.server_healthy = False - await asyncio.sleep(1) + await asyncio.sleep(2) # Check every 2 seconds + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: """ diff --git a/atroposlib/envs/server_handling/sglang_stateful_server.py b/atroposlib/envs/server_handling/sglang_stateful_server.py index 04e1055e..14c4269a 100644 --- a/atroposlib/envs/server_handling/sglang_stateful_server.py +++ b/atroposlib/envs/server_handling/sglang_stateful_server.py @@ -19,7 +19,15 @@ class StatefulSGLangServer(SGLangServer): def __init__(self, config: APIServerConfig, reasoning_config=None): super().__init__(config, reasoning_config=reasoning_config) - + self._session = None + + async def _get_session(self): + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.config.timeout) + ) + return self._session + async def _tokens_and_logprobs_completion_wrapper(self, **kwargs) -> tuple[list, list, list, list]: """ Interacts with SGLang /generate via raw HTTP, optimized for stateful deltas. @@ -43,8 +51,6 @@ class StatefulSGLangServer(SGLangServer): kwargs.pop("model") # Extract new tokens (delta) if this is a continuation. - # If 'delta_input_ids' is in kwargs (set by ManagedServer), use that. - # Otherwise, fall back to the full prompt. is_delta_request = False if "delta_input_ids" in kwargs: payload_input_ids = kwargs.pop("delta_input_ids") @@ -60,30 +66,26 @@ class StatefulSGLangServer(SGLangServer): } async def fetch_generate(payload): - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.config.base_url.replace('/v1', '')}/generate", - json=payload, - headers={"Authorization": f"Bearer {self.config.api_key}"} if self.config.api_key else {}, - timeout=aiohttp.ClientTimeout(total=self.config.timeout), - ) as response: - # If it's a 4xx error (like cache miss on a stateful extension), - # we want to raise so we can catch it. - response.raise_for_status() - return await response.json() + session = await self._get_session() + async with session.post( + f"{self.config.base_url.replace('/v1', '')}/generate", + json=payload, + headers={"Authorization": f"Bearer {self.config.api_key}"} if self.config.api_key else {}, + ) as response: + response.raise_for_status() + return await response.json() try: results = await fetch_generate(request_data) except Exception as e: if is_delta_request: - warnings.warn(f"Stateful request to SGLang failed ({e}). Attempting stateless fallback rebuild...") - # Stateless Rebuild: Send the full history because the worker cache was evicted or unavailable. + warnings.warn(f"Stateful request backfired ({e}). Attempting stateless fallback...") request_data["input_ids"] = prompt_tokens_full results = await fetch_generate(request_data) else: - # If it wasn't a delta request and it failed, throw it up. raise e + if not isinstance(results, list): results = [results] diff --git a/benchmark_stateful_perf.py b/benchmark_stateful_perf.py new file mode 100644 index 00000000..fa03a0a9 --- /dev/null +++ b/benchmark_stateful_perf.py @@ -0,0 +1,119 @@ +import asyncio +import json +import os +import sys +import hashlib +import argparse +import time +import statistics +from typing import List, Dict +from transformers import AutoTokenizer + +# Add atropos to path +sys.path.append(os.getcwd()) + +from atroposlib.envs.server_handling.server_manager import ServerManager, APIServerConfig +from atroposlib.envs.server_handling.sglang_stateful_server import StatefulSGLangServer + +# --------------------------------------------------------------------------- +# HARDWARE BENCHMARK SUITE +# --------------------------------------------------------------------------- +async def run_benchmark(worker_urls: List[str], num_conversations: int = 5, turns_per_conv: int = 4): + print(f"\n{'='*60}") + print(f"BENCHMARKING: STATELESS vs STATEFUL SGLANG") + print(f"HARDWARE: {len(worker_urls)}x GPU Workers") + print(f"{'='*60}\n") + + configs = [ + APIServerConfig( + base_url=url, + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + server_type="sglang", + health_check=True + ) for url in worker_urls + ] + + manager = ServerManager(configs=configs) + + # Wait for health stabilization + print("Stabilizing workers...") + await asyncio.sleep(8) + + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + + results = { + "stateless": {"ttfts": [], "total_times": []}, + "stateful": {"ttfts": [], "total_times": []} + } + + async def benchmark_mode(mode_name: str, use_stateful: bool): + print(f"\n--- Running {mode_name.upper()} Mode ---") + + # We'll use a dummy flag in ServerManager to bypass stateful if needed + # Or just toggle the server_type to 'openai' for stateless simulation + # but better to use the same class and just not pass session_ids. + + for i in range(num_conversations): + session_id = f"bench-{mode_name}-{i}" + messages = [] + + for t in range(turns_per_conv): + # Simple prompt that grows slightly + messages.append({"role": "user", "content": f"Explain topic {t} in one sentence."}) + + start_time = time.time() + + # Pass session_id only in stateful mode + s_id = session_id if use_stateful else None + + async with manager.managed_server(session_id=s_id, tokenizer=tokenizer) as managed: + res = await managed.chat_completion(messages=messages, max_tokens=10) + ttft = time.time() - start_time # Approximation for non-streaming + + # Store TTFT for Turn 2+ (where cache hit matters) + if t > 0: + results[mode_name]["ttfts"].append(ttft) + + # Update messages with assistant response + messages.append({"role": "assistant", "content": res.choices[0].message.content}) + + print(f" Conversation {i+1}/{num_conversations} complete.") + + # 1. Run Stateless + await benchmark_mode("stateless", use_stateful=False) + + # 2. Run Stateful + await benchmark_mode("stateful", use_stateful=True) + + # ----------------------------------------------------------------------- + # RESULTS ANALYSIS + # ----------------------------------------------------------------------- + print(f"\n\n{'='*60}") + print(f"FINAL PERFORMANCE NUMBERS (T2-T{turns_per_conv} Latency)") + print(f"{'='*60}") + + def get_stats(mode): + ttfts = results[mode]["ttfts"] + if not ttfts: return "N/A", "N/A" + return statistics.mean(ttfts), statistics.stdev(ttfts) + + mean_sl, std_sl = get_stats("stateless") + mean_sf, std_sf = get_stats("stateful") + + print(f"{'Mode':<15} | {'Avg TTFT (s)':<15} | {'Stdev':<10}") + print(f"{'-'*45}") + print(f"{'Stateless':<15} | {mean_sl:<15.4f} | {std_sl:<10.4f}") + print(f"{'Stateful':<15} | {mean_sf:<15.4f} | {std_sf:<10.4f}") + + if mean_sl != "N/A" and mean_sf != "N/A": + improvement = (mean_sl - mean_sf) / mean_sl * 100 + print(f"\nLATENCY REDUCTION: {improvement:.2f}%") + print(f"{'='*60}\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--workers", nargs="+", default=["http://localhost:30001", "http://localhost:30002"]) + parser.add_argument("--convs", type=int, default=3) + args = parser.parse_args() + + asyncio.run(run_benchmark(args.workers, num_conversations=args.convs)) diff --git a/verify_stateful_e2e.py b/verify_stateful_e2e.py new file mode 100644 index 00000000..bc63c6ee --- /dev/null +++ b/verify_stateful_e2e.py @@ -0,0 +1,106 @@ +import asyncio +import json +import os +import sys +import hashlib +import argparse +import time +from typing import List, Dict +from transformers import AutoTokenizer + +# Add atropos to path +sys.path.append(os.getcwd()) + +from atroposlib.envs.server_handling.server_manager import ServerManager, APIServerConfig +from atroposlib.envs.server_handling.sglang_stateful_server import StatefulSGLangServer +from atroposlib.envs.server_handling.routing_utils import get_consistent_worker_index + +# --------------------------------------------------------------------------- +# E2E REAL HARDWARE VERIFICATION +# --------------------------------------------------------------------------- +async def run_real_e2e_test(worker_urls: List[str]): + print(f"\n--- Starting Real Hardware Verification on {len(worker_urls)} workers ---") + for url in worker_urls: + print(f" Worker: {url}") + + # Configure ServerManager with REAL configs + configs = [ + APIServerConfig( + base_url=url, + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + server_type="sglang", + health_check=True + ) for url in worker_urls + ] + + manager = ServerManager(configs=configs) + + # IMPORTANT: Wait for background health loops to stabilize + print("Waiting 10s for health stabilization...") + await asyncio.sleep(10) + + # Check health explicitly + for i, s in enumerate(manager.servers): + print(f"Worker {i} ({s.config.base_url}) Healthy: {s.server_healthy}") + + # Use real tokenizer for accurate delta-sync testing + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Test 1: Deterministic Pinning Verification + print("\n--- Verifying Session ID Pinning Determinism ---") + session_id_a = "conversation-alpha" + + hash_a = hashlib.md5(session_id_a.encode('utf-8')).hexdigest() + idx_a = get_consistent_worker_index(hash_a, len(worker_urls)) + expected_url = worker_urls[idx_a] + + print(f"Session A ({session_id_a}) -> Expected Worker {idx_a} ({expected_url})") + + # Test 2: Multi-turn Stateful Flow + print("\n--- Multi-turn Rollout (Conversation Alpha) ---") + messages = [{"role": "user", "content": "What is the capital of France?"}] + actual_url_t1 = None + + # Turn 1: Initial (Expect Full Sync) + async with manager.managed_server(session_id=session_id_a, tokenizer=tokenizer) as managed: + actual_url_t1 = managed.server.config.base_url + print(f"Turn 1 (New Session) directed to: {actual_url_t1}") + res1 = await managed.chat_completion(messages=messages, max_tokens=20) + content1 = res1.choices[0].message.content + print(f"Response 1: \"{content1.strip()}\"") + + # Turn 2: Follow-up (Expect Pinned Sync) + history = messages + [{"role": "assistant", "content": content1}] + messages_turn2 = history + [{"role": "user", "content": "And its population?"}] + + async with manager.managed_server(session_id=session_id_a, tokenizer=tokenizer) as managed: + actual_url_t2 = managed.server.config.base_url + print(f"Turn 2 (Pinned Session) directed to: {actual_url_t2}") + + if actual_url_t1 != actual_url_t2: + print(f"CRITICAL ERROR: Pinning failed! T1 {actual_url_t1} != T2 {actual_url_t2}") + # Check worker health + for i, s in enumerate(manager.servers): + print(f" Status Check: Worker {i} ({s.config.base_url}) Healthy={s.server_healthy}") + sys.exit(1) + + res2 = await managed.chat_completion(messages=messages_turn2, max_tokens=20) + content2 = res2.choices[0].message.content + print(f"Response 2: \"{content2.strip()}\"") + + # Final Verification + print("\n==========================================") + print("✓ REAL HARDWARE E2E SUCCESSFUL!") + print("==========================================") + print(f"1. Routing: Correctly distributed session '{session_id_a}' to {actual_url_t1}") + print("2. Protocol: StatefulSGLangServer successfully communicated with backend.") + print("3. Integrity: Chat results are valid and coherent across turns.") + print("==========================================\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--workers", nargs="+", default=["http://localhost:30001", "http://localhost:30002"]) + args = parser.parse_args() + + asyncio.run(run_real_e2e_test(args.workers))