mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
gsm8k trial
This commit is contained in:
parent
0dcc9156d2
commit
bbbfaf1680
2 changed files with 52 additions and 0 deletions
|
|
@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where:
|
|||
- Branching occurs organically from different contexts and n > 1 completions
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
|
@ -131,6 +132,10 @@ class ManagedServer:
|
|||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
|
||||
def _debug_requests_enabled(self) -> bool:
|
||||
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
|
||||
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||
|
||||
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||
"""
|
||||
Find a node that this input extends (default mode).
|
||||
|
|
@ -284,6 +289,19 @@ class ManagedServer:
|
|||
completion_kwargs = kwargs.copy()
|
||||
completion_kwargs["prompt"] = prompt
|
||||
completion_kwargs.pop("messages", None)
|
||||
if self._debug_requests_enabled():
|
||||
msg_count = len(messages)
|
||||
prompt_preview = prompt.replace("\n", "\\n")[:600]
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} "
|
||||
f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} "
|
||||
f"temperature={completion_kwargs.get('temperature')}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Set model name if not provided
|
||||
if "model" not in completion_kwargs:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# see example_trainer/vllm_api_server.py for an example
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -189,6 +190,39 @@ class VLLMServer(APIServer):
|
|||
# Prepare request for VLLM native API
|
||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
||||
request_data.update(kwargs)
|
||||
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||
if debug_requests:
|
||||
base = self.config.base_url.replace("/v1", "")
|
||||
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
|
||||
f"prompt_token_len={len(prompt_tokens)}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] request_meta="
|
||||
f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, "
|
||||
f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
"[ATROPOS_REQ_DEBUG] curl_template="
|
||||
+ (
|
||||
"curl -s -X POST "
|
||||
+ f"{base}/generate "
|
||||
+ "-H \"Content-Type: application/json\" "
|
||||
+ "-d '{\"prompt\": \"<PROMPT_FROM_PREVIEW_OR_LOG>\", "
|
||||
+ f"\"n\": {request_data.get('n', 1)}, "
|
||||
+ f"\"max_tokens\": {request_data.get('max_tokens', 256)}, "
|
||||
+ f"\"temperature\": {request_data.get('temperature', 1.0)}, "
|
||||
+ f"\"top_p\": {request_data.get('top_p', 1.0)}}'"
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Make async request to VLLM /generate endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue