update local server

This commit is contained in:
Shannon Sands 2025-05-10 08:18:41 +10:00
parent c506bb147e
commit ba604d44f9
3 changed files with 251 additions and 320 deletions

View file

@ -968,193 +968,145 @@ class BlackjackEnv(BaseEnv):
filtered_trajectory: List[BlackjackScoredDataGroup] = []
for step_idx, original_step_data in enumerate(trajectory):
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} has "
f"{len(original_step_data['messages'])} alternatives."
)
if (
not original_step_data.get("messages")
or not original_step_data.get("tokens")
or not original_step_data.get("masks")
if not (
original_step_data.get("messages")
and original_step_data.get("tokens")
and original_step_data.get("masks")
and original_step_data.get("seed") is not None
and original_step_data.get("parsed_actions") is not None # Specific to MC version
):
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} "
f"is missing messages, tokens, or masks. Skipping."
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env "
f"is missing critical data. Skipping."
)
continue
current_step_messages_orig = [
msgs.copy() for msgs in original_step_data["messages"]
]
current_step_tokens_orig = [
tkns.copy() for tkns in original_step_data["tokens"]
]
current_step_masks_orig = [
msks.copy() for msks in original_step_data["masks"]
]
num_alternatives = len(current_step_messages_orig)
if num_alternatives == 0:
filtered_trajectory.append(original_step_data)
continue
max_initial_tokens = max(
len(alt_tokens) for alt_tokens in current_step_tokens_orig
)
# Initial token calculation from original data
max_initial_tokens = 0
if original_step_data["tokens"]:
max_initial_tokens = max(
len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list)
) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0
if max_initial_tokens <= self.config.max_trajectory_tokens:
filtered_trajectory.append(original_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant in MC env. "
f"Max tokens: {max_initial_tokens}"
)
continue
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting uniform truncation."
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env (max tokens: {max_initial_tokens}) "
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
)
working_messages = [msgs.copy() for msgs in current_step_messages_orig]
working_tokens = [tkns.copy() for tkns in current_step_tokens_orig]
working_masks = [msks.copy() for msks in current_step_masks_orig]
working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []]
working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []]
working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []]
max_current_tokens = max_initial_tokens
num_alternatives = len(working_messages)
step_successfully_truncated = False
while True:
num_messages_to_pop_per_alt = [0] * num_alternatives
can_truncate_globally = True
if num_alternatives == 0:
logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping.")
continue
retokenization_error_this_step = False
while max_current_tokens > self.config.max_trajectory_tokens:
target_pop_counts_per_alt = []
for alt_idx in range(num_alternatives):
alt_msg_list = working_messages[alt_idx]
num_preserved_at_end = 0
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES:
num_preserved_at_end = 1
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
num_preserved_at_end = 2
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
min_len_to_preserve = 1
if len(alt_msg_list) > 0 and alt_msg_list[-1]["role"] in ["agent"]:
min_len_to_preserve += 1
if (
len(alt_msg_list) > 1
and alt_msg_list[-2]["role"] == "environment"
):
min_len_to_preserve += 1
if len(alt_msg_list) <= min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
if (
len(alt_msg_list) > 2
and alt_msg_list[1]["role"] == "environment"
and alt_msg_list[2]["role"] == "agent"
):
if (len(alt_msg_list) - 2) < min_len_to_preserve:
if (len(alt_msg_list) - 1) < min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
else:
num_messages_to_pop_per_alt[alt_idx] = 1
else:
num_messages_to_pop_per_alt[alt_idx] = 2
elif len(alt_msg_list) > 1:
if (len(alt_msg_list) - 1) < min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
else:
num_messages_to_pop_per_alt[alt_idx] = 1
if available_to_pop <= 0:
target_pop_counts_per_alt.append(0)
else:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
if not can_truncate_globally:
break
min_pop_count = float("inf")
for count in num_messages_to_pop_per_alt:
if count > 0:
min_pop_count = min(min_pop_count, count)
if min_pop_count == float("inf") or min_pop_count == 0:
break
successfully_retokenized_all = True
new_alt_tokens_list = []
new_alt_masks_list = []
can_pop_pair = (
available_to_pop >= 2 and
len(alt_msg_list) > 2 and
alt_msg_list[1]["role"] == "environment" and
alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES
)
if can_pop_pair:
target_pop_counts_per_alt.append(2)
else:
target_pop_counts_per_alt.append(1)
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
if not positive_pop_counts:
break
min_pop_this_round = min(positive_pop_counts)
temp_new_alt_tokens = []
temp_new_alt_masks = []
max_tokens_after_this_trunc = 0
for alt_idx in range(num_alternatives):
for _ in range(min_pop_count):
for _ in range(min_pop_this_round):
if len(working_messages[alt_idx]) > 1:
working_messages[alt_idx].pop(1)
else:
logger.error(
f"[_ensure_trajectory_token_limit] Critical error during pop for "
f"alt {alt_idx}, step {step_idx}."
f"[_ensure_trajectory_token_limit] MC env: Critical error during pop for "
f"alt {alt_idx}, step {step_idx}. List too short."
)
successfully_retokenized_all = False
break
if not successfully_retokenized_all:
break
retokenization_error_this_step = True; break
if retokenization_error_this_step: break
try:
tokenized_alt = tokenize_for_trainer(
self.tokenizer, working_messages[alt_idx]
)
new_alt_tokens_list.append(tokenized_alt["tokens"])
new_alt_masks_list.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
)
tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx])
temp_new_alt_tokens.append(tokenized_alt["tokens"])
temp_new_alt_masks.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"]))
except Exception as e:
logger.error(
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
f"[_ensure_trajectory_token_limit] MC env: Error re-tokenizing alt {alt_idx} "
f"in step {step_idx} after truncation: {e}"
)
successfully_retokenized_all = False
break
retokenization_error_this_step = True; break
if retokenization_error_this_step: break
if not successfully_retokenized_all:
step_successfully_truncated = False
break
working_tokens = new_alt_tokens_list
working_masks = new_alt_masks_list
working_tokens = temp_new_alt_tokens
working_masks = temp_new_alt_masks
max_current_tokens = max_tokens_after_this_trunc
logger.debug(
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_count}, "
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, "
f"max tokens: {max_current_tokens}"
)
if max_current_tokens <= self.config.max_trajectory_tokens:
step_successfully_truncated = True
break
if step_successfully_truncated:
updated_step_data = BlackjackScoredDataGroup(
seed=original_step_data["seed"],
messages=working_messages,
tokens=working_tokens,
masks=working_masks,
scores=original_step_data["scores"],
parsed_actions=original_step_data["parsed_actions"],
)
if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens:
updated_step_data: BlackjackScoredDataGroup = {
"seed": original_step_data["seed"],
"messages": working_messages,
"tokens": working_tokens,
"masks": working_masks,
"scores": original_step_data.get("scores"),
"parsed_actions": original_step_data.get("parsed_actions") # MC version specific
}
filtered_trajectory.append(updated_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. "
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx} successfully processed. "
f"Final max tokens: {max_current_tokens}"
)
else:
if max_current_tokens > self.config.max_trajectory_tokens:
logger.warning(
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit "
f"({self.config.max_trajectory_tokens}) after maximum possible "
f"uniform truncation or re-tokenization error."
)
logger.warning(
f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
f"or retokenization error occurred ({retokenization_error_this_step})."
)
if len(filtered_trajectory) < len(trajectory):
logger.warning(
f"[_ensure_trajectory_token_limit] Filtered out {len(trajectory) - len(filtered_trajectory)} steps "
f"due to token limit constraints. Original trajectory length: {len(trajectory)}, "
f"Filtered: {len(filtered_trajectory)}"
f"[_ensure_trajectory_token_limit] MC env: Filtered out "
f"{len(trajectory) - len(filtered_trajectory)} steps "
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
)
return filtered_trajectory

View file

@ -5,7 +5,8 @@ import os
import random
from dotenv import load_dotenv
from environments.game_environments.gymnasium.blackjack_env import BlackjackEnv
from environments.game_environments.gymnasium.blackjack_env import BlackjackEnv, BlackjackEnvConfig
from atroposlib.envs.base import OpenaiConfig, EvalHandlingEnum
load_dotenv()
@ -13,51 +14,72 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_arguments():
parser = argparse.ArgumentParser(description="Blackjack environment local server")
parser.add_argument(
"--config",
type=str,
default="blackjack_local",
help="Configuration file name (without .yaml extension, relative to "
"envs/gymnasium/configs), or full path to a YAML file.",
)
return parser.parse_args()
# def parse_arguments(): # Removed
# parser = argparse.ArgumentParser(description="Blackjack environment local server")
# parser.add_argument(
# "--config",
# type=str,
# default="blackjack_local",
# help="Configuration file name (without .yaml extension, relative to "
# "envs/gymnasium/configs), or full path to a YAML file.",
# )
# return parser.parse_args()
async def main():
logger.info("Starting Blackjack environment server")
logger.info("Starting Blackjack environment local debug runner")
args = parse_arguments()
# args = parse_arguments() # Removed
# Determine the config name/path for config_init
# config_init expects the name relative to its own configs dir, or an absolute path
config_input = args.config
if not os.path.isabs(config_input) and not config_input.endswith(".yaml"):
# Assume it's a name relative to the blackjack env's config dir
config_name_or_path = config_input
logger.info(f"Using relative config name: {config_name_or_path}")
else:
# It's likely an absolute path or path relative to cwd
config_name_or_path = os.path.abspath(config_input)
logger.info(f"Using absolute config path: {config_name_or_path}")
# Removed logic for config_name_or_path and BlackjackEnv.config_init
# Use the environment's config_init method to load configurations
try:
config, server_configs = BlackjackEnv.config_init(config_name_or_path)
logger.info("Configuration loaded successfully via BlackjackEnv.config_init")
logger.debug(f"Loaded Env Config: {config}")
logger.debug(f"Loaded Server Configs: {server_configs}")
except Exception as e:
logger.exception(
f"Failed to load configuration using BlackjackEnv.config_init: {e}"
# Create hardcoded configurations for local debugging
env_config = BlackjackEnvConfig(
# BaseEnvConfig fields, tailored for debug
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=1, # Debug single generation path
use_wandb=False,
wandb_name="blackjack_local_debug", # Explicitly set for debug
max_num_workers=1,
rollout_server_url="http://localhost:8000", # Standard default
total_steps=1,
batch_size=1, # Consistent with 1 step, 1 worker, group_size 1
steps_per_eval=0, # No eval steps needed
max_token_length=1024 * 4, # Reduced for faster local debugging if necessary
inference_weight=1.0,
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.NONE, # No evaluation in this script
eval_limit_ratio=0.0,
# BlackjackEnvConfig specific fields (from blackjack_env.py's definition or defaults)
env_name="Blackjack-v1",
temperature=0.2, # Lower temperature for more deterministic debug output
top_p=0.9, # Standard default
max_turns=5, # Standard default
thinking_active=True,
eval_episodes=0, # No evaluation episodes
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=True, # Enable debug logging from the environment
mc_samples=1, # With group_size=1, this means 1 MC rollout for V(s)
)
server_configs = [
OpenaiConfig(
model_name="gpt-4.1-mini", # Ensure this is locally available if not mocked
base_url="https://api.openai.com/v1", # Explicitly set OpenAI base URL
api_key=os.getenv("OPENAI_API_KEY"), # Use env var or default
num_requests_for_eval=0, # No eval requests
)
return # Cannot proceed without config
]
logger.info("Using hardcoded debug configuration.")
logger.debug(f"Env Config: {env_config}")
logger.debug(f"Server Configs: {server_configs}")
# Create and set up the environment using the loaded configs
try:
env = BlackjackEnv(
config=config,
config=env_config,
server_configs=server_configs,
slurm=False, # Explicitly false for local testing
)

View file

@ -1513,199 +1513,156 @@ class BlackjackEnv(BaseEnv):
filtered_trajectory: List[BlackjackScoredDataGroup] = []
for step_idx, original_step_data in enumerate(trajectory):
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} has {len(original_step_data['messages'])} messages."
)
if (
not original_step_data.get("messages")
or not original_step_data.get("tokens")
or not original_step_data.get("masks")
if not (
original_step_data.get("messages")
and original_step_data.get("tokens")
and original_step_data.get("masks")
and original_step_data.get("seed") is not None # seed is mandatory for new group
):
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} is missing messages, tokens, or masks. Skipping."
f"[_ensure_trajectory_token_limit] Step {step_idx} "
f"is missing critical data (messages, tokens, masks, or seed). Skipping."
)
continue
current_step_messages_orig = [
msgs.copy() for msgs in original_step_data["messages"]
]
for alt_idx, alt_msgs in enumerate(current_step_messages_orig):
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx}, alt {alt_idx} has {len(alt_msgs)} messages."
)
for msg_idx, msg in enumerate(alt_msgs):
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx}, alt {alt_idx}, msg {msg['content']}"
)
current_step_tokens_orig = [
tkns.copy() for tkns in original_step_data["tokens"]
]
current_step_masks_orig = [
msks.copy() for msks in original_step_data["masks"]
]
num_alternatives = len(current_step_messages_orig)
if num_alternatives == 0:
filtered_trajectory.append(original_step_data)
continue
max_initial_tokens = max(
len(alt_tokens) for alt_tokens in current_step_tokens_orig
)
# Initial token calculation from original data to see if truncation is needed
# Ensure tokens are lists of integers before calling len
max_initial_tokens = 0
if original_step_data["tokens"]:
max_initial_tokens = max(
len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list)
) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0
if max_initial_tokens <= self.config.max_trajectory_tokens:
filtered_trajectory.append(original_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant. "
f"Max tokens: {max_initial_tokens}"
)
continue
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) exceeds limit "
f"({self.config.max_trajectory_tokens}). Attempting uniform truncation."
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
)
working_messages = [msgs.copy() for msgs in current_step_messages_orig]
working_tokens = [tkns.copy() for tkns in current_step_tokens_orig]
working_masks = [msks.copy() for msks in current_step_masks_orig]
# Prepare working copies for modification
# Ensure deep copies for lists of dicts if dicts are modified, but here we pop from list of dicts.
working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []]
working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []]
working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []]
max_current_tokens = max_initial_tokens
num_alternatives = len(working_messages)
step_successfully_truncated = False
while True:
num_messages_to_pop_per_alt = [0] * num_alternatives
can_truncate_globally = True
if num_alternatives == 0: # Should not happen if initial checks passed
logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping.")
continue
retokenization_error_this_step = False
while max_current_tokens > self.config.max_trajectory_tokens:
target_pop_counts_per_alt = []
for alt_idx in range(num_alternatives):
alt_msg_list = working_messages[alt_idx]
min_len_to_preserve = 1
if (
len(alt_msg_list) > 0
and alt_msg_list[-1]["role"] in UNMASKED_ROLES
):
min_len_to_preserve += 1
if (
len(alt_msg_list) > 1
and alt_msg_list[-2]["role"] == "environment"
):
min_len_to_preserve += 1
# Calculate how many initial messages (after system prompt) can be popped.
# Preserving: system prompt (index 0), last agent response, and its preceding env observation.
num_preserved_at_end = 0
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in UNMASKED_ROLES:
num_preserved_at_end = 1 # Last agent response
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
num_preserved_at_end = 2 # Agent response + preceding env observation
# Number of messages available for popping (between system prompt and preserved end messages)
# Subtract 1 for the system prompt itself (which is never popped from index 0).
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
if len(alt_msg_list) <= min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
if (
len(alt_msg_list) > 2
and alt_msg_list[1]["role"] == "environment"
and alt_msg_list[2]["role"] in UNMASKED_ROLES
):
if (len(alt_msg_list) - 2) < min_len_to_preserve:
if (len(alt_msg_list) - 1) < min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
else:
num_messages_to_pop_per_alt[alt_idx] = 1
else:
num_messages_to_pop_per_alt[alt_idx] = 2
elif len(alt_msg_list) > 1:
if (len(alt_msg_list) - 1) < min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
else:
num_messages_to_pop_per_alt[alt_idx] = 1
if available_to_pop <= 0:
target_pop_counts_per_alt.append(0)
else:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
# Try to pop a pair (environment, agent) if they are at list[1] and list[2]
can_pop_pair = (
available_to_pop >= 2 and
len(alt_msg_list) > 2 and # Ensure messages at index 1 and 2 exist
alt_msg_list[1]["role"] == "environment" and
alt_msg_list[2]["role"] in UNMASKED_ROLES
)
if can_pop_pair:
target_pop_counts_per_alt.append(2)
else: # Can pop at least 1 since available_to_pop > 0
target_pop_counts_per_alt.append(1)
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
if not positive_pop_counts:
break # No alternative can be truncated further
min_pop_this_round = min(positive_pop_counts)
if not can_truncate_globally:
break
min_pop_count = float("inf")
for count in num_messages_to_pop_per_alt:
if count > 0:
min_pop_count = min(min_pop_count, count)
if min_pop_count == float("inf") or min_pop_count == 0:
break
successfully_retokenized_all = True
new_alt_tokens_list = []
new_alt_masks_list = []
# Pop messages and re-tokenize
temp_new_alt_tokens = []
temp_new_alt_masks = []
max_tokens_after_this_trunc = 0
for alt_idx in range(num_alternatives):
for _ in range(min_pop_count):
if len(working_messages[alt_idx]) > 1:
for _ in range(min_pop_this_round):
if len(working_messages[alt_idx]) > 1: # Ensure there's something to pop after system
working_messages[alt_idx].pop(1)
else:
logger.error(
f"[_ensure_trajectory_token_limit] Critical error during "
f"pop for alt {alt_idx}, step {step_idx}."
f"[_ensure_trajectory_token_limit] Critical error during pop for "
f"alt {alt_idx}, step {step_idx}. List too short."
)
successfully_retokenized_all = False
break
if not successfully_retokenized_all:
break
retokenization_error_this_step = True; break
if retokenization_error_this_step: break
try:
tokenized_alt = tokenize_for_trainer(
self.tokenizer, working_messages[alt_idx]
)
new_alt_tokens_list.append(tokenized_alt["tokens"])
new_alt_masks_list.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
)
tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx])
temp_new_alt_tokens.append(tokenized_alt["tokens"])
temp_new_alt_masks.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"]))
except Exception as e:
logger.error(
f"[_ensure_trajectory_token_limit] Error re-tokenizing "
f"alt {alt_idx} in step {step_idx} after truncation: {e}"
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
f"in step {step_idx} after truncation: {e}"
)
successfully_retokenized_all = False
break
retokenization_error_this_step = True; break
if retokenization_error_this_step: break
if not successfully_retokenized_all:
step_successfully_truncated = False
break
working_tokens = new_alt_tokens_list
working_masks = new_alt_masks_list
working_tokens = temp_new_alt_tokens
working_masks = temp_new_alt_masks
max_current_tokens = max_tokens_after_this_trunc
logger.debug(
f"[_ensure_trajectory_token_limit] Step {step_idx}, after "
f"uniform pop of {min_pop_count}, max tokens: {max_current_tokens}"
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, "
f"max tokens: {max_current_tokens}"
)
# End of while loop for truncation attempts
if max_current_tokens <= self.config.max_trajectory_tokens:
step_successfully_truncated = True
break
if step_successfully_truncated:
updated_step_data = original_step_data.copy()
updated_step_data["messages"] = working_messages
updated_step_data["tokens"] = working_tokens
updated_step_data["masks"] = working_masks
if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens:
updated_step_data: BlackjackScoredDataGroup = {
"seed": original_step_data["seed"],
"messages": working_messages,
"tokens": working_tokens,
"masks": working_masks,
"scores": original_step_data.get("scores"),
"parsed_action": original_step_data.get("parsed_action")
}
filtered_trajectory.append(updated_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. "
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully processed. "
f"Final max tokens: {max_current_tokens}"
)
else:
if max_current_tokens > self.config.max_trajectory_tokens:
logger.warning(
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit "
f"({self.config.max_trajectory_tokens}) after maximum possible "
f"uniform truncation or re-tokenization error."
)
logger.warning(
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
f"or retokenization error occurred ({retokenization_error_this_step})."
)
if len(filtered_trajectory) < len(trajectory):
logger.warning(
f"[_ensure_trajectory_token_limit] Filtered out "
f"{len(trajectory) - len(filtered_trajectory)} steps "
f"due to token limit constraints. Original trajectory length: {len(trajectory)}, "
f"Filtered: {len(filtered_trajectory)}"
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
)
return filtered_trajectory