diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py new file mode 100644 index 00000000..2740afea --- /dev/null +++ b/atroposlib/envs/server_handling/managed_server.py @@ -0,0 +1,511 @@ +""" +Managed server wrapper that tracks text sequences with aligned tokens and logprobs. + +This wrapper maintains a tree structure of sequences, where: +- Each node represents a complete text sequence (prompt + completion) +- Tokens and logprobs are tracked with proper masking for training +- Branching occurs organically from different contexts and n > 1 completions +""" + +import time +import uuid +import warnings +from typing import Any, Dict, List, Optional, Union + +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.completion import Completion, CompletionChoice +from pydantic import BaseModel + +from atroposlib.envs.server_handling.server_baseline import APIServer + + +class SequenceNode(BaseModel): + """ + A node in the sequence tree representing a complete text sequence. + + Attributes: + full_text: Complete text (prompt + completion) + tokens: Full token sequence (actual token IDs) + masked_tokens: Tokens with -100 for prompt positions, actual IDs for completion + logprobs: Logprobs with 0.0 for prompt positions, actual values for completion + metadata: Optional metadata (e.g., role information, finish_reason, etc.) + """ + + full_text: str + tokens: List[int] + masked_tokens: List[int] + logprobs: List[float] + metadata: Optional[Dict[str, Any]] = None + + +class ManagedServer: + """ + Wrapper around APIServer that tracks sequences with aligned tokens and logprobs. + + Maintains a tree structure keyed by input text, where each completion creates + new branches. Provides proper masking for training (prompt tokens masked with -100, + logprobs set to 0.0). + + Uses the clean _tokens_and_logprobs_completion_wrapper interface internally. + """ + + def __init__( + self, + server: APIServer, + tokenizer: Optional[Any] = None, + track_tree: bool = False, + ): + """ + Initialize the managed server. + + Args: + server: The underlying APIServer instance to wrap + tokenizer: Optional tokenizer for encoding/decoding. If not provided, + will attempt to extract from server or create from model name. + track_tree: If True, maintains a tree structure with parent-child links + (for multi-turn RL with per-step advantages). If False (default), + maintains a simple list of current nodes that updates in-place. + """ + self.server = server + self.tokenizer = tokenizer + self.track_tree = track_tree + + # Initialize storage based on mode + if track_tree: + self.sequences: Dict[str, SequenceNode] = {} # Tree mode: dict lookup + else: + self.current_nodes: List[SequenceNode] = [] # Default mode: simple list + + # Try to get tokenizer from server if not provided + if self.tokenizer is None: + self._initialize_tokenizer() + + def _initialize_tokenizer(self): + """Initialize tokenizer from server or model name.""" + # Check if the wrapped server has a tokenizer + if hasattr(self.server, "tokenizer"): + self.tokenizer = self.server.tokenizer + else: + # Try to create from model name + try: + from transformers import AutoTokenizer + + model_name = self.server.config.model_name + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + except Exception as e: + warnings.warn( + f"Could not initialize tokenizer: {e}. " + "Sequence tracking will be limited without tokenizer." + ) + self.tokenizer = None + + def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: + """ + Convert chat messages to prompt text using tokenizer's chat template. + + Args: + messages: List of message dicts with 'role' and 'content' + + Returns: + Formatted prompt string + """ + if self.tokenizer is None: + # Fallback: simple concatenation + return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + + if hasattr(self.tokenizer, "apply_chat_template"): + # Only add generation prompt if last message is not from assistant + add_generation_prompt = ( + len(messages) == 0 or messages[-1].get("role") != "assistant" + ) + + # Use the tokenizer's chat template + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=add_generation_prompt + ) + else: + # Fallback for tokenizers without chat template + return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + + def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]: + """ + Find a node that this input extends (default mode). + + Args: + input_text: The input text to check + + Returns: + The node that input_text extends, or None if no match + """ + if not self.current_nodes: + return None + + # Check if any current node's full_text is a prefix of the input + # This means the input is extending that node + for node in self.current_nodes: + if input_text.startswith(node.full_text): + return node + return None + + def _compute_input_ids( + self, input_text: str, extending_node: Optional[SequenceNode] + ) -> List[int]: + """ + Compute input_ids for the prompt, using existing tokens if extending. + + Args: + input_text: The full input prompt text + extending_node: Node being extended, if any + + Returns: + List of token IDs to use as input_ids + """ + if extending_node is not None: + # Extending an existing sequence - use its tokens + tokenize the new part + existing_text = extending_node.full_text + new_text_suffix = input_text[len(existing_text) :] + + # Tokenize only the new suffix (without BOS since we're continuing) + if new_text_suffix: + new_tokens = self.tokenizer.encode( + new_text_suffix, add_special_tokens=False + ) + return extending_node.tokens + new_tokens + else: + # No new text, just use existing tokens + return extending_node.tokens.copy() + else: + # New sequence - tokenize the whole thing + return self.tokenizer.encode(input_text, add_special_tokens=True) + + def _find_parent_node(self, input_text: str) -> Optional[SequenceNode]: + """ + Find a parent node whose full_text matches the input_text (tree mode). + + Args: + input_text: The input text to search for + + Returns: + Parent SequenceNode if found, None otherwise + """ + return self.sequences.get(input_text, None) + + def _create_sequence_node( + self, + input_text: str, + parent_node: Optional[SequenceNode], + prompt_tokens: List[int], + output_tokens: List[int], + output_logprobs: List[float], + completion_text: str, + finish_reason: str = "stop", + ) -> SequenceNode: + """ + Create a sequence node with proper masking. + + Args: + input_text: The input prompt text + parent_node: Parent node to extend from (if available) + prompt_tokens: Token IDs for the prompt + output_tokens: Token IDs for the output/completion + output_logprobs: Logprobs for output tokens + completion_text: The completion text + finish_reason: Finish reason from server + + Returns: + SequenceNode with properly masked tokens and logprobs + """ + # Combine text + full_text = input_text + completion_text + + # If we have a parent node, we should use its tokens as the prompt base + if parent_node is not None: + # Use parent's full tokens as the prompt + prompt_tokens = parent_node.tokens.copy() + + # Combine tokens + full_tokens = prompt_tokens + output_tokens + prompt_len = len(prompt_tokens) + + # Create masked tokens: -100 for prompt, actual IDs for completion + masked_tokens = [-100] * prompt_len + output_tokens + + # Create masked logprobs: 0.0 for prompt, actual for completion + # Pad logprobs to match token length if needed + if len(output_logprobs) < len(output_tokens): + output_logprobs = output_logprobs + [0.0] * ( + len(output_tokens) - len(output_logprobs) + ) + elif len(output_logprobs) > len(output_tokens): + output_logprobs = output_logprobs[: len(output_tokens)] + + full_logprobs = [0.0] * prompt_len + output_logprobs + + return SequenceNode( + full_text=full_text, + tokens=full_tokens, + masked_tokens=masked_tokens, + logprobs=full_logprobs, + metadata={"finish_reason": finish_reason}, + ) + + async def chat_completion(self, **kwargs) -> ChatCompletion: + """ + Intercept chat completion call and track sequences. + + Internally converts to prompt, calls tokens_and_logprobs_completion, + tracks the sequence, and reconstructs a ChatCompletion response. + + Args: + **kwargs: Standard chat completion kwargs (messages, n, etc.) + + Returns: + ChatCompletion response + """ + # Get input text + messages = kwargs.get("messages", []) + prompt = self._convert_messages_to_prompt(messages) + + # Handle parent node and extending logic based on mode + if self.track_tree: + # Tree mode: look up parent in dict + parent_node = self._find_parent_node(prompt) + extending_node = None + else: + # Default mode: check if extending existing sequence + extending_node = self._find_extending_node(prompt) + parent_node = None # Don't use parent merging in default mode + + # Convert to completion format + completion_kwargs = kwargs.copy() + completion_kwargs["prompt"] = prompt + completion_kwargs.pop("messages", None) + + # Set model name if not provided + if "model" not in completion_kwargs: + completion_kwargs["model"] = self.server.config.model_name + + # Compute input_ids (using existing tokens if extending) + if not self.track_tree and self.tokenizer is not None: + input_ids = self._compute_input_ids(prompt, extending_node) + completion_kwargs["input_ids"] = input_ids + + # Call the tokens and logprobs wrapper directly + ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons, + ) = await self.server._tokens_and_logprobs_completion_wrapper( + **completion_kwargs + ) + + # Track each completion and build choices + n = len(output_tokens_list) + choices = [] + + for i in range(n): + output_tokens = output_tokens_list[i] + output_logprobs = output_logprobs_list[i] + finish_reason_raw = finish_reasons[i] if i < len(finish_reasons) else "stop" + + # Extract finish_reason string from dict if needed + if isinstance(finish_reason_raw, dict): + finish_reason = finish_reason_raw.get("type", "stop") + else: + finish_reason = finish_reason_raw + + # Decode completion text + if self.tokenizer is not None: + completion_text = self.tokenizer.decode( + output_tokens, skip_special_tokens=True + ) + else: + completion_text = "".join([chr(t) for t in output_tokens if t > 31]) + + # Create and store sequence node + node = self._create_sequence_node( + input_text=prompt, + parent_node=parent_node, + prompt_tokens=prompt_tokens, + output_tokens=output_tokens, + output_logprobs=output_logprobs, + completion_text=completion_text, + finish_reason=finish_reason, + ) + + # Store node based on mode + if self.track_tree: + # Tree mode: key by full text in dict + self.sequences[node.full_text] = node + else: + # Default mode: replace if extending, append if new context + if extending_node is not None: + # Replace the extending node with the new extended version + try: + idx = self.current_nodes.index(extending_node) + self.current_nodes[idx] = node + except ValueError: + # Extending node not in list anymore, just append + self.current_nodes.append(node) + else: + # New context - append to list + self.current_nodes.append(node) + + # Build choice + choice = Choice( + finish_reason=finish_reason, + index=i, + message=ChatCompletionMessage(content=completion_text, role="assistant"), + ) + choices.append(choice) + + # Construct ChatCompletion response + return ChatCompletion( + id=str(uuid.uuid4()), + created=int(time.time()), + model=self.server.config.model_name, + object="chat.completion", + choices=choices, + ) + + async def completion(self, **kwargs) -> Completion: + """ + Intercept completion call and track sequences. + + Uses tokens_and_logprobs_completion internally, tracks the sequence, + and reconstructs a Completion response. + + Args: + **kwargs: Standard completion kwargs (prompt, n, etc.) + + Returns: + Completion response + """ + # Get input text + prompt = kwargs.get("prompt", "") + + # Handle parent node and extending logic based on mode + if self.track_tree: + # Tree mode: look up parent in dict + parent_node = self._find_parent_node(prompt) + extending_node = None + else: + # Default mode: check if extending existing sequence + extending_node = self._find_extending_node(prompt) + parent_node = None # Don't use parent merging in default mode + + # Set model name if not provided + if "model" not in kwargs: + kwargs["model"] = self.server.config.model_name + + # Compute input_ids (using existing tokens if extending) + if not self.track_tree and self.tokenizer is not None: + input_ids = self._compute_input_ids(prompt, extending_node) + kwargs["input_ids"] = input_ids + + # Call the tokens and logprobs wrapper directly + ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons, + ) = await self.server._tokens_and_logprobs_completion_wrapper(**kwargs) + + # Track each completion and build choices + n = len(output_tokens_list) + choices = [] + + for i in range(n): + output_tokens = output_tokens_list[i] + output_logprobs = output_logprobs_list[i] + finish_reason_raw = finish_reasons[i] if i < len(finish_reasons) else "stop" + + # Extract finish_reason string from dict if needed + if isinstance(finish_reason_raw, dict): + finish_reason = finish_reason_raw.get("type", "stop") + else: + finish_reason = finish_reason_raw + + # Decode completion text + if self.tokenizer is not None: + completion_text = self.tokenizer.decode( + output_tokens, skip_special_tokens=True + ) + else: + completion_text = "".join([chr(t) for t in output_tokens if t > 31]) + + # Create and store sequence node + node = self._create_sequence_node( + input_text=prompt, + parent_node=parent_node, + prompt_tokens=prompt_tokens, + output_tokens=output_tokens, + output_logprobs=output_logprobs, + completion_text=completion_text, + finish_reason=finish_reason, + ) + + # Store node based on mode + if self.track_tree: + # Tree mode: key by full text in dict + self.sequences[node.full_text] = node + else: + # Default mode: replace if extending, append if new context + if extending_node is not None: + # Replace the extending node with the new extended version + try: + idx = self.current_nodes.index(extending_node) + self.current_nodes[idx] = node + except ValueError: + # Extending node not in list anymore, just append + self.current_nodes.append(node) + else: + # New context - append to list + self.current_nodes.append(node) + + # Build choice + choice = CompletionChoice( + finish_reason=finish_reason, index=i, text=completion_text + ) + choices.append(choice) + + # Construct Completion response + return Completion( + id=str(uuid.uuid4()), + created=int(time.time()), + model=self.server.config.model_name, + object="text_completion", + choices=choices, + ) + + def get_state(self) -> Dict[str, Any]: + """ + Get the current state of tracked sequences. + + Returns: + For default mode (track_tree=False): + Dictionary with 'nodes': List[SequenceNode] - ready for training + For tree mode (track_tree=True): + Dictionary with 'sequences': Dict[str, SequenceNode] and 'tree' alias + """ + if self.track_tree: + return { + "sequences": self.sequences.copy(), + "tree": self.sequences.copy(), # Alias for compatibility + } + else: + return { + "nodes": self.current_nodes.copy(), # Return a copy so reset() doesn't affect it + } + + def reset(self): + """Clear all tracked sequences.""" + if self.track_tree: + self.sequences.clear() + else: + self.current_nodes.clear() diff --git a/atroposlib/envs/server_handling/server_harness.py b/atroposlib/envs/server_handling/server_harness.py index b0807d98..af004d02 100644 --- a/atroposlib/envs/server_handling/server_harness.py +++ b/atroposlib/envs/server_handling/server_harness.py @@ -87,6 +87,7 @@ def create_completion( class ServerHarness: def __init__(self): self.response_map = dict() + self.tokens_and_logprobs_map = dict() # Map for tokens/logprobs responses self.sem = asyncio.Semaphore(1) self.eval_sem = asyncio.Semaphore(1) pass @@ -110,6 +111,31 @@ class ServerHarness: def set_desired_completion(self, input_message: str, completion: Completion): self.response_map[input_message] = completion + def set_tokens_and_logprobs_response( + self, + prompt: str, + prompt_tokens: list, + output_tokens_list: list, + output_logprobs_list: list, + finish_reasons: list, + ): + """ + Set expected response for _tokens_and_logprobs_completion_wrapper. + + Args: + prompt: The prompt string (key) + prompt_tokens: List of prompt token IDs + output_tokens_list: List of lists of output token IDs (one per completion) + output_logprobs_list: List of lists of output logprobs (one per completion) + finish_reasons: List of finish reasons (one per completion) + """ + self.tokens_and_logprobs_map[prompt] = ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons, + ) + async def chat_completion(self, *args, **kwargs) -> ChatCompletion: messages = kwargs.get("messages") dictkey = self.conv_to_dictkey(messages) @@ -125,6 +151,21 @@ class ServerHarness: except KeyError as e: raise KeyError(f"KeyError: {e} for key:\n{prompt}") + async def _tokens_and_logprobs_completion_wrapper( + self, **kwargs + ) -> tuple[list, list, list, list]: + """ + Mock implementation of tokens and logprobs completion wrapper. + + Returns: + Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons) + """ + prompt = kwargs.get("prompt") + try: + return self.tokens_and_logprobs_map.get(prompt) + except KeyError as e: + raise KeyError(f"KeyError: {e} for prompt:\n{prompt}") + if __name__ == "__main__": diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 41ed3bf5..d2c89ab8 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -8,6 +8,7 @@ from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion from pydantic import BaseModel, Field +from atroposlib.envs.server_handling.managed_server import ManagedServer from atroposlib.envs.server_handling.openai_server import OpenAIServer from atroposlib.envs.server_handling.server_baseline import ( APIServer, @@ -308,3 +309,50 @@ class ServerManager: yield self.servers[most_available_server] finally: pass + + @asynccontextmanager + async def managed_server( + self, tokenizer=None + ) -> AsyncGenerator[ManagedServer, None]: + """ + Context manager that provides a ManagedServer instance. + + The ManagedServer wraps the most available server and tracks text sequences + with aligned tokens and logprobs. State is automatically cleared on exit. + + Args: + tokenizer: Optional tokenizer to use. If not provided, will attempt to + extract from server or create from model name. + + Yields: + ManagedServer instance wrapping the selected server + + Example: + async with server_manager.managed_server() as managed: + response = await managed.chat_completion( + messages=[{"role": "user", "content": "Hello"}], + n=2 + ) + state = managed.get_state() + # Process state... + # State is automatically cleared when exiting context + """ + most_available_server = 0 + most_available_server_num_slots = -1 + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if server.sem._value > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = server.sem._value + + # Create ManagedServer wrapping the selected server + managed = ManagedServer( + server=self.servers[most_available_server], tokenizer=tokenizer + ) + + try: + yield managed + finally: + # Clean up: reset tracked sequences + managed.reset() diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index 1a2f1362..1e838374 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -148,14 +148,24 @@ class SGLangServer(APIServer): kwargs.get("model", None) is not None ), "Model is required for completion!" assert ( - kwargs.get("prompt", None) is not None - ), "Prompt is required for completion!" + kwargs.get("prompt", None) is not None or kwargs.get("input_ids", None) is not None + ), "Prompt or input_ids is required for completion!" # Get n parameter for number of completions n = kwargs.get("n", 1) - prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt")) + + # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt + if "input_ids" in kwargs: + prompt_tokens = kwargs.pop("input_ids") + kwargs.pop("prompt", None) # Remove prompt if it exists + else: + prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt")) + # Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token - if prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]: + if ( + len(prompt_tokens) >= 2 + and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1] + ): prompt_tokens = prompt_tokens[1:] if "max_tokens" in kwargs: kwargs["max_new_tokens"] = kwargs.pop("max_tokens") diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py new file mode 100644 index 00000000..b8a14046 --- /dev/null +++ b/atroposlib/tests/test_managed_server.py @@ -0,0 +1,498 @@ +"""Tests for ManagedServer tracking of sequences with tokens and logprobs.""" + +import pytest + +from atroposlib.envs.server_handling.managed_server import ManagedServer +from atroposlib.envs.server_handling.server_harness import ServerHarness + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self): + self.eos_token_id = 2 + self.bos_token_id = 1 + + def encode(self, text, add_special_tokens=True): + """Simple character-based encoding for testing.""" + tokens = [ord(c) for c in text] + if add_special_tokens: + tokens = [self.bos_token_id] + tokens + return tokens + + def decode(self, tokens, skip_special_tokens=False): + """Simple character-based decoding for testing.""" + if skip_special_tokens: + # Filter out special tokens + tokens = [ + t for t in tokens if t not in [self.bos_token_id, self.eos_token_id] + ] + return "".join([chr(t) if t > 31 else "" for t in tokens]) + + def apply_chat_template( + self, messages, tokenize=False, add_generation_prompt=True + ): + """Simple chat template for testing.""" + result = "" + for msg in messages: + result += f"<{msg['role']}>{msg['content']}" + if add_generation_prompt: + result += "" + if tokenize: + return self.encode(result) + return result + + +@pytest.fixture +def mock_server(): + """Create a mock server with a tokenizer.""" + server = ServerHarness() + server.tokenizer = MockTokenizer() + # Add config for compatibility + class Config: + model_name = "test_model" + + server.config = Config() + return server + + +@pytest.mark.asyncio +async def test_single_completion(mock_server): + """Test single completion tracking.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + prompt_tokens = mock_server.tokenizer.encode(prompt) + output_text = " World" + output_tokens = [ord(c) for c in output_text] # Don't include BOS + output_logprobs = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6] + + # Set up mock response + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + # Call managed server + result = await managed.completion(prompt=prompt) + + # Check response + assert result.choices[0].text == " World" + assert result.choices[0].finish_reason == "stop" + + # Check tracked state (default mode uses nodes list) + state = managed.get_state() + assert len(state["nodes"]) == 1 + + # Get the sequence node + node = state["nodes"][0] + full_text = prompt + output_text + + # Check structure + assert node.full_text == full_text + assert len(node.tokens) == len(prompt_tokens) + len(output_tokens) + + # Check masking: prompt should be -100, completion should have actual tokens + prompt_len = len(prompt_tokens) + assert all(t == -100 for t in node.masked_tokens[:prompt_len]) + assert node.masked_tokens[prompt_len:] == output_tokens + + # Check logprobs: prompt should be 0.0, completion should have actual logprobs + assert all(lp == 0.0 for lp in node.logprobs[:prompt_len]) + assert node.logprobs[prompt_len:] == output_logprobs + + +@pytest.mark.asyncio +async def test_chat_completion(mock_server): + """Test chat completion with message conversion.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + messages = [{"role": "user", "content": "Hello"}] + prompt = managed._convert_messages_to_prompt(messages) + prompt_tokens = mock_server.tokenizer.encode(prompt) + output_text = "Hi there!" + output_tokens = [ord(c) for c in output_text] + output_logprobs = [-0.1] * len(output_tokens) + + # Set up mock response + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + # Call managed server + result = await managed.chat_completion(messages=messages) + + # Check response + assert result.choices[0].message.content == output_text + assert result.choices[0].message.role == "assistant" + + # Check tracked state + state = managed.get_state() + assert len(state["nodes"]) == 1 + + # Verify tokens are properly tracked + node = state["nodes"][0] + prompt_len = len(prompt_tokens) + assert all(t == -100 for t in node.masked_tokens[:prompt_len]) + + +@pytest.mark.asyncio +async def test_multi_turn_conversation(mock_server): + """Test multi-turn conversation with parent node merging.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + # Turn 1 + prompt_1 = "Hello" + prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) + output_1 = " World" + output_tokens_1 = [ord(c) for c in output_1] + output_logprobs_1 = [-0.1] * len(output_tokens_1) + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_1, + prompt_tokens=prompt_tokens_1, + output_tokens_list=[output_tokens_1], + output_logprobs_list=[output_logprobs_1], + finish_reasons=["stop"], + ) + + await managed.completion(prompt=prompt_1) + + # Turn 2: extends turn 1 + prompt_2 = "Hello World" # Parent's full_text + # In a real scenario, prompt_tokens would be from parent node + full_tokens_from_turn_1 = prompt_tokens_1 + output_tokens_1 + output_2 = "!" + output_tokens_2 = [ord(c) for c in output_2] + output_logprobs_2 = [-0.2] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_2, + prompt_tokens=full_tokens_from_turn_1, # Use parent's full tokens + output_tokens_list=[output_tokens_2], + output_logprobs_list=[output_logprobs_2], + finish_reasons=["stop"], + ) + + await managed.completion(prompt=prompt_2) + + # Check state - turn 2 extended turn 1, so it should replace that node + state = managed.get_state() + assert len(state["nodes"]) == 1 + + # Check the extended node + node = state["nodes"][0] + assert node.full_text == "Hello World!" + + # Check the extended node has correct tokens + # Tokens should be: turn_1_full + output_2 + assert len(node.tokens) == len(full_tokens_from_turn_1) + len(output_tokens_2) + + +@pytest.mark.asyncio +async def test_branching_with_n(mock_server): + """Test branching when n > 1.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Hello" + prompt_tokens = mock_server.tokenizer.encode(prompt) + + # Three different completions + output_texts = [" World", " There", " Friend"] + output_tokens_list = [[ord(c) for c in text] for text in output_texts] + output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=output_tokens_list, + output_logprobs_list=output_logprobs_list, + finish_reasons=["stop", "stop", "stop"], + ) + + result = await managed.completion(prompt=prompt, n=3) + + # Check we got 3 completions + assert len(result.choices) == 3 + + # Check state has 3 nodes (one per branch) + state = managed.get_state() + assert len(state["nodes"]) == 3 + + # Verify each node has different text + full_texts = {node.full_text for node in state["nodes"]} + assert full_texts == {"Hello World", "Hello There", "Hello Friend"} + + +@pytest.mark.asyncio +async def test_bos_token_handling(mock_server): + """Test that BOS token is only at the start of sequence.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Test" + # Tokenizer adds BOS + prompt_tokens = mock_server.tokenizer.encode(prompt) # [1, 84, 101, 115, 116] + assert prompt_tokens[0] == mock_server.tokenizer.bos_token_id + + output_text = "ing" + output_tokens = [ord(c) for c in output_text] # Should NOT have BOS + output_logprobs = [-0.1] * len(output_tokens) + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + await managed.completion(prompt=prompt) + + # Check sequence + state = managed.get_state() + assert len(state["nodes"]) == 1 + node = state["nodes"][0] + + # Should have exactly one BOS at the start + assert node.tokens[0] == mock_server.tokenizer.bos_token_id + # And no BOS in the rest of the sequence + assert mock_server.tokenizer.bos_token_id not in node.tokens[1:] + + +@pytest.mark.asyncio +async def test_reset_clears_sequences(mock_server): + """Test that reset() clears all tracked sequences.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + prompt = "Test" + prompt_tokens = mock_server.tokenizer.encode(prompt) + output_tokens = [ord("!")] + output_logprobs = [-0.1] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt, + prompt_tokens=prompt_tokens, + output_tokens_list=[output_tokens], + output_logprobs_list=[output_logprobs], + finish_reasons=["stop"], + ) + + await managed.completion(prompt=prompt) + + # Verify nodes exist + state = managed.get_state() + assert len(state["nodes"]) > 0 + + # Reset + managed.reset() + + # Verify nodes are cleared + state = managed.get_state() + assert len(state["nodes"]) == 0 + + +@pytest.mark.asyncio +async def test_tokenizer_initialization_from_server(mock_server): + """Test that tokenizer is initialized from server if available.""" + managed = ManagedServer(mock_server) # Don't pass tokenizer + + # Should have gotten tokenizer from server + assert managed.tokenizer is not None + assert managed.tokenizer == mock_server.tokenizer + + +@pytest.mark.asyncio +async def test_input_ids_extension(mock_server): + """Test that input_ids are computed correctly when extending sequences.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + # Turn 1 + prompt_1 = "Hello" + prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) # [1, 72, 101, 108, 108, 111] + output_1 = " World" + output_tokens_1 = [ord(c) for c in output_1] + output_logprobs_1 = [-0.1] * len(output_tokens_1) + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_1, + prompt_tokens=prompt_tokens_1, + output_tokens_list=[output_tokens_1], + output_logprobs_list=[output_logprobs_1], + finish_reasons=["stop"], + ) + + await managed.completion(prompt=prompt_1) + + # Turn 2: extends turn 1 with new text + prompt_2 = "Hello World!" # Extends "Hello World" with "!" + # The input_ids should be: existing_node_tokens + tokenize("!") + node_1 = managed.current_nodes[0] + expected_input_ids = node_1.tokens + mock_server.tokenizer.encode("!", add_special_tokens=False) + + output_2 = " Yay" + output_tokens_2 = [ord(c) for c in output_2] + output_logprobs_2 = [-0.2] * len(output_tokens_2) + + # The server should receive the computed input_ids + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_2, + prompt_tokens=expected_input_ids, # Should match what ManagedServer computes! + output_tokens_list=[output_tokens_2], + output_logprobs_list=[output_logprobs_2], + finish_reasons=["stop"], + ) + + result = await managed.completion(prompt=prompt_2) + + # Verify the response + assert result.choices[0].text == " Yay" + + # Verify we have 1 node (turn 2 replaced turn 1 since it extended) + state = managed.get_state() + assert len(state["nodes"]) == 1 + + # Verify the node has the correct combined tokens + node = state["nodes"][0] + assert node.tokens == expected_input_ids + output_tokens_2 + + +@pytest.mark.asyncio +async def test_multi_turn_chat_with_branching(mock_server): + """ + Test complex multi-turn scenario: + - Turn 1: n=8 group completion → 8 nodes (1 assistant turn each) + - Turn 2: 8 individual calls extending each → extends those 8 nodes (2 assistant turns each) + - Turn 3: Add system prompt, 8 calls → 8 NEW nodes (different context, 3 assistant turns each) + Final: 16 nodes total + """ + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + + # Turn 1: Single completion with n=8 + messages_1 = [{"role": "user", "content": "Hello"}] + prompt_1 = managed._convert_messages_to_prompt(messages_1) + prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) + + # Create 8 different responses + responses_1 = [f"Response{i}" for i in range(8)] + output_tokens_1 = [[ord(c) for c in resp] for resp in responses_1] + output_logprobs_1 = [[-0.1] * len(tokens) for tokens in output_tokens_1] + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_1, + prompt_tokens=prompt_tokens_1, + output_tokens_list=output_tokens_1, + output_logprobs_list=output_logprobs_1, + finish_reasons=["stop"] * 8, + ) + + await managed.chat_completion(messages=messages_1, n=8) + + # After turn 1: should have 8 nodes + state = managed.get_state() + assert len(state["nodes"]) == 8 + + # Save references to turn 1 nodes for later verification + turn_1_nodes = [node.full_text for node in state["nodes"]] + + # Turn 2: For each of the 8 nodes, extend with another user+assistant turn + for i in range(8): + messages_2 = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": responses_1[i]}, + {"role": "user", "content": "Continue"}, + ] + prompt_2 = managed._convert_messages_to_prompt(messages_2) + + # This prompt extends turn 1's output, so input_ids should use existing tokens + extending_node = state["nodes"][i] + # The new part is just the user turn + new_suffix = prompt_2[len(extending_node.full_text):] + expected_input_ids = extending_node.tokens + mock_server.tokenizer.encode( + new_suffix, add_special_tokens=False + ) + + response_2 = f"Continued{i}" + output_tokens_2 = [ord(c) for c in response_2] + output_logprobs_2 = [-0.2] * len(output_tokens_2) + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_2, + prompt_tokens=expected_input_ids, + output_tokens_list=[output_tokens_2], + output_logprobs_list=[output_logprobs_2], + finish_reasons=["stop"], + ) + + await managed.chat_completion(messages=messages_2, n=1) + + # After turn 2: still 8 nodes (they were extended/replaced, not added) + state = managed.get_state() + assert len(state["nodes"]) == 8 + + # Verify turn 2 nodes have 2 assistant turns each + for i in range(8): + node = state["nodes"][i] + assert f"Response{i}" in node.full_text + assert f"Continued{i}" in node.full_text + + # Turn 3: Add system prompt at start - this creates a DIFFERENT context + # These won't extend because the prefix is different (system prompt added) + for i in range(8): + messages_3 = [ + {"role": "system", "content": "Helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": responses_1[i]}, + {"role": "user", "content": "Continue"}, + {"role": "assistant", "content": f"Continued{i}"}, + {"role": "user", "content": "More"}, + ] + prompt_3 = managed._convert_messages_to_prompt(messages_3) + prompt_tokens_3 = mock_server.tokenizer.encode(prompt_3) + + response_3 = f"More{i}" + output_tokens_3 = [ord(c) for c in response_3] + output_logprobs_3 = [-0.3] * len(output_tokens_3) + + mock_server.set_tokens_and_logprobs_response( + prompt=prompt_3, + prompt_tokens=prompt_tokens_3, + output_tokens_list=[output_tokens_3], + output_logprobs_list=[output_logprobs_3], + finish_reasons=["stop"], + ) + + await managed.chat_completion(messages=messages_3, n=1) + + # After turn 3: 8 (turn 2 nodes) + 8 (turn 3 new context) = 16 nodes total! + state = managed.get_state() + assert len(state["nodes"]) == 16 + + # Verify structure: + # First 8 nodes: 2 assistant turns (no system prompt) + for i in range(8): + node = state["nodes"][i] + assert "Helpful" not in node.full_text # No system prompt + assert f"Response{i}" in node.full_text + assert f"Continued{i}" in node.full_text + assert f"More{i}" not in node.full_text # Not the third turn + + # Last 8 nodes: 3 assistant turns (with system prompt) + for i in range(8, 16): + node = state["nodes"][i] + actual_i = i - 8 + assert "Helpful" in node.full_text # Has system prompt + assert f"Response{actual_i}" in node.full_text + assert f"Continued{actual_i}" in node.full_text + assert f"More{actual_i}" in node.full_text # Has third turn + + +if __name__ == "__main__": + # Run tests + pytest.main([__file__, "-v"]) diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 7596e6cf..93b7fda1 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -315,8 +315,10 @@ class MathEnv(BaseEnv): curr_length = int(curr_length * (self.curr_step / self.config.total_steps)) curr_length += self.config.start_tok_length thinking_len = min(thinking_len, curr_length) - prompt_tokens, out_tokens, out_logprobs, finish_reasons = ( - await self.server.tokens_and_logprobs_completion( + + # Use managed server for automatic token/logprob tracking + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.completion( prompt=user_prompt, n=self.config.group_size, max_tokens=thinking_len, @@ -324,22 +326,23 @@ class MathEnv(BaseEnv): top_p=1.0, stop=stop_list, ) - ) - # print(completions, flush=True) + + # Get tracked sequences with aligned tokens and logprobs + state = managed.get_state() + nodes = state["nodes"] + + # Extract data from SequenceNodes for scoring to_score = list() to_backlog = list() - for i, (tokens, logprobs, finish_reason) in enumerate( - zip(out_tokens, out_logprobs, finish_reasons) - ): - message = self.tokenizer.decode(prompt_tokens + tokens) + for i, (choice, node) in enumerate(zip(completion.choices, nodes)): to_score.append( ( - message, - item[1], - finish_reason, - prompt_tokens, - tokens, - logprobs, + node.full_text, # Complete text (prompt + completion) + item[1], # Answer + choice.finish_reason, # finish_reason (already a clean string) + node.tokens, # all tokens (prompt + completion) + node.masked_tokens, # masked tokens (already formatted correctly) + node.logprobs, # logprobs (already formatted correctly) ) ) to_postprocess = await self.score(to_score) @@ -376,11 +379,13 @@ class MathEnv(BaseEnv): for item in rollout_group_data: scores["overrides"].append(dict()) resp = item[0] - finish_reason = item[2] - user_prompt_tokens = item[3] - out_toks = item[4] - out_logps = item[5] - if item[2]["type"] == "length": + finish_reason = item[2] # Now a clean string like "stop" or "length" + # ManagedServer already provides properly formatted data + tokens = item[3] # Full token sequence + masks = item[4] # Masked tokens (already formatted) + inf_logp = item[5] # Logprobs (already formatted) + + if finish_reason == "length": reward = False if self.config.mask_too_long_completions: scores["overrides"][-1]["set_advantage_to_zero"] = True @@ -389,11 +394,7 @@ class MathEnv(BaseEnv): reward = await task if reward is None: return None - tokens = user_prompt_tokens + out_toks - masks = [-100 for _ in range(len(user_prompt_tokens))] - masks = masks + out_toks - inf_logp = [0 for _ in range(len(user_prompt_tokens))] - inf_logp = inf_logp + out_logps + assert len(inf_logp) == len( masks ), f"{len(inf_logp)}, {len(masks)} mismatch" @@ -405,7 +406,7 @@ class MathEnv(BaseEnv): # remove obviously bad examples if len([1 for i in masks if i != -100]) < 10: continue - if (item[2] == "length") and (not self.config.mask_too_long_completions): + if (finish_reason == "length") and (not self.config.mask_too_long_completions): scores["overrides"][-1]["set_advantage_to_zero"] = True scores["tokens"].append(tokens) scores["masks"].append(masks)