#!/usr/bin/env python3 """ BlackjackEnv: Trainer environment for Gymnasium Blackjack This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern using function-call style actions. Extends BaseEnv. Sort of inspired by VinePPO, but uses a recursive exact value calculation instead of Monte Carlo sampling (because the action space is so small and the environment is deterministic, plus short episode lengths). """ import copy import json import logging import random import re from typing import Dict, List, Optional, Tuple import gymnasium from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, ScoredDataGroup, ) from atroposlib.utils.best_of_n_selection import select_best_index from atroposlib.utils.message_history_utils import ( ensure_trajectory_token_limit, truncate_thinking, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tool_call_parser import parse_tool_call logger = logging.getLogger(__name__) class BlackjackEnvConfig(BaseEnvConfig): env_name: str = "Blackjack-v1" temperature: float = 0.7 top_p: float = 0.9 max_turns: Optional[int] = 5 wandb_name: str = "blackjack" thinking_active: bool = True eval_episodes: int = 100 max_think_chars_history: int = 3000 max_trajectory_tokens: int = 24576 # seq_len of RL trainer debug_mode: bool = False group_size: int = 16 tiebreak_token_factor: float = 0.01 class BlackjackScoredDataGroup(ScoredDataGroup): seed: int tokens: Optional[List[List[int]]] = None masks: Optional[List[List[int]]] = None scores: Optional[List[float]] = None messages: Optional[List[List[Dict]]] = None parsed_actions: Optional[List[int]] = None class EpisodeState: def __init__(self, seed: int, env: gymnasium.Env): self.seed = seed self.env = env self.message_history: List[Dict] = [] self.actions: List[int] = [] self.step_rewards: List[float] = [] self.total_reward: float = 0.0 self.num_steps: int = 0 self.num_correct_actions: int = 0 self.num_total_actions: int = 0 class BlackjackEnv(BaseEnv): def __init__( self, config: BlackjackEnvConfig, server_configs: List[APIServerConfig], slurm: bool = True, testing: bool = False, ): super().__init__(config, server_configs, slurm, testing) self.episodes: Dict[int, EpisodeState] = {} self.debug_mode = config.debug_mode self.completed_episode_metrics_buffer: List[Dict[str, float]] = [] if self.debug_mode: logger.setLevel(logging.DEBUG) else: if logger.level == logging.NOTSET or logger.level > logging.WARNING: logger.setLevel(logging.WARNING) self.tools = [ { "type": "function", "function": { "name": "take_action", "description": "Choose to 'hit' or 'stick' in Blackjack.", "parameters": { "action": {"type": "string", "enum": ["hit", "stick"]} }, }, } ] tools_json = json.dumps(self.tools) self.system_prompt = ( "You are an AI agent playing Blackjack who uses extreme long chains of thought " "to carefully consider the probabilities and optimal strategy. " "You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n" "You should enclose your thoughts and internal monologue inside tags, and then " "provide your decision using the take_action function call. You may use extremely long chains " "of thought to carefully consider the probabilities and optimal strategy.\n\n" f"\n{tools_json}\n\n\n" "For your function call, return a JSON object with function name and arguments " "within tags with the following schema:\n" '\n{"arguments": {"action": "hit"}, "name": "take_action"}\n\n\n' "Your full answer format should be:\n" "\n[Your detailed reasoning process about whether to hit or stick]\n\n\n" '\n{"arguments": {"action": "stick"}, "name": "take_action"}\n\n\n' "Remember to carefully consider the probabilities and optimal strategy for Blackjack." ) def _get_or_create_episode(self, seed: int) -> EpisodeState: if seed not in self.episodes: env = gymnasium.make(self.config.env_name) obs, _ = env.reset(seed=seed) ep = EpisodeState(seed, env) ep.message_history = [{"role": "system", "content": self.system_prompt}] ep.message_history.append( {"role": "environment", "content": self._format_observation(obs)} ) self.episodes[seed] = ep return self.episodes[seed] def _format_observation(self, obs: Tuple[int, int, int]) -> str: player_sum, dealer_card, usable_ace = obs return f"Your hand sum is {player_sum}. Dealer showing: {dealer_card}. You have a usable ace: {usable_ace}." def _score_response( self, env_reward: float, response_text: str, parsed_action: int, ) -> float: """ Calculates a score for a single agent response based purely on environment reward and a penalty for invalid action format. """ current_env_reward = env_reward * 1.0 # Action is good? if parsed_action == -1: current_env_reward -= 0.2 else: current_env_reward += 0.2 # Check the thinking tags exist match = re.search(r"(.*?)", response_text) if match: thinking_content = match.group(1) if thinking_content: current_env_reward += 0.2 # Check there's actually valid content (not just whitespace) if not thinking_content.strip(): current_env_reward -= 0.2 else: current_env_reward -= 0.2 # Calculate the number of tokens in the agent's response if response_text: num_tokens = len(self.tokenizer.encode(response_text)) else: num_tokens = 0 # tiebreak & small length penalty if self.config.max_token_length > 0: token_ratio = min(1.0, num_tokens / self.config.max_token_length) tiebreak_bonus = self.config.tiebreak_token_factor * (1.0 - token_ratio) current_env_reward += tiebreak_bonus return current_env_reward def _parse_tool_call(self, response: str) -> int: if not response: logger.warning( "Attempted to parse an empty response string. Returning invalid action (-1)." ) return -1 parsed_name, parsed_args, is_error = parse_tool_call( response, self.tools, ["tool_call"] ) if is_error: error_detail = ( parsed_name if isinstance(parsed_name, str) and parsed_name else "Parser indicated error, but no specific message was returned in the typical error slot." ) logger.warning( f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}" ) return -1 action = parsed_args.get("action", "").lower() if action == "hit": return 1 elif action == "stick": return 0 else: logger.warning( f"Successfully parsed tool call, but action is invalid. Action: '{action}'. " f"Full response: '{response}'. Parsed args: {parsed_args}" ) return -1 async def _sample_response(self, messages: List[Dict], n: int = 1) -> List[str]: prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) try: completions = await self.server.completion( prompt=prompt, n=n, max_tokens=self.config.max_token_length, temperature=self.config.temperature, top_p=self.config.top_p, ) return [choice.text for choice in completions.choices] except Exception as e: logger.error(f"API error during completion: {e}") return [] async def _estimate_value( self, episode_seed_for_sim: int, env_actions_to_replay: List[int], ) -> float: """Calculate exact state value V*(s) Args: episode_seed_for_sim: The seed of the original episode for deterministic env creation. env_actions_to_replay: List of environment actions (0 or 1) taken to reach current state s. """ v_star_cache: Dict[Tuple[int, int, int], float] = {} def _get_v_star_recursive( obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env ) -> float: player_sum, _, _ = obs_tuple # Base Case 1: Bust if player_sum > 21: return -1.0 # Base Case 2: Check memoization cache if obs_tuple in v_star_cache: return v_star_cache[obs_tuple] env_for_stick = copy.deepcopy(current_env) _, reward_stick, _, _, _ = env_for_stick.step(0) # stick is terminal, so reward is final outcome q_star_stick = reward_stick # Q-value for HIT (action 1) env_for_hit = copy.deepcopy(current_env) obs_hit, reward_hit, term_hit, trunc_hit, _ = env_for_hit.step(1) if term_hit or trunc_hit: # Game ended after hitting q_star_hit = reward_hit else: # Game continues, recursively find V* of next state # reward_hit is typically 0 if the game didn't end q_star_hit = reward_hit + _get_v_star_recursive(obs_hit, env_for_hit) v_star = max(q_star_stick, q_star_hit) v_star_cache[obs_tuple] = v_star return v_star sim_env = None try: sim_env = gymnasium.make(self.config.env_name) current_obs, _ = sim_env.reset(seed=episode_seed_for_sim) # Replay actions to reach the current state s_t is_terminal_after_replay = False for action_idx, prev_action in enumerate(env_actions_to_replay): current_obs, _, term_replay, trunc_replay, _ = sim_env.step(prev_action) if term_replay or trunc_replay: logger.debug( f"[_estimate_value] State became terminal during action replay " f"(action {action_idx+1}/{len(env_actions_to_replay)} of prev_actions). Value is 0." ) is_terminal_after_replay = True break if is_terminal_after_replay: return 0.0 final_v_star = _get_v_star_recursive(current_obs, sim_env) return final_v_star except Exception as e: logger.error( f"[_estimate_value] Error during exact value" f" calculation for seed {episode_seed_for_sim}, " f"actions {env_actions_to_replay}: {e}", exc_info=True, ) return 0.0 finally: if sim_env is not None: sim_env.close() async def _next_step( self, ep: EpisodeState, current_turn: int, max_turns: int ) -> Tuple[Optional[BlackjackScoredDataGroup], bool]: """Process one step/turn of an episode. This involves estimating current state value, sampling multiple (G) responses from the LLM, evaluating each response by simulating its action, calculating advantages, selecting the best response/action, updating the episode state (message history, actions, rewards), and stepping the main environment. Args: ep: The current state of the episode. current_turn: The current turn number (0-indexed). max_turns: The maximum number of turns allowed in the episode. Returns: A tuple containing: - BlackjackScoredDataGroup: Data collected for this step (tokens, masks, advantages, etc.). None if a critical error occurred during the step. - bool: True if the episode terminated during this step, False otherwise. """ G = self.config.group_size current_state_messages = ep.message_history.copy() logger.debug( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}/{max_turns}] " f"Current state history length: {len(current_state_messages)}" ) try: value_t = await self._estimate_value( episode_seed_for_sim=ep.seed, env_actions_to_replay=ep.actions ) logger.debug( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Estimated V(s_t) = {value_t:.4f}" ) except Exception as e_vt: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error estimating V(s_t): {e_vt}", exc_info=True, ) return None, True # Indicate error and episode termination messages_for_llm = current_state_messages.copy() agent_prompt_content = "\n" if self.config.thinking_active else "" messages_for_llm.append({"role": "agent", "content": agent_prompt_content}) try: responses = await self._sample_response(messages_for_llm, n=G) if not responses or len(responses) != G: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] " f"Expected {G} responses, got {len(responses) if responses else 0}. " f"Aborting step." ) return None, True # Indicate error and episode termination except Exception as e_sample: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error sampling responses: {e_sample}", exc_info=True, ) return None, True # Indicate error and episode termination alt_full_responses: List[str] = [] alt_parsed_actions: List[int] = [] alt_env_actions: List[int] = [] alt_raw_rewards: List[float] = [] alt_combined_rewards: List[float] = [] alt_next_state_msgs: List[List[Dict]] = [] alt_is_terminal: List[bool] = [] alt_tokens: List[List[int]] = [] alt_masks: List[List[int]] = [] alt_advantages: List[float] = [] for i in range(G): llm_output_only = responses[i] full_agent_response = agent_prompt_content + llm_output_only alt_full_responses.append(full_agent_response) parsed_action = self._parse_tool_call(full_agent_response) alt_parsed_actions.append(parsed_action) env_action = ( parsed_action if parsed_action != -1 else 0 ) # Default to stick on parse error alt_env_actions.append(env_action) sim_env_i = None raw_env_reward_i = 0.0 term_i, trunc_i = False, False next_state_msgs_i = [] sim_obs_next_i = None try: sim_env_i = gymnasium.make(self.config.env_name) # replay env to same state as current episode _, _ = sim_env_i.reset(seed=ep.seed) for prev_action_idx, prev_action in enumerate(ep.actions): _, _, term_replay, trunc_replay, _ = sim_env_i.step(prev_action) if term_replay or trunc_replay: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1} Alt: {i}] " f"Sim env for alternative {i} terminated prematurely during history replay " f"(action {prev_action_idx+1}/{len(ep.actions)}). State mismatch or unexpected termination." ) term_i, trunc_i = True, True raw_env_reward_i = 0.0 break if not (term_i or trunc_i): sim_obs_next_i, raw_env_reward_i, term_i, trunc_i, _ = ( sim_env_i.step(env_action) ) alt_raw_rewards.append(raw_env_reward_i) alt_is_terminal.append(term_i or trunc_i) combined_reward_i = self._score_response( raw_env_reward_i, full_agent_response, parsed_action ) alt_combined_rewards.append(combined_reward_i) current_state_plus_response_i = current_state_messages + [ {"role": "agent", "content": full_agent_response} ] if sim_obs_next_i is not None and not (term_i or trunc_i): next_state_msgs_i = current_state_plus_response_i + [ { "role": "environment", "content": self._format_observation(sim_obs_next_i), } ] else: next_state_msgs_i = current_state_plus_response_i alt_next_state_msgs.append(next_state_msgs_i) tokenized_i = tokenize_for_trainer(self.tokenizer, next_state_msgs_i) alt_tokens.append(tokenized_i["tokens"]) alt_masks.append(tokenized_i["masks"]) except Exception as e_sim: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1} Alt: {i}] " f"Error simulating action {env_action} for alternative: {e_sim}", exc_info=True, ) alt_raw_rewards.append(0.0) alt_combined_rewards.append(-1.0) alt_next_state_msgs.append( current_state_messages + [{"role": "agent", "content": full_agent_response}] ) alt_is_terminal.append(True) alt_tokens.append([]) alt_masks.append([]) finally: if sim_env_i: sim_env_i.close() alt_value_next: List[float] = [] for i in range(G): if not alt_is_terminal[i]: try: actions_to_reach_s_prime = ep.actions + [alt_env_actions[i]] value_next_i = await self._estimate_value( episode_seed_for_sim=ep.seed, env_actions_to_replay=actions_to_reach_s_prime, ) alt_value_next.append(value_next_i) except Exception as e_vn: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1} Alt: {i}] " f"Error estimating V(s') for alternative: {e_vn}", exc_info=True, ) alt_value_next.append(0.0) else: alt_value_next.append(0.0) for i in range(G): if ( i < len(alt_combined_rewards) and i < len(alt_value_next) and value_t is not None ): advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t alt_advantages.append(advantage_i) logger.debug( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1} Alt: {i}] " f"CombinedR={alt_combined_rewards[i]:.2f}, V_t={value_t:.2f}, " f"V_t+1={alt_value_next[i]:.2f} => Advantage={advantage_i:.2f}" ) else: logger.warning( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1} Alt: {i}] " f"Skipping advantage calculation due to missing data or value_t. " f"len(alt_combined_rewards)={len(alt_combined_rewards)}, len(alt_value_next)={len(alt_value_next)}" ) alt_advantages.append(-float("inf")) if not ( len(alt_tokens) == G and len(alt_masks) == G and len(alt_advantages) == G and len(alt_next_state_msgs) == G and len(alt_parsed_actions) == G ): logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] " f"Mismatch in alternative list lengths before creating ScoredDataGroup. " f"Tokens:{len(alt_tokens)}, Masks:{len(alt_masks)}, Adv:{len(alt_advantages)}, " f"Msgs:{len(alt_next_state_msgs)}, ParsedAct:{len(alt_parsed_actions)}. Expected {G} for all. " f"Aborting step." ) return None, True current_step_data = BlackjackScoredDataGroup( seed=ep.seed, tokens=alt_tokens, masks=alt_masks, scores=alt_advantages, messages=alt_next_state_msgs, parsed_actions=alt_parsed_actions, ) alt_token_lengths = [len(tkns) if tkns else 0 for tkns in alt_tokens] best_advantage_idx = select_best_index( primary_scores=alt_advantages, secondary_scores=alt_token_lengths, primary_higher_is_better=True, secondary_lower_is_better=True, ) chosen_advantage_for_log = ( alt_advantages[best_advantage_idx] if best_advantage_idx < len(alt_advantages) else float("-inf") ) chosen_token_length_for_log = ( alt_token_lengths[best_advantage_idx] if best_advantage_idx < len(alt_token_lengths) else float("-inf") ) logger.debug( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] " f"Selected Alt {best_advantage_idx} " f"(Adv: {chosen_advantage_for_log}, " f"Tokens: {chosen_token_length_for_log}) " f"from {G} alternatives." ) chosen_env_action = ( alt_env_actions[best_advantage_idx] if best_advantage_idx < len(alt_env_actions) else 0 ) chosen_full_response = ( alt_full_responses[best_advantage_idx] if best_advantage_idx < len(alt_full_responses) else "" ) chosen_parsed_action = ( alt_parsed_actions[best_advantage_idx] if best_advantage_idx < len(alt_parsed_actions) else -1 ) logger.info( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Chosen action to step env: " f"{chosen_env_action} (from Alt {best_advantage_idx} with " f"Adv {chosen_advantage_for_log})" ) ep.num_total_actions += 1 if chosen_parsed_action != -1: ep.num_correct_actions += 1 response_for_history = truncate_thinking( chosen_full_response, self.tokenizer, self.config.max_think_chars_history, ) ep.message_history.append({"role": "agent", "content": response_for_history}) ( main_obs_next, main_reward_this_step, main_term_this_step, main_trunc_this_step, ) = (None, 0.0, False, False) try: ( main_obs_next, main_reward_this_step, main_term_this_step, main_trunc_this_step, _, ) = ep.env.step(chosen_env_action) ep.actions.append(chosen_env_action) ep.step_rewards.append(main_reward_this_step) ep.num_steps += 1 if main_obs_next: ep.message_history.append( { "role": "environment", "content": self._format_observation(main_obs_next), } ) except Exception as e_main_step: logger.error( f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] " f"Error stepping MAIN environment with chosen action {chosen_env_action}: {e_main_step}", exc_info=True, ) main_term_this_step, main_trunc_this_step = True, True is_episode_terminal_this_step = main_term_this_step or main_trunc_this_step return current_step_data, is_episode_terminal_this_step async def score( self, rollout_group_data: List[BlackjackScoredDataGroup] ) -> List[Optional[BlackjackScoredDataGroup]]: """Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup already contains the A*(s,a) advantages from the collection phase. If you wanted to play around with additional scoring metrics, you could do so here. Eg, bonuses for the specific winning action trajectory Args: rollout_group_data: List of BlackjackScoredDataGroup objects containing the collected rollout data. Returns: List of BlackjackScoredDataGroup objects with the scores field updated. """ logger.info(f"[Score] Processing {len(rollout_group_data)} steps.") return rollout_group_data async def collect_trajectories( self, item: Tuple[int, int] ) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]: """Collect data for ONE FULL trajectory (episode) by repeatedly calling _next_step. This method initializes an episode, then iterates through turns, calling `_next_step` to get data for each turn. It accumulates this data and handles episode termination, final metric calculation, and cleanup. Args: item: Tuple containing the seed and the group index (group index currently unused here). Returns: Tuple of two lists: - List of BlackjackScoredDataGroup objects: Contains the collected data for each step of the trajectory. - List of Tuple[int, int]: Backlog items (always empty in this implementation). """ seed, _ = item G_config = self.config.group_size max_turns = self.config.max_turns or 5 trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = [] logger.info( f"[Collect Trajectories Seed: {seed}] Starting new trajectory. " f"Group size G={G_config}, Max turns={max_turns}." ) try: ep = self._get_or_create_episode(seed) except Exception as e: logger.error( f"[Collect Trajectories Seed: {seed}] Fatal error creating/getting episode: {e}", exc_info=True, ) return [], [] for turn_idx in range(max_turns): logger.debug( f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}." ) step_data, is_terminal_this_step = await self._next_step( ep, turn_idx, max_turns ) if step_data: trajectory_data_for_trainer.append(step_data) else: logger.error( f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data." " Terminating episode." ) is_terminal_this_step = True if is_terminal_this_step: final_reward_at_termination = ( sum(ep.step_rewards) if ep.step_rewards else 0.0 ) logger.info( f"[Collect Trajectories Seed: {seed}] Episode ended at turn {turn_idx + 1}. " f"Reason: step reported terminal. Total raw env reward: {final_reward_at_termination:.2f}" ) break else: logger.info( f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})." ) final_raw_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0 logger.info( f"[Collect Trajectories Seed: {seed}] Finished collecting trajectory. " f"Total steps in trajectory: {len(trajectory_data_for_trainer)}, " f"Actual turns in episode: {ep.num_steps}, " f"Final raw reward: {final_raw_reward:.2f}" ) if ep: game_outcome = 0 if final_raw_reward > 0: game_outcome = 1 elif final_raw_reward < 0: game_outcome = -1 episode_summary_metrics = { "seed": ep.seed, "total_reward": final_raw_reward, "num_steps": ep.num_steps, "num_correct_actions": ep.num_correct_actions, "num_total_actions": ep.num_total_actions, "game_outcome": game_outcome, } self.completed_episode_metrics_buffer.append(episode_summary_metrics) logger.debug( f"[Collect Trajectories Seed: {seed}] Added episode summary to buffer: {episode_summary_metrics}" ) if seed in self.episodes: try: if ( hasattr(self.episodes[seed], "env") and self.episodes[seed].env is not None ): self.episodes[seed].env.close() except Exception as e_close: logger.warning( f"[Collect Trajectories Seed: {seed}] Exception closing environment for episode: {e_close}", exc_info=True, ) del self.episodes[seed] if not trajectory_data_for_trainer: logger.warning( f"[Collect Trajectories Seed: {seed}] Collected an empty trajectory (no valid steps)." ) return [], [] limited_trajectory_data = ensure_trajectory_token_limit( trajectory_data_for_trainer, self.tokenizer, self.config.max_trajectory_tokens, ) if not limited_trajectory_data: logger.warning( f"[Collect Trajectories Seed: {seed}] Trajectory became empty after token limiting." ) return [], [] return limited_trajectory_data, [] async def setup(self): pass async def get_next_item(self) -> Tuple[int, int]: return (random.randint(0, 1000000), 0) async def rollout_and_score_eval(self, seed: int) -> Dict[str, float]: """Run a single episode for evaluation with a single response per step.""" ep = self._get_or_create_episode(seed) max_turns = self.config.max_turns or 5 metrics = { "seed": seed, "total_reward": 0.0, "num_turns": 0, "num_correct_actions": 0, "num_invalid_actions": 0, "game_outcome": 0, } for turn in range(max_turns): messages = ep.message_history.copy() agent_prompt_content = "\n" if self.config.thinking_active else "" messages.append({"role": "agent", "content": agent_prompt_content}) responses = await self._sample_response(messages, n=1) if not responses: logger.error( f"[Eval Seed: {seed}, Turn: {turn+1}] No response. Aborting." ) break llm_output_only = responses[0] full_agent_response = agent_prompt_content + llm_output_only action = self._parse_tool_call(full_agent_response) if action == -1: metrics["num_invalid_actions"] += 1 action = 0 else: metrics["num_correct_actions"] += 1 try: obs, reward, term, trunc, _ = ep.env.step(action) except Exception as e: logger.error(f"[Eval Seed: {seed}, Turn: {turn+1}] Env error: {e}") term = True reward = -1.0 obs = None metrics["total_reward"] += reward metrics["num_turns"] = turn + 1 response_for_history = truncate_thinking( full_agent_response, self.tokenizer, self.config.max_think_chars_history ) ep.message_history.append( {"role": "agent", "content": response_for_history} ) if obs: ep.message_history.append( {"role": "environment", "content": self._format_observation(obs)} ) if term or trunc: metrics["game_outcome"] = int(reward) logger.info(f"[Eval Seed: {seed}] Episode ended. Outcome: {reward}") break ep.env.close() del self.episodes[seed] return metrics async def evaluate(self, *args, **kwargs): if not self.config.use_wandb: logger.info("Skipping evaluation as wandb is not enabled.") return num_eval_episodes = self.config.eval_episodes logger.info(f"Starting evaluation for {num_eval_episodes} episodes.") eval_tasks = [ self.rollout_and_score_eval(random.randint(1000001, 2000000)) for _ in range(num_eval_episodes) ] all_metrics = await tqdm_asyncio.gather(*eval_tasks) valid_metrics = [m for m in all_metrics if m] if not valid_metrics: logger.warning("No valid evaluation metrics.") return num_completed = len(valid_metrics) avg_total_reward = sum(m["total_reward"] for m in valid_metrics) / num_completed avg_num_turns = sum(m["num_turns"] for m in valid_metrics) / num_completed total_correct = sum(m["num_correct_actions"] for m in valid_metrics) total_invalid = sum(m["num_invalid_actions"] for m in valid_metrics) total_actions = total_correct + total_invalid action_accuracy = total_correct / total_actions if total_actions > 0 else 0 win_rate = ( sum(1 for m in valid_metrics if m["game_outcome"] == 1) / num_completed ) loss_rate = ( sum(1 for m in valid_metrics if m["game_outcome"] == -1) / num_completed ) draw_rate = ( sum(1 for m in valid_metrics if m["game_outcome"] == 0) / num_completed ) self.eval_metrics = [ ("eval/avg_total_reward", avg_total_reward), ("eval/avg_num_turns", avg_num_turns), ("eval/action_accuracy", action_accuracy), ("eval/win_rate", win_rate), ("eval/loss_rate", loss_rate), ("eval/draw_rate", draw_rate), ("eval/num_completed_episodes", num_completed), ] logger.info(f"Evaluation completed. Metrics: {self.eval_metrics}") async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None): if wandb_metrics is None: wandb_metrics = {} if self.completed_episode_metrics_buffer: num_episodes = len(self.completed_episode_metrics_buffer) avg_reward = ( sum(m["total_reward"] for m in self.completed_episode_metrics_buffer) / num_episodes ) avg_steps = ( sum(m["num_steps"] for m in self.completed_episode_metrics_buffer) / num_episodes ) win_rate = ( sum( 1 for m in self.completed_episode_metrics_buffer if m["game_outcome"] == 1 ) / num_episodes ) wandb_metrics[ f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward" ] = avg_reward wandb_metrics[ f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_steps" ] = avg_steps wandb_metrics[ f"{self.wandb_prepend or 'blackjack'}_train/episode_win_rate" ] = win_rate wandb_metrics[f"{self.wandb_prepend or 'blackjack'}_train/num_episodes"] = ( num_episodes ) self.completed_episode_metrics_buffer = [] await super().wandb_log(wandb_metrics) @classmethod def config_init(cls) -> Tuple[BlackjackEnvConfig, List[APIServerConfig]]: env_config = BlackjackEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, use_wandb=True, max_num_workers=128, rollout_server_url="http://localhost:8000", total_steps=2000, batch_size=1024, steps_per_eval=20, max_token_length=1024 * 16, inference_weight=1.0, wandb_name="fundamental_metric_prediction", data_path_to_save_groups=None, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, env_name="Blackjack-v1", temperature=0.7, top_p=0.9, max_turns=5, thinking_active=True, eval_episodes=100, max_think_chars_history=3000, max_trajectory_tokens=24576, debug_mode=False, tiebreak_token_factor=0.01, ) server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_requests_for_eval=256, ) ] return env_config, server_configs @classmethod def cli(cls): super().cli() if __name__ == "__main__": BlackjackEnv.cli()