mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add env using the tool api stuff
This commit is contained in:
parent
c8eb63f33d
commit
12d61d197f
15 changed files with 2632 additions and 21 deletions
|
|
@ -143,6 +143,10 @@ class ManagedServer:
|
|||
return None
|
||||
return self._translator
|
||||
|
||||
# Placeholder used to protect <think> blocks from chat templates that strip them
|
||||
_THINK_OPEN = "__MNGD_THINK__"
|
||||
_THINK_CLOSE = "__MNGD_ENDTHINK__"
|
||||
|
||||
def _convert_messages_to_prompt(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
|
@ -175,6 +179,11 @@ class ManagedServer:
|
|||
len(messages) == 0 or messages[-1].get("role") != "assistant"
|
||||
)
|
||||
|
||||
# Protect <think> blocks in assistant messages — some chat templates
|
||||
# (e.g. Qwen3) strip them during re-rendering, which breaks prefix
|
||||
# matching for multi-turn sequence extension.
|
||||
messages = self._protect_think_blocks(messages)
|
||||
|
||||
# Build kwargs
|
||||
template_kwargs = {
|
||||
"tokenize": False,
|
||||
|
|
@ -184,11 +193,41 @@ class ManagedServer:
|
|||
template_kwargs["tools"] = tools
|
||||
|
||||
# Use the tokenizer's chat template
|
||||
return self.tokenizer.apply_chat_template(messages, **template_kwargs)
|
||||
prompt = self.tokenizer.apply_chat_template(messages, **template_kwargs)
|
||||
|
||||
# Restore <think> blocks
|
||||
prompt = prompt.replace(self._THINK_OPEN, "<think>")
|
||||
prompt = prompt.replace(self._THINK_CLOSE, "</think>")
|
||||
|
||||
return prompt
|
||||
else:
|
||||
# Fallback for tokenizers without chat template
|
||||
return "\n".join([f"{m['role']}: {m.get('content', '')}" for m in messages])
|
||||
|
||||
def _protect_think_blocks(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Replace <think>...</think> with placeholders in assistant messages.
|
||||
|
||||
Only touches assistant messages that already have content (i.e., messages
|
||||
being replayed from prior turns, not the generation prompt). This prevents
|
||||
chat templates from stripping or relocating think blocks.
|
||||
"""
|
||||
out = []
|
||||
for msg in messages:
|
||||
if (
|
||||
msg.get("role") == "assistant"
|
||||
and msg.get("content")
|
||||
and "<think>" in msg["content"]
|
||||
):
|
||||
content = msg["content"]
|
||||
content = content.replace("<think>", self._THINK_OPEN)
|
||||
content = content.replace("</think>", self._THINK_CLOSE)
|
||||
out.append({**msg, "content": content})
|
||||
else:
|
||||
out.append(msg)
|
||||
return out
|
||||
|
||||
def _debug_requests_enabled(self) -> bool:
|
||||
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
|
||||
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||
|
|
@ -265,18 +304,22 @@ class ManagedServer:
|
|||
output_logprobs: List[float],
|
||||
completion_text: str,
|
||||
finish_reason: str = "stop",
|
||||
extending_node: Optional[SequenceNode] = None,
|
||||
) -> SequenceNode:
|
||||
"""
|
||||
Create a sequence node with proper masking.
|
||||
|
||||
Args:
|
||||
input_text: The input prompt text
|
||||
parent_node: Parent node to extend from (if available)
|
||||
parent_node: Parent node to extend from (tree mode)
|
||||
prompt_tokens: Token IDs for the prompt
|
||||
output_tokens: Token IDs for the output/completion
|
||||
output_logprobs: Logprobs for output tokens
|
||||
completion_text: The completion text
|
||||
finish_reason: Finish reason from server
|
||||
extending_node: Node being extended (default mode). When provided,
|
||||
carries forward its masked_tokens and logprobs so previous
|
||||
completions stay unmasked across multi-turn extensions.
|
||||
|
||||
Returns:
|
||||
SequenceNode with properly masked tokens and logprobs
|
||||
|
|
@ -284,19 +327,6 @@ class ManagedServer:
|
|||
# Combine text
|
||||
full_text = input_text + completion_text
|
||||
|
||||
# If we have a parent node, we should use its tokens as the prompt base
|
||||
if parent_node is not None:
|
||||
# Use parent's full tokens as the prompt
|
||||
prompt_tokens = parent_node.tokens.copy()
|
||||
|
||||
# Combine tokens
|
||||
full_tokens = prompt_tokens + output_tokens
|
||||
prompt_len = len(prompt_tokens)
|
||||
|
||||
# Create masked tokens: -100 for prompt, actual IDs for completion
|
||||
masked_tokens = [-100] * prompt_len + output_tokens
|
||||
|
||||
# Create masked logprobs: 1.0 for prompt, actual for completion
|
||||
# Pad logprobs to match token length if needed
|
||||
if len(output_logprobs) < len(output_tokens):
|
||||
output_logprobs = output_logprobs + [1.0] * (
|
||||
|
|
@ -305,7 +335,30 @@ class ManagedServer:
|
|||
elif len(output_logprobs) > len(output_tokens):
|
||||
output_logprobs = output_logprobs[: len(output_tokens)]
|
||||
|
||||
full_logprobs = [1.0] * prompt_len + output_logprobs
|
||||
# If we have a parent node (tree mode), use its tokens as the prompt base
|
||||
if parent_node is not None:
|
||||
prompt_tokens = parent_node.tokens.copy()
|
||||
|
||||
# Combine tokens
|
||||
full_tokens = prompt_tokens + output_tokens
|
||||
|
||||
if extending_node is not None:
|
||||
# Carry forward the extending node's mask and logprobs.
|
||||
# The prompt_tokens = extending_node.tokens + new_suffix_tokens.
|
||||
# We preserve the extending node's mask (which has previous
|
||||
# completions unmasked) and mask only the new suffix as prompt.
|
||||
suffix_len = len(prompt_tokens) - len(extending_node.tokens)
|
||||
masked_tokens = (
|
||||
extending_node.masked_tokens + [-100] * suffix_len + output_tokens
|
||||
)
|
||||
full_logprobs = (
|
||||
extending_node.logprobs + [1.0] * suffix_len + output_logprobs
|
||||
)
|
||||
else:
|
||||
# Fresh node — mask entire prompt
|
||||
prompt_len = len(prompt_tokens)
|
||||
masked_tokens = [-100] * prompt_len + output_tokens
|
||||
full_logprobs = [1.0] * prompt_len + output_logprobs
|
||||
|
||||
return SequenceNode(
|
||||
full_text=full_text,
|
||||
|
|
@ -431,6 +484,7 @@ class ManagedServer:
|
|||
output_logprobs=output_logprobs,
|
||||
completion_text=completion_text,
|
||||
finish_reason=finish_reason,
|
||||
extending_node=extending_node,
|
||||
)
|
||||
|
||||
# Store node based on mode
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue