mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linting
This commit is contained in:
parent
826de9e283
commit
67cfd961c5
6 changed files with 111 additions and 85 deletions
|
|
@ -3,19 +3,21 @@ Trajectory utils
|
|||
|
||||
Utils for managing trajectory sizing, formatting, compression, etc.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def strip_thinking(response_text: str) -> str:
|
||||
"""Helper to strip the <think> block of a response entirely.
|
||||
|
||||
|
||||
Args:
|
||||
response_text: The response text to strip.
|
||||
|
||||
|
|
@ -24,20 +26,24 @@ def strip_thinking(response_text: str) -> str:
|
|||
"""
|
||||
think_start_tag = "<think>"
|
||||
think_end_tag = "</think>"
|
||||
|
||||
|
||||
think_start_idx = response_text.find(think_start_tag)
|
||||
think_end_idx = response_text.find(think_end_tag)
|
||||
|
||||
if think_start_idx != -1 and think_end_idx != -1:
|
||||
return response_text[:think_start_idx] + response_text[think_end_idx + len(think_end_tag):]
|
||||
return (
|
||||
response_text[:think_start_idx]
|
||||
+ response_text[think_end_idx + len(think_end_tag) :]
|
||||
)
|
||||
else:
|
||||
return response_text
|
||||
|
||||
|
||||
|
||||
def truncate_thinking(
|
||||
response_text: str, tokenizer: PreTrainedTokenizer, max_think_tokens: int
|
||||
) -> str:
|
||||
"""Helper to truncate the <think> block of a response for message history based on token count.
|
||||
|
||||
|
||||
Args:
|
||||
response_text: The response text to truncate.
|
||||
tokenizer: The tokenizer to use for counting tokens.
|
||||
|
|
@ -60,9 +66,7 @@ def truncate_thinking(
|
|||
):
|
||||
return response_text
|
||||
|
||||
part_before_content = response_text[
|
||||
: think_start_idx + len(think_start_tag)
|
||||
]
|
||||
part_before_content = response_text[: think_start_idx + len(think_start_tag)]
|
||||
original_think_content_raw = response_text[
|
||||
think_start_idx + len(think_start_tag) : think_end_idx
|
||||
]
|
||||
|
|
@ -77,7 +81,7 @@ def truncate_thinking(
|
|||
all_think_tokens = tokenizer.encode(
|
||||
original_think_content_stripped, add_special_tokens=False
|
||||
)
|
||||
|
||||
|
||||
is_truncated_internally = False
|
||||
final_think_tokens: List[int]
|
||||
|
||||
|
|
@ -85,13 +89,15 @@ def truncate_thinking(
|
|||
final_think_tokens = all_think_tokens
|
||||
is_truncated_internally = False
|
||||
else:
|
||||
is_truncated_internally = True # Mark as truncated if len(all_think_tokens) > max_think_tokens
|
||||
is_truncated_internally = (
|
||||
True # Mark as truncated if len(all_think_tokens) > max_think_tokens
|
||||
)
|
||||
paragraphs = [
|
||||
p.strip()
|
||||
for p in original_think_content_stripped.split("\n\n")
|
||||
if p.strip()
|
||||
]
|
||||
|
||||
|
||||
attempted_paragraph_truncation = False
|
||||
if paragraphs:
|
||||
last_paragraph_text = paragraphs[-1]
|
||||
|
|
@ -104,12 +110,14 @@ def truncate_thinking(
|
|||
if len(last_paragraph_tokens) <= max_think_tokens:
|
||||
final_think_tokens = last_paragraph_tokens
|
||||
attempted_paragraph_truncation = True
|
||||
|
||||
if not attempted_paragraph_truncation: # Default to truncating the whole content from the end
|
||||
|
||||
if (
|
||||
not attempted_paragraph_truncation
|
||||
): # Default to truncating the whole content from the end
|
||||
# Ensure max_think_tokens is not negative, though practically it shouldn't be.
|
||||
slice_start = max(0, len(all_think_tokens) - max_think_tokens)
|
||||
final_think_tokens = all_think_tokens[slice_start:]
|
||||
|
||||
|
||||
# Decode the tokens to string
|
||||
decoded_think_content = tokenizer.decode(
|
||||
final_think_tokens, skip_special_tokens=True
|
||||
|
|
@ -123,7 +131,10 @@ def truncate_thinking(
|
|||
# Determine the final block content (empty or with newlines)
|
||||
final_internal_content_str_stripped = final_internal_content_str.strip()
|
||||
final_content_for_block: str
|
||||
if not final_internal_content_str_stripped or final_internal_content_str_stripped == "...":
|
||||
if (
|
||||
not final_internal_content_str_stripped
|
||||
or final_internal_content_str_stripped == "..."
|
||||
):
|
||||
final_content_for_block = ""
|
||||
else:
|
||||
final_content_for_block = f"\n{final_internal_content_str_stripped}\n"
|
||||
|
|
@ -137,6 +148,7 @@ def truncate_thinking(
|
|||
)
|
||||
return response_text
|
||||
|
||||
|
||||
def ensure_trajectory_token_limit(
|
||||
trajectory: List[ScoredDataGroup],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
|
|
@ -325,7 +337,7 @@ def ensure_trajectory_token_limit(
|
|||
else:
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit ({max_trajectory_tokens}) "
|
||||
f"or retokenization error occurred ({retokenization_error_this_step})."
|
||||
)
|
||||
|
||||
|
|
@ -336,7 +348,3 @@ def ensure_trajectory_token_limit(
|
|||
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
|
||||
)
|
||||
return filtered_trajectory
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue