import logging from typing import Dict, List, Optional, Tuple import json import gymnasium as gym import random from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem from atroposlib.type_definitions import Item, Message from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tool_call_parser import parse_tool_call logger = logging.getLogger(__name__) ACTION_HIT = 1 ACTION_STICK = 0 ACTION_MAP_TO_STR = {ACTION_HIT: "hit", ACTION_STICK: "stick"} ACTION_STR_TO_INT = {v: k for k, v in ACTION_MAP_TO_STR.items()} class BlackjackEnvNoThinkingConfig(BaseEnvConfig): """ Configuration for the BlackjackEnvNoThinking environment. """ env_name: str = "Blackjack-v1" max_episode_turns: int = 10 eval_episodes: int = 100 class BlackjackEnvNoThinking(BaseEnv): name = "blackjack_no_thinking" env_config_cls = BlackjackEnvNoThinkingConfig def __init__( self, config: BlackjackEnvNoThinkingConfig, server_configs: List[OpenaiConfig], slurm: bool = True, testing: bool = False, ): super().__init__(config, server_configs, slurm, testing) self.config: BlackjackEnvNoThinkingConfig = config self.episode_outcomes_buffer: List[float] = [] self.eval_metrics_custom: List[Tuple[str, float]] = [] # Define tools available to the LLM self.tools = [ { "type": "function", "function": { "name": "take_action", "description": "Choose to 'hit' or 'stick' in Blackjack.", "parameters": { # Parameters are implicitly defined by the arguments of the function call # For this simple case, let's assume the LLM will provide arguments.action # based on the prompt. A more robust schema would define 'action' here. "type": "object", "properties": { "action": {"type": "string", "enum": ["hit", "stick"]} }, "required": ["action"], }, }, } ] tools_json = json.dumps(self.tools) # Updated system prompt for tool calling self.system_prompt = ( "You are an AI agent playing Blackjack. " "You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n" f"\n{tools_json}\n\n\n" "For your function call, return a JSON object with function name and arguments " "within tags with the following schema:\n" '\n{"arguments": {"action": "hit"}, "name": "take_action"}\n\n\n' "Your full answer format should be (NO THINKING BLOCK):\n" '\n{"arguments": {"action": "stick"}, "name": "take_action"}\n\n' ) @classmethod def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]: env_config = BlackjackEnvNoThinkingConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", max_token_length=2048, wandb_name=cls.name, steps_per_eval=50, max_episode_turns=10, eval_episodes=100, ) server_configs = [ OpenaiConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9001/v1", api_key="x", num_requests_for_eval=128, ), ] return env_config, server_configs def _format_observation(self, obs: Tuple[int, int, int]) -> str: """Converts a Blackjack observation to a human-readable string.""" player_sum, dealer_card, usable_ace = obs return ( f"Your current hand sum is {player_sum}. " f"The dealer is showing a {dealer_card}. " f"You have a usable ace: {'yes' if usable_ace else 'no'}." ) def _parse_action_from_llm(self, llm_response: str) -> Optional[int]: """Parses the action from the LLM's tool_call response.""" if not llm_response: logger.warning( "Attempted to parse an empty LLM response. Returning invalid action (None)." ) return None parsed_name, parsed_args, is_error = parse_tool_call( llm_response, self.tools, ["tool_call"] # Expecting ) if is_error: error_detail = ( str(parsed_name) # Error message is in parsed_name if is_error if parsed_name else "Parser indicated error, but no specific message was returned." ) logger.warning( f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}" ) return None if parsed_name != "take_action": logger.warning( f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'" ) return None action_str = parsed_args.get("action", "").lower() if action_str == "hit": return ACTION_HIT elif action_str == "stick": return ACTION_STICK else: logger.warning( f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. " f"Full response: '{llm_response}'. Parsed args: {parsed_args}" ) return None async def collect_trajectory( self, item: Item ) -> Tuple[Optional[ScoredDataItem], List[Item]]: """ Collects a single trajectory (episode) for the Blackjack environment. The LLM directly outputs 'hit' or 'stick'. The 'score' in ScoredDataItem is the final game outcome (+1, 0, -1). """ seed = item["seed"] messages: List[Message] = [] game_reward = 0.0 num_turns = 0 try: env = gym.make(self.config.env_name) except Exception as e: logger.error(f"Failed to make environment {self.config.env_name}: {e}") return None, [] try: obs, info = env.reset(seed=seed) except Exception as e: logger.error(f"Failed to reset environment with seed {seed}: {e}") env.close() return None, [] # Use the class system_prompt messages.append({"role": "system", "content": self.system_prompt}) current_obs_str = self._format_observation(obs) messages.append({"role": "user", "content": current_obs_str}) async with self.server.dedicated_server() as server: for _ in range(self.config.max_episode_turns): if ( len(self.tokenizer.apply_chat_template(messages, tokenize=False)) > self.config.max_token_length - 50 ): logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.") break max_tokens_for_action = 512 try: chat_completions = await server.chat_completion( messages=messages, n=1, max_tokens=max_tokens_for_action, temperature=0.5, ) llm_action_response = chat_completions.choices[0].message.content.strip() logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response except Exception as e: logger.error(f"[Seed: {seed}] LLM API error: {e}") break messages.append({"role": "assistant", "content": llm_action_response}) action = self._parse_action_from_llm(llm_action_response) if action is None: logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.") game_reward = -1.0 break try: obs, reward, terminated, truncated, _ = env.step(action) game_reward = float(reward) except Exception as e: logger.error(f"[Seed: {seed}] Error stepping env: {e}") break if terminated or truncated: break current_obs_str = self._format_observation(obs) messages.append({"role": "user", "content": current_obs_str}) env.close() self.episode_outcomes_buffer.append(game_reward) tokenization_result = tokenize_for_trainer( tokenizer=self.tokenizer, chat=messages, train_on_all_assistant_turns=True ) tokens = tokenization_result["tokens"] masks = tokenization_result["masks"] scored_data_item = ScoredDataItem( messages=messages if self.config.include_messages else None, tokens=tokens, masks=masks, scores=game_reward, ) return scored_data_item, [] async def get_next_item(self) -> Item: next_seed = random.randint(0, 1_000_000) return {"seed": next_seed} async def setup(self): logger.info(f"Setting up {self.name} environment.") async def evaluate(self, *args, **kwargs): logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.") wins = 0 losses = 0 draws = 0 eval_outcomes: List[float] = [] for i in range(self.config.eval_episodes): seed = random.randint(1_000_001, 2_000_000) item = {"seed": seed} scored_item_tuple = await self.collect_trajectory(item) if scored_item_tuple and scored_item_tuple[0]: outcome = scored_item_tuple[0]["scores"] eval_outcomes.append(outcome) else: logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.") if not eval_outcomes: logger.warning("No evaluation episodes completed successfully.") self.eval_metrics_custom = [] return for outcome in eval_outcomes: if outcome > 0: wins += 1 elif outcome < 0: losses += 1 else: draws += 1 num_completed = len(eval_outcomes) win_rate = wins / num_completed if num_completed > 0 else 0 loss_rate = losses / num_completed if num_completed > 0 else 0 draw_rate = draws / num_completed if num_completed > 0 else 0 avg_reward = sum(eval_outcomes) / num_completed if num_completed > 0 else 0 self.eval_metrics_custom = [ (f"{self.name}_eval/win_rate", win_rate), (f"{self.name}_eval/loss_rate", loss_rate), (f"{self.name}_eval/draw_rate", draw_rate), (f"{self.name}_eval/avg_reward", avg_reward), (f"{self.name}_eval/num_completed_episodes", num_completed), ] logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}") async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None): if wandb_metrics is None: wandb_metrics = {} if self.episode_outcomes_buffer: avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer) wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0) train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0) train_draws = sum(1 for r in self.episode_outcomes_buffer if r == 0) wandb_metrics[f"{self.name}_train/win_count"] = train_wins wandb_metrics[f"{self.name}_train/loss_count"] = train_losses wandb_metrics[f"{self.name}_train/draw_count"] = train_draws wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer) self.episode_outcomes_buffer = [] for key, value in self.eval_metrics_custom: wandb_metrics[key] = value self.eval_metrics_custom = [] await super().wandb_log(wandb_metrics) if __name__ == "__main__": BlackjackEnvNoThinking.cli()