import logging
from typing import Dict, List, Optional, Tuple
import json
import gymnasium as gym
import random
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
from atroposlib.type_definitions import Item, Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from atroposlib.utils.tool_call_parser import parse_tool_call
logger = logging.getLogger(__name__)
ACTION_HIT = 1
ACTION_STICK = 0
ACTION_MAP_TO_STR = {ACTION_HIT: "hit", ACTION_STICK: "stick"}
ACTION_STR_TO_INT = {v: k for k, v in ACTION_MAP_TO_STR.items()}
class BlackjackEnvNoThinkingConfig(BaseEnvConfig):
"""
Configuration for the BlackjackEnvNoThinking environment.
"""
env_name: str = "Blackjack-v1"
max_episode_turns: int = 10
eval_episodes: int = 100
class BlackjackEnvNoThinking(BaseEnv):
name = "blackjack_no_thinking"
env_config_cls = BlackjackEnvNoThinkingConfig
def __init__(
self,
config: BlackjackEnvNoThinkingConfig,
server_configs: List[OpenaiConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: BlackjackEnvNoThinkingConfig = config
self.episode_outcomes_buffer: List[float] = []
self.eval_metrics_custom: List[Tuple[str, float]] = []
# Define tools available to the LLM
self.tools = [
{
"type": "function",
"function": {
"name": "take_action",
"description": "Choose to 'hit' or 'stick' in Blackjack.",
"parameters": {
# Parameters are implicitly defined by the arguments of the function call
# For this simple case, let's assume the LLM will provide arguments.action
# based on the prompt. A more robust schema would define 'action' here.
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["hit", "stick"]}
},
"required": ["action"],
},
},
}
]
tools_json = json.dumps(self.tools)
# Updated system prompt for tool calling
self.system_prompt = (
"You are an AI agent playing Blackjack. "
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
f"\n{tools_json}\n\n\n"
"For your function call, return a JSON object with function name and arguments "
"within tags with the following schema:\n"
'\n{"arguments": {"action": "hit"}, "name": "take_action"}\n\n\n'
"Your full answer format should be (NO THINKING BLOCK):\n"
'\n{"arguments": {"action": "stick"}, "name": "take_action"}\n\n'
)
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
env_config = BlackjackEnvNoThinkingConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=2048,
wandb_name=cls.name,
steps_per_eval=50,
max_episode_turns=10,
eval_episodes=100,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=128,
),
]
return env_config, server_configs
def _format_observation(self, obs: Tuple[int, int, int]) -> str:
"""Converts a Blackjack observation to a human-readable string."""
player_sum, dealer_card, usable_ace = obs
return (
f"Your current hand sum is {player_sum}. "
f"The dealer is showing a {dealer_card}. "
f"You have a usable ace: {'yes' if usable_ace else 'no'}."
)
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
"""Parses the action from the LLM's tool_call response."""
if not llm_response:
logger.warning(
"Attempted to parse an empty LLM response. Returning invalid action (None)."
)
return None
parsed_name, parsed_args, is_error = parse_tool_call(
llm_response, self.tools, ["tool_call"] # Expecting
)
if is_error:
error_detail = (
str(parsed_name) # Error message is in parsed_name if is_error
if parsed_name
else "Parser indicated error, but no specific message was returned."
)
logger.warning(
f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}"
)
return None
if parsed_name != "take_action":
logger.warning(
f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'"
)
return None
action_str = parsed_args.get("action", "").lower()
if action_str == "hit":
return ACTION_HIT
elif action_str == "stick":
return ACTION_STICK
else:
logger.warning(
f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
)
return None
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
"""
Collects a single trajectory (episode) for the Blackjack environment.
The LLM directly outputs 'hit' or 'stick'.
The 'score' in ScoredDataItem is the final game outcome (+1, 0, -1).
"""
seed = item["seed"]
messages: List[Message] = []
game_reward = 0.0
num_turns = 0
try:
env = gym.make(self.config.env_name)
except Exception as e:
logger.error(f"Failed to make environment {self.config.env_name}: {e}")
return None, []
try:
obs, info = env.reset(seed=seed)
except Exception as e:
logger.error(f"Failed to reset environment with seed {seed}: {e}")
env.close()
return None, []
# Use the class system_prompt
messages.append({"role": "system", "content": self.system_prompt})
current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str})
async with self.server.dedicated_server() as server:
for _ in range(self.config.max_episode_turns):
if (
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
> self.config.max_token_length - 50
):
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
break
max_tokens_for_action = 512
try:
chat_completions = await server.chat_completion(
messages=messages,
n=1,
max_tokens=max_tokens_for_action,
temperature=0.5,
)
llm_action_response = chat_completions.choices[0].message.content.strip()
logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
except Exception as e:
logger.error(f"[Seed: {seed}] LLM API error: {e}")
break
messages.append({"role": "assistant", "content": llm_action_response})
action = self._parse_action_from_llm(llm_action_response)
if action is None:
logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.")
game_reward = -1.0
break
try:
obs, reward, terminated, truncated, _ = env.step(action)
game_reward = float(reward)
except Exception as e:
logger.error(f"[Seed: {seed}] Error stepping env: {e}")
break
if terminated or truncated:
break
current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str})
env.close()
self.episode_outcomes_buffer.append(game_reward)
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True
)
tokens = tokenization_result["tokens"]
masks = tokenization_result["masks"]
scored_data_item = ScoredDataItem(
messages=messages if self.config.include_messages else None,
tokens=tokens,
masks=masks,
scores=game_reward,
)
return scored_data_item, []
async def get_next_item(self) -> Item:
next_seed = random.randint(0, 1_000_000)
return {"seed": next_seed}
async def setup(self):
logger.info(f"Setting up {self.name} environment.")
async def evaluate(self, *args, **kwargs):
logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.")
wins = 0
losses = 0
draws = 0
eval_outcomes: List[float] = []
for i in range(self.config.eval_episodes):
seed = random.randint(1_000_001, 2_000_000)
item = {"seed": seed}
scored_item_tuple = await self.collect_trajectory(item)
if scored_item_tuple and scored_item_tuple[0]:
outcome = scored_item_tuple[0]["scores"]
eval_outcomes.append(outcome)
else:
logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.")
if not eval_outcomes:
logger.warning("No evaluation episodes completed successfully.")
self.eval_metrics_custom = []
return
for outcome in eval_outcomes:
if outcome > 0:
wins += 1
elif outcome < 0:
losses += 1
else:
draws += 1
num_completed = len(eval_outcomes)
win_rate = wins / num_completed if num_completed > 0 else 0
loss_rate = losses / num_completed if num_completed > 0 else 0
draw_rate = draws / num_completed if num_completed > 0 else 0
avg_reward = sum(eval_outcomes) / num_completed if num_completed > 0 else 0
self.eval_metrics_custom = [
(f"{self.name}_eval/win_rate", win_rate),
(f"{self.name}_eval/loss_rate", loss_rate),
(f"{self.name}_eval/draw_rate", draw_rate),
(f"{self.name}_eval/avg_reward", avg_reward),
(f"{self.name}_eval/num_completed_episodes", num_completed),
]
logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}")
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.episode_outcomes_buffer:
avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward
train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0)
train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0)
train_draws = sum(1 for r in self.episode_outcomes_buffer if r == 0)
wandb_metrics[f"{self.name}_train/win_count"] = train_wins
wandb_metrics[f"{self.name}_train/loss_count"] = train_losses
wandb_metrics[f"{self.name}_train/draw_count"] = train_draws
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer)
self.episode_outcomes_buffer = []
for key, value in self.eval_metrics_custom:
wandb_metrics[key] = value
self.eval_metrics_custom = []
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
BlackjackEnvNoThinking.cli()