mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
update local server
This commit is contained in:
parent
c506bb147e
commit
ba604d44f9
3 changed files with 251 additions and 320 deletions
|
|
@ -968,193 +968,145 @@ class BlackjackEnv(BaseEnv):
|
|||
filtered_trajectory: List[BlackjackScoredDataGroup] = []
|
||||
|
||||
for step_idx, original_step_data in enumerate(trajectory):
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} has "
|
||||
f"{len(original_step_data['messages'])} alternatives."
|
||||
)
|
||||
if (
|
||||
not original_step_data.get("messages")
|
||||
or not original_step_data.get("tokens")
|
||||
or not original_step_data.get("masks")
|
||||
if not (
|
||||
original_step_data.get("messages")
|
||||
and original_step_data.get("tokens")
|
||||
and original_step_data.get("masks")
|
||||
and original_step_data.get("seed") is not None
|
||||
and original_step_data.get("parsed_actions") is not None # Specific to MC version
|
||||
):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} "
|
||||
f"is missing messages, tokens, or masks. Skipping."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env "
|
||||
f"is missing critical data. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
current_step_messages_orig = [
|
||||
msgs.copy() for msgs in original_step_data["messages"]
|
||||
]
|
||||
current_step_tokens_orig = [
|
||||
tkns.copy() for tkns in original_step_data["tokens"]
|
||||
]
|
||||
current_step_masks_orig = [
|
||||
msks.copy() for msks in original_step_data["masks"]
|
||||
]
|
||||
|
||||
num_alternatives = len(current_step_messages_orig)
|
||||
if num_alternatives == 0:
|
||||
filtered_trajectory.append(original_step_data)
|
||||
continue
|
||||
|
||||
max_initial_tokens = max(
|
||||
len(alt_tokens) for alt_tokens in current_step_tokens_orig
|
||||
)
|
||||
# Initial token calculation from original data
|
||||
max_initial_tokens = 0
|
||||
if original_step_data["tokens"]:
|
||||
max_initial_tokens = max(
|
||||
len(alt_tokens) for alt_tokens in original_step_data["tokens"] if isinstance(alt_tokens, list)
|
||||
) if any(isinstance(alt_tokens, list) for alt_tokens in original_step_data["tokens"]) else 0
|
||||
|
||||
if max_initial_tokens <= self.config.max_trajectory_tokens:
|
||||
filtered_trajectory.append(original_step_data)
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant in MC env. "
|
||||
f"Max tokens: {max_initial_tokens}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
|
||||
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting uniform truncation."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env (max tokens: {max_initial_tokens}) "
|
||||
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
|
||||
)
|
||||
|
||||
working_messages = [msgs.copy() for msgs in current_step_messages_orig]
|
||||
working_tokens = [tkns.copy() for tkns in current_step_tokens_orig]
|
||||
working_masks = [msks.copy() for msks in current_step_masks_orig]
|
||||
working_messages = [msgs_list.copy() for msgs_list in original_step_data["messages"] or []]
|
||||
working_tokens = [tkns_list.copy() for tkns_list in original_step_data["tokens"] or []]
|
||||
working_masks = [msks_list.copy() for msks_list in original_step_data["masks"] or []]
|
||||
max_current_tokens = max_initial_tokens
|
||||
num_alternatives = len(working_messages)
|
||||
|
||||
step_successfully_truncated = False
|
||||
while True:
|
||||
num_messages_to_pop_per_alt = [0] * num_alternatives
|
||||
can_truncate_globally = True
|
||||
if num_alternatives == 0:
|
||||
logger.warning(f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping.")
|
||||
continue
|
||||
|
||||
retokenization_error_this_step = False
|
||||
while max_current_tokens > self.config.max_trajectory_tokens:
|
||||
target_pop_counts_per_alt = []
|
||||
for alt_idx in range(num_alternatives):
|
||||
alt_msg_list = working_messages[alt_idx]
|
||||
num_preserved_at_end = 0
|
||||
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES:
|
||||
num_preserved_at_end = 1
|
||||
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
|
||||
num_preserved_at_end = 2
|
||||
|
||||
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
|
||||
|
||||
min_len_to_preserve = 1
|
||||
if len(alt_msg_list) > 0 and alt_msg_list[-1]["role"] in ["agent"]:
|
||||
min_len_to_preserve += 1
|
||||
if (
|
||||
len(alt_msg_list) > 1
|
||||
and alt_msg_list[-2]["role"] == "environment"
|
||||
):
|
||||
min_len_to_preserve += 1
|
||||
|
||||
if len(alt_msg_list) <= min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
|
||||
if (
|
||||
len(alt_msg_list) > 2
|
||||
and alt_msg_list[1]["role"] == "environment"
|
||||
and alt_msg_list[2]["role"] == "agent"
|
||||
):
|
||||
if (len(alt_msg_list) - 2) < min_len_to_preserve:
|
||||
if (len(alt_msg_list) - 1) < min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 1
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 2
|
||||
elif len(alt_msg_list) > 1:
|
||||
if (len(alt_msg_list) - 1) < min_len_to_preserve:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 1
|
||||
if available_to_pop <= 0:
|
||||
target_pop_counts_per_alt.append(0)
|
||||
else:
|
||||
num_messages_to_pop_per_alt[alt_idx] = 0
|
||||
can_truncate_globally = False
|
||||
break
|
||||
|
||||
if not can_truncate_globally:
|
||||
break
|
||||
|
||||
min_pop_count = float("inf")
|
||||
for count in num_messages_to_pop_per_alt:
|
||||
if count > 0:
|
||||
min_pop_count = min(min_pop_count, count)
|
||||
|
||||
if min_pop_count == float("inf") or min_pop_count == 0:
|
||||
break
|
||||
|
||||
successfully_retokenized_all = True
|
||||
new_alt_tokens_list = []
|
||||
new_alt_masks_list = []
|
||||
can_pop_pair = (
|
||||
available_to_pop >= 2 and
|
||||
len(alt_msg_list) > 2 and
|
||||
alt_msg_list[1]["role"] == "environment" and
|
||||
alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES
|
||||
)
|
||||
if can_pop_pair:
|
||||
target_pop_counts_per_alt.append(2)
|
||||
else:
|
||||
target_pop_counts_per_alt.append(1)
|
||||
|
||||
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
|
||||
if not positive_pop_counts:
|
||||
break
|
||||
|
||||
min_pop_this_round = min(positive_pop_counts)
|
||||
temp_new_alt_tokens = []
|
||||
temp_new_alt_masks = []
|
||||
max_tokens_after_this_trunc = 0
|
||||
|
||||
for alt_idx in range(num_alternatives):
|
||||
for _ in range(min_pop_count):
|
||||
for _ in range(min_pop_this_round):
|
||||
if len(working_messages[alt_idx]) > 1:
|
||||
working_messages[alt_idx].pop(1)
|
||||
else:
|
||||
logger.error(
|
||||
f"[_ensure_trajectory_token_limit] Critical error during pop for "
|
||||
f"alt {alt_idx}, step {step_idx}."
|
||||
f"[_ensure_trajectory_token_limit] MC env: Critical error during pop for "
|
||||
f"alt {alt_idx}, step {step_idx}. List too short."
|
||||
)
|
||||
successfully_retokenized_all = False
|
||||
break
|
||||
if not successfully_retokenized_all:
|
||||
break
|
||||
|
||||
retokenization_error_this_step = True; break
|
||||
if retokenization_error_this_step: break
|
||||
|
||||
try:
|
||||
tokenized_alt = tokenize_for_trainer(
|
||||
self.tokenizer, working_messages[alt_idx]
|
||||
)
|
||||
new_alt_tokens_list.append(tokenized_alt["tokens"])
|
||||
new_alt_masks_list.append(tokenized_alt["masks"])
|
||||
max_tokens_after_this_trunc = max(
|
||||
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
|
||||
)
|
||||
tokenized_alt = tokenize_for_trainer(self.tokenizer, working_messages[alt_idx])
|
||||
temp_new_alt_tokens.append(tokenized_alt["tokens"])
|
||||
temp_new_alt_masks.append(tokenized_alt["masks"])
|
||||
max_tokens_after_this_trunc = max(max_tokens_after_this_trunc, len(tokenized_alt["tokens"]))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Error re-tokenizing alt {alt_idx} "
|
||||
f"in step {step_idx} after truncation: {e}"
|
||||
)
|
||||
successfully_retokenized_all = False
|
||||
break
|
||||
retokenization_error_this_step = True; break
|
||||
|
||||
if retokenization_error_this_step: break
|
||||
|
||||
if not successfully_retokenized_all:
|
||||
step_successfully_truncated = False
|
||||
break
|
||||
|
||||
working_tokens = new_alt_tokens_list
|
||||
working_masks = new_alt_masks_list
|
||||
working_tokens = temp_new_alt_tokens
|
||||
working_masks = temp_new_alt_masks
|
||||
max_current_tokens = max_tokens_after_this_trunc
|
||||
logger.debug(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_count}, "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, "
|
||||
f"max tokens: {max_current_tokens}"
|
||||
)
|
||||
|
||||
if max_current_tokens <= self.config.max_trajectory_tokens:
|
||||
step_successfully_truncated = True
|
||||
break
|
||||
|
||||
if step_successfully_truncated:
|
||||
updated_step_data = BlackjackScoredDataGroup(
|
||||
seed=original_step_data["seed"],
|
||||
messages=working_messages,
|
||||
tokens=working_tokens,
|
||||
masks=working_masks,
|
||||
scores=original_step_data["scores"],
|
||||
parsed_actions=original_step_data["parsed_actions"],
|
||||
)
|
||||
if not retokenization_error_this_step and max_current_tokens <= self.config.max_trajectory_tokens:
|
||||
updated_step_data: BlackjackScoredDataGroup = {
|
||||
"seed": original_step_data["seed"],
|
||||
"messages": working_messages,
|
||||
"tokens": working_tokens,
|
||||
"masks": working_masks,
|
||||
"scores": original_step_data.get("scores"),
|
||||
"parsed_actions": original_step_data.get("parsed_actions") # MC version specific
|
||||
}
|
||||
filtered_trajectory.append(updated_step_data)
|
||||
logger.info(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully truncated. "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx} successfully processed. "
|
||||
f"Final max tokens: {max_current_tokens}"
|
||||
)
|
||||
else:
|
||||
if max_current_tokens > self.config.max_trajectory_tokens:
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit "
|
||||
f"({self.config.max_trajectory_tokens}) after maximum possible "
|
||||
f"uniform truncation or re-tokenization error."
|
||||
)
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. "
|
||||
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
|
||||
f"or retokenization error occurred ({retokenization_error_this_step})."
|
||||
)
|
||||
|
||||
if len(filtered_trajectory) < len(trajectory):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Filtered out {len(trajectory) - len(filtered_trajectory)} steps "
|
||||
f"due to token limit constraints. Original trajectory length: {len(trajectory)}, "
|
||||
f"Filtered: {len(filtered_trajectory)}"
|
||||
f"[_ensure_trajectory_token_limit] MC env: Filtered out "
|
||||
f"{len(trajectory) - len(filtered_trajectory)} steps "
|
||||
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
|
||||
)
|
||||
return filtered_trajectory
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue