add env using the tool api stuff

This commit is contained in:
dmahan93 2026-03-03 19:51:30 -06:00
parent c8eb63f33d
commit 12d61d197f
15 changed files with 2632 additions and 21 deletions

View file

@ -143,6 +143,10 @@ class ManagedServer:
return None
return self._translator
# Placeholder used to protect <think> blocks from chat templates that strip them
_THINK_OPEN = "__MNGD_THINK__"
_THINK_CLOSE = "__MNGD_ENDTHINK__"
def _convert_messages_to_prompt(
self,
messages: List[Dict[str, str]],
@ -175,6 +179,11 @@ class ManagedServer:
len(messages) == 0 or messages[-1].get("role") != "assistant"
)
# Protect <think> blocks in assistant messages — some chat templates
# (e.g. Qwen3) strip them during re-rendering, which breaks prefix
# matching for multi-turn sequence extension.
messages = self._protect_think_blocks(messages)
# Build kwargs
template_kwargs = {
"tokenize": False,
@ -184,11 +193,41 @@ class ManagedServer:
template_kwargs["tools"] = tools
# Use the tokenizer's chat template
return self.tokenizer.apply_chat_template(messages, **template_kwargs)
prompt = self.tokenizer.apply_chat_template(messages, **template_kwargs)
# Restore <think> blocks
prompt = prompt.replace(self._THINK_OPEN, "<think>")
prompt = prompt.replace(self._THINK_CLOSE, "</think>")
return prompt
else:
# Fallback for tokenizers without chat template
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
def _protect_think_blocks(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
"""Replace <think>...</think> with placeholders in assistant messages.
Only touches assistant messages that already have content (i.e., messages
being replayed from prior turns, not the generation prompt). This prevents
chat templates from stripping or relocating think blocks.
"""
out = []
for msg in messages:
if (
msg.get("role") == "assistant"
and msg.get("content")
and "<think>" in msg["content"]
):
content = msg["content"]
content = content.replace("<think>", self._THINK_OPEN)
content = content.replace("</think>", self._THINK_CLOSE)
out.append({**msg, "content": content})
else:
out.append(msg)
return out
def _debug_requests_enabled(self) -> bool:
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
@ -265,18 +304,22 @@ class ManagedServer:
output_logprobs: List[float],
completion_text: str,
finish_reason: str = "stop",
extending_node: Optional[SequenceNode] = None,
) -> SequenceNode:
"""
Create a sequence node with proper masking.
Args:
input_text: The input prompt text
parent_node: Parent node to extend from (if available)
parent_node: Parent node to extend from (tree mode)
prompt_tokens: Token IDs for the prompt
output_tokens: Token IDs for the output/completion
output_logprobs: Logprobs for output tokens
completion_text: The completion text
finish_reason: Finish reason from server
extending_node: Node being extended (default mode). When provided,
carries forward its masked_tokens and logprobs so previous
completions stay unmasked across multi-turn extensions.
Returns:
SequenceNode with properly masked tokens and logprobs
@ -284,19 +327,6 @@ class ManagedServer:
# Combine text
full_text = input_text + completion_text
# If we have a parent node, we should use its tokens as the prompt base
if parent_node is not None:
# Use parent's full tokens as the prompt
prompt_tokens = parent_node.tokens.copy()
# Combine tokens
full_tokens = prompt_tokens + output_tokens
prompt_len = len(prompt_tokens)
# Create masked tokens: -100 for prompt, actual IDs for completion
masked_tokens = [-100] * prompt_len + output_tokens
# Create masked logprobs: 1.0 for prompt, actual for completion
# Pad logprobs to match token length if needed
if len(output_logprobs) < len(output_tokens):
output_logprobs = output_logprobs + [1.0] * (
@ -305,7 +335,30 @@ class ManagedServer:
elif len(output_logprobs) > len(output_tokens):
output_logprobs = output_logprobs[: len(output_tokens)]
full_logprobs = [1.0] * prompt_len + output_logprobs
# If we have a parent node (tree mode), use its tokens as the prompt base
if parent_node is not None:
prompt_tokens = parent_node.tokens.copy()
# Combine tokens
full_tokens = prompt_tokens + output_tokens
if extending_node is not None:
# Carry forward the extending node's mask and logprobs.
# The prompt_tokens = extending_node.tokens + new_suffix_tokens.
# We preserve the extending node's mask (which has previous
# completions unmasked) and mask only the new suffix as prompt.
suffix_len = len(prompt_tokens) - len(extending_node.tokens)
masked_tokens = (
extending_node.masked_tokens + [-100] * suffix_len + output_tokens
)
full_logprobs = (
extending_node.logprobs + [1.0] * suffix_len + output_logprobs
)
else:
# Fresh node — mask entire prompt
prompt_len = len(prompt_tokens)
masked_tokens = [-100] * prompt_len + output_tokens
full_logprobs = [1.0] * prompt_len + output_logprobs
return SequenceNode(
full_text=full_text,
@ -431,6 +484,7 @@ class ManagedServer:
output_logprobs=output_logprobs,
completion_text=completion_text,
finish_reason=finish_reason,
extending_node=extending_node,
)
# Store node based on mode