update local server

This commit is contained in:
Shannon Sands 2025-05-10 08:18:41 +10:00
parent c506bb147e
commit ba604d44f9
3 changed files with 251 additions and 320 deletions

View file

@ -968,193 +968,145 @@ class BlackjackEnv(BaseEnv):
filtered_trajectory: List[BlackjackScoredDataGroup] = []
for step_idx, original_step_data in enumerate(trajectory):
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} has "
f"{len(original_step_data['messages'])} alternatives."
)
if (
not original_step_data.get("messages")
or not original_step_data.get("tokens")
or not original_step_data.get("masks")
if not (
original_step_data.get("messages")
and original_step_data.get("tokens")
and original_step_data.get("masks")
and original_step_data.get("seed") is not None
and original_step_data.get("parsed_actions") is not None # Specific to MC version
):
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} "
f"is missing messages, tokens, or masks. Skipping."
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env "
f"is missing critical data. Skipping."
)
continue
current_step_messages_orig = [
msgs.copy() for msgs in original_step_data["messages"]
]
current_step_tokens_orig = [
tkns.copy() for tkns in original_step_data["tokens"]
]
current_step_masks_orig = [
msks.copy() for msks in original_step_data["masks"]
]
num_alternatives = len(current_step_messages_orig)
if num_alternatives == 0:
filtered_trajectory.append(original_step_data)
continue
max_initial_tokens = max(
len(alt_tokens) for alt_tokens in current_step_tokens_orig
)
# Initial token calculation from original data
max_initial_tokens = 0
if original_step_data["tokens"]:
max_initial_tokens = max(
len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list)
) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0
if max_initial_tokens <= self.config.max_trajectory_tokens:
filtered_trajectory.append(original_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant in MC env. "
f"Max tokens: {max_initial_tokens}"
)
continue
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting uniform truncation."
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env (max tokens: {max_initial_tokens}) "
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
)
working_messages = [msgs.copy() for msgs in current_step_messages_orig]
working_tokens = [tkns.copy() for tkns in current_step_tokens_orig]
working_masks = [msks.copy() for msks in current_step_masks_orig]
working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []]
working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []]
working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []]
max_current_tokens = max_initial_tokens
num_alternatives = len(working_messages)
step_successfully_truncated = False
while True:
num_messages_to_pop_per_alt = [0] * num_alternatives
can_truncate_globally = True
if num_alternatives == 0:
logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping.")
continue
retokenization_error_this_step = False
while max_current_tokens > self.config.max_trajectory_tokens:
target_pop_counts_per_alt = []
for alt_idx in range(num_alternatives):
alt_msg_list = working_messages[alt_idx]
num_preserved_at_end = 0
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES:
num_preserved_at_end = 1
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
num_preserved_at_end = 2
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
min_len_to_preserve = 1
if len(alt_msg_list) > 0 and alt_msg_list[-1]["role"] in ["agent"]:
min_len_to_preserve += 1
if (
len(alt_msg_list) > 1
and alt_msg_list[-2]["role"] == "environment"
):
min_len_to_preserve += 1
if len(alt_msg_list) <= min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
if (
len(alt_msg_list) > 2
and alt_msg_list[1]["role"] == "environment"
and alt_msg_list[2]["role"] == "agent"
):
if (len(alt_msg_list) - 2) < min_len_to_preserve:
if (len(alt_msg_list) - 1) < min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
else:
num_messages_to_pop_per_alt[alt_idx] = 1
else:
num_messages_to_pop_per_alt[alt_idx] = 2
elif len(alt_msg_list) > 1:
if (len(alt_msg_list) - 1) < min_len_to_preserve:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
else:
num_messages_to_pop_per_alt[alt_idx] = 1
if available_to_pop <= 0:
target_pop_counts_per_alt.append(0)
else:
num_messages_to_pop_per_alt[alt_idx] = 0
can_truncate_globally = False
break
if not can_truncate_globally:
break
min_pop_count = float("inf")
for count in num_messages_to_pop_per_alt:
if count > 0:
min_pop_count = min(min_pop_count, count)
if min_pop_count == float("inf") or min_pop_count == 0:
break
successfully_retokenized_all = True
new_alt_tokens_list = []
new_alt_masks_list = []
can_pop_pair = (
available_to_pop >= 2 and
len(alt_msg_list) > 2 and
alt_msg_list[1]["role"] == "environment" and
alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES
)
if can_pop_pair:
target_pop_counts_per_alt.append(2)
else:
target_pop_counts_per_alt.append(1)
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
if not positive_pop_counts:
break
min_pop_this_round = min(positive_pop_counts)
temp_new_alt_tokens = []
temp_new_alt_masks = []
max_tokens_after_this_trunc = 0
for alt_idx in range(num_alternatives):
for _ in range(min_pop_count):
for _ in range(min_pop_this_round):
if len(working_messages[alt_idx]) > 1:
working_messages[alt_idx].pop(1)
else:
logger.error(
f"[_ensure_trajectory_token_limit] Critical error during pop for "
f"alt {alt_idx}, step {step_idx}."
f"[_ensure_trajectory_token_limit] MC env: Critical error during pop for "
f"alt {alt_idx}, step {step_idx}. List too short."
)
successfully_retokenized_all = False
break
if not successfully_retokenized_all:
break
retokenization_error_this_step = True; break
if retokenization_error_this_step: break
try:
tokenized_alt = tokenize_for_trainer(
self.tokenizer, working_messages[alt_idx]
)
new_alt_tokens_list.append(tokenized_alt["tokens"])
new_alt_masks_list.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
)
tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx])
temp_new_alt_tokens.append(tokenized_alt["tokens"])
temp_new_alt_masks.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"]))
except Exception as e:
logger.error(
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
f"[_ensure_trajectory_token_limit] MC env: Error re-tokenizing alt {alt_idx} "
f"in step {step_idx} after truncation: {e}"
)
successfully_retokenized_all = False
break
retokenization_error_this_step = True; break
if retokenization_error_this_step: break
if not successfully_retokenized_all:
step_successfully_truncated = False
break
working_tokens = new_alt_tokens_list
working_masks = new_alt_masks_list
working_tokens = temp_new_alt_tokens
working_masks = temp_new_alt_masks
max_current_tokens = max_tokens_after_this_trunc
logger.debug(
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_count}, "
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, "
f"max tokens: {max_current_tokens}"
)
if max_current_tokens <= self.config.max_trajectory_tokens:
step_successfully_truncated = True
break
if step_successfully_truncated:
updated_step_data = BlackjackScoredDataGroup(
seed=original_step_data["seed"],
messages=working_messages,
tokens=working_tokens,
masks=working_masks,
scores=original_step_data["scores"],
parsed_actions=original_step_data["parsed_actions"],
)
if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens:
updated_step_data: BlackjackScoredDataGroup = {
"seed": original_step_data["seed"],
"messages": working_messages,
"tokens": working_tokens,
"masks": working_masks,
"scores": original_step_data.get("scores"),
"parsed_actions": original_step_data.get("parsed_actions") # MC version specific
}
filtered_trajectory.append(updated_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. "
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx} successfully processed. "
f"Final max tokens: {max_current_tokens}"
)
else:
if max_current_tokens > self.config.max_trajectory_tokens:
logger.warning(
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit "
f"({self.config.max_trajectory_tokens}) after maximum possible "
f"uniform truncation or re-tokenization error."
)
logger.warning(
f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
f"or retokenization error occurred ({retokenization_error_this_step})."
)
if len(filtered_trajectory) < len(trajectory):
logger.warning(
f"[_ensure_trajectory_token_limit] Filtered out {len(trajectory) - len(filtered_trajectory)} steps "
f"due to token limit constraints. Original trajectory length: {len(trajectory)}, "
f"Filtered: {len(filtered_trajectory)}"
f"[_ensure_trajectory_token_limit] MC env: Filtered out "
f"{len(trajectory) - len(filtered_trajectory)} steps "
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
)
return filtered_trajectory