diff --git a/atroposlib/envs/server_handling/atropos_managed_client.py b/atroposlib/envs/server_handling/atropos_managed_client.py deleted file mode 100644 index 75e78e39..00000000 --- a/atroposlib/envs/server_handling/atropos_managed_client.py +++ /dev/null @@ -1,330 +0,0 @@ -""" -AtroposManagedClient: AsyncOpenAI-compatible client backed by ManagedServer. - -This module provides a drop-in replacement for AsyncOpenAI that uses Atropos's -ManagedServer for inference, enabling token tracking for multi-turn RL training -with the Verifiers library. - -Usage: - async with server_manager.managed_server(tokenizer=tokenizer) as managed: - client = AtroposManagedClient(managed_server=managed, model="model-name") - - # Use like AsyncOpenAI - tokens are tracked automatically - response = await client.chat.completions.create( - messages=[{"role": "user", "content": "Hello"}], - max_tokens=100 - ) - - # Token data is available on the response: - # - response.prompt_token_ids - # - response.choices[0].token_ids - # - response.choices[0].logprobs.content[i].logprob -""" - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -from openai.types.chat.chat_completion_message import ChatCompletionMessage - -from atroposlib.envs.server_handling.managed_server import ManagedServer, SequenceNode - -# ============================================================================= -# Enhanced Types for Token Data Injection -# ============================================================================= - - -@dataclass -class LogprobContent: - """ - Single token logprob entry. - - Compatible with verifiers' parse_response_tokens() which accesses: - - response.choices[i].logprobs.content[j].logprob - """ - - logprob: float - token: str = "" - token_id: int = 0 - top_logprobs: Optional[List[Any]] = None - - -@dataclass -class ChoiceLogprobs: - """ - Logprobs structure compatible with verifiers expectations. - - Verifiers checks for either object or dict format: - - Object: response.choices[i].logprobs.content[j].logprob - - Dict: response.choices[i].logprobs["content"][j]["logprob"] - - This dataclass supports the object format. - """ - - content: List[LogprobContent] = field(default_factory=list) - - -@dataclass -class EnhancedChoice: - """ - Choice with token_ids and logprobs for RL training. - - Adds the following attributes that verifiers expects: - - token_ids: List[int] - completion token IDs - - logprobs: ChoiceLogprobs - structured logprobs - """ - - index: int - message: ChatCompletionMessage - finish_reason: str - token_ids: List[int] - logprobs: ChoiceLogprobs - - -@dataclass -class EnhancedChatCompletion: - """ - ChatCompletion with token data for RL training. - - Compatible with verifiers' parse_response_tokens() expectations: - - prompt_token_ids: list[int] - - choices[i].token_ids: list[int] - - choices[i].logprobs.content[j].logprob - """ - - id: str - created: int - model: str - object: str - choices: List[EnhancedChoice] - prompt_token_ids: List[int] - usage: Optional[Dict[str, int]] = None - - -# ============================================================================= -# AsyncOpenAI-Compatible Client Classes -# ============================================================================= - - -class _CompletionsNamespace: - """ - Mimics openai.resources.chat.completions.AsyncCompletions. - - Provides the create() method that verifiers calls. - """ - - def __init__(self, parent: "AtroposManagedClient"): - self.parent = parent - - async def create( - self, - *, - messages: List[Dict[str, Any]], - model: Optional[str] = None, - n: int = 1, - max_tokens: Optional[int] = None, - max_completion_tokens: Optional[int] = None, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict]] = None, - stop: Optional[List[str]] = None, - **kwargs, - ) -> EnhancedChatCompletion: - """ - Create chat completion with token tracking. - - Returns ChatCompletion with additional attributes: - - prompt_token_ids: list[int] - - choices[i].token_ids: list[int] - - choices[i].logprobs.content: list with logprob info - - Args: - messages: List of message dicts with 'role' and 'content' - model: Model name (defaults to client's model) - n: Number of completions (should be 1 for multi-turn) - max_tokens: Max tokens in completion (legacy param) - max_completion_tokens: Max tokens in completion (new param) - temperature: Sampling temperature - top_p: Nucleus sampling parameter - tools: Tool definitions for function calling - stop: Stop sequences - **kwargs: Additional parameters passed to ManagedServer - """ - # Use max_completion_tokens if provided, else max_tokens - effective_max_tokens = max_completion_tokens or max_tokens - - # Build kwargs for ManagedServer - completion_kwargs = { - "messages": messages, - "model": model or self.parent.model, - "n": n, - "temperature": temperature, - "top_p": top_p, - } - - if effective_max_tokens is not None: - completion_kwargs["max_tokens"] = effective_max_tokens - - if tools is not None: - completion_kwargs["tools"] = tools - - if stop is not None: - completion_kwargs["stop"] = stop - - # Add any extra kwargs (like logprobs settings) - for key, value in kwargs.items(): - if value is not None: - completion_kwargs[key] = value - - # Call ManagedServer for inference - completion = await self.parent.managed_server.chat_completion( - **completion_kwargs - ) - - # Get token state from managed server - state = self.parent.managed_server.get_state() - nodes: List[SequenceNode] = state["nodes"] - - # Inject token data into response - return self._enhance_completion(completion, nodes) - - def _enhance_completion( - self, completion: Any, nodes: List[SequenceNode] - ) -> EnhancedChatCompletion: - """ - Convert ManagedServer output to verifiers-compatible format. - - Extracts token data from SequenceNodes and injects it into the - ChatCompletion response in the format verifiers expects. - """ - enhanced_choices = [] - prompt_token_ids: List[int] = [] - - for i, (choice, node) in enumerate(zip(completion.choices, nodes)): - # Find prompt/completion boundary from masked_tokens - # -100 indicates prompt tokens, actual token IDs indicate completion - prompt_len = sum(1 for m in node.masked_tokens if m == -100) - - # Extract prompt and completion portions - if i == 0: - prompt_token_ids = node.tokens[:prompt_len] - - completion_ids = node.tokens[prompt_len:] - completion_logprobs = node.logprobs[prompt_len:] - - # Build logprobs structure verifiers expects - logprobs_content = [] - tokenizer = self.parent.managed_server.tokenizer - - for token_id, logprob in zip(completion_ids, completion_logprobs): - # Decode token to string if tokenizer available - token_str = "" - if tokenizer is not None: - try: - token_str = tokenizer.decode([token_id]) - except Exception: - token_str = f"" - - logprobs_content.append( - LogprobContent( - logprob=logprob, - token=token_str, - token_id=token_id, - ) - ) - - # Create enhanced choice with token data - enhanced_choice = EnhancedChoice( - index=choice.index, - message=choice.message, - finish_reason=choice.finish_reason or "stop", - token_ids=completion_ids, - logprobs=ChoiceLogprobs(content=logprobs_content), - ) - enhanced_choices.append(enhanced_choice) - - return EnhancedChatCompletion( - id=completion.id, - created=completion.created, - model=completion.model, - object=completion.object, - choices=enhanced_choices, - prompt_token_ids=prompt_token_ids, - usage=completion.usage.model_dump() if completion.usage else None, - ) - - -class _ChatNamespace: - """Mimics openai.resources.chat.AsyncChat.""" - - def __init__(self, parent: "AtroposManagedClient"): - self.completions = _CompletionsNamespace(parent) - - -class AtroposManagedClient: - """ - AsyncOpenAI-compatible client backed by ManagedServer. - - This client provides the same interface as AsyncOpenAI but uses Atropos's - ManagedServer for inference, enabling automatic token tracking for - multi-turn RL training with the Verifiers library. - - The key feature is that responses include token data attributes that - verifiers' parse_response_tokens() expects: - - response.prompt_token_ids - - response.choices[i].token_ids - - response.choices[i].logprobs.content[j].logprob - - Usage: - async with server_manager.managed_server(tokenizer=tokenizer) as managed: - client = AtroposManagedClient( - managed_server=managed, - model="Qwen/Qwen2.5-1.5B-Instruct" - ) - - # Pass to verifiers env.rollout() - state = await vf_env.rollout( - input=rollout_input, - client=client, - model="Qwen/Qwen2.5-1.5B-Instruct", - ) - - # Token data is now in state["trajectory"][i]["tokens"] - """ - - def __init__( - self, - managed_server: ManagedServer, - model: str, - base_url: Optional[str] = None, - ): - """ - Initialize the managed client. - - Args: - managed_server: ManagedServer instance for inference and token tracking - model: Model name to use for completions - base_url: Optional base URL (for API compatibility, not used) - """ - self.managed_server = managed_server - self.model = model - self.base_url = base_url or "http://managed-server" - - # Mimic AsyncOpenAI namespace structure - self.chat = _ChatNamespace(self) - - def reset(self): - """Reset token tracking state between rollouts.""" - self.managed_server.reset() - - async def close(self): - """Compatibility method - no-op since ManagedServer handles cleanup.""" - pass - - def copy(self, **_kwargs) -> "AtroposManagedClient": - """ - Create a copy of this client (for API compatibility). - - Verifiers may call client.copy() for certain operations. - Returns self since we want to maintain the same ManagedServer state. - """ - return self diff --git a/atroposlib/tests/test_atropos_managed_client.py b/atroposlib/tests/test_atropos_managed_client.py deleted file mode 100644 index cb1cf955..00000000 --- a/atroposlib/tests/test_atropos_managed_client.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Tests for AtroposManagedClient - AsyncOpenAI-compatible wrapper for ManagedServer.""" - -import pytest - -from atroposlib.envs.server_handling.atropos_managed_client import ( - AtroposManagedClient, - ChoiceLogprobs, - EnhancedChatCompletion, - LogprobContent, -) -from atroposlib.envs.server_handling.managed_server import ManagedServer -from atroposlib.envs.server_handling.server_harness import ServerHarness - - -class MockTokenizer: - """Mock tokenizer for testing.""" - - def __init__(self): - self.eos_token_id = 2 - self.bos_token_id = 1 - - def encode(self, text, add_special_tokens=True): - """Simple character-based encoding for testing.""" - tokens = [ord(c) for c in text] - if add_special_tokens: - tokens = [self.bos_token_id] + tokens - return tokens - - def decode(self, tokens, skip_special_tokens=False): - """Simple character-based decoding for testing.""" - if skip_special_tokens: - tokens = [ - t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] - ] - return "".join([chr(t) if t > 31 else "" for t in tokens]) - - def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): - """Simple chat template for testing.""" - result = "" - for msg in messages: - result += f"<{msg['role']}>{msg['content']}" - if add_generation_prompt: - result += "" - if tokenize: - return self.encode(result) - return result - - -@pytest.fixture -def mock_server(): - """Create a mock server with a tokenizer.""" - server = ServerHarness() - server.tokenizer = MockTokenizer() - - class Config: - model_name = "test_model" - - server.config = Config() - return server - - -@pytest.fixture -def managed_client(mock_server): - """Create an AtroposManagedClient with mocked server.""" - managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) - return AtroposManagedClient(managed_server=managed, model="test_model") - - -class TestDataclasses: - """Test the enhanced dataclasses.""" - - def test_logprob_content(self): - """Test LogprobContent creation.""" - lp = LogprobContent(logprob=-0.5, token="hello", token_id=100) - assert lp.logprob == -0.5 - assert lp.token == "hello" - assert lp.token_id == 100 - - def test_choice_logprobs(self): - """Test ChoiceLogprobs structure.""" - content = [ - LogprobContent(logprob=-0.1), - LogprobContent(logprob=-0.2), - ] - logprobs = ChoiceLogprobs(content=content) - assert len(logprobs.content) == 2 - assert logprobs.content[0].logprob == -0.1 - - -class TestAtroposManagedClient: - """Test AtroposManagedClient behavior.""" - - def test_reset(self, managed_client): - """Test reset clears ManagedServer state.""" - # Add some state to managed server - managed_client.managed_server.current_nodes = ["dummy"] - - # Reset should clear it - managed_client.reset() - assert len(managed_client.managed_server.current_nodes) == 0 - - def test_copy_returns_self(self, managed_client): - """Test copy returns same instance for state sharing.""" - copied = managed_client.copy() - assert copied is managed_client - - def test_namespace_structure(self, managed_client): - """Test client has correct namespace structure like AsyncOpenAI.""" - assert hasattr(managed_client, "chat") - assert hasattr(managed_client.chat, "completions") - assert hasattr(managed_client.chat.completions, "create") - - @pytest.mark.asyncio - async def test_close_is_noop(self, managed_client): - """Test close() doesn't raise.""" - await managed_client.close() # Should not raise - - -class TestChatCompletionCreate: - """Test the chat.completions.create() method.""" - - @pytest.mark.asyncio - async def test_basic_completion(self, mock_server, managed_client): - """Test basic chat completion returns enhanced response.""" - messages = [{"role": "user", "content": "Hello"}] - managed = managed_client.managed_server - prompt = managed._convert_messages_to_prompt(messages) - prompt_tokens = mock_server.tokenizer.encode(prompt) - - output_text = "Hi there!" - output_tokens = [ord(c) for c in output_text] - output_logprobs = [-0.1] * len(output_tokens) - - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - result = await managed_client.chat.completions.create( - messages=messages, - max_tokens=100, - ) - - # Should return EnhancedChatCompletion - assert isinstance(result, EnhancedChatCompletion) - assert len(result.choices) == 1 - assert result.choices[0].message.content == output_text - - # Should have prompt_token_ids - assert len(result.prompt_token_ids) == len(prompt_tokens) - - # Should have token_ids on choice - assert len(result.choices[0].token_ids) == len(output_tokens) - assert result.choices[0].token_ids == output_tokens - - # Should have logprobs - assert len(result.choices[0].logprobs.content) == len(output_tokens) - assert result.choices[0].logprobs.content[0].logprob == -0.1 - - @pytest.mark.asyncio - async def test_max_completion_tokens_param(self, mock_server, managed_client): - """Test max_completion_tokens is preferred over max_tokens.""" - messages = [{"role": "user", "content": "Hi"}] - managed = managed_client.managed_server - prompt = managed._convert_messages_to_prompt(messages) - prompt_tokens = mock_server.tokenizer.encode(prompt) - - output_tokens = [ord("!")] - output_logprobs = [-0.1] - - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - # Should accept max_completion_tokens (new OpenAI param) - result = await managed_client.chat.completions.create( - messages=messages, - max_completion_tokens=50, - ) - - assert isinstance(result, EnhancedChatCompletion) - - @pytest.mark.asyncio - async def test_reset_between_rollouts(self, mock_server, managed_client): - """Test that reset clears state between rollouts.""" - messages = [{"role": "user", "content": "Hello"}] - managed = managed_client.managed_server - prompt = managed._convert_messages_to_prompt(messages) - prompt_tokens = mock_server.tokenizer.encode(prompt) - - output_tokens = [ord("!")] - output_logprobs = [-0.1] - - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - # First rollout - await managed_client.chat.completions.create(messages=messages, max_tokens=10) - state = managed_client.managed_server.get_state() - assert len(state["nodes"]) == 1 - - # Reset - managed_client.reset() - state = managed_client.managed_server.get_state() - assert len(state["nodes"]) == 0 - - # Setup for second rollout - mock_server.set_tokens_and_logprobs_response( - prompt=prompt, - prompt_tokens=prompt_tokens, - output_tokens_list=[output_tokens], - output_logprobs_list=[output_logprobs], - finish_reasons=["stop"], - ) - - # Second rollout - await managed_client.chat.completions.create(messages=messages, max_tokens=10) - state = managed_client.managed_server.get_state() - assert len(state["nodes"]) == 1 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"])