diff --git a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py new file mode 100644 index 00000000..ccdc9cea --- /dev/null +++ b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py @@ -0,0 +1,265 @@ +import logging +from typing import Dict, List, Optional, Tuple + +import gymnasium as gym +import random + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem +from atroposlib.type_definitions import Item, Message +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +logger = logging.getLogger(__name__) + +ACTION_HIT = 1 +ACTION_STICK = 0 +ACTION_MAP_TO_STR = {ACTION_HIT: "hit", ACTION_STICK: "stick"} +ACTION_STR_TO_INT = {v: k for k, v in ACTION_MAP_TO_STR.items()} + + +class BlackjackEnvNoThinkingConfig(BaseEnvConfig): + """ + Configuration for the BlackjackEnvNoThinking environment. + """ + + env_name: str = "Blackjack-v1" + max_episode_turns: int = 10 + eval_episodes: int = 100 + + +class BlackjackEnvNoThinking(BaseEnv): + name = "blackjack_no_thinking" + env_config_cls = BlackjackEnvNoThinkingConfig + + def __init__( + self, + config: BlackjackEnvNoThinkingConfig, + server_configs: List[OpenaiConfig], + slurm: bool = True, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + self.config: BlackjackEnvNoThinkingConfig = config + self.episode_outcomes_buffer: List[float] = [] + self.eval_metrics_custom: List[Tuple[str, float]] = [] + + + @classmethod + def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]: + env_config = BlackjackEnvNoThinkingConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, + use_wandb=True, + rollout_server_url="http://localhost:8000", + max_token_length=2048, + wandb_name=cls.name, + steps_per_eval=50, + max_episode_turns=10, + eval_episodes=100, + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=128, + ), + ] + return env_config, server_configs + + def _format_observation(self, obs: Tuple[int, int, int]) -> str: + """Converts a Blackjack observation to a human-readable string.""" + player_sum, dealer_card, usable_ace = obs + return ( + f"Your current hand sum is {player_sum}. " + f"The dealer is showing a {dealer_card}. " + f"You have a usable ace: {'yes' if usable_ace else 'no'}." + ) + + def _parse_action_from_llm(self, llm_response: str) -> Optional[int]: + """Parses 'hit' or 'stick' from the LLM response.""" + action_str = llm_response.strip().lower() + if action_str in ACTION_STR_TO_INT: + return ACTION_STR_TO_INT[action_str] + logger.warning(f"Could not parse action from LLM response: '{llm_response}'") + return None + + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[ScoredDataItem], List[Item]]: + """ + Collects a single trajectory (episode) for the Blackjack environment. + The LLM directly outputs 'hit' or 'stick'. + The 'score' in ScoredDataItem is the final game outcome (+1, 0, -1). + """ + seed = item["seed"] + messages: List[Message] = [] + game_reward = 0.0 + num_turns = 0 + + try: + env = gym.make(self.config.env_name) + except Exception as e: + logger.error(f"Failed to make environment {self.config.env_name}: {e}") + return None, [] + + try: + obs, info = env.reset(seed=seed) + except Exception as e: + logger.error(f"Failed to reset environment with seed {seed}: {e}") + env.close() + return None, [] + + system_prompt = ( + "You are playing Blackjack. Respond with either 'hit' or 'stick'." + ) + messages.append({"role": "system", "content": system_prompt}) + + current_obs_str = self._format_observation(obs) + messages.append({"role": "user", "content": current_obs_str}) + + async with self.server.dedicated_server() as server: + for _ in range(self.config.max_episode_turns): + if ( + len(self.tokenizer.apply_chat_template(messages, tokenize=False)) + > self.config.max_token_length - 50 + ): + logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.") + break + + max_tokens_for_action = 10 + + try: + chat_completions = await server.chat_completion( + messages=messages, + n=1, + max_tokens=max_tokens_for_action, + temperature=0.5, + ) + llm_action_response = chat_completions.choices[0].message.content.strip() + except Exception as e: + logger.error(f"[Seed: {seed}] LLM API error: {e}") + break + + messages.append({"role": "assistant", "content": llm_action_response}) + + action = self._parse_action_from_llm(llm_action_response) + if action is None: + logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.") + game_reward = -1.0 + break + + try: + obs, reward, terminated, truncated, _ = env.step(action) + game_reward = float(reward) + except Exception as e: + logger.error(f"[Seed: {seed}] Error stepping env: {e}") + break + + if terminated or truncated: + break + + current_obs_str = self._format_observation(obs) + messages.append({"role": "user", "content": current_obs_str}) + + env.close() + self.episode_outcomes_buffer.append(game_reward) + + tokenization_result = tokenize_for_trainer( + tokenizer=self.tokenizer, + chat=messages, + train_on_all_assistant_turns=True + ) + + tokens = tokenization_result["tokens"] + masks = tokenization_result["masks"] + + scored_data_item = ScoredDataItem( + messages=messages if self.config.include_messages else None, + tokens=tokens, + masks=masks, + scores=game_reward, + ) + return scored_data_item, [] + + async def get_next_item(self) -> Item: + next_seed = random.randint(0, 1_000_000) + return {"seed": next_seed} + + async def setup(self): + logger.info(f"Setting up {self.name} environment.") + + async def evaluate(self, *args, **kwargs): + logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.") + + wins = 0 + losses = 0 + draws = 0 + + eval_outcomes: List[float] = [] + + for i in range(self.config.eval_episodes): + seed = random.randint(1_000_001, 2_000_000) + item = {"seed": seed} + scored_item_tuple = await self.collect_trajectory(item) + if scored_item_tuple and scored_item_tuple[0]: + outcome = scored_item_tuple[0]["scores"] + eval_outcomes.append(outcome) + else: + logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.") + + + if not eval_outcomes: + logger.warning("No evaluation episodes completed successfully.") + self.eval_metrics_custom = [] + return + + for outcome in eval_outcomes: + if outcome > 0: + wins += 1 + elif outcome < 0: + losses += 1 + else: + draws += 1 + + num_completed = len(eval_outcomes) + win_rate = wins / num_completed if num_completed > 0 else 0 + loss_rate = losses / num_completed if num_completed > 0 else 0 + draw_rate = draws / num_completed if num_completed > 0 else 0 + avg_reward = sum(eval_outcomes) / num_completed if num_completed > 0 else 0 + + self.eval_metrics_custom = [ + (f"{self.name}_eval/win_rate", win_rate), + (f"{self.name}_eval/loss_rate", loss_rate), + (f"{self.name}_eval/draw_rate", draw_rate), + (f"{self.name}_eval/avg_reward", avg_reward), + (f"{self.name}_eval/num_completed_episodes", num_completed), + ] + logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}") + + + async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None): + if wandb_metrics is None: + wandb_metrics = {} + + if self.episode_outcomes_buffer: + avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer) + wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward + train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0) + train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0) + train_draws = sum(1 for r in self.episode_outcomes_buffer if r == 0) + wandb_metrics[f"{self.name}_train/win_count"] = train_wins + wandb_metrics[f"{self.name}_train/loss_count"] = train_losses + wandb_metrics[f"{self.name}_train/draw_count"] = train_draws + wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer) + + self.episode_outcomes_buffer = [] + + for key, value in self.eval_metrics_custom: + wandb_metrics[key] = value + self.eval_metrics_custom = [] + + await super().wandb_log(wandb_metrics) + + +if __name__ == "__main__": + BlackjackEnvNoThinking.cli() diff --git a/environments/game_environments/gymnasium/blackjack_env.py b/environments/game_environments/gymnasium/blackjack_env_thinking.py similarity index 97% rename from environments/game_environments/gymnasium/blackjack_env.py rename to environments/game_environments/gymnasium/blackjack_env_thinking.py index 8f2e6518..3c03dd64 100644 --- a/environments/game_environments/gymnasium/blackjack_env.py +++ b/environments/game_environments/gymnasium/blackjack_env_thinking.py @@ -489,7 +489,7 @@ class BlackjackEnv(BaseEnv): advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t # If we pass this then instead of raw scores, implicitly, we're # doing some credit assignment. Could maybe do bonus on a win too - # and apply with a discount factor to alts in winning trajectories + # and/or apply with a discount factor to alts in winning trajectories alt_advantages.append(advantage_i) logger.debug( f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] " @@ -663,6 +663,15 @@ class BlackjackEnv(BaseEnv): ) -> List[Optional[BlackjackScoredDataGroup]]: """Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup already contains the A*(s,a) advantages from the collection phase. + + If you wanted to play around with additional scoring metrics, you could do so here. + Eg, bonuses for the specific winning action trajectory + + Args: + rollout_group_data: List of BlackjackScoredDataGroup objects containing the collected rollout data. + + Returns: + List of BlackjackScoredDataGroup objects with the scores field updated. """ logger.info(f"[Score] Processing {len(rollout_group_data)} steps.") return rollout_group_data @@ -670,6 +679,16 @@ class BlackjackEnv(BaseEnv): async def collect_trajectories( self, item: Tuple[int, int] ) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]: + """Collect trajectories for training. + + Args: + item: Tuple containing the seed and the group index. + + Returns: + Tuple of two lists: + - List of BlackjackScoredDataGroup objects containing the collected rollout data. + - List of Tuple[int, int] objects for the backlog + """ seed, _ = item traj = await self._collect_trajectory(seed) if not traj: