diff --git a/atroposlib/utils/__init__.py b/atroposlib/utils/__init__.py index 368eee11..cd578734 100644 --- a/atroposlib/utils/__init__.py +++ b/atroposlib/utils/__init__.py @@ -16,6 +16,7 @@ from .advantages import ( compute_discounted_returns, compute_grpo_process_supervision_advantages, ) +from .best_of_n_selection import select_best_index __all__ = [ "ConfigHandler", @@ -28,4 +29,5 @@ __all__ = [ "compute_discounted_returns", "compute_grpo_process_supervision_advantages", "ensure_trajectory_token_limit", + "select_best_index", ] diff --git a/atroposlib/utils/best_of_n_selection.py b/atroposlib/utils/best_of_n_selection.py new file mode 100644 index 00000000..c8539d9b --- /dev/null +++ b/atroposlib/utils/best_of_n_selection.py @@ -0,0 +1,84 @@ +""" +Greedy selection of the best alternative in a group of alternatives. + +For a group of alternatives, select the one with the highest score (raw rewards or advantages). +""" + +from typing import List, Union, Tuple + +def select_best_index( + primary_scores: List[Union[float, int]], + secondary_scores: List[Union[float, int]], + primary_higher_is_better: bool = True, + secondary_lower_is_better: bool = True, +) -> int: + """ + Selects the index of the best item from a list based on primary and secondary scores. + + Args: + primary_scores: A list of scores that are the primary criterion for selection. + secondary_scores: A list of scores used for tie-breaking if primary scores are equal. + primary_higher_is_better: If True, higher primary scores are considered better. + If False, lower primary scores are considered better. + secondary_lower_is_better: If True, lower secondary scores are considered better for tie-breaking. + If False, higher secondary scores are considered better. + + Returns: + The index of the best item. + + Raises: + ValueError: If primary_scores and secondary_scores have different lengths or are empty. + """ + if not primary_scores or not secondary_scores: + raise ValueError("Input score lists cannot be empty.") + if len(primary_scores) != len(secondary_scores): + raise ValueError("Primary and secondary score lists must have the same length.") + + num_items = len(primary_scores) + if num_items == 0: # Should be caught by the first check, but as a safeguard. + raise ValueError("Input score lists cannot be empty.") + + best_index = 0 + + for i in range(1, num_items): + # Primary score comparison + current_primary_is_better = False + primary_score_i = primary_scores[i] + primary_score_best = primary_scores[best_index] + + if primary_higher_is_better: + if primary_score_i > primary_score_best: + current_primary_is_better = True + else: # primary_lower_is_better + if primary_score_i < primary_score_best: + current_primary_is_better = True + + if current_primary_is_better: + best_index = i + continue + + # If primary scores are effectively equal (within a very small tolerance for floats) + # or exactly equal for integers, then compare secondary scores. + # Using a small tolerance for float comparison might be needed if scores are computed. + # For simplicity here, we'll use direct equality, which is fine for typical int/float rewards. + if primary_score_i == primary_score_best: + secondary_score_i = secondary_scores[i] + secondary_score_best = secondary_scores[best_index] + + current_secondary_is_better_for_tiebreak = False + if secondary_lower_is_better: + if secondary_score_i < secondary_score_best: + current_secondary_is_better_for_tiebreak = True + else: # secondary_higher_is_better + if secondary_score_i > secondary_score_best: + current_secondary_is_better_for_tiebreak = True + + if current_secondary_is_better_for_tiebreak: + best_index = i + + return best_index + + + + + diff --git a/environments/game_environments/gymnasium/blackjack_env.py b/environments/game_environments/gymnasium/blackjack_env.py index b84d8e5a..8f2e6518 100644 --- a/environments/game_environments/gymnasium/blackjack_env.py +++ b/environments/game_environments/gymnasium/blackjack_env.py @@ -31,7 +31,8 @@ from atroposlib.utils import ( tokenize_for_trainer, parse_tool_call, truncate_thinking, - ensure_trajectory_token_limit + ensure_trajectory_token_limit, + select_best_index ) logger = logging.getLogger(__name__) @@ -522,26 +523,24 @@ class BlackjackEnv(BaseEnv): ) ) - # Prepare items for sorting: (-advantage, token_length, original_index) - sortable_alternatives = [] - for i in range(G): - # alt_tokens[i] is expected to be a list (possibly empty) - token_len = len(alt_tokens[i]) - sortable_alternatives.append((-alt_advantages[i], token_len, i)) + # token lengths for tie-breaking during selection + alt_token_lengths = [len(tkns) for tkns in alt_tokens] - sortable_alternatives.sort() - - best_advantage_idx = sortable_alternatives[0][2] - - # Log details of the selected alternative based on the sort - chosen_advantage_for_log = -sortable_alternatives[0][0] - chosen_token_length_for_log = sortable_alternatives[0][1] + best_advantage_idx = select_best_index( + primary_scores=alt_advantages, + secondary_scores=alt_token_lengths, + primary_higher_is_better=True, + secondary_lower_is_better=True + ) + + chosen_advantage_for_log = alt_advantages[best_advantage_idx] + chosen_token_length_for_log = alt_token_lengths[best_advantage_idx] logger.debug( f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] " f"Selected Alt {best_advantage_idx} " f"(Adv: {chosen_advantage_for_log:.2f}, " f"Tokens: {chosen_token_length_for_log}) " - f"from {G} alternatives using sort." + f"from {G} alternatives using select_best_index." ) chosen_env_action = alt_env_actions[best_advantage_idx] @@ -563,7 +562,7 @@ class BlackjackEnv(BaseEnv): ep.message_history = current_state_messages response_for_history = truncate_thinking( - chosen_full_response, self.config.max_think_chars_history + chosen_full_response, self.tokenizer, self.config.max_think_chars_history ) ep.message_history.append( {"role": "agent", "content": response_for_history} @@ -730,7 +729,7 @@ class BlackjackEnv(BaseEnv): metrics["num_turns"] = turn + 1 response_for_history = truncate_thinking( - full_agent_response, self.config.max_think_chars_history + full_agent_response, self.tokenizer, self.config.max_think_chars_history ) ep.message_history.append(