diff --git a/atroposlib/utils/__init__.py b/atroposlib/utils/__init__.py
index cd578734..98fd052e 100644
--- a/atroposlib/utils/__init__.py
+++ b/atroposlib/utils/__init__.py
@@ -3,31 +3,5 @@ Utility functions and classes for the atroposlib package.
"""
from .config_handler import ConfigHandler
-from .message_history_utils import (
- strip_thinking,
- truncate_thinking,
- ensure_trajectory_token_limit,
-)
-from .tokenize_for_trainer import tokenize_for_trainer
-from .tool_call_parser import parse_tool_call
-from .advantages import (
- allclose_to_first,
- compute_stats,
- compute_discounted_returns,
- compute_grpo_process_supervision_advantages,
-)
-from .best_of_n_selection import select_best_index
-__all__ = [
- "ConfigHandler",
- "strip_thinking",
- "truncate_thinking",
- "tokenize_for_trainer",
- "parse_tool_call",
- "allclose_to_first",
- "compute_stats",
- "compute_discounted_returns",
- "compute_grpo_process_supervision_advantages",
- "ensure_trajectory_token_limit",
- "select_best_index",
-]
+__all__ = ["ConfigHandler"]
diff --git a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py
index ccdc9cea..81a6c883 100644
--- a/environments/game_environments/gymnasium/blackjack_env_no_thinking.py
+++ b/environments/game_environments/gymnasium/blackjack_env_no_thinking.py
@@ -1,5 +1,6 @@
import logging
from typing import Dict, List, Optional, Tuple
+import json
import gymnasium as gym
import random
@@ -7,6 +8,7 @@ import random
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
from atroposlib.type_definitions import Item, Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
+from atroposlib.utils.tool_call_parser import parse_tool_call
logger = logging.getLogger(__name__)
@@ -42,6 +44,39 @@ class BlackjackEnvNoThinking(BaseEnv):
self.episode_outcomes_buffer: List[float] = []
self.eval_metrics_custom: List[Tuple[str, float]] = []
+ # Define tools available to the LLM
+ self.tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "take_action",
+ "description": "Choose to 'hit' or 'stick' in Blackjack.",
+ "parameters": {
+ # Parameters are implicitly defined by the arguments of the function call
+ # For this simple case, let's assume the LLM will provide arguments.action
+ # based on the prompt. A more robust schema would define 'action' here.
+ "type": "object",
+ "properties": {
+ "action": {"type": "string", "enum": ["hit", "stick"]}
+ },
+ "required": ["action"],
+ },
+ },
+ }
+ ]
+
+ tools_json = json.dumps(self.tools)
+ # Updated system prompt for tool calling
+ self.system_prompt = (
+ "You are an AI agent playing Blackjack. "
+ "You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
+ f"\n{tools_json}\n\n\n"
+ "For your function call, return a JSON object with function name and arguments "
+ "within tags with the following schema:\n"
+ '\n{"arguments": {"action": "hit"}, "name": "take_action"}\n\n\n'
+ "Your full answer format should be (NO THINKING BLOCK):\n"
+ '\n{"arguments": {"action": "stick"}, "name": "take_action"}\n\n'
+ )
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
@@ -76,12 +111,45 @@ class BlackjackEnvNoThinking(BaseEnv):
)
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
- """Parses 'hit' or 'stick' from the LLM response."""
- action_str = llm_response.strip().lower()
- if action_str in ACTION_STR_TO_INT:
- return ACTION_STR_TO_INT[action_str]
- logger.warning(f"Could not parse action from LLM response: '{llm_response}'")
- return None
+ """Parses the action from the LLM's tool_call response."""
+ if not llm_response:
+ logger.warning(
+ "Attempted to parse an empty LLM response. Returning invalid action (None)."
+ )
+ return None
+
+ parsed_name, parsed_args, is_error = parse_tool_call(
+ llm_response, self.tools, ["tool_call"] # Expecting
+ )
+
+ if is_error:
+ error_detail = (
+ str(parsed_name) # Error message is in parsed_name if is_error
+ if parsed_name
+ else "Parser indicated error, but no specific message was returned."
+ )
+ logger.warning(
+ f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}"
+ )
+ return None
+
+ if parsed_name != "take_action":
+ logger.warning(
+ f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'"
+ )
+ return None
+
+ action_str = parsed_args.get("action", "").lower()
+ if action_str == "hit":
+ return ACTION_HIT
+ elif action_str == "stick":
+ return ACTION_STICK
+ else:
+ logger.warning(
+ f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
+ f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
+ )
+ return None
async def collect_trajectory(
self, item: Item
@@ -109,10 +177,8 @@ class BlackjackEnvNoThinking(BaseEnv):
env.close()
return None, []
- system_prompt = (
- "You are playing Blackjack. Respond with either 'hit' or 'stick'."
- )
- messages.append({"role": "system", "content": system_prompt})
+ # Use the class system_prompt
+ messages.append({"role": "system", "content": self.system_prompt})
current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str})
@@ -126,7 +192,7 @@ class BlackjackEnvNoThinking(BaseEnv):
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
break
- max_tokens_for_action = 10
+ max_tokens_for_action = 512
try:
chat_completions = await server.chat_completion(
@@ -136,6 +202,7 @@ class BlackjackEnvNoThinking(BaseEnv):
temperature=0.5,
)
llm_action_response = chat_completions.choices[0].message.content.strip()
+ logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
except Exception as e:
logger.error(f"[Seed: {seed}] LLM API error: {e}")
break
diff --git a/environments/game_environments/gymnasium/blackjack_env_thinking.py b/environments/game_environments/gymnasium/blackjack_env_thinking.py
index 3c03dd64..85396a41 100644
--- a/environments/game_environments/gymnasium/blackjack_env_thinking.py
+++ b/environments/game_environments/gymnasium/blackjack_env_thinking.py
@@ -27,13 +27,10 @@ from atroposlib.envs.base import (
OpenaiConfig,
ScoredDataGroup,
)
-from atroposlib.utils import (
- tokenize_for_trainer,
- parse_tool_call,
- truncate_thinking,
- ensure_trajectory_token_limit,
- select_best_index
-)
+from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
+from atroposlib.utils.message_history_utils import truncate_thinking
+from atroposlib.utils.tool_call_parser import parse_tool_call
+from atroposlib.utils.best_of_n_selection import select_best_index
logger = logging.getLogger(__name__)
diff --git a/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py b/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py
new file mode 100644
index 00000000..d942c878
--- /dev/null
+++ b/environments/game_environments/gymnasium/blackjack_local_server_no_thinking.py
@@ -0,0 +1,121 @@
+import asyncio
+import logging
+import os
+import random
+from typing import Optional
+
+from dotenv import load_dotenv
+
+from atroposlib.envs.base import EvalHandlingEnum, OpenaiConfig, ScoredDataItem
+from environments.game_environments.gymnasium.blackjack_env_no_thinking import (
+ BlackjackEnvNoThinking,
+ BlackjackEnvNoThinkingConfig,
+)
+
+load_dotenv()
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+async def main():
+ logger.info(
+ "Starting Blackjack (No Thinking) environment local debug runner"
+ )
+
+ env_config = BlackjackEnvNoThinkingConfig(
+ tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
+ group_size=1,
+ use_wandb=False,
+ wandb_name="blackjack_no_thinking_local_debug",
+ max_num_workers=1,
+ rollout_server_url="http://localhost:8000",
+ total_steps=1,
+ batch_size=1,
+ steps_per_eval=0,
+ max_token_length=1024,
+ inference_weight=1.0,
+ data_path_to_save_groups=None,
+ eval_handling=EvalHandlingEnum.NONE,
+ eval_limit_ratio=0.0,
+ env_name="Blackjack-v1",
+ max_episode_turns=10,
+ eval_episodes=0,
+ )
+ server_configs = [
+ OpenaiConfig(
+ model_name="gpt-4.1-nano",
+ base_url="https://api.openai.com/v1",
+ api_key=os.getenv("OPENAI_API_KEY"),
+ num_requests_for_eval=0,
+ )
+ ]
+ logger.info("Using hardcoded debug configuration for No Thinking Blackjack.")
+ logger.debug(f"Env Config: {env_config}")
+ logger.debug(f"Server Configs: {server_configs}")
+
+ try:
+ env = BlackjackEnvNoThinking(
+ config=env_config,
+ server_configs=server_configs,
+ slurm=False,
+ testing=False,
+ )
+ except Exception as e:
+ logger.exception(f"Failed to initialize BlackjackEnvNoThinking: {e}")
+ return
+
+ logger.info("Running a single trajectory directly using collect_trajectory")
+ try:
+ await env.setup()
+ seed = random.randint(0, 1000000)
+ item_for_env = {"seed": seed}
+ logger.info(f"Using seed: {seed} for item: {item_for_env}")
+
+ result_tuple = await env.collect_trajectory(item_for_env)
+
+ scored_data_item: Optional[ScoredDataItem] = None
+ if result_tuple and result_tuple[0]:
+ scored_data_item = result_tuple[0]
+ logger.info(
+ f"Trajectory collection complete. Score: {scored_data_item.get('scores')}"
+ )
+ if env_config.include_messages and scored_data_item.get('messages'):
+ logger.info("Collected Messages:")
+ for i, msg in enumerate(scored_data_item['messages']):
+ logger.info(f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'")
+ logger.info(f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}...")
+ logger.info(f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}...")
+ else:
+ logger.error("Trajectory collection did not return a ScoredDataItem.")
+
+ episode_summary_reward = None
+ if env.episode_outcomes_buffer:
+ episode_summary_reward = env.episode_outcomes_buffer[-1]
+
+ if episode_summary_reward is not None:
+ logger.info("\n========== Episode Summary ==========")
+ logger.info(f"Seed: {seed}")
+ logger.info(
+ f"Final Environment reward (Score): {episode_summary_reward:.2f}"
+ )
+ outcome_str = "Draw"
+ if episode_summary_reward > 0:
+ outcome_str = "Win"
+ elif episode_summary_reward < 0:
+ outcome_str = "Loss"
+ logger.info(f"Game Outcome: {outcome_str}")
+ logger.info("=======================================")
+ else:
+ logger.error(
+ f"Could not get episode summary for seed {seed} from metrics buffer."
+ )
+
+ except Exception as e:
+ logger.exception(
+ f"An error occurred during trajectory collection or summary: {e}"
+ )
+
+
+if __name__ == "__main__":
+ asyncio.run(main())