mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add managed server to make grabbing logprobs easier w/ tokenized items
This commit is contained in:
parent
312f8859e3
commit
7bf4cfbf80
6 changed files with 1138 additions and 29 deletions
511
atroposlib/envs/server_handling/managed_server.py
Normal file
511
atroposlib/envs/server_handling/managed_server.py
Normal file
|
|
@ -0,0 +1,511 @@
|
|||
"""
|
||||
Managed server wrapper that tracks text sequences with aligned tokens and logprobs.
|
||||
|
||||
This wrapper maintains a tree structure of sequences, where:
|
||||
- Each node represents a complete text sequence (prompt + completion)
|
||||
- Tokens and logprobs are tracked with proper masking for training
|
||||
- Branching occurs organically from different contexts and n > 1 completions
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion import Completion, CompletionChoice
|
||||
from pydantic import BaseModel
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServer
|
||||
|
||||
|
||||
class SequenceNode(BaseModel):
|
||||
"""
|
||||
A node in the sequence tree representing a complete text sequence.
|
||||
|
||||
Attributes:
|
||||
full_text: Complete text (prompt + completion)
|
||||
tokens: Full token sequence (actual token IDs)
|
||||
masked_tokens: Tokens with -100 for prompt positions, actual IDs for completion
|
||||
logprobs: Logprobs with 0.0 for prompt positions, actual values for completion
|
||||
metadata: Optional metadata (e.g., role information, finish_reason, etc.)
|
||||
"""
|
||||
|
||||
full_text: str
|
||||
tokens: List[int]
|
||||
masked_tokens: List[int]
|
||||
logprobs: List[float]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ManagedServer:
|
||||
"""
|
||||
Wrapper around APIServer that tracks sequences with aligned tokens and logprobs.
|
||||
|
||||
Maintains a tree structure keyed by input text, where each completion creates
|
||||
new branches. Provides proper masking for training (prompt tokens masked with -100,
|
||||
logprobs set to 0.0).
|
||||
|
||||
Uses the clean _tokens_and_logprobs_completion_wrapper interface internally.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: APIServer,
|
||||
tokenizer: Optional[Any] = None,
|
||||
track_tree: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the managed server.
|
||||
|
||||
Args:
|
||||
server: The underlying APIServer instance to wrap
|
||||
tokenizer: Optional tokenizer for encoding/decoding. If not provided,
|
||||
will attempt to extract from server or create from model name.
|
||||
track_tree: If True, maintains a tree structure with parent-child links
|
||||
(for multi-turn RL with per-step advantages). If False (default),
|
||||
maintains a simple list of current nodes that updates in-place.
|
||||
"""
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
|
||||
# Initialize storage based on mode
|
||||
if track_tree:
|
||||
self.sequences: Dict[str, SequenceNode] = {} # Tree mode: dict lookup
|
||||
else:
|
||||
self.current_nodes: List[SequenceNode] = [] # Default mode: simple list
|
||||
|
||||
# Try to get tokenizer from server if not provided
|
||||
if self.tokenizer is None:
|
||||
self._initialize_tokenizer()
|
||||
|
||||
def _initialize_tokenizer(self):
|
||||
"""Initialize tokenizer from server or model name."""
|
||||
# Check if the wrapped server has a tokenizer
|
||||
if hasattr(self.server, "tokenizer"):
|
||||
self.tokenizer = self.server.tokenizer
|
||||
else:
|
||||
# Try to create from model name
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name = self.server.config.model_name
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Could not initialize tokenizer: {e}. "
|
||||
"Sequence tracking will be limited without tokenizer."
|
||||
)
|
||||
self.tokenizer = None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Convert chat messages to prompt text using tokenizer's chat template.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
if self.tokenizer is None:
|
||||
# Fallback: simple concatenation
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
# Only add generation prompt if last message is not from assistant
|
||||
add_generation_prompt = (
|
||||
len(messages) == 0 or messages[-1].get("role") != "assistant"
|
||||
)
|
||||
|
||||
# Use the tokenizer's chat template
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
else:
|
||||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||
|
||||
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||
"""
|
||||
Find a node that this input extends (default mode).
|
||||
|
||||
Args:
|
||||
input_text: The input text to check
|
||||
|
||||
Returns:
|
||||
The node that input_text extends, or None if no match
|
||||
"""
|
||||
if not self.current_nodes:
|
||||
return None
|
||||
|
||||
# Check if any current node's full_text is a prefix of the input
|
||||
# This means the input is extending that node
|
||||
for node in self.current_nodes:
|
||||
if input_text.startswith(node.full_text):
|
||||
return node
|
||||
return None
|
||||
|
||||
def _compute_input_ids(
|
||||
self, input_text: str, extending_node: Optional[SequenceNode]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Compute input_ids for the prompt, using existing tokens if extending.
|
||||
|
||||
Args:
|
||||
input_text: The full input prompt text
|
||||
extending_node: Node being extended, if any
|
||||
|
||||
Returns:
|
||||
List of token IDs to use as input_ids
|
||||
"""
|
||||
if extending_node is not None:
|
||||
# Extending an existing sequence - use its tokens + tokenize the new part
|
||||
existing_text = extending_node.full_text
|
||||
new_text_suffix = input_text[len(existing_text) :]
|
||||
|
||||
# Tokenize only the new suffix (without BOS since we're continuing)
|
||||
if new_text_suffix:
|
||||
new_tokens = self.tokenizer.encode(
|
||||
new_text_suffix, add_special_tokens=False
|
||||
)
|
||||
return extending_node.tokens + new_tokens
|
||||
else:
|
||||
# No new text, just use existing tokens
|
||||
return extending_node.tokens.copy()
|
||||
else:
|
||||
# New sequence - tokenize the whole thing
|
||||
return self.tokenizer.encode(input_text, add_special_tokens=True)
|
||||
|
||||
def _find_parent_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||
"""
|
||||
Find a parent node whose full_text matches the input_text (tree mode).
|
||||
|
||||
Args:
|
||||
input_text: The input text to search for
|
||||
|
||||
Returns:
|
||||
Parent SequenceNode if found, None otherwise
|
||||
"""
|
||||
return self.sequences.get(input_text, None)
|
||||
|
||||
def _create_sequence_node(
|
||||
self,
|
||||
input_text: str,
|
||||
parent_node: Optional[SequenceNode],
|
||||
prompt_tokens: List[int],
|
||||
output_tokens: List[int],
|
||||
output_logprobs: List[float],
|
||||
completion_text: str,
|
||||
finish_reason: str = "stop",
|
||||
) -> SequenceNode:
|
||||
"""
|
||||
Create a sequence node with proper masking.
|
||||
|
||||
Args:
|
||||
input_text: The input prompt text
|
||||
parent_node: Parent node to extend from (if available)
|
||||
prompt_tokens: Token IDs for the prompt
|
||||
output_tokens: Token IDs for the output/completion
|
||||
output_logprobs: Logprobs for output tokens
|
||||
completion_text: The completion text
|
||||
finish_reason: Finish reason from server
|
||||
|
||||
Returns:
|
||||
SequenceNode with properly masked tokens and logprobs
|
||||
"""
|
||||
# Combine text
|
||||
full_text = input_text + completion_text
|
||||
|
||||
# If we have a parent node, we should use its tokens as the prompt base
|
||||
if parent_node is not None:
|
||||
# Use parent's full tokens as the prompt
|
||||
prompt_tokens = parent_node.tokens.copy()
|
||||
|
||||
# Combine tokens
|
||||
full_tokens = prompt_tokens + output_tokens
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
# Create masked tokens: -100 for prompt, actual IDs for completion
|
||||
masked_tokens = [-100] * prompt_len + output_tokens
|
||||
|
||||
# Create masked logprobs: 0.0 for prompt, actual for completion
|
||||
# Pad logprobs to match token length if needed
|
||||
if len(output_logprobs) < len(output_tokens):
|
||||
output_logprobs = output_logprobs + [0.0] * (
|
||||
len(output_tokens) - len(output_logprobs)
|
||||
)
|
||||
elif len(output_logprobs) > len(output_tokens):
|
||||
output_logprobs = output_logprobs[: len(output_tokens)]
|
||||
|
||||
full_logprobs = [0.0] * prompt_len + output_logprobs
|
||||
|
||||
return SequenceNode(
|
||||
full_text=full_text,
|
||||
tokens=full_tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=full_logprobs,
|
||||
metadata={"finish_reason": finish_reason},
|
||||
)
|
||||
|
||||
async def chat_completion(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Intercept chat completion call and track sequences.
|
||||
|
||||
Internally converts to prompt, calls tokens_and_logprobs_completion,
|
||||
tracks the sequence, and reconstructs a ChatCompletion response.
|
||||
|
||||
Args:
|
||||
**kwargs: Standard chat completion kwargs (messages, n, etc.)
|
||||
|
||||
Returns:
|
||||
ChatCompletion response
|
||||
"""
|
||||
# Get input text
|
||||
messages = kwargs.get("messages", [])
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: look up parent in dict
|
||||
parent_node = self._find_parent_node(prompt)
|
||||
extending_node = None
|
||||
else:
|
||||
# Default mode: check if extending existing sequence
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
parent_node = None # Don't use parent merging in default mode
|
||||
|
||||
# Convert to completion format
|
||||
completion_kwargs = kwargs.copy()
|
||||
completion_kwargs["prompt"] = prompt
|
||||
completion_kwargs.pop("messages", None)
|
||||
|
||||
# Set model name if not provided
|
||||
if "model" not in completion_kwargs:
|
||||
completion_kwargs["model"] = self.server.config.model_name
|
||||
|
||||
# Compute input_ids (using existing tokens if extending)
|
||||
if not self.track_tree and self.tokenizer is not None:
|
||||
input_ids = self._compute_input_ids(prompt, extending_node)
|
||||
completion_kwargs["input_ids"] = input_ids
|
||||
|
||||
# Call the tokens and logprobs wrapper directly
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server._tokens_and_logprobs_completion_wrapper(
|
||||
**completion_kwargs
|
||||
)
|
||||
|
||||
# Track each completion and build choices
|
||||
n = len(output_tokens_list)
|
||||
choices = []
|
||||
|
||||
for i in range(n):
|
||||
output_tokens = output_tokens_list[i]
|
||||
output_logprobs = output_logprobs_list[i]
|
||||
finish_reason_raw = finish_reasons[i] if i < len(finish_reasons) else "stop"
|
||||
|
||||
# Extract finish_reason string from dict if needed
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
if self.tokenizer is not None:
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
)
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
|
||||
# Create and store sequence node
|
||||
node = self._create_sequence_node(
|
||||
input_text=prompt,
|
||||
parent_node=parent_node,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
output_logprobs=output_logprobs,
|
||||
completion_text=completion_text,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
# Store node based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: key by full text in dict
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
# Default mode: replace if extending, append if new context
|
||||
if extending_node is not None:
|
||||
# Replace the extending node with the new extended version
|
||||
try:
|
||||
idx = self.current_nodes.index(extending_node)
|
||||
self.current_nodes[idx] = node
|
||||
except ValueError:
|
||||
# Extending node not in list anymore, just append
|
||||
self.current_nodes.append(node)
|
||||
else:
|
||||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Build choice
|
||||
choice = Choice(
|
||||
finish_reason=finish_reason,
|
||||
index=i,
|
||||
message=ChatCompletionMessage(content=completion_text, role="assistant"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
# Construct ChatCompletion response
|
||||
return ChatCompletion(
|
||||
id=str(uuid.uuid4()),
|
||||
created=int(time.time()),
|
||||
model=self.server.config.model_name,
|
||||
object="chat.completion",
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
async def completion(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Intercept completion call and track sequences.
|
||||
|
||||
Uses tokens_and_logprobs_completion internally, tracks the sequence,
|
||||
and reconstructs a Completion response.
|
||||
|
||||
Args:
|
||||
**kwargs: Standard completion kwargs (prompt, n, etc.)
|
||||
|
||||
Returns:
|
||||
Completion response
|
||||
"""
|
||||
# Get input text
|
||||
prompt = kwargs.get("prompt", "")
|
||||
|
||||
# Handle parent node and extending logic based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: look up parent in dict
|
||||
parent_node = self._find_parent_node(prompt)
|
||||
extending_node = None
|
||||
else:
|
||||
# Default mode: check if extending existing sequence
|
||||
extending_node = self._find_extending_node(prompt)
|
||||
parent_node = None # Don't use parent merging in default mode
|
||||
|
||||
# Set model name if not provided
|
||||
if "model" not in kwargs:
|
||||
kwargs["model"] = self.server.config.model_name
|
||||
|
||||
# Compute input_ids (using existing tokens if extending)
|
||||
if not self.track_tree and self.tokenizer is not None:
|
||||
input_ids = self._compute_input_ids(prompt, extending_node)
|
||||
kwargs["input_ids"] = input_ids
|
||||
|
||||
# Call the tokens and logprobs wrapper directly
|
||||
(
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
) = await self.server._tokens_and_logprobs_completion_wrapper(**kwargs)
|
||||
|
||||
# Track each completion and build choices
|
||||
n = len(output_tokens_list)
|
||||
choices = []
|
||||
|
||||
for i in range(n):
|
||||
output_tokens = output_tokens_list[i]
|
||||
output_logprobs = output_logprobs_list[i]
|
||||
finish_reason_raw = finish_reasons[i] if i < len(finish_reasons) else "stop"
|
||||
|
||||
# Extract finish_reason string from dict if needed
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = finish_reason_raw
|
||||
|
||||
# Decode completion text
|
||||
if self.tokenizer is not None:
|
||||
completion_text = self.tokenizer.decode(
|
||||
output_tokens, skip_special_tokens=True
|
||||
)
|
||||
else:
|
||||
completion_text = "".join([chr(t) for t in output_tokens if t > 31])
|
||||
|
||||
# Create and store sequence node
|
||||
node = self._create_sequence_node(
|
||||
input_text=prompt,
|
||||
parent_node=parent_node,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
output_logprobs=output_logprobs,
|
||||
completion_text=completion_text,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
# Store node based on mode
|
||||
if self.track_tree:
|
||||
# Tree mode: key by full text in dict
|
||||
self.sequences[node.full_text] = node
|
||||
else:
|
||||
# Default mode: replace if extending, append if new context
|
||||
if extending_node is not None:
|
||||
# Replace the extending node with the new extended version
|
||||
try:
|
||||
idx = self.current_nodes.index(extending_node)
|
||||
self.current_nodes[idx] = node
|
||||
except ValueError:
|
||||
# Extending node not in list anymore, just append
|
||||
self.current_nodes.append(node)
|
||||
else:
|
||||
# New context - append to list
|
||||
self.current_nodes.append(node)
|
||||
|
||||
# Build choice
|
||||
choice = CompletionChoice(
|
||||
finish_reason=finish_reason, index=i, text=completion_text
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
# Construct Completion response
|
||||
return Completion(
|
||||
id=str(uuid.uuid4()),
|
||||
created=int(time.time()),
|
||||
model=self.server.config.model_name,
|
||||
object="text_completion",
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current state of tracked sequences.
|
||||
|
||||
Returns:
|
||||
For default mode (track_tree=False):
|
||||
Dictionary with 'nodes': List[SequenceNode] - ready for training
|
||||
For tree mode (track_tree=True):
|
||||
Dictionary with 'sequences': Dict[str, SequenceNode] and 'tree' alias
|
||||
"""
|
||||
if self.track_tree:
|
||||
return {
|
||||
"sequences": self.sequences.copy(),
|
||||
"tree": self.sequences.copy(), # Alias for compatibility
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"nodes": self.current_nodes.copy(), # Return a copy so reset() doesn't affect it
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Clear all tracked sequences."""
|
||||
if self.track_tree:
|
||||
self.sequences.clear()
|
||||
else:
|
||||
self.current_nodes.clear()
|
||||
|
|
@ -87,6 +87,7 @@ def create_completion(
|
|||
class ServerHarness:
|
||||
def __init__(self):
|
||||
self.response_map = dict()
|
||||
self.tokens_and_logprobs_map = dict() # Map for tokens/logprobs responses
|
||||
self.sem = asyncio.Semaphore(1)
|
||||
self.eval_sem = asyncio.Semaphore(1)
|
||||
pass
|
||||
|
|
@ -110,6 +111,31 @@ class ServerHarness:
|
|||
def set_desired_completion(self, input_message: str, completion: Completion):
|
||||
self.response_map[input_message] = completion
|
||||
|
||||
def set_tokens_and_logprobs_response(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_tokens: list,
|
||||
output_tokens_list: list,
|
||||
output_logprobs_list: list,
|
||||
finish_reasons: list,
|
||||
):
|
||||
"""
|
||||
Set expected response for _tokens_and_logprobs_completion_wrapper.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string (key)
|
||||
prompt_tokens: List of prompt token IDs
|
||||
output_tokens_list: List of lists of output token IDs (one per completion)
|
||||
output_logprobs_list: List of lists of output logprobs (one per completion)
|
||||
finish_reasons: List of finish reasons (one per completion)
|
||||
"""
|
||||
self.tokens_and_logprobs_map[prompt] = (
|
||||
prompt_tokens,
|
||||
output_tokens_list,
|
||||
output_logprobs_list,
|
||||
finish_reasons,
|
||||
)
|
||||
|
||||
async def chat_completion(self, *args, **kwargs) -> ChatCompletion:
|
||||
messages = kwargs.get("messages")
|
||||
dictkey = self.conv_to_dictkey(messages)
|
||||
|
|
@ -125,6 +151,21 @@ class ServerHarness:
|
|||
except KeyError as e:
|
||||
raise KeyError(f"KeyError: {e} for key:\n{prompt}")
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[list, list, list, list]:
|
||||
"""
|
||||
Mock implementation of tokens and logprobs completion wrapper.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
"""
|
||||
prompt = kwargs.get("prompt")
|
||||
try:
|
||||
return self.tokens_and_logprobs_map.get(prompt)
|
||||
except KeyError as e:
|
||||
raise KeyError(f"KeyError: {e} for prompt:\n{prompt}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from openai.types.chat.chat_completion import ChatCompletion
|
|||
from openai.types.completion import Completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
|
|
@ -308,3 +309,50 @@ class ServerManager:
|
|||
yield self.servers[most_available_server]
|
||||
finally:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_server(
|
||||
self, tokenizer=None
|
||||
) -> AsyncGenerator[ManagedServer, None]:
|
||||
"""
|
||||
Context manager that provides a ManagedServer instance.
|
||||
|
||||
The ManagedServer wraps the most available server and tracks text sequences
|
||||
with aligned tokens and logprobs. State is automatically cleared on exit.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer to use. If not provided, will attempt to
|
||||
extract from server or create from model name.
|
||||
|
||||
Yields:
|
||||
ManagedServer instance wrapping the selected server
|
||||
|
||||
Example:
|
||||
async with server_manager.managed_server() as managed:
|
||||
response = await managed.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=2
|
||||
)
|
||||
state = managed.get_state()
|
||||
# Process state...
|
||||
# State is automatically cleared when exiting context
|
||||
"""
|
||||
most_available_server = 0
|
||||
most_available_server_num_slots = -1
|
||||
for i, server in enumerate(self.servers):
|
||||
if not server.server_healthy:
|
||||
continue
|
||||
if server.sem._value > most_available_server_num_slots:
|
||||
most_available_server = i
|
||||
most_available_server_num_slots = server.sem._value
|
||||
|
||||
# Create ManagedServer wrapping the selected server
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
# Clean up: reset tracked sequences
|
||||
managed.reset()
|
||||
|
|
|
|||
|
|
@ -148,14 +148,24 @@ class SGLangServer(APIServer):
|
|||
kwargs.get("model", None) is not None
|
||||
), "Model is required for completion!"
|
||||
assert (
|
||||
kwargs.get("prompt", None) is not None
|
||||
), "Prompt is required for completion!"
|
||||
kwargs.get("prompt", None) is not None or kwargs.get("input_ids", None) is not None
|
||||
), "Prompt or input_ids is required for completion!"
|
||||
|
||||
# Get n parameter for number of completions
|
||||
n = kwargs.get("n", 1)
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None) # Remove prompt if it exists
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS token, can happen if you use chat templates and forget that they insert a BOS token
|
||||
if prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]:
|
||||
if (
|
||||
len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]
|
||||
):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_new_tokens"] = kwargs.pop("max_tokens")
|
||||
|
|
|
|||
498
atroposlib/tests/test_managed_server.py
Normal file
498
atroposlib/tests/test_managed_server.py
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
"""Tests for ManagedServer tracking of sequences with tokens and logprobs."""
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
from atroposlib.envs.server_handling.server_harness import ServerHarness
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""Mock tokenizer for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.eos_token_id = 2
|
||||
self.bos_token_id = 1
|
||||
|
||||
def encode(self, text, add_special_tokens=True):
|
||||
"""Simple character-based encoding for testing."""
|
||||
tokens = [ord(c) for c in text]
|
||||
if add_special_tokens:
|
||||
tokens = [self.bos_token_id] + tokens
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens, skip_special_tokens=False):
|
||||
"""Simple character-based decoding for testing."""
|
||||
if skip_special_tokens:
|
||||
# Filter out special tokens
|
||||
tokens = [
|
||||
t for t in tokens if t not in [self.bos_token_id, self.eos_token_id]
|
||||
]
|
||||
return "".join([chr(t) if t > 31 else "" for t in tokens])
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize=False, add_generation_prompt=True
|
||||
):
|
||||
"""Simple chat template for testing."""
|
||||
result = ""
|
||||
for msg in messages:
|
||||
result += f"<{msg['role']}>{msg['content']}</{msg['role']}>"
|
||||
if add_generation_prompt:
|
||||
result += "<assistant>"
|
||||
if tokenize:
|
||||
return self.encode(result)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server():
|
||||
"""Create a mock server with a tokenizer."""
|
||||
server = ServerHarness()
|
||||
server.tokenizer = MockTokenizer()
|
||||
# Add config for compatibility
|
||||
class Config:
|
||||
model_name = "test_model"
|
||||
|
||||
server.config = Config()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_completion(mock_server):
|
||||
"""Test single completion tracking."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
output_text = " World"
|
||||
output_tokens = [ord(c) for c in output_text] # Don't include BOS
|
||||
output_logprobs = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6]
|
||||
|
||||
# Set up mock response
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
# Call managed server
|
||||
result = await managed.completion(prompt=prompt)
|
||||
|
||||
# Check response
|
||||
assert result.choices[0].text == " World"
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
|
||||
# Check tracked state (default mode uses nodes list)
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
# Get the sequence node
|
||||
node = state["nodes"][0]
|
||||
full_text = prompt + output_text
|
||||
|
||||
# Check structure
|
||||
assert node.full_text == full_text
|
||||
assert len(node.tokens) == len(prompt_tokens) + len(output_tokens)
|
||||
|
||||
# Check masking: prompt should be -100, completion should have actual tokens
|
||||
prompt_len = len(prompt_tokens)
|
||||
assert all(t == -100 for t in node.masked_tokens[:prompt_len])
|
||||
assert node.masked_tokens[prompt_len:] == output_tokens
|
||||
|
||||
# Check logprobs: prompt should be 0.0, completion should have actual logprobs
|
||||
assert all(lp == 0.0 for lp in node.logprobs[:prompt_len])
|
||||
assert node.logprobs[prompt_len:] == output_logprobs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion(mock_server):
|
||||
"""Test chat completion with message conversion."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
prompt = managed._convert_messages_to_prompt(messages)
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
output_text = "Hi there!"
|
||||
output_tokens = [ord(c) for c in output_text]
|
||||
output_logprobs = [-0.1] * len(output_tokens)
|
||||
|
||||
# Set up mock response
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
# Call managed server
|
||||
result = await managed.chat_completion(messages=messages)
|
||||
|
||||
# Check response
|
||||
assert result.choices[0].message.content == output_text
|
||||
assert result.choices[0].message.role == "assistant"
|
||||
|
||||
# Check tracked state
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
# Verify tokens are properly tracked
|
||||
node = state["nodes"][0]
|
||||
prompt_len = len(prompt_tokens)
|
||||
assert all(t == -100 for t in node.masked_tokens[:prompt_len])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_conversation(mock_server):
|
||||
"""Test multi-turn conversation with parent node merging."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
# Turn 1
|
||||
prompt_1 = "Hello"
|
||||
prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1)
|
||||
output_1 = " World"
|
||||
output_tokens_1 = [ord(c) for c in output_1]
|
||||
output_logprobs_1 = [-0.1] * len(output_tokens_1)
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_1,
|
||||
prompt_tokens=prompt_tokens_1,
|
||||
output_tokens_list=[output_tokens_1],
|
||||
output_logprobs_list=[output_logprobs_1],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.completion(prompt=prompt_1)
|
||||
|
||||
# Turn 2: extends turn 1
|
||||
prompt_2 = "Hello World" # Parent's full_text
|
||||
# In a real scenario, prompt_tokens would be from parent node
|
||||
full_tokens_from_turn_1 = prompt_tokens_1 + output_tokens_1
|
||||
output_2 = "!"
|
||||
output_tokens_2 = [ord(c) for c in output_2]
|
||||
output_logprobs_2 = [-0.2]
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_2,
|
||||
prompt_tokens=full_tokens_from_turn_1, # Use parent's full tokens
|
||||
output_tokens_list=[output_tokens_2],
|
||||
output_logprobs_list=[output_logprobs_2],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.completion(prompt=prompt_2)
|
||||
|
||||
# Check state - turn 2 extended turn 1, so it should replace that node
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
# Check the extended node
|
||||
node = state["nodes"][0]
|
||||
assert node.full_text == "Hello World!"
|
||||
|
||||
# Check the extended node has correct tokens
|
||||
# Tokens should be: turn_1_full + output_2
|
||||
assert len(node.tokens) == len(full_tokens_from_turn_1) + len(output_tokens_2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branching_with_n(mock_server):
|
||||
"""Test branching when n > 1."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Hello"
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
|
||||
# Three different completions
|
||||
output_texts = [" World", " There", " Friend"]
|
||||
output_tokens_list = [[ord(c) for c in text] for text in output_texts]
|
||||
output_logprobs_list = [[-0.1] * len(tokens) for tokens in output_tokens_list]
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=output_tokens_list,
|
||||
output_logprobs_list=output_logprobs_list,
|
||||
finish_reasons=["stop", "stop", "stop"],
|
||||
)
|
||||
|
||||
result = await managed.completion(prompt=prompt, n=3)
|
||||
|
||||
# Check we got 3 completions
|
||||
assert len(result.choices) == 3
|
||||
|
||||
# Check state has 3 nodes (one per branch)
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 3
|
||||
|
||||
# Verify each node has different text
|
||||
full_texts = {node.full_text for node in state["nodes"]}
|
||||
assert full_texts == {"Hello World", "Hello There", "Hello Friend"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bos_token_handling(mock_server):
|
||||
"""Test that BOS token is only at the start of sequence."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Test"
|
||||
# Tokenizer adds BOS
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt) # [1, 84, 101, 115, 116]
|
||||
assert prompt_tokens[0] == mock_server.tokenizer.bos_token_id
|
||||
|
||||
output_text = "ing"
|
||||
output_tokens = [ord(c) for c in output_text] # Should NOT have BOS
|
||||
output_logprobs = [-0.1] * len(output_tokens)
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.completion(prompt=prompt)
|
||||
|
||||
# Check sequence
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
node = state["nodes"][0]
|
||||
|
||||
# Should have exactly one BOS at the start
|
||||
assert node.tokens[0] == mock_server.tokenizer.bos_token_id
|
||||
# And no BOS in the rest of the sequence
|
||||
assert mock_server.tokenizer.bos_token_id not in node.tokens[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_sequences(mock_server):
|
||||
"""Test that reset() clears all tracked sequences."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
prompt = "Test"
|
||||
prompt_tokens = mock_server.tokenizer.encode(prompt)
|
||||
output_tokens = [ord("!")]
|
||||
output_logprobs = [-0.1]
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens_list=[output_tokens],
|
||||
output_logprobs_list=[output_logprobs],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.completion(prompt=prompt)
|
||||
|
||||
# Verify nodes exist
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) > 0
|
||||
|
||||
# Reset
|
||||
managed.reset()
|
||||
|
||||
# Verify nodes are cleared
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_initialization_from_server(mock_server):
|
||||
"""Test that tokenizer is initialized from server if available."""
|
||||
managed = ManagedServer(mock_server) # Don't pass tokenizer
|
||||
|
||||
# Should have gotten tokenizer from server
|
||||
assert managed.tokenizer is not None
|
||||
assert managed.tokenizer == mock_server.tokenizer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_ids_extension(mock_server):
|
||||
"""Test that input_ids are computed correctly when extending sequences."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
# Turn 1
|
||||
prompt_1 = "Hello"
|
||||
prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1) # [1, 72, 101, 108, 108, 111]
|
||||
output_1 = " World"
|
||||
output_tokens_1 = [ord(c) for c in output_1]
|
||||
output_logprobs_1 = [-0.1] * len(output_tokens_1)
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_1,
|
||||
prompt_tokens=prompt_tokens_1,
|
||||
output_tokens_list=[output_tokens_1],
|
||||
output_logprobs_list=[output_logprobs_1],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.completion(prompt=prompt_1)
|
||||
|
||||
# Turn 2: extends turn 1 with new text
|
||||
prompt_2 = "Hello World!" # Extends "Hello World" with "!"
|
||||
# The input_ids should be: existing_node_tokens + tokenize("!")
|
||||
node_1 = managed.current_nodes[0]
|
||||
expected_input_ids = node_1.tokens + mock_server.tokenizer.encode("!", add_special_tokens=False)
|
||||
|
||||
output_2 = " Yay"
|
||||
output_tokens_2 = [ord(c) for c in output_2]
|
||||
output_logprobs_2 = [-0.2] * len(output_tokens_2)
|
||||
|
||||
# The server should receive the computed input_ids
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_2,
|
||||
prompt_tokens=expected_input_ids, # Should match what ManagedServer computes!
|
||||
output_tokens_list=[output_tokens_2],
|
||||
output_logprobs_list=[output_logprobs_2],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
result = await managed.completion(prompt=prompt_2)
|
||||
|
||||
# Verify the response
|
||||
assert result.choices[0].text == " Yay"
|
||||
|
||||
# Verify we have 1 node (turn 2 replaced turn 1 since it extended)
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 1
|
||||
|
||||
# Verify the node has the correct combined tokens
|
||||
node = state["nodes"][0]
|
||||
assert node.tokens == expected_input_ids + output_tokens_2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_chat_with_branching(mock_server):
|
||||
"""
|
||||
Test complex multi-turn scenario:
|
||||
- Turn 1: n=8 group completion → 8 nodes (1 assistant turn each)
|
||||
- Turn 2: 8 individual calls extending each → extends those 8 nodes (2 assistant turns each)
|
||||
- Turn 3: Add system prompt, 8 calls → 8 NEW nodes (different context, 3 assistant turns each)
|
||||
Final: 16 nodes total
|
||||
"""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
|
||||
# Turn 1: Single completion with n=8
|
||||
messages_1 = [{"role": "user", "content": "Hello"}]
|
||||
prompt_1 = managed._convert_messages_to_prompt(messages_1)
|
||||
prompt_tokens_1 = mock_server.tokenizer.encode(prompt_1)
|
||||
|
||||
# Create 8 different responses
|
||||
responses_1 = [f"Response{i}" for i in range(8)]
|
||||
output_tokens_1 = [[ord(c) for c in resp] for resp in responses_1]
|
||||
output_logprobs_1 = [[-0.1] * len(tokens) for tokens in output_tokens_1]
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_1,
|
||||
prompt_tokens=prompt_tokens_1,
|
||||
output_tokens_list=output_tokens_1,
|
||||
output_logprobs_list=output_logprobs_1,
|
||||
finish_reasons=["stop"] * 8,
|
||||
)
|
||||
|
||||
await managed.chat_completion(messages=messages_1, n=8)
|
||||
|
||||
# After turn 1: should have 8 nodes
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 8
|
||||
|
||||
# Save references to turn 1 nodes for later verification
|
||||
turn_1_nodes = [node.full_text for node in state["nodes"]]
|
||||
|
||||
# Turn 2: For each of the 8 nodes, extend with another user+assistant turn
|
||||
for i in range(8):
|
||||
messages_2 = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": responses_1[i]},
|
||||
{"role": "user", "content": "Continue"},
|
||||
]
|
||||
prompt_2 = managed._convert_messages_to_prompt(messages_2)
|
||||
|
||||
# This prompt extends turn 1's output, so input_ids should use existing tokens
|
||||
extending_node = state["nodes"][i]
|
||||
# The new part is just the user turn
|
||||
new_suffix = prompt_2[len(extending_node.full_text):]
|
||||
expected_input_ids = extending_node.tokens + mock_server.tokenizer.encode(
|
||||
new_suffix, add_special_tokens=False
|
||||
)
|
||||
|
||||
response_2 = f"Continued{i}"
|
||||
output_tokens_2 = [ord(c) for c in response_2]
|
||||
output_logprobs_2 = [-0.2] * len(output_tokens_2)
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_2,
|
||||
prompt_tokens=expected_input_ids,
|
||||
output_tokens_list=[output_tokens_2],
|
||||
output_logprobs_list=[output_logprobs_2],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.chat_completion(messages=messages_2, n=1)
|
||||
|
||||
# After turn 2: still 8 nodes (they were extended/replaced, not added)
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 8
|
||||
|
||||
# Verify turn 2 nodes have 2 assistant turns each
|
||||
for i in range(8):
|
||||
node = state["nodes"][i]
|
||||
assert f"Response{i}" in node.full_text
|
||||
assert f"Continued{i}" in node.full_text
|
||||
|
||||
# Turn 3: Add system prompt at start - this creates a DIFFERENT context
|
||||
# These won't extend because the prefix is different (system prompt added)
|
||||
for i in range(8):
|
||||
messages_3 = [
|
||||
{"role": "system", "content": "Helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": responses_1[i]},
|
||||
{"role": "user", "content": "Continue"},
|
||||
{"role": "assistant", "content": f"Continued{i}"},
|
||||
{"role": "user", "content": "More"},
|
||||
]
|
||||
prompt_3 = managed._convert_messages_to_prompt(messages_3)
|
||||
prompt_tokens_3 = mock_server.tokenizer.encode(prompt_3)
|
||||
|
||||
response_3 = f"More{i}"
|
||||
output_tokens_3 = [ord(c) for c in response_3]
|
||||
output_logprobs_3 = [-0.3] * len(output_tokens_3)
|
||||
|
||||
mock_server.set_tokens_and_logprobs_response(
|
||||
prompt=prompt_3,
|
||||
prompt_tokens=prompt_tokens_3,
|
||||
output_tokens_list=[output_tokens_3],
|
||||
output_logprobs_list=[output_logprobs_3],
|
||||
finish_reasons=["stop"],
|
||||
)
|
||||
|
||||
await managed.chat_completion(messages=messages_3, n=1)
|
||||
|
||||
# After turn 3: 8 (turn 2 nodes) + 8 (turn 3 new context) = 16 nodes total!
|
||||
state = managed.get_state()
|
||||
assert len(state["nodes"]) == 16
|
||||
|
||||
# Verify structure:
|
||||
# First 8 nodes: 2 assistant turns (no system prompt)
|
||||
for i in range(8):
|
||||
node = state["nodes"][i]
|
||||
assert "Helpful" not in node.full_text # No system prompt
|
||||
assert f"Response{i}" in node.full_text
|
||||
assert f"Continued{i}" in node.full_text
|
||||
assert f"More{i}" not in node.full_text # Not the third turn
|
||||
|
||||
# Last 8 nodes: 3 assistant turns (with system prompt)
|
||||
for i in range(8, 16):
|
||||
node = state["nodes"][i]
|
||||
actual_i = i - 8
|
||||
assert "Helpful" in node.full_text # Has system prompt
|
||||
assert f"Response{actual_i}" in node.full_text
|
||||
assert f"Continued{actual_i}" in node.full_text
|
||||
assert f"More{actual_i}" in node.full_text # Has third turn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -315,8 +315,10 @@ class MathEnv(BaseEnv):
|
|||
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
|
||||
curr_length += self.config.start_tok_length
|
||||
thinking_len = min(thinking_len, curr_length)
|
||||
prompt_tokens, out_tokens, out_logprobs, finish_reasons = (
|
||||
await self.server.tokens_and_logprobs_completion(
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
|
|
@ -324,22 +326,23 @@ class MathEnv(BaseEnv):
|
|||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
)
|
||||
# print(completions, flush=True)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, (tokens, logprobs, finish_reason) in enumerate(
|
||||
zip(out_tokens, out_logprobs, finish_reasons)
|
||||
):
|
||||
message = self.tokenizer.decode(prompt_tokens + tokens)
|
||||
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
|
||||
to_score.append(
|
||||
(
|
||||
message,
|
||||
item[1],
|
||||
finish_reason,
|
||||
prompt_tokens,
|
||||
tokens,
|
||||
logprobs,
|
||||
node.full_text, # Complete text (prompt + completion)
|
||||
item[1], # Answer
|
||||
choice.finish_reason, # finish_reason (already a clean string)
|
||||
node.tokens, # all tokens (prompt + completion)
|
||||
node.masked_tokens, # masked tokens (already formatted correctly)
|
||||
node.logprobs, # logprobs (already formatted correctly)
|
||||
)
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
|
@ -376,11 +379,13 @@ class MathEnv(BaseEnv):
|
|||
for item in rollout_group_data:
|
||||
scores["overrides"].append(dict())
|
||||
resp = item[0]
|
||||
finish_reason = item[2]
|
||||
user_prompt_tokens = item[3]
|
||||
out_toks = item[4]
|
||||
out_logps = item[5]
|
||||
if item[2]["type"] == "length":
|
||||
finish_reason = item[2] # Now a clean string like "stop" or "length"
|
||||
# ManagedServer already provides properly formatted data
|
||||
tokens = item[3] # Full token sequence
|
||||
masks = item[4] # Masked tokens (already formatted)
|
||||
inf_logp = item[5] # Logprobs (already formatted)
|
||||
|
||||
if finish_reason == "length":
|
||||
reward = False
|
||||
if self.config.mask_too_long_completions:
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
|
|
@ -389,11 +394,7 @@ class MathEnv(BaseEnv):
|
|||
reward = await task
|
||||
if reward is None:
|
||||
return None
|
||||
tokens = user_prompt_tokens + out_toks
|
||||
masks = [-100 for _ in range(len(user_prompt_tokens))]
|
||||
masks = masks + out_toks
|
||||
inf_logp = [0 for _ in range(len(user_prompt_tokens))]
|
||||
inf_logp = inf_logp + out_logps
|
||||
|
||||
assert len(inf_logp) == len(
|
||||
masks
|
||||
), f"{len(inf_logp)}, {len(masks)} mismatch"
|
||||
|
|
@ -405,7 +406,7 @@ class MathEnv(BaseEnv):
|
|||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
if (item[2] == "length") and (not self.config.mask_too_long_completions):
|
||||
if (finish_reason == "length") and (not self.config.mask_too_long_completions):
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue