mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
update local server
This commit is contained in:
parent
c506bb147e
commit
ba604d44f9
3 changed files with 251 additions and 320 deletions
|
|
@ -968,193 +968,145 @@ class BlackjackEnv(BaseEnv):
|
|||
filtered_trajectory: List[BlackjackScoredDataGroup] = []
|
||||
|
||||
for step_idx, original_step_data in enumerate(trajectory):
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} has "
|
||||
f"{len(original_step_data['messages'])} alternatives."
|
||||
)
|
||||
if (
|
||||
not original_step_data.get("messages")
|
||||
or not original_step_data.get("tokens")
|
||||
or not original_step_data.get("masks")
|
||||
if not (
|
||||
original_step_data.get("messages")
|
||||
and original_step_data.get("tokens")
|
||||
and original_step_data.get("masks")
|
||||
and original_step_data.get("seed") is not None
|
||||
and original_step_data.get("parsed_actions") is not None # Specific to MC version
|
||||
):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} "
|
||||
f"is missing messages, tokens, or masks. Skipping."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env "
|
||||
f"is missing critical data. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
current_step_messages_orig = [
|
||||
msgs.copy() for msgs in original_step_data["messages"]
|
||||
]
|
||||
current_step_tokens_orig = [
|
||||
tkns.copy() for tkns in original_step_data["tokens"]
|
||||
]
|
||||
current_step_masks_orig = [
|
||||
msks.copy() for msks in original_step_data["masks"]
|
||||
]
|
||||
|
||||
num_alternatives = len(current_step_messages_orig)
|
||||
if num_alternatives == 0:
|
||||
filtered_trajectory.append(original_step_data)
|
||||
continue
|
||||
|
||||
max_initial_tokens = max(
|
||||
len(alt_tokens) for alt_tokens in current_step_tokens_orig
|
||||
)
|
||||
# Initial token calculation from original data
|
||||
max_initial_tokens = 0
|
||||
if original_step_data["tokens"]:
|
||||
max_initial_tokens = max(
|
||||
len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list)
|
||||
) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0
|
||||
|
||||
if max_initial_tokens <= self.config.max_trajectory_tokens:
|
||||
filtered_trajectory.append(original_step_data)
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant in MC env. "
|
||||
f"Max tokens: {max_initial_tokens}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
|
||||
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting uniform truncation."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env (max tokens: {max_initial_tokens}) "
|
||||
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
|
||||
)
|
||||
|
||||
working_messages = [msgs.copy() for msgs in current_step_messages_orig]
|
||||
working_tokens = [tkns.copy() for tkns in current_step_tokens_orig]
|
||||
working_masks = [msks.copy() for msks in current_step_masks_orig]
|
||||
working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []]
|
||||
working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []]
|
||||
working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []]
|
||||
max_current_tokens = max_initial_tokens
|
||||
num_alternatives = len(working_messages)
|
||||
|
||||
step_successfully_truncated = False
|
||||
while True:
|
||||
num_messages_to_pop_per_alt = [0] * num_alternatives
|
||||
can_truncate_globally = True
|
||||
if num_alternatives == 0:
|
||||
logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping.")
|
||||
continue
|
||||
|
||||
retokenization_error_this_step = False
|
||||
while max_current_tokens > self.config.max_trajectory_tokens:
|
||||
target_pop_counts_per_alt = []
|
||||
for alt_idx in range(num_alternatives):
|
||||
alt_msg_list = working_messages[alt_idx]
|
||||
num_preserved_at_end = 0
|
||||
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES:
|
||||
num_preserved_at_end = 1
|
||||
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
|
||||
num_preserved_at_end = 2
|
||||
|
||||
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
|
||||
|
||||
min_len_to_preserve = 1
|
||||
if len(alt_msg_list) > 0 and alt_msg_list[-1]["role"] in ["agent"]:
|
||||
min_len_to_preserve += 1
|
||||
if (
|
||||
len(alt_msg_list) > 1
|
||||
and alt_msg_list[-2]["role"] == "environment"
|
||||
):
|
||||
min_len_to_preserve += 1
|
||||
|
||||
if len(alt_msg_list) <= min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
|
||||
if (
|
||||
len(alt_msg_list) > 2
|
||||
and alt_msg_list[1]["role"] == "environment"
|
||||
and alt_msg_list[2]["role"] == "agent"
|
||||
):
|
||||
if (len(alt_msg_list) - 2) < min_len_to_preserve:
|
||||
if (len(alt_msg_list) - 1) < min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 1
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 2
|
||||
elif len(alt_msg_list) > 1:
|
||||
if (len(alt_msg_list) - 1) < min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 1
|
||||
if available_to_pop <= 0:
|
||||
target_pop_counts_per_alt.append(0)
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
|
||||
if not can_truncate_globally:
|
||||
break
|
||||
|
||||
min_pop_count = float("inf")
|
||||
for count in num_messages_to_pop_per_alt:
|
||||
if count > 0:
|
||||
min_pop_count = min(min_pop_count, count)
|
||||
|
||||
if min_pop_count == float("inf") or min_pop_count == 0:
|
||||
break
|
||||
|
||||
successfully_retokenized_all = True
|
||||
new_alt_tokens_list = []
|
||||
new_alt_masks_list = []
|
||||
can_pop_pair = (
|
||||
available_to_pop >= 2 and
|
||||
len(alt_msg_list) > 2 and
|
||||
alt_msg_list[1]["role"] == "environment" and
|
||||
alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES
|
||||
)
|
||||
if can_pop_pair:
|
||||
target_pop_counts_per_alt.append(2)
|
||||
else:
|
||||
target_pop_counts_per_alt.append(1)
|
||||
|
||||
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
|
||||
if not positive_pop_counts:
|
||||
break
|
||||
|
||||
min_pop_this_round = min(positive_pop_counts)
|
||||
temp_new_alt_tokens = []
|
||||
temp_new_alt_masks = []
|
||||
max_tokens_after_this_trunc = 0
|
||||
|
||||
for alt_idx in range(num_alternatives):
|
||||
for _ in range(min_pop_count):
|
||||
for _ in range(min_pop_this_round):
|
||||
if len(working_messages[alt_idx]) > 1:
|
||||
working_messages[alt_idx].pop(1)
|
||||
else:
|
||||
logger.error(
|
||||
f"[_ensure_trajectory_token_limit] Critical error during pop for "
|
||||
f"alt {alt_idx}, step {step_idx}."
|
||||
f"[_ensure_trajectory_token_limit] MC env: Critical error during pop for "
|
||||
f"alt {alt_idx}, step {step_idx}. List too short."
|
||||
)
|
||||
successfully_retokenized_all = False
|
||||
break
|
||||
if not successfully_retokenized_all:
|
||||
break
|
||||
|
||||
retokenization_error_this_step = True; break
|
||||
if retokenization_error_this_step: break
|
||||
|
||||
try:
|
||||
tokenized_alt = tokenize_for_trainer(
|
||||
self.tokenizer, working_messages[alt_idx]
|
||||
)
|
||||
new_alt_tokens_list.append(tokenized_alt["tokens"])
|
||||
new_alt_masks_list.append(tokenized_alt["masks"])
|
||||
max_tokens_after_this_trunc = max(
|
||||
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
|
||||
)
|
||||
tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx])
|
||||
temp_new_alt_tokens.append(tokenized_alt["tokens"])
|
||||
temp_new_alt_masks.append(tokenized_alt["masks"])
|
||||
max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"]))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Error re-tokenizing alt {alt_idx} "
|
||||
f"in step {step_idx} after truncation: {e}"
|
||||
)
|
||||
successfully_retokenized_all = False
|
||||
break
|
||||
retokenization_error_this_step = True; break
|
||||
|
||||
if retokenization_error_this_step: break
|
||||
|
||||
if not successfully_retokenized_all:
|
||||
step_successfully_truncated = False
|
||||
break
|
||||
|
||||
working_tokens = new_alt_tokens_list
|
||||
working_masks = new_alt_masks_list
|
||||
working_tokens = temp_new_alt_tokens
|
||||
working_masks = temp_new_alt_masks
|
||||
max_current_tokens = max_tokens_after_this_trunc
|
||||
logger.debug(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_count}, "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, "
|
||||
f"max tokens: {max_current_tokens}"
|
||||
)
|
||||
|
||||
if max_current_tokens <= self.config.max_trajectory_tokens:
|
||||
step_successfully_truncated = True
|
||||
break
|
||||
|
||||
if step_successfully_truncated:
|
||||
updated_step_data = BlackjackScoredDataGroup(
|
||||
seed=original_step_data["seed"],
|
||||
messages=working_messages,
|
||||
tokens=working_tokens,
|
||||
masks=working_masks,
|
||||
scores=original_step_data["scores"],
|
||||
parsed_actions=original_step_data["parsed_actions"],
|
||||
)
|
||||
if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens:
|
||||
updated_step_data: BlackjackScoredDataGroup = {
|
||||
"seed": original_step_data["seed"],
|
||||
"messages": working_messages,
|
||||
"tokens": working_tokens,
|
||||
"masks": working_masks,
|
||||
"scores": original_step_data.get("scores"),
|
||||
"parsed_actions": original_step_data.get("parsed_actions") # MC version specific
|
||||
}
|
||||
filtered_trajectory.append(updated_step_data)
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx} successfully processed. "
|
||||
f"Final max tokens: {max_current_tokens}"
|
||||
)
|
||||
else:
|
||||
if max_current_tokens > self.config.max_trajectory_tokens:
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit "
|
||||
f"({self.config.max_trajectory_tokens}) after maximum possible "
|
||||
f"uniform truncation or re-tokenization error."
|
||||
)
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
|
||||
f"or retokenization error occurred ({retokenization_error_this_step})."
|
||||
)
|
||||
|
||||
if len(filtered_trajectory) < len(trajectory):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Filtered out {len(trajectory) - len(filtered_trajectory)} steps "
|
||||
f"due to token limit constraints. Original trajectory length: {len(trajectory)}, "
|
||||
f"Filtered: {len(filtered_trajectory)}"
|
||||
f"[_ensure_trajectory_token_limit] MC env: Filtered out "
|
||||
f"{len(trajectory) - len(filtered_trajectory)} steps "
|
||||
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
|
||||
)
|
||||
return filtered_trajectory
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ import os
|
|||
import random
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from environments.game_environments.gymnasium.blackjack_env import BlackjackEnv
|
||||
from environments.game_environments.gymnasium.blackjack_env import BlackjackEnv, BlackjackEnvConfig
|
||||
from atroposlib.envs.base import OpenaiConfig, EvalHandlingEnum
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -13,51 +14,72 @@ logging.basicConfig(level=logging.INFO)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Blackjack environment local server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="blackjack_local",
|
||||
help="Configuration file name (without .yaml extension, relative to "
|
||||
"envs/gymnasium/configs), or full path to a YAML file.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
# def parse_arguments(): # Removed
|
||||
# parser = argparse.ArgumentParser(description="Blackjack environment local server")
|
||||
# parser.add_argument(
|
||||
# "--config",
|
||||
# type=str,
|
||||
# default="blackjack_local",
|
||||
# help="Configuration file name (without .yaml extension, relative to "
|
||||
# "envs/gymnasium/configs), or full path to a YAML file.",
|
||||
# )
|
||||
# return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting Blackjack environment server")
|
||||
logger.info("Starting Blackjack environment local debug runner")
|
||||
|
||||
args = parse_arguments()
|
||||
# args = parse_arguments() # Removed
|
||||
|
||||
# Determine the config name/path for config_init
|
||||
# config_init expects the name relative to its own configs dir, or an absolute path
|
||||
config_input = args.config
|
||||
if not os.path.isabs(config_input) and not config_input.endswith(".yaml"):
|
||||
# Assume it's a name relative to the blackjack env's config dir
|
||||
config_name_or_path = config_input
|
||||
logger.info(f"Using relative config name: {config_name_or_path}")
|
||||
else:
|
||||
# It's likely an absolute path or path relative to cwd
|
||||
config_name_or_path = os.path.abspath(config_input)
|
||||
logger.info(f"Using absolute config path: {config_name_or_path}")
|
||||
# Removed logic for config_name_or_path and BlackjackEnv.config_init
|
||||
|
||||
# Use the environment's config_init method to load configurations
|
||||
try:
|
||||
config, server_configs = BlackjackEnv.config_init(config_name_or_path)
|
||||
logger.info("Configuration loaded successfully via BlackjackEnv.config_init")
|
||||
logger.debug(f"Loaded Env Config: {config}")
|
||||
logger.debug(f"Loaded Server Configs: {server_configs}")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load configuration using BlackjackEnv.config_init: {e}"
|
||||
# Create hardcoded configurations for local debugging
|
||||
env_config = BlackjackEnvConfig(
|
||||
# BaseEnvConfig fields, tailored for debug
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=1, # Debug single generation path
|
||||
use_wandb=False,
|
||||
wandb_name="blackjack_local_debug", # Explicitly set for debug
|
||||
max_num_workers=1,
|
||||
rollout_server_url="http://localhost:8000", # Standard default
|
||||
total_steps=1,
|
||||
batch_size=1, # Consistent with 1 step, 1 worker, group_size 1
|
||||
steps_per_eval=0, # No eval steps needed
|
||||
max_token_length=1024 * 4, # Reduced for faster local debugging if necessary
|
||||
inference_weight=1.0,
|
||||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.NONE, # No evaluation in this script
|
||||
eval_limit_ratio=0.0,
|
||||
|
||||
# BlackjackEnvConfig specific fields (from blackjack_env.py's definition or defaults)
|
||||
env_name="Blackjack-v1",
|
||||
temperature=0.2, # Lower temperature for more deterministic debug output
|
||||
top_p=0.9, # Standard default
|
||||
max_turns=5, # Standard default
|
||||
thinking_active=True,
|
||||
eval_episodes=0, # No evaluation episodes
|
||||
max_think_chars_history=3000,
|
||||
max_trajectory_tokens=24576,
|
||||
debug_mode=True, # Enable debug logging from the environment
|
||||
mc_samples=1, # With group_size=1, this means 1 MC rollout for V(s)
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4.1-mini", # Ensure this is locally available if not mocked
|
||||
base_url="https://api.openai.com/v1", # Explicitly set OpenAI base URL
|
||||
api_key=os.getenv("OPENAI_API_KEY"), # Use env var or default
|
||||
num_requests_for_eval=0, # No eval requests
|
||||
)
|
||||
return # Cannot proceed without config
|
||||
]
|
||||
logger.info("Using hardcoded debug configuration.")
|
||||
logger.debug(f"Env Config: {env_config}")
|
||||
logger.debug(f"Server Configs: {server_configs}")
|
||||
|
||||
# Create and set up the environment using the loaded configs
|
||||
try:
|
||||
env = BlackjackEnv(
|
||||
config=config,
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
slurm=False, # Explicitly false for local testing
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1513,199 +1513,156 @@ class BlackjackEnv(BaseEnv):
|
|||
filtered_trajectory: List[BlackjackScoredDataGroup] = []
|
||||
|
||||
for step_idx, original_step_data in enumerate(trajectory):
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} has {len(original_step_data['messages'])} messages."
|
||||
)
|
||||
if (
|
||||
not original_step_data.get("messages")
|
||||
or not original_step_data.get("tokens")
|
||||
or not original_step_data.get("masks")
|
||||
if not (
|
||||
original_step_data.get("messages")
|
||||
and original_step_data.get("tokens")
|
||||
and original_step_data.get("masks")
|
||||
and original_step_data.get("seed") is not None # seed is mandatory for new group
|
||||
):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} is missing messages, tokens, or masks. Skipping."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} "
|
||||
f"is missing critical data (messages, tokens, masks, or seed). Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
current_step_messages_orig = [
|
||||
msgs.copy() for msgs in original_step_data["messages"]
|
||||
]
|
||||
for alt_idx, alt_msgs in enumerate(current_step_messages_orig):
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx}, alt {alt_idx} has {len(alt_msgs)} messages."
|
||||
)
|
||||
for msg_idx, msg in enumerate(alt_msgs):
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx}, alt {alt_idx}, msg {msg['content']}"
|
||||
)
|
||||
current_step_tokens_orig = [
|
||||
tkns.copy() for tkns in original_step_data["tokens"]
|
||||
]
|
||||
current_step_masks_orig = [
|
||||
msks.copy() for msks in original_step_data["masks"]
|
||||
]
|
||||
|
||||
num_alternatives = len(current_step_messages_orig)
|
||||
if num_alternatives == 0:
|
||||
filtered_trajectory.append(original_step_data)
|
||||
continue
|
||||
|
||||
max_initial_tokens = max(
|
||||
len(alt_tokens) for alt_tokens in current_step_tokens_orig
|
||||
)
|
||||
# Initial token calculation from original data to see if truncation is needed
|
||||
# Ensure tokens are lists of integers before calling len
|
||||
max_initial_tokens = 0
|
||||
if original_step_data["tokens"]:
|
||||
max_initial_tokens = max(
|
||||
len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list)
|
||||
) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0
|
||||
|
||||
if max_initial_tokens <= self.config.max_trajectory_tokens:
|
||||
filtered_trajectory.append(original_step_data)
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant. "
|
||||
f"Max tokens: {max_initial_tokens}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) exceeds limit "
|
||||
f"({self.config.max_trajectory_tokens}). Attempting uniform truncation."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
|
||||
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
|
||||
)
|
||||
|
||||
working_messages = [msgs.copy() for msgs in current_step_messages_orig]
|
||||
working_tokens = [tkns.copy() for tkns in current_step_tokens_orig]
|
||||
working_masks = [msks.copy() for msks in current_step_masks_orig]
|
||||
# Prepare working copies for modification
|
||||
# Ensure deep copies for lists of dicts if dicts are modified, but here we pop from list of dicts.
|
||||
working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []]
|
||||
working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []]
|
||||
working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []]
|
||||
max_current_tokens = max_initial_tokens
|
||||
num_alternatives = len(working_messages)
|
||||
|
||||
step_successfully_truncated = False
|
||||
while True:
|
||||
num_messages_to_pop_per_alt = [0] * num_alternatives
|
||||
can_truncate_globally = True
|
||||
if num_alternatives == 0: # Should not happen if initial checks passed
|
||||
logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping.")
|
||||
continue
|
||||
|
||||
retokenization_error_this_step = False
|
||||
while max_current_tokens > self.config.max_trajectory_tokens:
|
||||
target_pop_counts_per_alt = []
|
||||
for alt_idx in range(num_alternatives):
|
||||
alt_msg_list = working_messages[alt_idx]
|
||||
|
||||
min_len_to_preserve = 1
|
||||
if (
|
||||
len(alt_msg_list) > 0
|
||||
and alt_msg_list[-1]["role"] in UNMASKED_ROLES
|
||||
):
|
||||
min_len_to_preserve += 1
|
||||
if (
|
||||
len(alt_msg_list) > 1
|
||||
and alt_msg_list[-2]["role"] == "environment"
|
||||
):
|
||||
min_len_to_preserve += 1
|
||||
# Calculate how many initial messages (after system prompt) can be popped.
|
||||
# Preserving: system prompt (index 0), last agent response, and its preceding env observation.
|
||||
num_preserved_at_end = 0
|
||||
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in UNMASKED_ROLES:
|
||||
num_preserved_at_end = 1 # Last agent response
|
||||
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
|
||||
num_preserved_at_end = 2 # Agent response + preceding env observation
|
||||
|
||||
# Number of messages available for popping (between system prompt and preserved end messages)
|
||||
# Subtract 1 for the system prompt itself (which is never popped from index 0).
|
||||
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
|
||||
|
||||
if len(alt_msg_list) <= min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
|
||||
if (
|
||||
len(alt_msg_list) > 2
|
||||
and alt_msg_list[1]["role"] == "environment"
|
||||
and alt_msg_list[2]["role"] in UNMASKED_ROLES
|
||||
):
|
||||
if (len(alt_msg_list) - 2) < min_len_to_preserve:
|
||||
if (len(alt_msg_list) - 1) < min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 1
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 2
|
||||
elif len(alt_msg_list) > 1:
|
||||
if (len(alt_msg_list) - 1) < min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 1
|
||||
if available_to_pop <= 0:
|
||||
target_pop_counts_per_alt.append(0)
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
# Try to pop a pair (environment, agent) if they are at list[1] and list[2]
|
||||
can_pop_pair = (
|
||||
available_to_pop >= 2 and
|
||||
len(alt_msg_list) > 2 and # Ensure messages at index 1 and 2 exist
|
||||
alt_msg_list[1]["role"] == "environment" and
|
||||
alt_msg_list[2]["role"] in UNMASKED_ROLES
|
||||
)
|
||||
if can_pop_pair:
|
||||
target_pop_counts_per_alt.append(2)
|
||||
else: # Can pop at least 1 since available_to_pop > 0
|
||||
target_pop_counts_per_alt.append(1)
|
||||
|
||||
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
|
||||
if not positive_pop_counts:
|
||||
break # No alternative can be truncated further
|
||||
|
||||
min_pop_this_round = min(positive_pop_counts)
|
||||
|
||||
if not can_truncate_globally:
|
||||
break
|
||||
|
||||
min_pop_count = float("inf")
|
||||
for count in num_messages_to_pop_per_alt:
|
||||
if count > 0:
|
||||
min_pop_count = min(min_pop_count, count)
|
||||
|
||||
if min_pop_count == float("inf") or min_pop_count == 0:
|
||||
break
|
||||
|
||||
successfully_retokenized_all = True
|
||||
new_alt_tokens_list = []
|
||||
new_alt_masks_list = []
|
||||
# Pop messages and re-tokenize
|
||||
temp_new_alt_tokens = []
|
||||
temp_new_alt_masks = []
|
||||
max_tokens_after_this_trunc = 0
|
||||
|
||||
for alt_idx in range(num_alternatives):
|
||||
for _ in range(min_pop_count):
|
||||
if len(working_messages[alt_idx]) > 1:
|
||||
for _ in range(min_pop_this_round):
|
||||
if len(working_messages[alt_idx]) > 1: # Ensure there's something to pop after system
|
||||
working_messages[alt_idx].pop(1)
|
||||
else:
|
||||
logger.error(
|
||||
f"[_ensure_trajectory_token_limit] Critical error during "
|
||||
f"pop for alt {alt_idx}, step {step_idx}."
|
||||
f"[_ensure_trajectory_token_limit] Critical error during pop for "
|
||||
f"alt {alt_idx}, step {step_idx}. List too short."
|
||||
)
|
||||
successfully_retokenized_all = False
|
||||
break
|
||||
if not successfully_retokenized_all:
|
||||
break
|
||||
|
||||
retokenization_error_this_step = True; break
|
||||
if retokenization_error_this_step: break
|
||||
|
||||
try:
|
||||
tokenized_alt = tokenize_for_trainer(
|
||||
self.tokenizer, working_messages[alt_idx]
|
||||
)
|
||||
new_alt_tokens_list.append(tokenized_alt["tokens"])
|
||||
new_alt_masks_list.append(tokenized_alt["masks"])
|
||||
max_tokens_after_this_trunc = max(
|
||||
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
|
||||
)
|
||||
tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx])
|
||||
temp_new_alt_tokens.append(tokenized_alt["tokens"])
|
||||
temp_new_alt_masks.append(tokenized_alt["masks"])
|
||||
max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"]))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[_ensure_trajectory_token_limit] Error re-tokenizing "
|
||||
f"alt {alt_idx} in step {step_idx} after truncation: {e}"
|
||||
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
|
||||
f"in step {step_idx} after truncation: {e}"
|
||||
)
|
||||
successfully_retokenized_all = False
|
||||
break
|
||||
retokenization_error_this_step = True; break
|
||||
|
||||
if retokenization_error_this_step: break
|
||||
|
||||
if not successfully_retokenized_all:
|
||||
step_successfully_truncated = False
|
||||
break
|
||||
|
||||
working_tokens = new_alt_tokens_list
|
||||
working_masks = new_alt_masks_list
|
||||
working_tokens = temp_new_alt_tokens
|
||||
working_masks = temp_new_alt_masks
|
||||
max_current_tokens = max_tokens_after_this_trunc
|
||||
logger.debug(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx}, after "
|
||||
f"uniform pop of {min_pop_count}, max tokens: {max_current_tokens}"
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, "
|
||||
f"max tokens: {max_current_tokens}"
|
||||
)
|
||||
# End of while loop for truncation attempts
|
||||
|
||||
if max_current_tokens <= self.config.max_trajectory_tokens:
|
||||
step_successfully_truncated = True
|
||||
break
|
||||
|
||||
if step_successfully_truncated:
|
||||
updated_step_data = original_step_data.copy()
|
||||
updated_step_data["messages"] = working_messages
|
||||
updated_step_data["tokens"] = working_tokens
|
||||
updated_step_data["masks"] = working_masks
|
||||
if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens:
|
||||
updated_step_data: BlackjackScoredDataGroup = {
|
||||
"seed": original_step_data["seed"],
|
||||
"messages": working_messages,
|
||||
"tokens": working_tokens,
|
||||
"masks": working_masks,
|
||||
"scores": original_step_data.get("scores"),
|
||||
"parsed_action": original_step_data.get("parsed_action")
|
||||
}
|
||||
filtered_trajectory.append(updated_step_data)
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. "
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully processed. "
|
||||
f"Final max tokens: {max_current_tokens}"
|
||||
)
|
||||
else:
|
||||
if max_current_tokens > self.config.max_trajectory_tokens:
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit "
|
||||
f"({self.config.max_trajectory_tokens}) after maximum possible "
|
||||
f"uniform truncation or re-tokenization error."
|
||||
)
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
|
||||
f"or retokenization error occurred ({retokenization_error_this_step})."
|
||||
)
|
||||
|
||||
if len(filtered_trajectory) < len(trajectory):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Filtered out "
|
||||
f"{len(trajectory) - len(filtered_trajectory)} steps "
|
||||
f"due to token limit constraints. Original trajectory length: {len(trajectory)}, "
|
||||
f"Filtered: {len(filtered_trajectory)}"
|
||||
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
|
||||
)
|
||||
return filtered_trajectory
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue