diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 0918c325..cb14d210 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where: - Branching occurs organically from different contexts and n > 1 completions """ +import os import time import uuid import warnings @@ -131,6 +132,10 @@ class ManagedServer: # Fallback for tokenizers without chat template return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + def _debug_requests_enabled(self) -> bool: + """Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1.""" + return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]: """ Find a node that this input extends (default mode). @@ -284,6 +289,19 @@ class ManagedServer: completion_kwargs = kwargs.copy() completion_kwargs["prompt"] = prompt completion_kwargs.pop("messages", None) + if self._debug_requests_enabled(): + msg_count = len(messages) + prompt_preview = prompt.replace("\n", "\\n")[:600] + print( + f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} " + f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} " + f"temperature={completion_kwargs.get('temperature')}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", + flush=True, + ) # Set model name if not provided if "model" not in completion_kwargs: diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 96242754..503d4419 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -2,6 +2,7 @@ # see example_trainer/vllm_api_server.py for an example import asyncio +import os import warnings import aiohttp @@ -189,6 +190,39 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) + debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + if debug_requests: + base = self.config.base_url.replace("/v1", "") + prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n") + print( + f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate " + f"prompt_token_len={len(prompt_tokens)}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] request_meta=" + f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, " + f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", + flush=True, + ) + print( + "[ATROPOS_REQ_DEBUG] curl_template=" + + ( + "curl -s -X POST " + + f"{base}/generate " + + "-H \"Content-Type: application/json\" " + + "-d '{\"prompt\": \"\", " + + f"\"n\": {request_data.get('n', 1)}, " + + f"\"max_tokens\": {request_data.get('max_tokens', 256)}, " + + f"\"temperature\": {request_data.get('temperature', 1.0)}, " + + f"\"top_p\": {request_data.get('top_p', 1.0)}}'" + ), + flush=True, + ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: