diff --git a/atroposlib/utils/__init__.py b/atroposlib/utils/__init__.py index cd578734..98fd052e 100644 --- a/atroposlib/utils/__init__.py +++ b/atroposlib/utils/__init__.py @@ -3,31 +3,5 @@ Utility functions and classes for the atroposlib package. """ from .config_handler import ConfigHandler -from .message_history_utils import ( - strip_thinking, - truncate_thinking, - ensure_trajectory_token_limit, -) -from .tokenize_for_trainer import tokenize_for_trainer -from .tool_call_parser import parse_tool_call -from .advantages import ( - allclose_to_first, - compute_stats, - compute_discounted_returns, - compute_grpo_process_supervision_advantages, -) -from .best_of_n_selection import select_best_index -__all__ = [ - "ConfigHandler", - "strip_thinking", - "truncate_thinking", - "tokenize_for_trainer", - "parse_tool_call", - "allclose_to_first", - "compute_stats", - "compute_discounted_returns", - "compute_grpo_process_supervision_advantages", - "ensure_trajectory_token_limit", - "select_best_index", -] +__all__ = ["ConfigHandler"] diff --git a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py index ccdc9cea..81a6c883 100644 --- a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py +++ b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py @@ -1,5 +1,6 @@ import logging from typing import Dict, List, Optional, Tuple +import json import gymnasium as gym import random @@ -7,6 +8,7 @@ import random from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem from atroposlib.type_definitions import Item, Message from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +from atroposlib.utils.tool_call_parser import parse_tool_call logger = logging.getLogger(__name__) @@ -42,6 +44,39 @@ class BlackjackEnvNoThinking(BaseEnv): self.episode_outcomes_buffer: List[float] = [] self.eval_metrics_custom: List[Tuple[str, float]] = [] + # Define tools available to the LLM + self.tools = [ + { + "type": "function", + "function": { + "name": "take_action", + "description": "Choose to 'hit' or 'stick' in Blackjack.", + "parameters": { + # Parameters are implicitly defined by the arguments of the function call + # For this simple case, let's assume the LLM will provide arguments.action + # based on the prompt. A more robust schema would define 'action' here. + "type": "object", + "properties": { + "action": {"type": "string", "enum": ["hit", "stick"]} + }, + "required": ["action"], + }, + }, + } + ] + + tools_json = json.dumps(self.tools) + # Updated system prompt for tool calling + self.system_prompt = ( + "You are an AI agent playing Blackjack. " + "You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n" + f"\n{tools_json}\n\n\n" + "For your function call, return a JSON object with function name and arguments " + "within tags with the following schema:\n" + '\n{"arguments": {"action": "hit"}, "name": "take_action"}\n\n\n' + "Your full answer format should be (NO THINKING BLOCK):\n" + '\n{"arguments": {"action": "stick"}, "name": "take_action"}\n\n' + ) @classmethod def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]: @@ -76,12 +111,45 @@ class BlackjackEnvNoThinking(BaseEnv): ) def _parse_action_from_llm(self, llm_response: str) -> Optional[int]: - """Parses 'hit' or 'stick' from the LLM response.""" - action_str = llm_response.strip().lower() - if action_str in ACTION_STR_TO_INT: - return ACTION_STR_TO_INT[action_str] - logger.warning(f"Could not parse action from LLM response: '{llm_response}'") - return None + """Parses the action from the LLM's tool_call response.""" + if not llm_response: + logger.warning( + "Attempted to parse an empty LLM response. Returning invalid action (None)." + ) + return None + + parsed_name, parsed_args, is_error = parse_tool_call( + llm_response, self.tools, ["tool_call"] # Expecting + ) + + if is_error: + error_detail = ( + str(parsed_name) # Error message is in parsed_name if is_error + if parsed_name + else "Parser indicated error, but no specific message was returned." + ) + logger.warning( + f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}" + ) + return None + + if parsed_name != "take_action": + logger.warning( + f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'" + ) + return None + + action_str = parsed_args.get("action", "").lower() + if action_str == "hit": + return ACTION_HIT + elif action_str == "stick": + return ACTION_STICK + else: + logger.warning( + f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. " + f"Full response: '{llm_response}'. Parsed args: {parsed_args}" + ) + return None async def collect_trajectory( self, item: Item @@ -109,10 +177,8 @@ class BlackjackEnvNoThinking(BaseEnv): env.close() return None, [] - system_prompt = ( - "You are playing Blackjack. Respond with either 'hit' or 'stick'." - ) - messages.append({"role": "system", "content": system_prompt}) + # Use the class system_prompt + messages.append({"role": "system", "content": self.system_prompt}) current_obs_str = self._format_observation(obs) messages.append({"role": "user", "content": current_obs_str}) @@ -126,7 +192,7 @@ class BlackjackEnvNoThinking(BaseEnv): logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.") break - max_tokens_for_action = 10 + max_tokens_for_action = 512 try: chat_completions = await server.chat_completion( @@ -136,6 +202,7 @@ class BlackjackEnvNoThinking(BaseEnv): temperature=0.5, ) llm_action_response = chat_completions.choices[0].message.content.strip() + logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response except Exception as e: logger.error(f"[Seed: {seed}] LLM API error: {e}") break diff --git a/environments/game_environments/gymnasium/blackjack_env_thinking.py b/environments/game_environments/gymnasium/blackjack_env_thinking.py index 3c03dd64..85396a41 100644 --- a/environments/game_environments/gymnasium/blackjack_env_thinking.py +++ b/environments/game_environments/gymnasium/blackjack_env_thinking.py @@ -27,13 +27,10 @@ from atroposlib.envs.base import ( OpenaiConfig, ScoredDataGroup, ) -from atroposlib.utils import ( - tokenize_for_trainer, - parse_tool_call, - truncate_thinking, - ensure_trajectory_token_limit, - select_best_index -) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +from atroposlib.utils.message_history_utils import truncate_thinking +from atroposlib.utils.tool_call_parser import parse_tool_call +from atroposlib.utils.best_of_n_selection import select_best_index logger = logging.getLogger(__name__) diff --git a/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py b/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py new file mode 100644 index 00000000..d942c878 --- /dev/null +++ b/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py @@ -0,0 +1,121 @@ +import asyncio +import logging +import os +import random +from typing import Optional + +from dotenv import load_dotenv + +from atroposlib.envs.base import EvalHandlingEnum, OpenaiConfig, ScoredDataItem +from environments.game_environments.gymnasium.blackjack_env_no_thinking import ( + BlackjackEnvNoThinking, + BlackjackEnvNoThinkingConfig, +) + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + logger.info( + "Starting Blackjack (No Thinking) environment local debug runner" + ) + + env_config = BlackjackEnvNoThinkingConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=1, + use_wandb=False, + wandb_name="blackjack_no_thinking_local_debug", + max_num_workers=1, + rollout_server_url="http://localhost:8000", + total_steps=1, + batch_size=1, + steps_per_eval=0, + max_token_length=1024, + inference_weight=1.0, + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.NONE, + eval_limit_ratio=0.0, + env_name="Blackjack-v1", + max_episode_turns=10, + eval_episodes=0, + ) + server_configs = [ + OpenaiConfig( + model_name="gpt-4.1-nano", + base_url="https://api.openai.com/v1", + api_key=os.getenv("OPENAI_API_KEY"), + num_requests_for_eval=0, + ) + ] + logger.info("Using hardcoded debug configuration for No Thinking Blackjack.") + logger.debug(f"Env Config: {env_config}") + logger.debug(f"Server Configs: {server_configs}") + + try: + env = BlackjackEnvNoThinking( + config=env_config, + server_configs=server_configs, + slurm=False, + testing=False, + ) + except Exception as e: + logger.exception(f"Failed to initialize BlackjackEnvNoThinking: {e}") + return + + logger.info("Running a single trajectory directly using collect_trajectory") + try: + await env.setup() + seed = random.randint(0, 1000000) + item_for_env = {"seed": seed} + logger.info(f"Using seed: {seed} for item: {item_for_env}") + + result_tuple = await env.collect_trajectory(item_for_env) + + scored_data_item: Optional[ScoredDataItem] = None + if result_tuple and result_tuple[0]: + scored_data_item = result_tuple[0] + logger.info( + f"Trajectory collection complete. Score: {scored_data_item.get('scores')}" + ) + if env_config.include_messages and scored_data_item.get('messages'): + logger.info("Collected Messages:") + for i, msg in enumerate(scored_data_item['messages']): + logger.info(f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'") + logger.info(f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}...") + logger.info(f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}...") + else: + logger.error("Trajectory collection did not return a ScoredDataItem.") + + episode_summary_reward = None + if env.episode_outcomes_buffer: + episode_summary_reward = env.episode_outcomes_buffer[-1] + + if episode_summary_reward is not None: + logger.info("\n========== Episode Summary ==========") + logger.info(f"Seed: {seed}") + logger.info( + f"Final Environment reward (Score): {episode_summary_reward:.2f}" + ) + outcome_str = "Draw" + if episode_summary_reward > 0: + outcome_str = "Win" + elif episode_summary_reward < 0: + outcome_str = "Loss" + logger.info(f"Game Outcome: {outcome_str}") + logger.info("=======================================") + else: + logger.error( + f"Could not get episode summary for seed {seed} from metrics buffer." + ) + + except Exception as e: + logger.exception( + f"An error occurred during trajectory collection or summary: {e}" + ) + + +if __name__ == "__main__": + asyncio.run(main())