atropos/environments/game_environments/gymnasium/blackjack_env_thinking.py
2025-05-14 11:57:45 -07:00

893 lines
35 KiB
Python

#!/usr/bin/env python3
"""
BlackjackEnv: Trainer environment for Gymnasium Blackjack
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
using function-call style actions. Extends BaseEnv.
Sort of inspired by VinePPO, but uses a recursive exact value calculation
instead of Monte Carlo sampling (because the action space is so small and
the environment is deterministic, plus short episode lengths).
"""
import copy
import json
import logging
import random
import re
from typing import Any, Dict, List, Optional, Tuple
import gymnasium
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from atroposlib.utils.message_history_utils import truncate_thinking
from atroposlib.utils.tool_call_parser import parse_tool_call
from atroposlib.utils.best_of_n_selection import select_best_index
logger = logging.getLogger(__name__)
class BlackjackEnvConfig(BaseEnvConfig):
env_name: str = "Blackjack-v1"
temperature: float = 0.7
top_p: float = 0.9
max_turns: Optional[int] = 5
wandb_name: str = "blackjack"
thinking_active: bool = True
eval_episodes: int = 100
max_think_chars_history: int = 3000
max_trajectory_tokens: int = 24576 #seq_len of RL trainer
debug_mode: bool = False
group_size: int = 16
tiebreak_token_factor: float = 0.01
class BlackjackScoredDataGroup(ScoredDataGroup):
seed: int
tokens: Optional[List[List[int]]] = None
masks: Optional[List[List[int]]] = None
scores: Optional[List[float]] = None
messages: Optional[List[List[Dict]]] = None
parsed_actions: Optional[List[int]] = None
class EpisodeState:
def __init__(self, seed: int, env: gymnasium.Env):
self.seed = seed
self.env = env
self.message_history: List[Dict] = []
self.actions: List[int] = []
self.step_rewards: List[float] = []
self.total_reward: float = 0.0
self.num_steps: int = 0
self.num_correct_actions: int = 0
self.num_total_actions: int = 0
class BlackjackEnv(BaseEnv):
def __init__(
self,
config: BlackjackEnvConfig,
server_configs: List[OpenaiConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.episodes: Dict[int, EpisodeState] = {}
self.debug_mode = config.debug_mode
self.completed_episode_metrics_buffer: List[Dict[str, float]] = []
if self.debug_mode:
logger.setLevel(logging.DEBUG)
else:
if logger.level == logging.NOTSET or logger.level > logging.WARNING:
logger.setLevel(logging.WARNING)
self.tools = [
{
"type": "function",
"function": {
"name": "take_action",
"description": "Choose to 'hit' or 'stick' in Blackjack.",
"parameters": {
"action": {"type": "string", "enum": ["hit", "stick"]}
},
},
}
]
tools_json = json.dumps(self.tools)
self.system_prompt = (
"You are an AI agent playing Blackjack who uses extreme long chains of thought "
"to carefully consider the probabilities and optimal strategy. "
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
"You should enclose your thoughts and internal monologue inside <think> </think> tags, and then "
"provide your decision using the take_action function call. You may use extremely long chains "
"of thought to carefully consider the probabilities and optimal strategy.\n\n"
f"<tools>\n{tools_json}\n</tools>\n\n"
"For your function call, return a JSON object with function name and arguments "
"within <tool_call> </tool_call> tags with the following schema:\n"
'<tool_call>\n{"arguments": {"action": "hit"}, "name": "take_action"}\n</tool_call>\n\n'
"Your full answer format should be:\n"
"<think>\n[Your detailed reasoning process about whether to hit or stick]\n</think>\n\n"
'<tool_call>\n{"arguments": {"action": "stick"}, "name": "take_action"}\n</tool_call>\n\n'
"Remember to carefully consider the probabilities and optimal strategy for Blackjack."
)
def _get_or_create_episode(self, seed: int) -> EpisodeState:
if seed not in self.episodes:
env = gymnasium.make(self.config.env_name)
obs, _ = env.reset(seed=seed)
ep = EpisodeState(seed, env)
ep.message_history = [{"role": "system", "content": self.system_prompt}]
ep.message_history.append(
{"role": "environment", "content": self._format_observation(obs)}
)
self.episodes[seed] = ep
return self.episodes[seed]
def _format_observation(self, obs: Tuple[int, int, int]) -> str:
player_sum, dealer_card, usable_ace = obs
return f"Your hand sum is {player_sum}. Dealer showing: {dealer_card}. You have a usable ace: {usable_ace}."
def _score_response(
self,
env_reward: float,
response_text: str,
parsed_action: int,
episode_seed: int,
) -> float:
"""
Calculates a score for a single agent response based purely on environment reward
and a penalty for invalid action format.
"""
current_env_reward = env_reward * 1.0
# Action is good?
if parsed_action == -1:
current_env_reward -= 0.2
else:
current_env_reward += 0.2
# Check the thinking tags exist
match = re.search(r"<think>(.*?)</think>", response_text)
if match:
thinking_content = match.group(1)
if thinking_content:
current_env_reward += 0.2
# Check there's actually valid content (not just whitespace)
if not thinking_content.strip():
current_env_reward -= 0.2
else:
current_env_reward -= 0.2
# Calculate the number of tokens in the agent's response
if response_text:
num_tokens = len(self.tokenizer.encode(response_text))
else:
num_tokens = 0
# tiebreak & small length penalty
if self.config.max_token_length > 0:
token_ratio = min(1.0, num_tokens / self.config.max_token_length)
tiebreak_bonus = self.config.tiebreak_token_factor * (1.0 - token_ratio)
current_env_reward += tiebreak_bonus
return current_env_reward
def _parse_tool_call(self, response: str) -> int:
if not response:
logger.warning(
"Attempted to parse an empty response string. Returning invalid action (-1)."
)
return -1
parsed_name, parsed_args, is_error = parse_tool_call(
response, self.tools, ["tool_call"]
)
if is_error:
error_detail = (
parsed_name
if isinstance(parsed_name, str) and parsed_name
else "Parser indicated error, but no specific message was returned in the typical error slot."
)
logger.warning(
f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}"
)
return -1
action = parsed_args.get("action", "").lower()
if action == "hit":
return 1
elif action == "stick":
return 0
else:
logger.warning(
f"Successfully parsed tool call, but action is invalid. Action: '{action}'. "
f"Full response: '{response}'. Parsed args: {parsed_args}"
)
return -1
async def _sample_response(self, messages: List[Dict], n: int = 1) -> List[str]:
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
try:
completions = await self.server.completion(
prompt=prompt,
n=n,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
return [choice.text for choice in completions.choices]
except Exception as e:
logger.error(f"API error during completion: {e}")
return []
async def _estimate_value(
self,
episode_seed_for_sim: int,
env_actions_to_replay: List[int],
) -> float:
"""Calculate exact state value V*(s)
Args:
episode_seed_for_sim: The seed of the original episode for deterministic env creation.
env_actions_to_replay: List of environment actions (0 or 1) taken to reach current state s.
"""
v_star_cache: Dict[Tuple[int, int, int], float] = {}
def _get_v_star_recursive(
obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env
) -> float:
player_sum, dealer_card, usable_ace = obs_tuple
# Base Case 1: Bust
if player_sum > 21:
return -1.0
# Base Case 2: Check memoization cache
if obs_tuple in v_star_cache:
return v_star_cache[obs_tuple]
env_for_stick = copy.deepcopy(current_env)
_, reward_stick, _, _, _ = env_for_stick.step(0)
# stick is terminal, so reward is final outcome
q_star_stick = reward_stick
# Q-value for HIT (action 1)
env_for_hit = copy.deepcopy(current_env)
obs_hit, reward_hit, term_hit, trunc_hit, _ = env_for_hit.step(1)
if term_hit or trunc_hit: # Game ended after hitting
q_star_hit = reward_hit
else: # Game continues, recursively find V* of next state
# reward_hit is typically 0 if the game didn't end
q_star_hit = reward_hit + _get_v_star_recursive(obs_hit, env_for_hit)
v_star = max(q_star_stick, q_star_hit)
v_star_cache[obs_tuple] = v_star
return v_star
sim_env = None
try:
sim_env = gymnasium.make(self.config.env_name)
current_obs, _ = sim_env.reset(seed=episode_seed_for_sim)
# Replay actions to reach the current state s_t
is_terminal_after_replay = False
for action_idx, prev_action in enumerate(env_actions_to_replay):
current_obs, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
if term_replay or trunc_replay:
logger.debug(
f"[_estimate_value] State became terminal during action replay "
f"(action {action_idx+1}/{len(env_actions_to_replay)} of prev_actions). Value is 0."
)
is_terminal_after_replay = True
break
if is_terminal_after_replay:
return 0.0
final_v_star = _get_v_star_recursive(current_obs, sim_env)
return final_v_star
except Exception as e:
logger.error(
f"[_estimate_value] Error during exact value"
f" calculation for seed {episode_seed_for_sim}, "
f"actions {env_actions_to_replay}: {e}",
exc_info=True,
)
return 0.0
finally:
if sim_env is not None:
sim_env.close()
async def _collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
"""Collect data for ONE trajectory, evaluating G alternatives per step using MC advantages."""
G = self.config.group_size
max_turns = self.config.max_turns or 5
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
episode_summary_metrics: Optional[Dict[str, Any]] = None
logger.info(
f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}."
)
try:
ep = self._get_or_create_episode(seed)
except Exception as e:
logger.error(
f"[Collect Trajectory Seed: {seed}] Failed to create/get episode: {e}",
exc_info=True,
)
return []
for turn in range(max_turns):
current_state_messages = ep.message_history.copy()
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}/{max_turns}] "
f"Current state history length: {len(current_state_messages)}"
)
try:
value_t = await self._estimate_value(
episode_seed_for_sim=ep.seed, env_actions_to_replay=ep.actions
)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Estimated V(s_t) = {value_t:.4f}"
)
except Exception as e_vt:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Error estimating V(s_t): {e_vt}",
exc_info=True,
)
break
messages_for_llm = current_state_messages.copy()
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
messages_for_llm.append({"role": "agent", "content": agent_prompt_content})
try:
responses = await self._sample_response(messages_for_llm, n=G)
if len(responses) != G:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Expected {G} responses, got {len(responses)}. "
f"Aborting trajectory."
)
break
except Exception as e_sample:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Error sampling responses: {e_sample}",
exc_info=True,
)
break
alt_full_responses: List[str] = []
alt_parsed_actions: List[int] = []
alt_env_actions: List[int] = []
alt_raw_rewards: List[float] = []
alt_combined_rewards: List[float] = []
alt_next_state_msgs: List[List[Dict]] = []
alt_is_terminal: List[bool] = []
alt_tokens: List[List[int]] = []
alt_masks: List[List[int]] = []
alt_value_next: List[float] = []
alt_advantages: List[float] = []
for i in range(G):
llm_output_only = responses[i]
full_agent_response = agent_prompt_content + llm_output_only
alt_full_responses.append(full_agent_response)
parsed_action = self._parse_tool_call(full_agent_response)
alt_parsed_actions.append(parsed_action)
env_action = parsed_action if parsed_action != -1 else 0
alt_env_actions.append(env_action)
sim_env = None
raw_env_reward_i = 0.0
term_i, trunc_i = False, False
next_state_msgs_i = []
try:
sim_env = gymnasium.make(self.config.env_name)
_, _ = sim_env.reset(seed=ep.seed)
for prev_action in ep.actions:
_, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
if term_replay or trunc_replay:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
f"Sim env terminated during replay. State mismatch?"
)
term_i, trunc_i = True, True
raw_env_reward_i = 0.0
break
if not (term_i or trunc_i):
sim_obs_next, raw_env_reward_i, term_i, trunc_i, _ = (
sim_env.step(env_action)
)
alt_raw_rewards.append(raw_env_reward_i)
alt_is_terminal.append(term_i or trunc_i)
combined_reward_i = self._score_response(
raw_env_reward_i, full_agent_response, parsed_action, ep.seed
)
alt_combined_rewards.append(combined_reward_i)
current_state_plus_response = current_state_messages + [
{"role": "agent", "content": full_agent_response}
]
if sim_obs_next is not None:
next_state_msgs_i = current_state_plus_response + [
{
"role": "environment",
"content": self._format_observation(sim_obs_next),
}
]
else:
next_state_msgs_i = current_state_plus_response
alt_next_state_msgs.append(next_state_msgs_i)
tokenized_i = tokenize_for_trainer(
self.tokenizer, next_state_msgs_i
)
alt_tokens.append(tokenized_i["tokens"])
alt_masks.append(tokenized_i["masks"])
except Exception as e_sim:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
f"Error simulating action {env_action}: {e_sim}",
exc_info=True,
)
alt_raw_rewards.append(0.0)
alt_combined_rewards.append(-1.0)
alt_next_state_msgs.append(
current_state_messages
+ [{"role": "agent", "content": full_agent_response}]
)
alt_is_terminal.append(True)
alt_tokens.append([])
alt_masks.append([])
finally:
if sim_env:
sim_env.close()
alt_value_next: List[float] = []
for i in range(G):
if not alt_is_terminal[i]:
try:
actions_to_reach_s_prime = ep.actions + [alt_env_actions[i]]
value_next_i = await self._estimate_value(
episode_seed_for_sim=ep.seed,
env_actions_to_replay=actions_to_reach_s_prime,
)
alt_value_next.append(value_next_i)
except Exception as e_vn:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
f"Error estimating V(s'): {e_vn}",
exc_info=True,
)
alt_value_next.append(0.0)
else:
alt_value_next.append(0.0)
for i in range(G):
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
# If we pass this then instead of raw scores, implicitly, we're
# doing some credit assignment. Could maybe do bonus on a win too
# and/or apply with a discount factor to alts in winning trajectories
alt_advantages.append(advantage_i)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
f"CombinedR={alt_combined_rewards[i]:.2f}, V_t={value_t:.2f}, "
f"V_t+1={alt_value_next[i]:.2f} => Advantage={advantage_i:.2f}"
)
if (
len(alt_tokens) != G
or len(alt_masks) != G
or len(alt_advantages) != G
or len(alt_next_state_msgs) != G
or len(alt_parsed_actions) != G
):
# sanity check
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Mismatch in alternative list lengths after processing. "
f"Aborting trajectory."
)
break
trajectory_data_for_trainer.append(
BlackjackScoredDataGroup(
seed=ep.seed,
tokens=alt_tokens,
masks=alt_masks,
scores=alt_advantages,
messages=alt_next_state_msgs,
parsed_actions=alt_parsed_actions,
)
)
# token lengths for tie-breaking during selection
alt_token_lengths = [len(tkns) for tkns in alt_tokens]
best_advantage_idx = select_best_index(
primary_scores=alt_advantages,
secondary_scores=alt_token_lengths,
primary_higher_is_better=True,
secondary_lower_is_better=True
)
chosen_advantage_for_log = alt_advantages[best_advantage_idx]
chosen_token_length_for_log = alt_token_lengths[best_advantage_idx]
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Selected Alt {best_advantage_idx} "
f"(Adv: {chosen_advantage_for_log:.2f}, "
f"Tokens: {chosen_token_length_for_log}) "
f"from {G} alternatives using select_best_index."
)
chosen_env_action = alt_env_actions[best_advantage_idx]
chosen_full_response = alt_full_responses[best_advantage_idx]
chosen_raw_env_reward = alt_raw_rewards[best_advantage_idx]
chosen_is_terminal = alt_is_terminal[best_advantage_idx]
chosen_parsed_action = alt_parsed_actions[best_advantage_idx]
logger.info(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Chosen action to step env: "
f"{chosen_env_action} (from Alt {best_advantage_idx} with "
f"Adv {chosen_advantage_for_log:.2f})"
)
ep.num_total_actions += 1
if chosen_parsed_action != -1:
ep.num_correct_actions += 1
ep.message_history = current_state_messages
response_for_history = truncate_thinking(
chosen_full_response, self.tokenizer, self.config.max_think_chars_history
)
ep.message_history.append(
{"role": "agent", "content": response_for_history}
)
try:
main_obs, main_reward, main_term, main_trunc, _ = ep.env.step(
chosen_env_action
)
if abs(main_reward - chosen_raw_env_reward) > 1e-6:
logger.warning(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Mismatch between simulated reward ({chosen_raw_env_reward}) and "
f"main env step reward ({main_reward}) for chosen action {chosen_env_action}."
)
if (main_term or main_trunc) != chosen_is_terminal:
logger.warning(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Mismatch between simulated terminal state ({chosen_is_terminal}) and "
f"main env step terminal state ({(main_term or main_trunc)}) "
f"for chosen action {chosen_env_action}."
)
term = main_term
trunc = main_trunc
obs = main_obs
ep.actions.append(chosen_env_action)
ep.step_rewards.append(main_reward)
ep.num_steps += 1
if obs:
ep.message_history.append(
{
"role": "environment",
"content": self._format_observation(obs),
}
)
except Exception as e_main_step:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Error stepping MAIN environment with chosen action {chosen_env_action}: {e_main_step}",
exc_info=True,
)
term, trunc = True, True
if term or trunc:
ep.total_reward = sum(ep.step_rewards)
logger.info(
f"[Collect Trajectory Seed: {seed}] Trajectory ended. "
f"Term={term}, Trunc={trunc}. Total raw env reward: {ep.total_reward}"
)
break
final_raw_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0
logger.info(
f"[Collect Trajectory Seed: {seed}] Finished collecting trajectory. "
f"Steps collected: {len(trajectory_data_for_trainer)}, "
f"Final raw reward: {final_raw_reward:.2f}"
)
if ep:
game_outcome = 0
if final_raw_reward > 0:
game_outcome = 1
elif final_raw_reward < 0:
game_outcome = -1
# debugging
episode_summary_metrics = {
"seed": ep.seed,
"total_reward": final_raw_reward,
"num_steps": ep.num_steps,
"num_correct_actions": ep.num_correct_actions,
"num_total_actions": ep.num_total_actions,
"game_outcome": game_outcome,
}
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
logger.debug(
f"[Collect Trajectory Seed: {seed}] Added episode summary to buffer: {episode_summary_metrics}"
)
if seed in self.episodes:
try:
self.episodes[seed].env.close()
except Exception as e_close:
logger.warning(
f"[Collect Trajectory Seed: {seed}] Exception closing final env: {e_close}"
)
del self.episodes[seed]
return ensure_trajectory_token_limit(
trajectory_data_for_trainer,
self.tokenizer,
self.config.max_trajectory_tokens,
)
async def score(
self, rollout_group_data: List[BlackjackScoredDataGroup]
) -> List[Optional[BlackjackScoredDataGroup]]:
"""Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
already contains the A*(s,a) advantages from the collection phase.
If you wanted to play around with additional scoring metrics, you could do so here.
Eg, bonuses for the specific winning action trajectory
Args:
rollout_group_data: List of BlackjackScoredDataGroup objects containing the collected rollout data.
Returns:
List of BlackjackScoredDataGroup objects with the scores field updated.
"""
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
return rollout_group_data
async def collect_trajectories(
self, item: Tuple[int, int]
) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]:
"""Collect trajectories for training.
Args:
item: Tuple containing the seed and the group index.
Returns:
Tuple of two lists:
- List of BlackjackScoredDataGroup objects containing the collected rollout data.
- List of Tuple[int, int] objects for the backlog
"""
seed, _ = item
traj = await self._collect_trajectory(seed)
if not traj:
logger.warning(f"[Collect Trajectories] Empty trajectory for seed {seed}.")
return traj, []
async def setup(self):
pass
async def get_next_item(self) -> Tuple[int, int]:
return (random.randint(0, 1000000), 0)
async def rollout_and_score_eval(self, seed: int) -> Dict[str, float]:
"""Run a single episode for evaluation with a single response per step."""
ep = self._get_or_create_episode(seed)
max_turns = self.config.max_turns or 5
metrics = {
"seed": seed,
"total_reward": 0.0,
"num_turns": 0,
"num_correct_actions": 0,
"num_invalid_actions": 0,
"game_outcome": 0,
}
for turn in range(max_turns):
messages = ep.message_history.copy()
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
messages.append({"role": "agent", "content": agent_prompt_content})
responses = await self._sample_response(messages, n=1)
if not responses:
logger.error(
f"[Eval Seed: {seed}, Turn: {turn+1}] No response. Aborting."
)
break
llm_output_only = responses[0]
full_agent_response = agent_prompt_content + llm_output_only
action = self._parse_tool_call(full_agent_response)
if action == -1:
metrics["num_invalid_actions"] += 1
action = 0
else:
metrics["num_correct_actions"] += 1
try:
obs, reward, term, trunc, _ = ep.env.step(action)
except Exception as e:
logger.error(f"[Eval Seed: {seed}, Turn: {turn+1}] Env error: {e}")
term = True
reward = -1.0
obs = None
metrics["total_reward"] += reward
metrics["num_turns"] = turn + 1
response_for_history = truncate_thinking(
full_agent_response, self.tokenizer, self.config.max_think_chars_history
)
ep.message_history.append(
{"role": "agent", "content": response_for_history}
)
if obs:
ep.message_history.append(
{"role": "environment", "content": self._format_observation(obs)}
)
if term or trunc:
metrics["game_outcome"] = int(reward)
logger.info(f"[Eval Seed: {seed}] Episode ended. Outcome: {reward}")
break
ep.env.close()
del self.episodes[seed]
return metrics
async def evaluate(self, *args, **kwargs):
if not self.config.use_wandb:
logger.info("Skipping evaluation as wandb is not enabled.")
return
num_eval_episodes = self.config.eval_episodes
logger.info(f"Starting evaluation for {num_eval_episodes} episodes.")
eval_tasks = [
self.rollout_and_score_eval(random.randint(1000001, 2000000))
for _ in range(num_eval_episodes)
]
all_metrics = await tqdm_asyncio.gather(*eval_tasks)
valid_metrics = [m for m in all_metrics if m]
if not valid_metrics:
logger.warning("No valid evaluation metrics.")
return
num_completed = len(valid_metrics)
avg_total_reward = sum(m["total_reward"] for m in valid_metrics) / num_completed
avg_num_turns = sum(m["num_turns"] for m in valid_metrics) / num_completed
total_correct = sum(m["num_correct_actions"] for m in valid_metrics)
total_invalid = sum(m["num_invalid_actions"] for m in valid_metrics)
total_actions = total_correct + total_invalid
action_accuracy = total_correct / total_actions if total_actions > 0 else 0
win_rate = (
sum(1 for m in valid_metrics if m["game_outcome"] == 1) / num_completed
)
loss_rate = (
sum(1 for m in valid_metrics if m["game_outcome"] == -1) / num_completed
)
draw_rate = (
sum(1 for m in valid_metrics if m["game_outcome"] == 0) / num_completed
)
self.eval_metrics = [
("eval/avg_total_reward", avg_total_reward),
("eval/avg_num_turns", avg_num_turns),
("eval/action_accuracy", action_accuracy),
("eval/win_rate", win_rate),
("eval/loss_rate", loss_rate),
("eval/draw_rate", draw_rate),
("eval/num_completed_episodes", num_completed),
]
logger.info(f"Evaluation completed. Metrics: {self.eval_metrics}")
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.completed_episode_metrics_buffer:
num_episodes = len(self.completed_episode_metrics_buffer)
avg_reward = (
sum(m["total_reward"] for m in self.completed_episode_metrics_buffer)
/ num_episodes
)
avg_steps = (
sum(m["num_steps"] for m in self.completed_episode_metrics_buffer)
/ num_episodes
)
win_rate = (
sum(
1
for m in self.completed_episode_metrics_buffer
if m["game_outcome"] == 1
)
/ num_episodes
)
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward"
] = avg_reward
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_steps"
] = avg_steps
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/episode_win_rate"
] = win_rate
wandb_metrics[f"{self.wandb_prepend or 'blackjack'}_train/num_episodes"] = (
num_episodes
)
self.completed_episode_metrics_buffer = []
await super().wandb_log(wandb_metrics)
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
env_config = BlackjackEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="fundamental_metric_prediction",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
env_name="Blackjack-v1",
temperature=0.7,
top_p=0.9,
max_turns=5,
thinking_active=True,
eval_episodes=100,
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=False,
tiebreak_token_factor=0.01,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256,
)
]
return env_config, server_configs
@classmethod
def cli(cls):
super().cli()
if __name__ == "__main__":
BlackjackEnv.cli()