diff --git a/environments/game_environments/gymnasium/blackjack_env.py b/environments/game_environments/gymnasium/blackjack_env.py index efe857d1..8db0dc21 100644 --- a/environments/game_environments/gymnasium/blackjack_env.py +++ b/environments/game_environments/gymnasium/blackjack_env.py @@ -968,193 +968,145 @@ class BlackjackEnv(BaseEnv): filtered_trajectory: List[BlackjackScoredDataGroup] = [] for step_idx, original_step_data in enumerate(trajectory): - logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx} has " - f"{len(original_step_data['messages'])} alternatives." - ) - if ( - not original_step_data.get("messages") - or not original_step_data.get("tokens") - or not original_step_data.get("masks") + if not ( + original_step_data.get("messages") + and original_step_data.get("tokens") + and original_step_data.get("masks") + and original_step_data.get("seed") is not None + and original_step_data.get("parsed_actions") is not None # Specific to MC version ): logger.warning( - f"[_ensure_trajectory_token_limit] Step {step_idx} " - f"is missing messages, tokens, or masks. Skipping." + f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env " + f"is missing critical data. Skipping." ) continue - current_step_messages_orig = [ - msgs.copy() for msgs in original_step_data["messages"] - ] - current_step_tokens_orig = [ - tkns.copy() for tkns in original_step_data["tokens"] - ] - current_step_masks_orig = [ - msks.copy() for msks in original_step_data["masks"] - ] - - num_alternatives = len(current_step_messages_orig) - if num_alternatives == 0: - filtered_trajectory.append(original_step_data) - continue - - max_initial_tokens = max( - len(alt_tokens) for alt_tokens in current_step_tokens_orig - ) + # Initial token calculation from original data + max_initial_tokens = 0 + if original_step_data["tokens"]: + max_initial_tokens = max( + len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list) + ) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0 if max_initial_tokens <= self.config.max_trajectory_tokens: filtered_trajectory.append(original_step_data) + logger.info( + f"[_ensure_trajectory_token_limit] Step {step_idx} compliant in MC env. " + f"Max tokens: {max_initial_tokens}" + ) continue logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) " - f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting uniform truncation." + f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env (max tokens: {max_initial_tokens}) " + f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation." ) - working_messages = [msgs.copy() for msgs in current_step_messages_orig] - working_tokens = [tkns.copy() for tkns in current_step_tokens_orig] - working_masks = [msks.copy() for msks in current_step_masks_orig] + working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []] + working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []] + working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []] max_current_tokens = max_initial_tokens + num_alternatives = len(working_messages) - step_successfully_truncated = False - while True: - num_messages_to_pop_per_alt = [0] * num_alternatives - can_truncate_globally = True + if num_alternatives == 0: + logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping.") + continue + retokenization_error_this_step = False + while max_current_tokens > self.config.max_trajectory_tokens: + target_pop_counts_per_alt = [] for alt_idx in range(num_alternatives): alt_msg_list = working_messages[alt_idx] + num_preserved_at_end = 0 + if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES: + num_preserved_at_end = 1 + if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment": + num_preserved_at_end = 2 + + available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end - min_len_to_preserve = 1 - if len(alt_msg_list) > 0 and alt_msg_list[-1]["role"] in ["agent"]: - min_len_to_preserve += 1 - if ( - len(alt_msg_list) > 1 - and alt_msg_list[-2]["role"] == "environment" - ): - min_len_to_preserve += 1 - - if len(alt_msg_list) <= min_len_to_preserve: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - - if ( - len(alt_msg_list) > 2 - and alt_msg_list[1]["role"] == "environment" - and alt_msg_list[2]["role"] == "agent" - ): - if (len(alt_msg_list) - 2) < min_len_to_preserve: - if (len(alt_msg_list) - 1) < min_len_to_preserve: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - else: - num_messages_to_pop_per_alt[alt_idx] = 1 - else: - num_messages_to_pop_per_alt[alt_idx] = 2 - elif len(alt_msg_list) > 1: - if (len(alt_msg_list) - 1) < min_len_to_preserve: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - else: - num_messages_to_pop_per_alt[alt_idx] = 1 + if available_to_pop <= 0: + target_pop_counts_per_alt.append(0) else: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - - if not can_truncate_globally: - break - - min_pop_count = float("inf") - for count in num_messages_to_pop_per_alt: - if count > 0: - min_pop_count = min(min_pop_count, count) - - if min_pop_count == float("inf") or min_pop_count == 0: - break - - successfully_retokenized_all = True - new_alt_tokens_list = [] - new_alt_masks_list = [] + can_pop_pair = ( + available_to_pop >= 2 and + len(alt_msg_list) > 2 and + alt_msg_list[1]["role"] == "environment" and + alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES + ) + if can_pop_pair: + target_pop_counts_per_alt.append(2) + else: + target_pop_counts_per_alt.append(1) + + positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0] + if not positive_pop_counts: + break + + min_pop_this_round = min(positive_pop_counts) + temp_new_alt_tokens = [] + temp_new_alt_masks = [] max_tokens_after_this_trunc = 0 for alt_idx in range(num_alternatives): - for _ in range(min_pop_count): + for _ in range(min_pop_this_round): if len(working_messages[alt_idx]) > 1: working_messages[alt_idx].pop(1) else: logger.error( - f"[_ensure_trajectory_token_limit] Critical error during pop for " - f"alt {alt_idx}, step {step_idx}." + f"[_ensure_trajectory_token_limit] MC env: Critical error during pop for " + f"alt {alt_idx}, step {step_idx}. List too short." ) - successfully_retokenized_all = False - break - if not successfully_retokenized_all: - break - + retokenization_error_this_step = True; break + if retokenization_error_this_step: break + try: - tokenized_alt = tokenize_for_trainer( - self.tokenizer, working_messages[alt_idx] - ) - new_alt_tokens_list.append(tokenized_alt["tokens"]) - new_alt_masks_list.append(tokenized_alt["masks"]) - max_tokens_after_this_trunc = max( - max_tokens_after_this_trunc, len(tokenized_alt["tokens"]) - ) + tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx]) + temp_new_alt_tokens.append(tokenized_alt["tokens"]) + temp_new_alt_masks.append(tokenized_alt["masks"]) + max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"])) except Exception as e: logger.error( - f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} " + f"[_ensure_trajectory_token_limit] MC env: Error re-tokenizing alt {alt_idx} " f"in step {step_idx} after truncation: {e}" ) - successfully_retokenized_all = False - break + retokenization_error_this_step = True; break + + if retokenization_error_this_step: break - if not successfully_retokenized_all: - step_successfully_truncated = False - break - - working_tokens = new_alt_tokens_list - working_masks = new_alt_masks_list + working_tokens = temp_new_alt_tokens + working_masks = temp_new_alt_masks max_current_tokens = max_tokens_after_this_trunc logger.debug( - f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_count}, " + f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, " f"max tokens: {max_current_tokens}" ) - if max_current_tokens <= self.config.max_trajectory_tokens: - step_successfully_truncated = True - break - - if step_successfully_truncated: - updated_step_data = BlackjackScoredDataGroup( - seed=original_step_data["seed"], - messages=working_messages, - tokens=working_tokens, - masks=working_masks, - scores=original_step_data["scores"], - parsed_actions=original_step_data["parsed_actions"], - ) + if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens: + updated_step_data: BlackjackScoredDataGroup = { + "seed": original_step_data["seed"], + "messages": working_messages, + "tokens": working_tokens, + "masks": working_masks, + "scores": original_step_data.get("scores"), + "parsed_actions": original_step_data.get("parsed_actions") # MC version specific + } filtered_trajectory.append(updated_step_data) logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. " + f"[_ensure_trajectory_token_limit] MC env: Step {step_idx} successfully processed. " f"Final max tokens: {max_current_tokens}" ) else: - if max_current_tokens > self.config.max_trajectory_tokens: - logger.warning( - f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. " - f"Max tokens ({max_current_tokens}) still exceed limit " - f"({self.config.max_trajectory_tokens}) after maximum possible " - f"uniform truncation or re-tokenization error." - ) + logger.warning( + f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. " + f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) " + f"or retokenization error occurred ({retokenization_error_this_step})." + ) if len(filtered_trajectory) < len(trajectory): logger.warning( - f"[_ensure_trajectory_token_limit] Filtered out {len(trajectory) - len(filtered_trajectory)} steps " - f"due to token limit constraints. Original trajectory length: {len(trajectory)}, " - f"Filtered: {len(filtered_trajectory)}" + f"[_ensure_trajectory_token_limit] MC env: Filtered out " + f"{len(trajectory) - len(filtered_trajectory)} steps " + f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}" ) return filtered_trajectory diff --git a/environments/game_environments/gymnasium/blackjack_local_server.py b/environments/game_environments/gymnasium/blackjack_local_server.py index 432b1184..7a646c47 100644 --- a/environments/game_environments/gymnasium/blackjack_local_server.py +++ b/environments/game_environments/gymnasium/blackjack_local_server.py @@ -5,7 +5,8 @@ import os import random from dotenv import load_dotenv -from environments.game_environments.gymnasium.blackjack_env import BlackjackEnv +from environments.game_environments.gymnasium.blackjack_env import BlackjackEnv, BlackjackEnvConfig +from atroposlib.envs.base import OpenaiConfig, EvalHandlingEnum load_dotenv() @@ -13,51 +14,72 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def parse_arguments(): - parser = argparse.ArgumentParser(description="Blackjack environment local server") - parser.add_argument( - "--config", - type=str, - default="blackjack_local", - help="Configuration file name (without .yaml extension, relative to " - "envs/gymnasium/configs), or full path to a YAML file.", - ) - return parser.parse_args() +# def parse_arguments(): # Removed +# parser = argparse.ArgumentParser(description="Blackjack environment local server") +# parser.add_argument( +# "--config", +# type=str, +# default="blackjack_local", +# help="Configuration file name (without .yaml extension, relative to " +# "envs/gymnasium/configs), or full path to a YAML file.", +# ) +# return parser.parse_args() async def main(): - logger.info("Starting Blackjack environment server") + logger.info("Starting Blackjack environment local debug runner") - args = parse_arguments() + # args = parse_arguments() # Removed - # Determine the config name/path for config_init - # config_init expects the name relative to its own configs dir, or an absolute path - config_input = args.config - if not os.path.isabs(config_input) and not config_input.endswith(".yaml"): - # Assume it's a name relative to the blackjack env's config dir - config_name_or_path = config_input - logger.info(f"Using relative config name: {config_name_or_path}") - else: - # It's likely an absolute path or path relative to cwd - config_name_or_path = os.path.abspath(config_input) - logger.info(f"Using absolute config path: {config_name_or_path}") + # Removed logic for config_name_or_path and BlackjackEnv.config_init - # Use the environment's config_init method to load configurations - try: - config, server_configs = BlackjackEnv.config_init(config_name_or_path) - logger.info("Configuration loaded successfully via BlackjackEnv.config_init") - logger.debug(f"Loaded Env Config: {config}") - logger.debug(f"Loaded Server Configs: {server_configs}") - except Exception as e: - logger.exception( - f"Failed to load configuration using BlackjackEnv.config_init: {e}" + # Create hardcoded configurations for local debugging + env_config = BlackjackEnvConfig( + # BaseEnvConfig fields, tailored for debug + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=1, # Debug single generation path + use_wandb=False, + wandb_name="blackjack_local_debug", # Explicitly set for debug + max_num_workers=1, + rollout_server_url="http://localhost:8000", # Standard default + total_steps=1, + batch_size=1, # Consistent with 1 step, 1 worker, group_size 1 + steps_per_eval=0, # No eval steps needed + max_token_length=1024 * 4, # Reduced for faster local debugging if necessary + inference_weight=1.0, + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.NONE, # No evaluation in this script + eval_limit_ratio=0.0, + + # BlackjackEnvConfig specific fields (from blackjack_env.py's definition or defaults) + env_name="Blackjack-v1", + temperature=0.2, # Lower temperature for more deterministic debug output + top_p=0.9, # Standard default + max_turns=5, # Standard default + thinking_active=True, + eval_episodes=0, # No evaluation episodes + max_think_chars_history=3000, + max_trajectory_tokens=24576, + debug_mode=True, # Enable debug logging from the environment + mc_samples=1, # With group_size=1, this means 1 MC rollout for V(s) + ) + + server_configs = [ + OpenaiConfig( + model_name="gpt-4.1-mini", # Ensure this is locally available if not mocked + base_url="https://api.openai.com/v1", # Explicitly set OpenAI base URL + api_key=os.getenv("OPENAI_API_KEY"), # Use env var or default + num_requests_for_eval=0, # No eval requests ) - return # Cannot proceed without config + ] + logger.info("Using hardcoded debug configuration.") + logger.debug(f"Env Config: {env_config}") + logger.debug(f"Server Configs: {server_configs}") # Create and set up the environment using the loaded configs try: env = BlackjackEnv( - config=config, + config=env_config, server_configs=server_configs, slurm=False, # Explicitly false for local testing ) diff --git a/environments/game_environments/gymnasium/blackjack_no_mc_env.py b/environments/game_environments/gymnasium/blackjack_no_mc_env.py index e6dca5c1..90a2c3fb 100644 --- a/environments/game_environments/gymnasium/blackjack_no_mc_env.py +++ b/environments/game_environments/gymnasium/blackjack_no_mc_env.py @@ -1513,199 +1513,156 @@ class BlackjackEnv(BaseEnv): filtered_trajectory: List[BlackjackScoredDataGroup] = [] for step_idx, original_step_data in enumerate(trajectory): - logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx} has {len(original_step_data['messages'])} messages." - ) - if ( - not original_step_data.get("messages") - or not original_step_data.get("tokens") - or not original_step_data.get("masks") + if not ( + original_step_data.get("messages") + and original_step_data.get("tokens") + and original_step_data.get("masks") + and original_step_data.get("seed") is not None # seed is mandatory for new group ): logger.warning( - f"[_ensure_trajectory_token_limit] Step {step_idx} is missing messages, tokens, or masks. Skipping." + f"[_ensure_trajectory_token_limit] Step {step_idx} " + f"is missing critical data (messages, tokens, masks, or seed). Skipping." ) continue - current_step_messages_orig = [ - msgs.copy() for msgs in original_step_data["messages"] - ] - for alt_idx, alt_msgs in enumerate(current_step_messages_orig): - logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx}, alt {alt_idx} has {len(alt_msgs)} messages." - ) - for msg_idx, msg in enumerate(alt_msgs): - logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx}, alt {alt_idx}, msg {msg['content']}" - ) - current_step_tokens_orig = [ - tkns.copy() for tkns in original_step_data["tokens"] - ] - current_step_masks_orig = [ - msks.copy() for msks in original_step_data["masks"] - ] - - num_alternatives = len(current_step_messages_orig) - if num_alternatives == 0: - filtered_trajectory.append(original_step_data) - continue - - max_initial_tokens = max( - len(alt_tokens) for alt_tokens in current_step_tokens_orig - ) + # Initial token calculation from original data to see if truncation is needed + # Ensure tokens are lists of integers before calling len + max_initial_tokens = 0 + if original_step_data["tokens"]: + max_initial_tokens = max( + len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list) + ) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0 if max_initial_tokens <= self.config.max_trajectory_tokens: filtered_trajectory.append(original_step_data) + logger.info( + f"[_ensure_trajectory_token_limit] Step {step_idx} compliant. " + f"Max tokens: {max_initial_tokens}" + ) continue logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) exceeds limit " - f"({self.config.max_trajectory_tokens}). Attempting uniform truncation." + f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) " + f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation." ) - working_messages = [msgs.copy() for msgs in current_step_messages_orig] - working_tokens = [tkns.copy() for tkns in current_step_tokens_orig] - working_masks = [msks.copy() for msks in current_step_masks_orig] + # Prepare working copies for modification + # Ensure deep copies for lists of dicts if dicts are modified, but here we pop from list of dicts. + working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []] + working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []] + working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []] max_current_tokens = max_initial_tokens + num_alternatives = len(working_messages) - step_successfully_truncated = False - while True: - num_messages_to_pop_per_alt = [0] * num_alternatives - can_truncate_globally = True + if num_alternatives == 0: # Should not happen if initial checks passed + logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping.") + continue + retokenization_error_this_step = False + while max_current_tokens > self.config.max_trajectory_tokens: + target_pop_counts_per_alt = [] for alt_idx in range(num_alternatives): alt_msg_list = working_messages[alt_idx] - min_len_to_preserve = 1 - if ( - len(alt_msg_list) > 0 - and alt_msg_list[-1]["role"] in UNMASKED_ROLES - ): - min_len_to_preserve += 1 - if ( - len(alt_msg_list) > 1 - and alt_msg_list[-2]["role"] == "environment" - ): - min_len_to_preserve += 1 + # Calculate how many initial messages (after system prompt) can be popped. + # Preserving: system prompt (index 0), last agent response, and its preceding env observation. + num_preserved_at_end = 0 + if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in UNMASKED_ROLES: + num_preserved_at_end = 1 # Last agent response + if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment": + num_preserved_at_end = 2 # Agent response + preceding env observation + + # Number of messages available for popping (between system prompt and preserved end messages) + # Subtract 1 for the system prompt itself (which is never popped from index 0). + available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end - if len(alt_msg_list) <= min_len_to_preserve: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - - if ( - len(alt_msg_list) > 2 - and alt_msg_list[1]["role"] == "environment" - and alt_msg_list[2]["role"] in UNMASKED_ROLES - ): - if (len(alt_msg_list) - 2) < min_len_to_preserve: - if (len(alt_msg_list) - 1) < min_len_to_preserve: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - else: - num_messages_to_pop_per_alt[alt_idx] = 1 - else: - num_messages_to_pop_per_alt[alt_idx] = 2 - elif len(alt_msg_list) > 1: - if (len(alt_msg_list) - 1) < min_len_to_preserve: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break - else: - num_messages_to_pop_per_alt[alt_idx] = 1 + if available_to_pop <= 0: + target_pop_counts_per_alt.append(0) else: - num_messages_to_pop_per_alt[alt_idx] = 0 - can_truncate_globally = False - break + # Try to pop a pair (environment, agent) if they are at list[1] and list[2] + can_pop_pair = ( + available_to_pop >= 2 and + len(alt_msg_list) > 2 and # Ensure messages at index 1 and 2 exist + alt_msg_list[1]["role"] == "environment" and + alt_msg_list[2]["role"] in UNMASKED_ROLES + ) + if can_pop_pair: + target_pop_counts_per_alt.append(2) + else: # Can pop at least 1 since available_to_pop > 0 + target_pop_counts_per_alt.append(1) + + positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0] + if not positive_pop_counts: + break # No alternative can be truncated further + + min_pop_this_round = min(positive_pop_counts) - if not can_truncate_globally: - break - - min_pop_count = float("inf") - for count in num_messages_to_pop_per_alt: - if count > 0: - min_pop_count = min(min_pop_count, count) - - if min_pop_count == float("inf") or min_pop_count == 0: - break - - successfully_retokenized_all = True - new_alt_tokens_list = [] - new_alt_masks_list = [] + # Pop messages and re-tokenize + temp_new_alt_tokens = [] + temp_new_alt_masks = [] max_tokens_after_this_trunc = 0 for alt_idx in range(num_alternatives): - for _ in range(min_pop_count): - if len(working_messages[alt_idx]) > 1: + for _ in range(min_pop_this_round): + if len(working_messages[alt_idx]) > 1: # Ensure there's something to pop after system working_messages[alt_idx].pop(1) else: logger.error( - f"[_ensure_trajectory_token_limit] Critical error during " - f"pop for alt {alt_idx}, step {step_idx}." + f"[_ensure_trajectory_token_limit] Critical error during pop for " + f"alt {alt_idx}, step {step_idx}. List too short." ) - successfully_retokenized_all = False - break - if not successfully_retokenized_all: - break - + retokenization_error_this_step = True; break + if retokenization_error_this_step: break + try: - tokenized_alt = tokenize_for_trainer( - self.tokenizer, working_messages[alt_idx] - ) - new_alt_tokens_list.append(tokenized_alt["tokens"]) - new_alt_masks_list.append(tokenized_alt["masks"]) - max_tokens_after_this_trunc = max( - max_tokens_after_this_trunc, len(tokenized_alt["tokens"]) - ) + tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx]) + temp_new_alt_tokens.append(tokenized_alt["tokens"]) + temp_new_alt_masks.append(tokenized_alt["masks"]) + max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"])) except Exception as e: logger.error( - f"[_ensure_trajectory_token_limit] Error re-tokenizing " - f"alt {alt_idx} in step {step_idx} after truncation: {e}" + f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} " + f"in step {step_idx} after truncation: {e}" ) - successfully_retokenized_all = False - break + retokenization_error_this_step = True; break + + if retokenization_error_this_step: break - if not successfully_retokenized_all: - step_successfully_truncated = False - break - - working_tokens = new_alt_tokens_list - working_masks = new_alt_masks_list + working_tokens = temp_new_alt_tokens + working_masks = temp_new_alt_masks max_current_tokens = max_tokens_after_this_trunc logger.debug( - f"[_ensure_trajectory_token_limit] Step {step_idx}, after " - f"uniform pop of {min_pop_count}, max tokens: {max_current_tokens}" + f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, " + f"max tokens: {max_current_tokens}" ) + # End of while loop for truncation attempts - if max_current_tokens <= self.config.max_trajectory_tokens: - step_successfully_truncated = True - break - - if step_successfully_truncated: - updated_step_data = original_step_data.copy() - updated_step_data["messages"] = working_messages - updated_step_data["tokens"] = working_tokens - updated_step_data["masks"] = working_masks + if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens: + updated_step_data: BlackjackScoredDataGroup = { + "seed": original_step_data["seed"], + "messages": working_messages, + "tokens": working_tokens, + "masks": working_masks, + "scores": original_step_data.get("scores"), + "parsed_action": original_step_data.get("parsed_action") + } filtered_trajectory.append(updated_step_data) logger.info( - f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. " + f"[_ensure_trajectory_token_limit] Step {step_idx} successfully processed. " f"Final max tokens: {max_current_tokens}" ) else: - if max_current_tokens > self.config.max_trajectory_tokens: - logger.warning( - f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. " - f"Max tokens ({max_current_tokens}) still exceed limit " - f"({self.config.max_trajectory_tokens}) after maximum possible " - f"uniform truncation or re-tokenization error." - ) + logger.warning( + f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. " + f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) " + f"or retokenization error occurred ({retokenization_error_this_step})." + ) if len(filtered_trajectory) < len(trajectory): logger.warning( f"[_ensure_trajectory_token_limit] Filtered out " f"{len(trajectory) - len(filtered_trajectory)} steps " - f"due to token limit constraints. Original trajectory length: {len(trajectory)}, " - f"Filtered: {len(filtered_trajectory)}" + f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}" ) return filtered_trajectory