move best-of-n selection to util

This commit is contained in:
Shannon Sands 2025-05-14 10:35:12 -07:00
parent 4c00e2b209
commit 21cc528b85
3 changed files with 102 additions and 17 deletions

View file

@ -31,7 +31,8 @@ from atroposlib.utils import (
tokenize_for_trainer,
parse_tool_call,
truncate_thinking,
ensure_trajectory_token_limit
ensure_trajectory_token_limit,
select_best_index
)
logger = logging.getLogger(__name__)
@ -522,26 +523,24 @@ class BlackjackEnv(BaseEnv):
)
)
# Prepare items for sorting: (-advantage, token_length, original_index)
sortable_alternatives = []
for i in range(G):
# alt_tokens[i] is expected to be a list (possibly empty)
token_len = len(alt_tokens[i])
sortable_alternatives.append((-alt_advantages[i], token_len, i))
# token lengths for tie-breaking during selection
alt_token_lengths = [len(tkns) for tkns in alt_tokens]
sortable_alternatives.sort()
best_advantage_idx = sortable_alternatives[0][2]
# Log details of the selected alternative based on the sort
chosen_advantage_for_log = -sortable_alternatives[0][0]
chosen_token_length_for_log = sortable_alternatives[0][1]
best_advantage_idx = select_best_index(
primary_scores=alt_advantages,
secondary_scores=alt_token_lengths,
primary_higher_is_better=True,
secondary_lower_is_better=True
)
chosen_advantage_for_log = alt_advantages[best_advantage_idx]
chosen_token_length_for_log = alt_token_lengths[best_advantage_idx]
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Selected Alt {best_advantage_idx} "
f"(Adv: {chosen_advantage_for_log:.2f}, "
f"Tokens: {chosen_token_length_for_log}) "
f"from {G} alternatives using sort."
f"from {G} alternatives using select_best_index."
)
chosen_env_action = alt_env_actions[best_advantage_idx]
@ -563,7 +562,7 @@ class BlackjackEnv(BaseEnv):
ep.message_history = current_state_messages
response_for_history = truncate_thinking(
chosen_full_response, self.config.max_think_chars_history
chosen_full_response, self.tokenizer, self.config.max_think_chars_history
)
ep.message_history.append(
{"role": "agent", "content": response_for_history}
@ -730,7 +729,7 @@ class BlackjackEnv(BaseEnv):
metrics["num_turns"] = turn + 1
response_for_history = truncate_thinking(
full_agent_response, self.config.max_think_chars_history
full_agent_response, self.tokenizer, self.config.max_think_chars_history
)
ep.message_history.append(