diff --git a/atroposlib/utils/best_of_n_selection.py b/atroposlib/utils/best_of_n_selection.py index c8539d9b..752cb599 100644 --- a/atroposlib/utils/best_of_n_selection.py +++ b/atroposlib/utils/best_of_n_selection.py @@ -4,7 +4,8 @@ Greedy selection of the best alternative in a group of alternatives. For a group of alternatives, select the one with the highest score (raw rewards or advantages). """ -from typing import List, Union, Tuple +from typing import List, Union + def select_best_index( primary_scores: List[Union[float, int]], @@ -35,7 +36,7 @@ def select_best_index( raise ValueError("Primary and secondary score lists must have the same length.") num_items = len(primary_scores) - if num_items == 0: # Should be caught by the first check, but as a safeguard. + if num_items == 0: # Should be caught by the first check, but as a safeguard. raise ValueError("Input score lists cannot be empty.") best_index = 0 @@ -49,10 +50,10 @@ def select_best_index( if primary_higher_is_better: if primary_score_i > primary_score_best: current_primary_is_better = True - else: # primary_lower_is_better + else: # primary_lower_is_better if primary_score_i < primary_score_best: current_primary_is_better = True - + if current_primary_is_better: best_index = i continue @@ -64,21 +65,16 @@ def select_best_index( if primary_score_i == primary_score_best: secondary_score_i = secondary_scores[i] secondary_score_best = secondary_scores[best_index] - + current_secondary_is_better_for_tiebreak = False if secondary_lower_is_better: if secondary_score_i < secondary_score_best: current_secondary_is_better_for_tiebreak = True - else: # secondary_higher_is_better + else: # secondary_higher_is_better if secondary_score_i > secondary_score_best: current_secondary_is_better_for_tiebreak = True - + if current_secondary_is_better_for_tiebreak: best_index = i return best_index - - - - - diff --git a/atroposlib/utils/message_history_utils.py b/atroposlib/utils/message_history_utils.py index 9fc9dd71..efcd169d 100644 --- a/atroposlib/utils/message_history_utils.py +++ b/atroposlib/utils/message_history_utils.py @@ -3,19 +3,21 @@ Trajectory utils Utils for managing trajectory sizing, formatting, compression, etc. """ + import logging from typing import List from transformers import PreTrainedTokenizer -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.envs.base import ScoredDataGroup +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer logger = logging.getLogger(__name__) + def strip_thinking(response_text: str) -> str: """Helper to strip the block of a response entirely. - + Args: response_text: The response text to strip. @@ -24,20 +26,24 @@ def strip_thinking(response_text: str) -> str: """ think_start_tag = "" think_end_tag = "" - + think_start_idx = response_text.find(think_start_tag) think_end_idx = response_text.find(think_end_tag) if think_start_idx != -1 and think_end_idx != -1: - return response_text[:think_start_idx] + response_text[think_end_idx + len(think_end_tag):] + return ( + response_text[:think_start_idx] + + response_text[think_end_idx + len(think_end_tag) :] + ) else: return response_text - + + def truncate_thinking( response_text: str, tokenizer: PreTrainedTokenizer, max_think_tokens: int ) -> str: """Helper to truncate the block of a response for message history based on token count. - + Args: response_text: The response text to truncate. tokenizer: The tokenizer to use for counting tokens. @@ -60,9 +66,7 @@ def truncate_thinking( ): return response_text - part_before_content = response_text[ - : think_start_idx + len(think_start_tag) - ] + part_before_content = response_text[: think_start_idx + len(think_start_tag)] original_think_content_raw = response_text[ think_start_idx + len(think_start_tag) : think_end_idx ] @@ -77,7 +81,7 @@ def truncate_thinking( all_think_tokens = tokenizer.encode( original_think_content_stripped, add_special_tokens=False ) - + is_truncated_internally = False final_think_tokens: List[int] @@ -85,13 +89,15 @@ def truncate_thinking( final_think_tokens = all_think_tokens is_truncated_internally = False else: - is_truncated_internally = True # Mark as truncated if len(all_think_tokens) > max_think_tokens + is_truncated_internally = ( + True # Mark as truncated if len(all_think_tokens) > max_think_tokens + ) paragraphs = [ p.strip() for p in original_think_content_stripped.split("\n\n") if p.strip() ] - + attempted_paragraph_truncation = False if paragraphs: last_paragraph_text = paragraphs[-1] @@ -104,12 +110,14 @@ def truncate_thinking( if len(last_paragraph_tokens) <= max_think_tokens: final_think_tokens = last_paragraph_tokens attempted_paragraph_truncation = True - - if not attempted_paragraph_truncation: # Default to truncating the whole content from the end + + if ( + not attempted_paragraph_truncation + ): # Default to truncating the whole content from the end # Ensure max_think_tokens is not negative, though practically it shouldn't be. slice_start = max(0, len(all_think_tokens) - max_think_tokens) final_think_tokens = all_think_tokens[slice_start:] - + # Decode the tokens to string decoded_think_content = tokenizer.decode( final_think_tokens, skip_special_tokens=True @@ -123,7 +131,10 @@ def truncate_thinking( # Determine the final block content (empty or with newlines) final_internal_content_str_stripped = final_internal_content_str.strip() final_content_for_block: str - if not final_internal_content_str_stripped or final_internal_content_str_stripped == "...": + if ( + not final_internal_content_str_stripped + or final_internal_content_str_stripped == "..." + ): final_content_for_block = "" else: final_content_for_block = f"\n{final_internal_content_str_stripped}\n" @@ -137,6 +148,7 @@ def truncate_thinking( ) return response_text + def ensure_trajectory_token_limit( trajectory: List[ScoredDataGroup], tokenizer: PreTrainedTokenizer, @@ -325,7 +337,7 @@ def ensure_trajectory_token_limit( else: logger.warning( f"[_ensure_trajectory_token_limit] MC env: Discarding step {step_idx}. " - f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) " + f"Max tokens ({max_current_tokens}) still exceed limit ({max_trajectory_tokens}) " f"or retokenization error occurred ({retokenization_error_this_step})." ) @@ -336,7 +348,3 @@ def ensure_trajectory_token_limit( f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}" ) return filtered_trajectory - - - - diff --git a/environments/game_environments/gymnasium/README.md b/environments/game_environments/gymnasium/README.md index 7bc8f365..0194aa79 100644 --- a/environments/game_environments/gymnasium/README.md +++ b/environments/game_environments/gymnasium/README.md @@ -33,7 +33,7 @@ Key components of this approach: \[ A(s_t, a_i) = R_i + \gamma V(s'_{i}) - V(s_t) \] (In `_collect_trajectory`, `gamma` is effectively 1, and \(R_i\) is represented by `alt_combined_rewards[i]`, \(V(s'_{i})\) by `alt_value_next[i]`, and \(V(s_t)\) by `value_t`). - Note: This has nothing to do with GPRO's internal advantage calculations! Don't get it mixed up, this is just used to help provide some credit for intermediate actions where immediate action rewards aren't available (as well as selecting the next canoncial action). Supplementing the actually winning trajectory scores (as in, the canonical trajectory) with the final outcome and a discount factor to assign credit to earlier actions would be an obvious improvement, which has been left off to keep things a little simpler (and will be explored more in another environment with longer trajectories and more sparse rewards where it might matter more to training) + Note: This has nothing to do with GPRO's internal advantage calculations! Don't get it mixed up, this is just used to help provide some credit for intermediate actions where immediate action rewards aren't available (as well as selecting the next canoncial action). Supplementing the actually winning trajectory scores (as in, the canonical trajectory) with the final outcome and a discount factor to assign credit to earlier actions would be an obvious improvement, which has been left off to keep things a little simpler (and will be explored more in another environment with longer trajectories and more sparse rewards where it might matter more to training) * **Choosing the Path (`select_best_index`)**: The `select_best_index` function is then used to pick the alternative with the highest calculated advantage. This chosen alternative's action is what is actually "played" in the environment, advancing the episode to the next state `s_{t+1}`. The other `G-1` alternatives serve as counterfactual data for training. So, we end up with a "canonical" trajectory through the environment. For more comprehensive exploration of alternatives, we'd need to use some more comprehensive form of search like MCTS, which is overkill for something like Blackjack (but we'll demo in some other more complex environments to be added) @@ -62,7 +62,7 @@ The GRPO trainer typically computes a loss using these advantages. For example, \[ L = -\sum_{j=1}^{M} \sum_{k=1}^{K_j} \left( \frac{\pi_{\theta}(a_{jk} | s_j)}{\pi_{\theta_{\text{old}}}(a_{jk} | s_j)} A_{jk}^{\text{GRPO}} \right) \] (often with a KL divergence penalty for stability, ensuring the new policy \(\pi_{\theta}\) doesn\'t deviate too drastically from the old policy \(\pi_{\theta_{\text{old}}}\\)). The `ratio = torch.exp(logp - logp.detach())` and `loss = -reward * ratio` (where `reward` is the \(A_{jk}^{\text{GRPO}}\) advantage) in a typical trainer snippet would align with this principle. -The `blackjack_env_thinking` environment's design is compatible with GRPO's core requirements for input data BUT allowing it to be used across long trajectories. We don't get a nice, well defined reward at every step of every environment - but we want to keep that nice, objective, outcome-oriented RLVR-style reward structure, even in reward-sparse environments. +The `blackjack_env_thinking` environment's design is compatible with GRPO's core requirements for input data BUT allowing it to be used across long trajectories. We don't get a nice, well defined reward at every step of every environment - but we want to keep that nice, objective, outcome-oriented RLVR-style reward structure, even in reward-sparse environments. 1. **Alternative Generation**: From a state \(s_t\), the environment generates `G` alternative continuations (thoughts and actions \(a_1, ..., a_G\)). 2. **Value-Informed Scoring (within the environment)**: For each alternative \(a_i\), the environment itself calculates a "score" using its internal value estimation: \(S_i = R_i + \gamma V_{\text{env}}(s'_{i}) - V_{\text{env}}(s_t)\). This score, \(S_i\), represents a local, value-informed assessment of that alternative's quality. @@ -102,4 +102,3 @@ In the `blackjack_env_no_thinking` environment: * The entire sequence of (observation, action, LLM response) usually fits within the model's `seq_len`. Blackjack is at most a few turns, so this is ok if you JUST want to train on actions, not additional long chains of thought. * `collect_trajectory` returns a single `ScoredDataItem` representing the full episode. The "score" is simply the final game outcome (e.g., +1 for a win) and some bonuses for formatting and correct tool calling. * The trainer can then process these entire episodes using the normal GRPO method (ie, we're just sending the full alternative trajectories and their scores to be compared, similar to the single-step bandit problems people are commonly using for RLVR). The complexity of per-step alternative generation for windowing and local value estimation isn't needed for fitting within `seq_len`. - diff --git a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py index 81a6c883..88b54ddc 100644 --- a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py +++ b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py @@ -1,9 +1,9 @@ -import logging -from typing import Dict, List, Optional, Tuple import json +import logging +import random +from typing import Dict, List, Optional, Tuple import gymnasium as gym -import random from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem from atroposlib.type_definitions import Item, Message @@ -119,13 +119,13 @@ class BlackjackEnvNoThinking(BaseEnv): return None parsed_name, parsed_args, is_error = parse_tool_call( - llm_response, self.tools, ["tool_call"] # Expecting + llm_response, self.tools, ["tool_call"] # Expecting ) if is_error: error_detail = ( - str(parsed_name) # Error message is in parsed_name if is_error - if parsed_name + str(parsed_name) # Error message is in parsed_name if is_error + if parsed_name else "Parser indicated error, but no specific message was returned." ) logger.warning( @@ -146,7 +146,8 @@ class BlackjackEnvNoThinking(BaseEnv): return ACTION_STICK else: logger.warning( - f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. " + f"Successfully parsed tool call '{parsed_name}', " + f"but action argument is invalid. Action: '{action_str}'. " f"Full response: '{llm_response}'. Parsed args: {parsed_args}" ) return None @@ -162,14 +163,13 @@ class BlackjackEnvNoThinking(BaseEnv): seed = item["seed"] messages: List[Message] = [] game_reward = 0.0 - num_turns = 0 try: env = gym.make(self.config.env_name) except Exception as e: logger.error(f"Failed to make environment {self.config.env_name}: {e}") return None, [] - + try: obs, info = env.reset(seed=seed) except Exception as e: @@ -189,7 +189,9 @@ class BlackjackEnvNoThinking(BaseEnv): len(self.tokenizer.apply_chat_template(messages, tokenize=False)) > self.config.max_token_length - 50 ): - logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.") + logger.warning( + f"[Seed: {seed}] Max token length reached, truncating episode." + ) break max_tokens_for_action = 512 @@ -201,19 +203,25 @@ class BlackjackEnvNoThinking(BaseEnv): max_tokens=max_tokens_for_action, temperature=0.5, ) - llm_action_response = chat_completions.choices[0].message.content.strip() - logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response + llm_action_response = chat_completions.choices[ + 0 + ].message.content.strip() + logger.info( + f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'" + ) # Log raw response except Exception as e: logger.error(f"[Seed: {seed}] LLM API error: {e}") break messages.append({"role": "assistant", "content": llm_action_response}) - + action = self._parse_action_from_llm(llm_action_response) if action is None: - logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.") - game_reward = -1.0 - break + logger.warning( + f"[Seed: {seed}] Invalid action parsed. Ending episode." + ) + game_reward = -1.0 + break try: obs, reward, terminated, truncated, _ = env.step(action) @@ -224,19 +232,17 @@ class BlackjackEnvNoThinking(BaseEnv): if terminated or truncated: break - + current_obs_str = self._format_observation(obs) messages.append({"role": "user", "content": current_obs_str}) - + env.close() self.episode_outcomes_buffer.append(game_reward) tokenization_result = tokenize_for_trainer( - tokenizer=self.tokenizer, - chat=messages, - train_on_all_assistant_turns=True + tokenizer=self.tokenizer, chat=messages, train_on_all_assistant_turns=True ) - + tokens = tokenization_result["tokens"] masks = tokenization_result["masks"] @@ -256,24 +262,27 @@ class BlackjackEnvNoThinking(BaseEnv): logger.info(f"Setting up {self.name} environment.") async def evaluate(self, *args, **kwargs): - logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.") - + logger.info( + f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes." + ) + wins = 0 losses = 0 draws = 0 - + eval_outcomes: List[float] = [] for i in range(self.config.eval_episodes): - seed = random.randint(1_000_001, 2_000_000) + seed = random.randint(1_000_001, 2_000_000) item = {"seed": seed} scored_item_tuple = await self.collect_trajectory(item) if scored_item_tuple and scored_item_tuple[0]: outcome = scored_item_tuple[0]["scores"] eval_outcomes.append(outcome) else: - logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.") - + logger.warning( + f"Evaluation episode {i+1} (seed {seed}) failed to produce data." + ) if not eval_outcomes: logger.warning("No evaluation episodes completed successfully.") @@ -287,7 +296,7 @@ class BlackjackEnvNoThinking(BaseEnv): losses += 1 else: draws += 1 - + num_completed = len(eval_outcomes) win_rate = wins / num_completed if num_completed > 0 else 0 loss_rate = losses / num_completed if num_completed > 0 else 0 @@ -301,15 +310,18 @@ class BlackjackEnvNoThinking(BaseEnv): (f"{self.name}_eval/avg_reward", avg_reward), (f"{self.name}_eval/num_completed_episodes", num_completed), ] - logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}") - + logger.info( + f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}" + ) async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None): if wandb_metrics is None: wandb_metrics = {} if self.episode_outcomes_buffer: - avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer) + avg_training_reward = sum(self.episode_outcomes_buffer) / len( + self.episode_outcomes_buffer + ) wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0) train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0) @@ -317,7 +329,9 @@ class BlackjackEnvNoThinking(BaseEnv): wandb_metrics[f"{self.name}_train/win_count"] = train_wins wandb_metrics[f"{self.name}_train/loss_count"] = train_losses wandb_metrics[f"{self.name}_train/draw_count"] = train_draws - wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer) + wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len( + self.episode_outcomes_buffer + ) self.episode_outcomes_buffer = [] diff --git a/environments/game_environments/gymnasium/blackjack_env_thinking.py b/environments/game_environments/gymnasium/blackjack_env_thinking.py index bb49a64a..5cdba662 100644 --- a/environments/game_environments/gymnasium/blackjack_env_thinking.py +++ b/environments/game_environments/gymnasium/blackjack_env_thinking.py @@ -27,10 +27,13 @@ from atroposlib.envs.base import ( OpenaiConfig, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer -from atroposlib.utils.message_history_utils import truncate_thinking, ensure_trajectory_token_limit -from atroposlib.utils.tool_call_parser import parse_tool_call from atroposlib.utils.best_of_n_selection import select_best_index +from atroposlib.utils.message_history_utils import ( + ensure_trajectory_token_limit, + truncate_thinking, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +from atroposlib.utils.tool_call_parser import parse_tool_call logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class BlackjackEnvConfig(BaseEnvConfig): thinking_active: bool = True eval_episodes: int = 100 max_think_chars_history: int = 3000 - max_trajectory_tokens: int = 24576 #seq_len of RL trainer + max_trajectory_tokens: int = 24576 # seq_len of RL trainer debug_mode: bool = False group_size: int = 16 tiebreak_token_factor: float = 0.01 @@ -526,9 +529,9 @@ class BlackjackEnv(BaseEnv): primary_scores=alt_advantages, secondary_scores=alt_token_lengths, primary_higher_is_better=True, - secondary_lower_is_better=True + secondary_lower_is_better=True, ) - + chosen_advantage_for_log = alt_advantages[best_advantage_idx] chosen_token_length_for_log = alt_token_lengths[best_advantage_idx] logger.debug( @@ -558,7 +561,9 @@ class BlackjackEnv(BaseEnv): ep.message_history = current_state_messages response_for_history = truncate_thinking( - chosen_full_response, self.tokenizer, self.config.max_think_chars_history + chosen_full_response, + self.tokenizer, + self.config.max_think_chars_history, ) ep.message_history.append( {"role": "agent", "content": response_for_history} diff --git a/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py b/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py index d942c878..d3ca559f 100644 --- a/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py +++ b/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py @@ -19,9 +19,7 @@ logger = logging.getLogger(__name__) async def main(): - logger.info( - "Starting Blackjack (No Thinking) environment local debug runner" - ) + logger.info("Starting Blackjack (No Thinking) environment local debug runner") env_config = BlackjackEnvNoThinkingConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", @@ -73,19 +71,25 @@ async def main(): logger.info(f"Using seed: {seed} for item: {item_for_env}") result_tuple = await env.collect_trajectory(item_for_env) - + scored_data_item: Optional[ScoredDataItem] = None if result_tuple and result_tuple[0]: scored_data_item = result_tuple[0] logger.info( f"Trajectory collection complete. Score: {scored_data_item.get('scores')}" ) - if env_config.include_messages and scored_data_item.get('messages'): + if env_config.include_messages and scored_data_item.get("messages"): logger.info("Collected Messages:") - for i, msg in enumerate(scored_data_item['messages']): - logger.info(f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'") - logger.info(f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}...") - logger.info(f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}...") + for i, msg in enumerate(scored_data_item["messages"]): + logger.info( + f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'" + ) + logger.info( + f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}..." + ) + logger.info( + f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}..." + ) else: logger.error("Trajectory collection did not return a ScoredDataItem.")