diff --git a/environments/game_environments/gymnasium/blackjack_env.py b/environments/game_environments/gymnasium/blackjack_env.py index 4d9d5aff..b414c5c5 100644 --- a/environments/game_environments/gymnasium/blackjack_env.py +++ b/environments/game_environments/gymnasium/blackjack_env.py @@ -11,7 +11,7 @@ Uses Monte Carlo sampling to estimate the value of the current state, similar to import json import logging import random -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import gymnasium import numpy as np @@ -1025,7 +1025,8 @@ class BlackjackEnv(BaseEnv): if num_alternatives == 0: logger.warning( - f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping." + f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives" + " after copying. Skipping." ) continue @@ -1109,7 +1110,8 @@ class BlackjackEnv(BaseEnv): working_masks = temp_new_alt_masks max_current_tokens = max_tokens_after_this_trunc logger.debug( - f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, " + f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, " + f"after uniform pop of {min_pop_this_round}, " f"max tokens: {max_current_tokens}" ) diff --git a/environments/game_environments/gymnasium/blackjack_no_mc_env.py b/environments/game_environments/gymnasium/blackjack_no_mc_env.py index 0c906afd..f3fde3d6 100644 --- a/environments/game_environments/gymnasium/blackjack_no_mc_env.py +++ b/environments/game_environments/gymnasium/blackjack_no_mc_env.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 """ BlackjackEnv: Trainer environment for Gymnasium Blackjack @@ -13,7 +12,7 @@ but may not be as effective at learning correct strategy (it's effectively a ser import json import logging import random -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import gymnasium from tqdm.asyncio import tqdm_asyncio @@ -48,7 +47,6 @@ class BlackjackEnvConfig(BaseEnvConfig): batch_size: int = 1024 max_think_chars_history: int = 3000 - # Should be higher than the max tokens to allow for multiple turns max_trajectory_tokens: int = 24576 debug_mode: bool = False @@ -259,7 +257,6 @@ class BlackjackEnv(BaseEnv): continue if action == -1: - # Penalty for parsing error is applied within _score_response. env_reward_sim = 0.0 else: _obs_sim, env_reward_sim, term_sim, trunc_sim, _info_sim = ( @@ -720,7 +717,6 @@ class BlackjackEnv(BaseEnv): ) return [] - # Ensure all elements are at least dictionaries before proceeding if not all( isinstance(rgd, dict) for rgd in rollout_group_data if rgd is not None ): @@ -728,10 +724,8 @@ class BlackjackEnv(BaseEnv): "score: rollout_group_data contains non-dictionary elements. " "Cannot proceed." ) - # Return a list of Nones matching input length or handle as error return [None] * len(rollout_group_data) - # 1. Determine Overall Game Outcome final_env_reward_for_outcome = 0.0 is_win = False seed_for_outcome = None @@ -776,7 +770,7 @@ class BlackjackEnv(BaseEnv): f"Invalid action ({action_for_outcome_step}) found at step {step_idx_outcome} " f"during game outcome replay. Assuming non-win outcome for scoring." ) - final_env_reward_for_outcome = 0.0 # Treat as non-win + final_env_reward_for_outcome = 0.0 break logger.debug( @@ -800,7 +794,7 @@ class BlackjackEnv(BaseEnv): f"Final env reward for outcome: {final_env_reward_for_outcome}" ) break - else: # Loop completed without break + else: logger.info( f"score [Seed: {seed_for_outcome}]: " f"Game outcome replay completed all steps. " @@ -840,25 +834,19 @@ class BlackjackEnv(BaseEnv): processed_rollout_data.append(None) continue - # Make a copy to modify scores current_step_group: BlackjackScoredDataGroup = ( original_step_group_untyped.copy() ) - step_seed = current_step_group.get( - "seed", "N/A" - ) # Use N/A if seed is somehow missing + step_seed = current_step_group.get("seed", "N/A") if current_step_group.get("scores") is None: logger.warning( f"score [Seed: {step_seed}, Step: {step_idx}]: " f"Scores are missing. Cannot apply win bonus or tie-breaking." ) - processed_rollout_data.append( - current_step_group - ) # Append original or a copy + processed_rollout_data.append(current_step_group) continue - # Ensure scores is a list of numbers original_scores = current_step_group["scores"] if not isinstance(original_scores, list) or not all( isinstance(s, (int, float)) for s in original_scores @@ -871,13 +859,11 @@ class BlackjackEnv(BaseEnv): processed_rollout_data.append(current_step_group) continue - modified_scores = original_scores.copy() # Work on a copy + modified_scores = original_scores.copy() - # 2. Apply Win Bonus (if applicable) - if is_win and modified_scores: # Ensure scores list is not empty + if is_win and modified_scores: try: max_score_in_step = -float("inf") - # Find max score correctly, even with Nones, though previous check should handle Nones in list valid_scores_for_max = [ s for s in modified_scores if isinstance(s, (int, float)) ] @@ -888,7 +874,6 @@ class BlackjackEnv(BaseEnv): ) else: max_score_in_step = max(valid_scores_for_max) - # Find first index of max_score_in_step. If multiple, bonus applies to first. best_alternative_idx_this_step = -1 for idx, score_val in enumerate(modified_scores): if score_val == max_score_in_step: @@ -913,10 +898,7 @@ class BlackjackEnv(BaseEnv): f"Could not find index of max score {max_score_in_step} for win bonus. " f"This should not happen if scores exist." ) - except ( - ValueError - ): # Should be caught by empty list check or valid_scores_for_max - # Split into two lines to avoid line length issues + except ValueError: score_msg = ( f"score [Seed: {step_seed}, Step: {step_idx}]: " f"Error finding max score for win bonus." @@ -929,7 +911,6 @@ class BlackjackEnv(BaseEnv): f"Unexpected error applying win bonus: {e_bonus}" ) - # 3. Apply Tie-Breaking Logic (to potentially bonus-adjusted scores) step_messages = current_step_group.get("messages") if not isinstance(step_messages, list) or len(modified_scores) != len( step_messages @@ -940,7 +921,7 @@ class BlackjackEnv(BaseEnv): f"({len(step_messages) if isinstance(step_messages, list) else 'not a list'}) " f"lengths, or messages missing. Skipping tie-breaking for this step." ) - elif modified_scores: # Ensure scores list is not empty for tie-breaking + elif modified_scores: token_lengths = [] valid_messages_for_tiebreak = True for alt_msg_list_idx, alt_msg_list in enumerate(step_messages): @@ -966,16 +947,12 @@ class BlackjackEnv(BaseEnv): f"Tokenization error for tie-breaking on alt {alt_msg_list_idx}: {e_tok}. " f"Defaulting token length to large value." ) - token_lengths.append( - float("inf") - ) # Penalize if tokenization fails + token_lengths.append(float("inf")) if valid_messages_for_tiebreak: - score_groups = {} # Maps score_value to list of indices + score_groups = {} for idx, score_val in enumerate(modified_scores): - if not isinstance( - score_val, (int, float) - ): # Skip non-numeric scores + if not isinstance(score_val, (int, float)): continue if score_val not in score_groups: score_groups[score_val] = [] @@ -983,8 +960,7 @@ class BlackjackEnv(BaseEnv): scores_after_tiebreak = modified_scores.copy() for score_val, indices_with_this_score in score_groups.items(): - if len(indices_with_this_score) > 1: # A tie is found - # Check if token_lengths are available for all tied indices + if len(indices_with_this_score) > 1: if not all( idx < len(token_lengths) for idx in indices_with_this_score @@ -999,16 +975,14 @@ class BlackjackEnv(BaseEnv): continue try: - # Sort tied indices by their token_lengths sorted_tied_indices = sorted( indices_with_this_score, key=lambda i: token_lengths[i], ) - # Apply penalty to all but the first (shortest token length) for rank, tied_idx in enumerate( sorted_tied_indices[1:], 1 - ): # Start rank from 1 for 2nd shortest + ): penalty = 0.0001 * rank scores_after_tiebreak[tied_idx] -= penalty logger.debug( @@ -1375,39 +1349,36 @@ class BlackjackEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]: env_config = BlackjackEnvConfig( - # Fields from fundamental_prediction_environment.py's BaseEnvConfig init: tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - group_size=16, # From Base, as not in BJ no_mc config's direct definition + group_size=16, use_wandb=True, max_num_workers=128, rollout_server_url="http://localhost:8000", total_steps=2000, - batch_size=1024, # Matches BlackjackEnvConfig (no_mc) default as well + batch_size=1024, steps_per_eval=20, max_token_length=1024 * 16, inference_weight=1.0, - wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction + wandb_name="fundamental_metric_prediction", data_path_to_save_groups=None, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, - # BlackjackEnvConfig (no_mc version) specific fields (those NOT in BaseEnvConfig from fundamental_prediction) - # using their defined defaults from BlackjackEnvConfig (no_mc): - env_name="Blackjack-v1", # Default from BlackjackEnvConfig (no_mc) - temperature=0.7, # Default from BlackjackEnvConfig (no_mc) - top_p=0.9, # Default from BlackjackEnvConfig (no_mc) - max_turns=5, # Default from BlackjackEnvConfig (no_mc) - thinking_active=True, # Default from BlackjackEnvConfig (no_mc) - eval_episodes=100, # Default from BlackjackEnvConfig (no_mc) - max_think_chars_history=3000, # Default from BlackjackEnvConfig (no_mc) - max_trajectory_tokens=24576, # Default from BlackjackEnvConfig (no_mc) - debug_mode=False, # Default from BlackjackEnvConfig (no_mc) + env_name="Blackjack-v1", + temperature=0.7, + top_p=0.9, + max_turns=5, + thinking_active=True, + eval_episodes=100, + max_think_chars_history=3000, + max_trajectory_tokens=24576, + debug_mode=False, ) server_configs = [ OpenaiConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", - num_requests_for_eval=256, # From fundamental_prediction_environment.py + num_requests_for_eval=256, ) ] return env_config, server_configs @@ -1514,8 +1485,7 @@ class BlackjackEnv(BaseEnv): original_step_data.get("messages") and original_step_data.get("tokens") and original_step_data.get("masks") - and original_step_data.get("seed") - is not None # seed is mandatory for new group + and original_step_data.get("seed") is not None ): logger.warning( f"[_ensure_trajectory_token_limit] Step {step_idx} " @@ -1523,8 +1493,6 @@ class BlackjackEnv(BaseEnv): ) continue - # Initial token calculation from original data to see if truncation is needed - # Ensure tokens are lists of integers before calling len max_initial_tokens = 0 if original_step_data["tokens"]: max_initial_tokens = ( @@ -1553,8 +1521,6 @@ class BlackjackEnv(BaseEnv): f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation." ) - # Prepare working copies for modification - # Ensure deep copies for lists of dicts if dicts are modified, but here we pop from list of dicts. working_messages = [ msgs_list.copy() for msgs_list in original_step_data["messages"] or [] ] @@ -1567,7 +1533,7 @@ class BlackjackEnv(BaseEnv): max_current_tokens = max_initial_tokens num_alternatives = len(working_messages) - if num_alternatives == 0: # Should not happen if initial checks passed + if num_alternatives == 0: logger.warning( f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping." ) @@ -1579,60 +1545,47 @@ class BlackjackEnv(BaseEnv): for alt_idx in range(num_alternatives): alt_msg_list = working_messages[alt_idx] - # Calculate how many initial messages (after system prompt) can be popped. - # Preserving: system prompt (index 0), last agent response, and its preceding env observation. num_preserved_at_end = 0 if ( len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in UNMASKED_ROLES ): - num_preserved_at_end = 1 # Last agent response + num_preserved_at_end = 1 if ( len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment" ): - num_preserved_at_end = ( - 2 # Agent response + preceding env observation - ) + num_preserved_at_end = 2 - # Number of messages available for popping (between system prompt and preserved end messages) - # Subtract 1 for the system prompt itself (which is never popped from index 0). available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end if available_to_pop <= 0: target_pop_counts_per_alt.append(0) else: - # Try to pop a pair (environment, agent) if they are at list[1] and list[2] can_pop_pair = ( available_to_pop >= 2 and len(alt_msg_list) > 2 - and alt_msg_list[ # Ensure messages at index 1 and 2 exist - 1 - ]["role"] - == "environment" + and alt_msg_list[1]["role"] == "environment" and alt_msg_list[2]["role"] in UNMASKED_ROLES ) if can_pop_pair: target_pop_counts_per_alt.append(2) - else: # Can pop at least 1 since available_to_pop > 0 + else: target_pop_counts_per_alt.append(1) positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0] if not positive_pop_counts: - break # No alternative can be truncated further + break min_pop_this_round = min(positive_pop_counts) - # Pop messages and re-tokenize temp_new_alt_tokens = [] temp_new_alt_masks = [] max_tokens_after_this_trunc = 0 for alt_idx in range(num_alternatives): for _ in range(min_pop_this_round): - if ( - len(working_messages[alt_idx]) > 1 - ): # Ensure there's something to pop after system + if len(working_messages[alt_idx]) > 1: working_messages[alt_idx].pop(1) else: logger.error( @@ -1671,7 +1624,6 @@ class BlackjackEnv(BaseEnv): f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, " f"max tokens: {max_current_tokens}" ) - # End of while loop for truncation attempts if ( not retokenization_error_this_step