mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
move best-of-n selection to util
This commit is contained in:
parent
4c00e2b209
commit
21cc528b85
3 changed files with 102 additions and 17 deletions
|
|
@ -31,7 +31,8 @@ from atroposlib.utils import (
|
|||
tokenize_for_trainer,
|
||||
parse_tool_call,
|
||||
truncate_thinking,
|
||||
ensure_trajectory_token_limit
|
||||
ensure_trajectory_token_limit,
|
||||
select_best_index
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -522,26 +523,24 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
)
|
||||
|
||||
# Prepare items for sorting: (-advantage, token_length, original_index)
|
||||
sortable_alternatives = []
|
||||
for i in range(G):
|
||||
# alt_tokens[i] is expected to be a list (possibly empty)
|
||||
token_len = len(alt_tokens[i])
|
||||
sortable_alternatives.append((-alt_advantages[i], token_len, i))
|
||||
# token lengths for tie-breaking during selection
|
||||
alt_token_lengths = [len(tkns) for tkns in alt_tokens]
|
||||
|
||||
sortable_alternatives.sort()
|
||||
|
||||
best_advantage_idx = sortable_alternatives[0][2]
|
||||
|
||||
# Log details of the selected alternative based on the sort
|
||||
chosen_advantage_for_log = -sortable_alternatives[0][0]
|
||||
chosen_token_length_for_log = sortable_alternatives[0][1]
|
||||
best_advantage_idx = select_best_index(
|
||||
primary_scores=alt_advantages,
|
||||
secondary_scores=alt_token_lengths,
|
||||
primary_higher_is_better=True,
|
||||
secondary_lower_is_better=True
|
||||
)
|
||||
|
||||
chosen_advantage_for_log = alt_advantages[best_advantage_idx]
|
||||
chosen_token_length_for_log = alt_token_lengths[best_advantage_idx]
|
||||
logger.debug(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
||||
f"Selected Alt {best_advantage_idx} "
|
||||
f"(Adv: {chosen_advantage_for_log:.2f}, "
|
||||
f"Tokens: {chosen_token_length_for_log}) "
|
||||
f"from {G} alternatives using sort."
|
||||
f"from {G} alternatives using select_best_index."
|
||||
)
|
||||
|
||||
chosen_env_action = alt_env_actions[best_advantage_idx]
|
||||
|
|
@ -563,7 +562,7 @@ class BlackjackEnv(BaseEnv):
|
|||
ep.message_history = current_state_messages
|
||||
|
||||
response_for_history = truncate_thinking(
|
||||
chosen_full_response, self.config.max_think_chars_history
|
||||
chosen_full_response, self.tokenizer, self.config.max_think_chars_history
|
||||
)
|
||||
ep.message_history.append(
|
||||
{"role": "agent", "content": response_for_history}
|
||||
|
|
@ -730,7 +729,7 @@ class BlackjackEnv(BaseEnv):
|
|||
metrics["num_turns"] = turn + 1
|
||||
|
||||
response_for_history = truncate_thinking(
|
||||
full_agent_response, self.config.max_think_chars_history
|
||||
full_agent_response, self.tokenizer, self.config.max_think_chars_history
|
||||
)
|
||||
|
||||
ep.message_history.append(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue