mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
remove unused managed_server wrapper + tese
This commit is contained in:
parent
6a27e88023
commit
57fa229846
2 changed files with 0 additions and 565 deletions
|
|
@ -1,330 +0,0 @@
|
|||
"""
|
||||
AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer.
|
||||
|
||||
This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's
|
||||
ManagedServer for inference, enabling token tracking for multi-turn RL training
|
||||
with the Verifiers library.
|
||||
|
||||
Usage:
|
||||
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
|
||||
client = AtroposManagedClient(managed_server=managed, model="model-name")
|
||||
|
||||
# Use like AsyncOpenAI - tokens are tracked automatically
|
||||
response = await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
# Token data is available on the response:
|
||||
# - response.prompt_token_ids
|
||||
# - response.choices[0].token_ids
|
||||
# - response.choices[0].logprobs.content[i].logprob
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode
|
||||
|
||||
# =============================================================================
|
||||
# Enhanced Types for Token Data Injection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobContent:
|
||||
"""
|
||||
Single token logprob entry.
|
||||
|
||||
Compatible with verifiers' parse_response_tokens() which accesses:
|
||||
- response.choices[i].logprobs.content[j].logprob
|
||||
"""
|
||||
|
||||
logprob: float
|
||||
token: str = ""
|
||||
token_id: int = 0
|
||||
top_logprobs: Optional[List[Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChoiceLogprobs:
|
||||
"""
|
||||
Logprobs structure compatible with verifiers expectations.
|
||||
|
||||
Verifiers checks for either object or dict format:
|
||||
- Object: response.choices[i].logprobs.content[j].logprob
|
||||
- Dict: response.choices[i].logprobs["content"][j]["logprob"]
|
||||
|
||||
This dataclass supports the object format.
|
||||
"""
|
||||
|
||||
content: List[LogprobContent] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnhancedChoice:
|
||||
"""
|
||||
Choice with token_ids and logprobs for RL training.
|
||||
|
||||
Adds the following attributes that verifiers expects:
|
||||
- token_ids: List[int] - completion token IDs
|
||||
- logprobs: ChoiceLogprobs - structured logprobs
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: str
|
||||
token_ids: List[int]
|
||||
logprobs: ChoiceLogprobs
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnhancedChatCompletion:
|
||||
"""
|
||||
ChatCompletion with token data for RL training.
|
||||
|
||||
Compatible with verifiers' parse_response_tokens() expectations:
|
||||
- prompt_token_ids: list[int]
|
||||
- choices[i].token_ids: list[int]
|
||||
- choices[i].logprobs.content[j].logprob
|
||||
"""
|
||||
|
||||
id: str
|
||||
created: int
|
||||
model: str
|
||||
object: str
|
||||
choices: List[EnhancedChoice]
|
||||
prompt_token_ids: List[int]
|
||||
usage: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AsyncOpenAI-Compatible Client Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class _CompletionsNamespace:
|
||||
"""
|
||||
Mimics openai.resources.chat.completions.AsyncCompletions.
|
||||
|
||||
Provides the create() method that verifiers calls.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: "AtroposManagedClient"):
|
||||
self.parent = parent
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: Optional[str] = None,
|
||||
n: int = 1,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> EnhancedChatCompletion:
|
||||
"""
|
||||
Create chat completion with token tracking.
|
||||
|
||||
Returns ChatCompletion with additional attributes:
|
||||
- prompt_token_ids: list[int]
|
||||
- choices[i].token_ids: list[int]
|
||||
- choices[i].logprobs.content: list with logprob info
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
model: Model name (defaults to client's model)
|
||||
n: Number of completions (should be 1 for multi-turn)
|
||||
max_tokens: Max tokens in completion (legacy param)
|
||||
max_completion_tokens: Max tokens in completion (new param)
|
||||
temperature: Sampling temperature
|
||||
top_p: Nucleus sampling parameter
|
||||
tools: Tool definitions for function calling
|
||||
stop: Stop sequences
|
||||
**kwargs: Additional parameters passed to ManagedServer
|
||||
"""
|
||||
# Use max_completion_tokens if provided, else max_tokens
|
||||
effective_max_tokens = max_completion_tokens or max_tokens
|
||||
|
||||
# Build kwargs for ManagedServer
|
||||
completion_kwargs = {
|
||||
"messages": messages,
|
||||
"model": model or self.parent.model,
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if effective_max_tokens is not None:
|
||||
completion_kwargs["max_tokens"] = effective_max_tokens
|
||||
|
||||
if tools is not None:
|
||||
completion_kwargs["tools"] = tools
|
||||
|
||||
if stop is not None:
|
||||
completion_kwargs["stop"] = stop
|
||||
|
||||
# Add any extra kwargs (like logprobs settings)
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
completion_kwargs[key] = value
|
||||
|
||||
# Call ManagedServer for inference
|
||||
completion = await self.parent.managed_server.chat_completion(
|
||||
**completion_kwargs
|
||||
)
|
||||
|
||||
# Get token state from managed server
|
||||
state = self.parent.managed_server.get_state()
|
||||
nodes: List[SequenceNode] = state["nodes"]
|
||||
|
||||
# Inject token data into response
|
||||
return self._enhance_completion(completion, nodes)
|
||||
|
||||
def _enhance_completion(
|
||||
self, completion: Any, nodes: List[SequenceNode]
|
||||
) -> EnhancedChatCompletion:
|
||||
"""
|
||||
Convert ManagedServer output to verifiers-compatible format.
|
||||
|
||||
Extracts token data from SequenceNodes and injects it into the
|
||||
ChatCompletion response in the format verifiers expects.
|
||||
"""
|
||||
enhanced_choices = []
|
||||
prompt_token_ids: List[int] = []
|
||||
|
||||
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
|
||||
# Find prompt/completion boundary from masked_tokens
|
||||
# -100 indicates prompt tokens, actual token IDs indicate completion
|
||||
prompt_len = sum(1 for m in node.masked_tokens if m == -100)
|
||||
|
||||
# Extract prompt and completion portions
|
||||
if i == 0:
|
||||
prompt_token_ids = node.tokens[:prompt_len]
|
||||
|
||||
completion_ids = node.tokens[prompt_len:]
|
||||
completion_logprobs = node.logprobs[prompt_len:]
|
||||
|
||||
# Build logprobs structure verifiers expects
|
||||
logprobs_content = []
|
||||
tokenizer = self.parent.managed_server.tokenizer
|
||||
|
||||
for token_id, logprob in zip(completion_ids, completion_logprobs):
|
||||
# Decode token to string if tokenizer available
|
||||
token_str = ""
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
token_str = tokenizer.decode([token_id])
|
||||
except Exception:
|
||||
token_str = f"<token_{token_id}>"
|
||||
|
||||
logprobs_content.append(
|
||||
LogprobContent(
|
||||
logprob=logprob,
|
||||
token=token_str,
|
||||
token_id=token_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Create enhanced choice with token data
|
||||
enhanced_choice = EnhancedChoice(
|
||||
index=choice.index,
|
||||
message=choice.message,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
token_ids=completion_ids,
|
||||
logprobs=ChoiceLogprobs(content=logprobs_content),
|
||||
)
|
||||
enhanced_choices.append(enhanced_choice)
|
||||
|
||||
return EnhancedChatCompletion(
|
||||
id=completion.id,
|
||||
created=completion.created,
|
||||
model=completion.model,
|
||||
object=completion.object,
|
||||
choices=enhanced_choices,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
usage=completion.usage.model_dump() if completion.usage else None,
|
||||
)
|
||||
|
||||
|
||||
class _ChatNamespace:
|
||||
"""Mimics openai.resources.chat.AsyncChat."""
|
||||
|
||||
def __init__(self, parent: "AtroposManagedClient"):
|
||||
self.completions = _CompletionsNamespace(parent)
|
||||
|
||||
|
||||
class AtroposManagedClient:
|
||||
"""
|
||||
AsyncOpenAI-compatible client backed by ManagedServer.
|
||||
|
||||
This client provides the same interface as AsyncOpenAI but uses Atropos's
|
||||
ManagedServer for inference, enabling automatic token tracking for
|
||||
multi-turn RL training with the Verifiers library.
|
||||
|
||||
The key feature is that responses include token data attributes that
|
||||
verifiers' parse_response_tokens() expects:
|
||||
- response.prompt_token_ids
|
||||
- response.choices[i].token_ids
|
||||
- response.choices[i].logprobs.content[j].logprob
|
||||
|
||||
Usage:
|
||||
async with server_manager.managed_server(tokenizer=tokenizer) as managed:
|
||||
client = AtroposManagedClient(
|
||||
managed_server=managed,
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct"
|
||||
)
|
||||
|
||||
# Pass to verifiers env.rollout()
|
||||
state = await vf_env.rollout(
|
||||
input=rollout_input,
|
||||
client=client,
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
)
|
||||
|
||||
# Token data is now in state["trajectory"][i]["tokens"]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
managed_server: ManagedServer,
|
||||
model: str,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the managed client.
|
||||
|
||||
Args:
|
||||
managed_server: ManagedServer instance for inference and token tracking
|
||||
model: Model name to use for completions
|
||||
base_url: Optional base URL (for API compatibility, not used)
|
||||
"""
|
||||
self.managed_server = managed_server
|
||||
self.model = model
|
||||
self.base_url = base_url or "http://managed-server"
|
||||
|
||||
# Mimic AsyncOpenAI namespace structure
|
||||
self.chat = _ChatNamespace(self)
|
||||
|
||||
def reset(self):
|
||||
"""Reset token tracking state between rollouts."""
|
||||
self.managed_server.reset()
|
||||
|
||||
async def close(self):
|
||||
"""Compatibility method - no-op since ManagedServer handles cleanup."""
|
||||
pass
|
||||
|
||||
def copy(self, **_kwargs) -> "AtroposManagedClient":
|
||||
"""
|
||||
Create a copy of this client (for API compatibility).
|
||||
|
||||
Verifiers may call client.copy() for certain operations.
|
||||
Returns self since we want to maintain the same ManagedServer state.
|
||||
"""
|
||||
return self
|
||||
|
|
@ -1,235 +0,0 @@
|
|||
"""Tests for AtroposManagedClient - AsyncOpenAI-compatible wrapper for ManagedServer."""
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.atropos_managed_client import (
|
||||
AtroposManagedClient,
|
||||
ChoiceLogprobs,
|
||||
EnhancedChatCompletion,
|
||||
LogprobContent,
|
||||
)
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""Mock tokenizer for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.eos_token_id = 2
|
||||
self.bos_token_id = 1
|
||||
|
||||
def encode(self, text, add_special_tokens=True):
|
||||
"""Simple character-based encoding for testing."""
|
||||
tokens = [ord(c) for c in text]
|
||||
if add_special_tokens:
|
||||
tokens = [self.bos_token_id] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens, skip_special_tokens=False):
|
||||
"""Simple character-based decoding for testing."""
|
||||
if skip_special_tokens:
|
||||
tokens = [
|
||||
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
|
||||
]
|
||||
return "".join([chr(t) if t > 31 else "" for t in tokens])
|
||||
|
||||
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
|
||||
"""Simple chat template for testing."""
|
||||
result = ""
|
||||
for msg in messages:
|
||||
result += f"<{msg['role']}>{msg['content']}</{msg['role']}>"
|
||||
if add_generation_prompt:
|
||||
result += "<assistant>"
|
||||
if tokenize:
|
||||
return self.encode(result)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server():
|
||||
"""Create a mock server with a tokenizer."""
|
||||
server = ServerHarness()
|
||||
server.tokenizer = MockTokenizer()
|
||||
|
||||
class Config:
|
||||
model_name = "test_model"
|
||||
|
||||
server.config = Config()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def managed_client(mock_server):
|
||||
"""Create an AtroposManagedClient with mocked server."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
return AtroposManagedClient(managed_server=managed, model="test_model")
|
||||
|
||||
|
||||
class TestDataclasses:
|
||||
"""Test the enhanced dataclasses."""
|
||||
|
||||
def test_logprob_content(self):
|
||||
"""Test LogprobContent creation."""
|
||||
lp = LogprobContent(logprob=-0.5, token="hello", token_id=100)
|
||||
assert lp.logprob == -0.5
|
||||
assert lp.token == "hello"
|
||||
assert lp.token_id == 100
|
||||
|
||||
def test_choice_logprobs(self):
|
||||
"""Test ChoiceLogprobs structure."""
|
||||
content = [
|
||||
LogprobContent(logprob=-0.1),
|
||||
LogprobContent(logprob=-0.2),
|
||||
]
|
||||
logprobs = ChoiceLogprobs(content=content)
|
||||
assert len(logprobs.content) == 2
|
||||
assert logprobs.content[0].logprob == -0.1
|
||||
|
||||
|
||||
class TestAtroposManagedClient:
|
||||
"""Test AtroposManagedClient behavior."""
|
||||
|
||||
def test_reset(self, managed_client):
|
||||
"""Test reset clears ManagedServer state."""
|
||||
# Add some state to managed server
|
||||
managed_client.managed_server.current_nodes = ["dummy"]
|
||||
|
||||
# Reset should clear it
|
||||
managed_client.reset()
|
||||
assert len(managed_client.managed_server.current_nodes) == 0
|
||||
|
||||
def test_copy_returns_self(self, managed_client):
|
||||
"""Test copy returns same instance for state sharing."""
|
||||
copied = managed_client.copy()
|
||||
assert copied is managed_client
|
||||
|
||||
def test_namespace_structure(self, managed_client):
|
||||
"""Test client has correct namespace structure like AsyncOpenAI."""
|
||||
assert hasattr(managed_client, "chat")
|
||||
assert hasattr(managed_client.chat, "completions")
|
||||
assert hasattr(managed_client.chat.completions, "create")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_is_noop(self, managed_client):
|
||||
"""Test close() doesn't raise."""
|
||||
await managed_client.close() # Should not raise
|
||||
|
||||
|
||||
class TestChatCompletionCreate:
|
||||
"""Test the chat.completions.create() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_completion(self, mock_server, managed_client):
|
||||
"""Test basic chat completion returns enhanced response."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
managed = managed_client.managed_server
|
||||
prompt = managed._convert_messages_to_prompt(messages)
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
|
||||
output_text = "Hi there!"
|
||||
output_tokens = [ord(c) for c in output_text]
|
||||
output_logprobs = [-0.1] * len(output_tokens)
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
result = await managed_client.chat.completions.create(
|
||||
messages=messages,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Should return EnhancedChatCompletion
|
||||
assert isinstance(result, EnhancedChatCompletion)
|
||||
assert len(result.choices) == 1
|
||||
assert result.choices[0].message.content == output_text
|
||||
|
||||
# Should have prompt_token_ids
|
||||
assert len(result.prompt_token_ids) == len(prompt_tokens)
|
||||
|
||||
# Should have token_ids on choice
|
||||
assert len(result.choices[0].token_ids) == len(output_tokens)
|
||||
assert result.choices[0].token_ids == output_tokens
|
||||
|
||||
# Should have logprobs
|
||||
assert len(result.choices[0].logprobs.content) == len(output_tokens)
|
||||
assert result.choices[0].logprobs.content[0].logprob == -0.1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_completion_tokens_param(self, mock_server, managed_client):
|
||||
"""Test max_completion_tokens is preferred over max_tokens."""
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
managed = managed_client.managed_server
|
||||
prompt = managed._convert_messages_to_prompt(messages)
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
|
||||
output_tokens = [ord("!")]
|
||||
output_logprobs = [-0.1]
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
# Should accept max_completion_tokens (new OpenAI param)
|
||||
result = await managed_client.chat.completions.create(
|
||||
messages=messages,
|
||||
max_completion_tokens=50,
|
||||
)
|
||||
|
||||
assert isinstance(result, EnhancedChatCompletion)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_between_rollouts(self, mock_server, managed_client):
|
||||
"""Test that reset clears state between rollouts."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
managed = managed_client.managed_server
|
||||
prompt = managed._convert_messages_to_prompt(messages)
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
|
||||
output_tokens = [ord("!")]
|
||||
output_logprobs = [-0.1]
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
# First rollout
|
||||
await managed_client.chat.completions.create(messages=messages, max_tokens=10)
|
||||
state = managed_client.managed_server.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
# Reset
|
||||
managed_client.reset()
|
||||
state = managed_client.managed_server.get_state()
|
||||
assert len(state["nodes"]) == 0
|
||||
|
||||
# Setup for second rollout
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
# Second rollout
|
||||
await managed_client.chat.completions.create(messages=messages, max_tokens=10)
|
||||
state = managed_client.managed_server.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue