diff --git a/environments/game_environments/gymnasium/blackjack_env.py b/environments/game_environments/gymnasium/blackjack_env.py
index fecd759b..c6dbdf81 100644
--- a/environments/game_environments/gymnasium/blackjack_env.py
+++ b/environments/game_environments/gymnasium/blackjack_env.py
@@ -5,7 +5,9 @@ BlackjackEnv: Trainer environment for Gymnasium Blackjack
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
using function-call style actions. Extends BaseEnv.
-Uses Monte Carlo sampling to estimate the value of the current state, similar to VinePPO
+Sort of inspired by VinePPO, but uses a recursive exact value calculation
+instead of Monte Carlo sampling (because the action space is so small and
+the environment is deterministic, plus short episode lengths).
"""
import copy
@@ -16,7 +18,6 @@ import re
from typing import Any, Dict, List, Optional, Tuple
import gymnasium
-import numpy as np
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
@@ -44,7 +45,6 @@ class BlackjackEnvConfig(BaseEnvConfig):
max_trajectory_tokens: int = 24576
debug_mode: bool = False
group_size: int = 16
- mc_samples: int = 3
tiebreak_token_factor: float = 0.01
@@ -166,11 +166,10 @@ class BlackjackEnv(BaseEnv):
current_env_reward -= 0.2
# Calculate the number of tokens in the agent's response
- num_tokens = len(
- tokenize_for_trainer_multistep(
- self.tokenizer, [{"role": "agent", "content": response_text}]
- )["tokens"]
- )
+ if response_text:
+ num_tokens = len(self.tokenizer.encode(response_text))
+ else:
+ num_tokens = 0
# tiebreak & small length penalty
if self.config.max_token_length > 0:
@@ -232,7 +231,7 @@ class BlackjackEnv(BaseEnv):
episode_seed_for_sim: int,
env_actions_to_replay: List[int],
) -> float:
- """Calculate exact state value V*(s) using recursive calls and memoization.
+ """Calculate exact state value V*(s)
Args:
episode_seed_for_sim: The seed of the original episode for deterministic env creation.
@@ -240,7 +239,9 @@ class BlackjackEnv(BaseEnv):
"""
v_star_cache: Dict[Tuple[int, int, int], float] = {}
- def _get_v_star_recursive(obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env) -> float:
+ def _get_v_star_recursive(
+ obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env
+ ) -> float:
player_sum, dealer_card, usable_ace = obs_tuple
# Base Case 1: Bust
@@ -259,13 +260,13 @@ class BlackjackEnv(BaseEnv):
# Q-value for HIT (action 1)
env_for_hit = copy.deepcopy(current_env)
obs_hit, reward_hit, term_hit, trunc_hit, _ = env_for_hit.step(1)
-
- if term_hit or trunc_hit: # Game ended after hitting
+
+ if term_hit or trunc_hit: # Game ended after hitting
q_star_hit = reward_hit
- else: # Game continues, recursively find V* of next state
+ else: # Game continues, recursively find V* of next state
# reward_hit is typically 0 if the game didn't end
q_star_hit = reward_hit + _get_v_star_recursive(obs_hit, env_for_hit)
-
+
v_star = max(q_star_stick, q_star_hit)
v_star_cache[obs_tuple] = v_star
return v_star
@@ -286,7 +287,7 @@ class BlackjackEnv(BaseEnv):
)
is_terminal_after_replay = True
break
-
+
if is_terminal_after_replay:
return 0.0
final_v_star = _get_v_star_recursive(current_obs, sim_env)
@@ -294,10 +295,12 @@ class BlackjackEnv(BaseEnv):
except Exception as e:
logger.error(
- f"[_estimate_value] Error during exact value calculation for seed {episode_seed_for_sim}, actions {env_actions_to_replay}: {e}",
- exc_info=True
+ f"[_estimate_value] Error during exact value"
+ f" calculation for seed {episode_seed_for_sim}, "
+ f"actions {env_actions_to_replay}: {e}",
+ exc_info=True,
)
- return 0.0 # Return a default value on error
+ return 0.0
finally:
if sim_env is not None:
sim_env.close()
@@ -305,14 +308,13 @@ class BlackjackEnv(BaseEnv):
async def collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
"""Collect data for ONE trajectory, evaluating G alternatives per step using MC advantages."""
G = self.config.group_size
- K = self.config.mc_samples
max_turns = self.config.max_turns or 5
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
episode_summary_metrics: Optional[Dict[str, Any]] = None
logger.info(
- f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}, MC samples K={K}."
+ f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}."
)
try:
@@ -333,8 +335,7 @@ class BlackjackEnv(BaseEnv):
try:
value_t = await self._estimate_value(
- episode_seed_for_sim=ep.seed,
- env_actions_to_replay=ep.actions
+ episode_seed_for_sim=ep.seed, env_actions_to_replay=ep.actions
)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Estimated V(s_t) = {value_t:.4f}"
@@ -395,11 +396,9 @@ class BlackjackEnv(BaseEnv):
next_state_msgs_i = []
try:
sim_env = gymnasium.make(self.config.env_name)
- sim_obs, _ = sim_env.reset(seed=ep.seed)
+ _, _ = sim_env.reset(seed=ep.seed)
for prev_action in ep.actions:
- sim_obs, _, term_replay, trunc_replay, _ = sim_env.step(
- prev_action
- )
+ _, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
if term_replay or trunc_replay:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
@@ -468,7 +467,7 @@ class BlackjackEnv(BaseEnv):
actions_to_reach_s_prime = ep.actions + [alt_env_actions[i]]
value_next_i = await self._estimate_value(
episode_seed_for_sim=ep.seed,
- env_actions_to_replay=actions_to_reach_s_prime
+ env_actions_to_replay=actions_to_reach_s_prime,
)
alt_value_next.append(value_next_i)
except Exception as e_vn:
@@ -483,6 +482,9 @@ class BlackjackEnv(BaseEnv):
for i in range(G):
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
+ # If we pass this then instead of raw scores, implicitly, we're
+ # doing some credit assignment. Could maybe do bonus on a win too
+ # and apply with a discount factor to alts in winning trajectories
alt_advantages.append(advantage_i)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
@@ -497,6 +499,7 @@ class BlackjackEnv(BaseEnv):
or len(alt_next_state_msgs) != G
or len(alt_parsed_actions) != G
):
+ # sanity check
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Mismatch in alternative list lengths after processing. "
@@ -509,38 +512,31 @@ class BlackjackEnv(BaseEnv):
seed=ep.seed,
tokens=alt_tokens,
masks=alt_masks,
- scores=alt_combined_rewards,
+ scores=alt_advantages,
messages=alt_next_state_msgs,
parsed_actions=alt_parsed_actions,
)
)
- # Determine the best alternative: highest advantage, tie-broken by shortest token length.
- if G == 0: # Should ideally not occur with G = self.config.group_size > 0
- logger.error(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"No alternatives to choose from (G=0). Aborting turn."
- )
- break # Exit the current turn processing.
-
# Prepare items for sorting: (-advantage, token_length, original_index)
- # Sorting this list will place the best alternative (highest advantage, then shortest tokens) first.
sortable_alternatives = []
for i in range(G):
# alt_tokens[i] is expected to be a list (possibly empty)
token_len = len(alt_tokens[i])
sortable_alternatives.append((-alt_advantages[i], token_len, i))
-
- sortable_alternatives.sort() # Sorts in-place
-
+
+ sortable_alternatives.sort()
+
best_advantage_idx = sortable_alternatives[0][2]
-
+
# Log details of the selected alternative based on the sort
- chosen_advantage_for_log = -sortable_alternatives[0][0] # Revert sign for logging
+ chosen_advantage_for_log = -sortable_alternatives[0][0]
chosen_token_length_for_log = sortable_alternatives[0][1]
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Selected Alt {best_advantage_idx} (Adv: {chosen_advantage_for_log:.2f}, Tokens: {chosen_token_length_for_log}) "
+ f"Selected Alt {best_advantage_idx} "
+ f"(Adv: {chosen_advantage_for_log:.2f}, "
+ f"Tokens: {chosen_token_length_for_log}) "
f"from {G} alternatives using sort."
)
@@ -570,7 +566,7 @@ class BlackjackEnv(BaseEnv):
)
try:
- main_obs, main_reward, main_term, main_trunc, main_info = ep.env.step(
+ main_obs, main_reward, main_term, main_trunc, _ = ep.env.step(
chosen_env_action
)
if abs(main_reward - chosen_raw_env_reward) > 1e-6:
@@ -630,7 +626,7 @@ class BlackjackEnv(BaseEnv):
game_outcome = 1
elif final_raw_reward < 0:
game_outcome = -1
-
+ # debugging
episode_summary_metrics = {
"seed": ep.seed,
"total_reward": final_raw_reward,
@@ -658,7 +654,9 @@ class BlackjackEnv(BaseEnv):
async def score(
self, rollout_group_data: List[BlackjackScoredDataGroup]
) -> List[Optional[BlackjackScoredDataGroup]]:
- """Return rollout data with advantages as scores."""
+ """Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
+ already contains the A*(s,a) advantages from the collection phase.
+ """
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
return rollout_group_data
diff --git a/environments/game_environments/gymnasium/blackjack_no_mc_env.py b/environments/game_environments/gymnasium/blackjack_no_mc_env.py
deleted file mode 100644
index fffb6aeb..00000000
--- a/environments/game_environments/gymnasium/blackjack_no_mc_env.py
+++ /dev/null
@@ -1,1665 +0,0 @@
-"""
-BlackjackEnv: Trainer environment for Gymnasium Blackjack
-
-This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
-using function-call style actions. Extends BaseEnv.
-
-Alternative formulation of BlackjackEnv that uses a best-of-n approach to select actions
-and no Monte Carlo sampling (direct bonus for winning trajectory). Much faster to train,
-but may not be as effective at learning correct strategy (it's effectively a series of bandits).
-"""
-
-import json
-import logging
-import random
-from typing import Any, Dict, List, Optional, Tuple
-
-import gymnasium
-from tqdm.asyncio import tqdm_asyncio
-
-from atroposlib.envs.base import (
- BaseEnv,
- BaseEnvConfig,
- EvalHandlingEnum,
- OpenaiConfig,
- ScoredDataGroup,
-)
-from atroposlib.type_definitions import Message
-from atroposlib.utils.tokenize_for_trainer import (
- UNMASKED_ROLES,
- tokenize_for_trainer_multistep,
-)
-from atroposlib.utils.tool_call_parser import parse_tool_call
-
-logger = logging.getLogger(__name__)
-
-
-class BlackjackEnvConfig(BaseEnvConfig):
- """
- Configuration for the Blackjack environment trainer.
- """
-
- env_name: str = "Blackjack-v1"
- temperature: float = 0.7
- top_p: float = 0.9
- max_turns: Optional[int] = 5
- wandb_name: str = "blackjack"
-
- thinking_active: bool = True
- eval_episodes: int = 100
-
- batch_size: int = 1024
- max_think_chars_history: int = 3000
- max_trajectory_tokens: int = 24576
- debug_mode: bool = False
-
-
-class BlackjackScoredDataGroup(ScoredDataGroup):
- """
- Represents the scored data for a single step in a Blackjack trajectory, potentially including multiple alternatives.
- """
-
- seed: int
- tokens: Optional[List[List[int]]] = None
- masks: Optional[List[List[int]]] = None
- scores: Optional[List[float]] = None
- messages: Optional[List[List[Message]]] = None
- parsed_action: Optional[int] = None
-
-
-class EpisodeState:
- """
- Stores per-episode state: gym env, history, actions, rewards, trajectory.
- """
-
- def __init__(self, seed: int, env: gymnasium.Env):
- self.seed: int = seed
- self.env: gymnasium.Env = env
- self.message_history: List[Message] = []
- self.actions: List[int] = []
- self.step_rewards: List[float] = []
- self.trajectory: List[BlackjackScoredDataGroup] = []
- self.total_env_reward: float = 0.0
- self.num_correct_actions: int = 0
- self.num_total_actions: int = 0
-
-
-class BlackjackEnv(BaseEnv):
- """
- Trainer environment for Gymnasium Blackjack using a best-of-n approach with function-call style actions.
- """
-
- def __init__(
- self,
- config: BlackjackEnvConfig,
- server_configs: List[OpenaiConfig],
- slurm: bool = True,
- testing: bool = False,
- ):
- super().__init__(config, server_configs, slurm, testing)
- self.episodes: Dict[int, EpisodeState] = {}
- self.debug_mode = config.debug_mode
- self.completed_episode_metrics_buffer: List[Dict[str, Any]] = []
-
- if self.debug_mode:
- logger.setLevel(logging.DEBUG)
- else:
- if logger.level == logging.NOTSET or logger.level > logging.WARNING:
- logger.setLevel(logging.WARNING)
-
- self.tools = [
- {
- "type": "function",
- "function": {
- "name": "take_action",
- "description": "Choose to 'hit' or 'stick' in Blackjack.",
- "parameters": {
- "action": {"type": "string", "enum": ["hit", "stick"]}
- },
- },
- }
- ]
-
- tools_json = json.dumps(self.tools)
- self.system_prompt = (
- "You are an AI agent playing Blackjack who uses extreme long chains of thought "
- "to carefully consider the probabilities and optimal strategy. "
- "You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
- "You should enclose your thoughts and internal monologue inside tags, and then "
- "provide your decision using the take_action function call. You may use extremely long chains "
- "of thought to carefully consider the probabilities and optimal strategy.\n\n"
- f"\n{tools_json}\n\n\n"
- "For your function call, return a JSON object with function name and arguments "
- "within tags with the following schema:\n"
- '\n{"arguments": {"action": "hit"}, "name": "take_action"}\n\n\n'
- "Your answer format should be:\n"
- "\n"
- "[Your detailed reasoning process about whether to hit or stick]\n"
- "\n\n"
- '\n{"arguments": {"action": "stick"}, "name": "take_action"}\n\n\n'
- "Remember to carefully consider the probabilities and optimal strategy for Blackjack."
- )
-
- def _get_or_create_episode(self, seed: int) -> EpisodeState:
- """Retrieve existing or create a new episode state keyed by seed."""
- if seed not in self.episodes:
- env = gymnasium.make(self.config.env_name)
- obs, _ = env.reset(seed=seed)
- ep = EpisodeState(seed, env)
- ep.message_history = [{"role": "system", "content": self.system_prompt}]
- formatted = self._format_observation(obs)
- ep.message_history.append({"role": "environment", "content": formatted})
- self.episodes[seed] = ep
- return self.episodes[seed]
-
- def _format_observation(self, obs: Tuple[int, int, int]) -> str:
- """Convert Blackjack observation to text for LLM."""
- player_sum, dealer_card, usable_ace = obs
- return (
- f"Your hand sum is {player_sum}. "
- f"Dealer showing: {dealer_card}. "
- f"You have a usable ace: {usable_ace}."
- )
-
- def _parse_tool_call(self, response: str) -> int:
- """Extract 'hit'/'stick' and map to action 1/0."""
- tool_name, arguments, is_error = parse_tool_call(
- response, self.tools, ["tool_call"]
- )
-
- logger.warning(
- f"Parsed tool call: name={tool_name}, args={arguments}, error={is_error}"
- )
-
- if is_error:
- logger.warning(f"Failed to parse tool call from response: {response}")
- return -1
-
- action = arguments.get("action", "").lower()
- if action == "hit":
- return 1
- elif action == "stick":
- return 0
- else:
- logger.warning(f"Invalid action value: {action}")
- return -1
-
- def _score_response(
- self,
- env_reward: float,
- response_text: str,
- parsed_action: int,
- episode_seed: int,
- ) -> float:
- """
- Calculates a score for a single agent response based purely on environment reward
- and a penalty for invalid action format.
- """
- current_env_reward = env_reward
-
- if parsed_action == -1:
- current_env_reward -= 0.5
- logger.debug(
- f"[_score_response Seed: {episode_seed}] Penalty applied for invalid action format (-0.5)."
- )
-
- final_score = current_env_reward
-
- logger.debug(
- f"[_score_response Seed: {episode_seed}] Final Score Calculation: "
- f"Env Reward (raw): {env_reward:.4f}, "
- f"Env Reward (adjusted for invalid): {current_env_reward:.4f}, "
- f"==> Final Score (from env): {final_score:.4f}"
- )
- return final_score
-
- async def _select_best_action(
- self, episode: EpisodeState, actions: List[int], responses: List[str]
- ) -> Tuple[int, List[float]]:
- """
- Simulates and scores multiple candidate actions to select the best one.
-
- Args:
- episode: The current episode state.
- actions: A list of parsed actions (0, 1, or -1) corresponding to the responses.
- responses: A list of full agent responses (......).
-
- Returns:
- A tuple containing:
- - The best action selected (0, 1, or -1).
- - A list of scores for each action/response.
- """
- if len(actions) != len(responses):
- logger.error(
- f"[_select_best_action Seed: {episode.seed}] "
- f"Mismatch between actions ({len(actions)}) and responses ({len(responses)}) count."
- )
- default_action = next((a for a in actions if a != -1), -1)
- return default_action, [-10.0] * len(actions)
-
- scores = [0.0] * len(actions)
- token_lengths = [0] * len(actions)
-
- try:
- for idx, (action, response_text) in enumerate(zip(actions, responses)):
- sim_env = gymnasium.make(self.config.env_name)
- sim_obs, sim_info = sim_env.reset(seed=episode.seed)
- valid_sim = True
- for past_action in episode.actions:
- sim_obs, _, term, trunc, sim_info = sim_env.step(past_action)
- if term or trunc:
- logger.warning(
- f"[_select_best_action Seed: {episode.seed}] "
- f"Episode terminated during history replay before simulating action {idx}. "
- f"Assigning low score."
- )
- valid_sim = False
- break
- if not valid_sim:
- scores[idx] = -10.0
- continue
-
- if action == -1:
- env_reward_sim = 0.0
- else:
- _obs_sim, env_reward_sim, term_sim, trunc_sim, _info_sim = (
- sim_env.step(action)
- )
- logger.debug(
- f"[_select_best_action Seed: {episode.seed}] Sim Action {idx} "
- f"(val:{action}) -> Reward:{env_reward_sim}, Term:{term_sim}"
- )
-
- combined_score = self._score_response(
- env_reward=env_reward_sim,
- response_text=response_text,
- parsed_action=action,
- episode_seed=episode.seed,
- )
- scores[idx] = combined_score
- token_lengths[idx] = len(self.tokenizer.encode(response_text))
-
- except Exception as e:
- logger.exception(
- f"[_select_best_action Seed: {episode.seed}] "
- f"Error during action simulation/scoring: {e}"
- )
- default_action = next((a for a in actions if a != -1), -1)
- return default_action, [-10.0] * len(actions)
-
- best_score = float("-inf")
- best_action = -1
- best_action_idx = -1
-
- if scores:
- best_score = max(scores)
- potential_best_indices = [
- i for i, score in enumerate(scores) if score == best_score
- ]
-
- valid_indices = [i for i in potential_best_indices if actions[i] != -1]
- if valid_indices:
- if len(valid_indices) > 1:
- try:
- best_action_idx = min(
- valid_indices, key=lambda i: token_lengths[i]
- )
- logger.debug(
- f"[_select_best_action Seed: {episode.seed}] "
- f"Tie-breaking valid actions based on token length. Chosen index: {best_action_idx}"
- )
- except IndexError:
- logger.warning(
- f"[_select_best_action Seed: {episode.seed}] "
- f"IndexError during token length tie-breaking. Defaulting to first valid index."
- )
- best_action_idx = valid_indices[0]
- else:
- best_action_idx = valid_indices[0]
- elif potential_best_indices:
- best_action_idx = potential_best_indices[0]
- logger.debug(
- f"[_select_best_action Seed: {episode.seed}] "
- f"All best scores correspond to invalid actions. Choosing first: index {best_action_idx}"
- )
- else:
- logger.error(
- f"[_select_best_action Seed: {episode.seed}] "
- f"No potential best indices found despite scores existing. Returning default action -1."
- )
- best_action_idx = -1
-
- if best_action_idx != -1:
- best_action = actions[best_action_idx]
- else:
- best_action = -1
-
- logger.info(
- f"[_select_best_action Seed: {episode.seed}] Selected action: {best_action} "
- f"(Index: {best_action_idx}, "
- f"Score: {scores[best_action_idx] if best_action_idx != -1 else 'N/A'}) "
- f"from scores: {['{:.4f}'.format(s) for s in scores]}"
- )
- else:
- logger.error(
- f"[_select_best_action Seed: {episode.seed}] No scores calculated. Returning default action -1."
- )
-
- return best_action, scores
-
- async def collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
- """
- Run a single episode from the given seed, using a best-of-n approach each step.
- Refactored to use _select_best_action.
- Returns a list of BlackjackScoredDataGroup, one per time step.
- """
- ep = self._get_or_create_episode(seed)
- max_turns = self.config.max_turns if self.config.max_turns is not None else 5
- logger.info(
- f"[Collect Trajectory Seed: {seed}] Starting episode. Max turns: {max_turns}"
- )
-
- for turn in range(max_turns):
- logger.debug(
- f"[Collect Trajectory Seed: {seed}] Starting Turn {turn + 1}/{max_turns}"
- )
- messages_for_prompt = ep.message_history.copy()
-
- if self.config.thinking_active:
- messages_for_prompt.append({"role": "agent", "content": "\n"})
- else:
- messages_for_prompt.append({"role": "agent", "content": ""})
-
- prompt = self.tokenizer.apply_chat_template(
- messages_for_prompt, tokenize=False
- )
- logger.debug(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Prompting LLM..."
- )
-
- try:
- completions = await self.server.completion(
- prompt=prompt,
- n=self.config.group_size,
- max_tokens=self.config.max_token_length,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- )
- except Exception as api_error:
- logger.exception(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"API Error during self.server.completion: {api_error}"
- )
- return self._ensure_trajectory_token_limit(ep.trajectory)
-
- if (
- not completions
- or not completions.choices
- or len(completions.choices) != self.config.group_size
- ):
- logger.error(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"API did not return the expected number of choices "
- f"({self.config.group_size} vs {len(completions.choices) if completions else 0}). "
- f"Aborting episode."
- )
- return self._ensure_trajectory_token_limit(ep.trajectory)
-
- alt_actions: List[int] = []
- alt_responses: List[str] = []
- for choice_idx, choice in enumerate(completions.choices):
- response_text = (
- choice.text
- if hasattr(choice, "text")
- else getattr(choice.message, "content", "")
- )
- full_response = (
- ("\n" + response_text)
- if self.config.thinking_active
- else response_text
- )
- alt_responses.append(full_response)
-
- parsed_act = self._parse_tool_call(full_response)
- alt_actions.append(parsed_act)
- logger.debug(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Choice {choice_idx}: Parsed Action={parsed_act}, Response Length={len(full_response)}"
- )
-
- logger.debug(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Selecting best action..."
- )
- best_action, scores = await self._select_best_action(
- ep, alt_actions, alt_responses
- )
-
- best_action_idx = -1
- try:
- best_score_val = max(scores)
- possible_indices = [
- i
- for i, (act, score) in enumerate(zip(alt_actions, scores))
- if act == best_action and score == best_score_val
- ]
- if possible_indices:
- best_action_idx = possible_indices[0]
- logger.info(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Best action selected: {best_action} "
- f"(Index: {best_action_idx}), "
- f"Score: {scores[best_action_idx]:.4f}"
- )
- else:
- logger.warning(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Could not find index for best action {best_action} with score {best_score_val}. "
- f"Trying first occurrence of action."
- )
- best_action_idx = alt_actions.index(best_action)
- logger.info(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Fallback - Best action selected: {best_action} (Index: {best_action_idx}), "
- f"Score: {scores[best_action_idx]:.4f}"
- )
-
- best_response = alt_responses[best_action_idx]
- except (ValueError, IndexError) as e:
- logger.error(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Error finding index for best action {best_action}: {e}. "
- f"Cannot proceed with episode."
- )
- if seed in self.episodes:
- try:
- self.episodes[seed].env.close()
- except Exception as close_exc:
- logger.warning(
- f"[Collect Trajectory Seed: {seed}] "
- f"Exception closing env for aborted episode on "
- f"best_action index error: {close_exc}"
- )
- del self.episodes[seed]
- return self._ensure_trajectory_token_limit(ep.trajectory)
-
- alt_tokens: List[List[int]] = []
- alt_masks: List[List[int]] = []
- alt_messages: List[List[Message]] = []
- tokenization_failed_for_step = False
- for response in alt_responses:
- step_msgs: List[Message] = [
- {"role": m["role"], "content": m["content"]}
- for m in ep.message_history
- ]
- step_msgs.append({"role": "agent", "content": response})
-
- try:
- out = tokenize_for_trainer_multistep(self.tokenizer, step_msgs)
- alt_tokens.append(out["tokens"])
- alt_masks.append(out["masks"])
- alt_messages.append(step_msgs)
- except Exception as tokenization_error:
- logger.exception(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Critical tokenization error for response: {response[:100]}... "
- f"Error: {tokenization_error}. Aborting episode."
- )
- tokenization_failed_for_step = True
- break
-
- if tokenization_failed_for_step:
- logger.warning(
- f"[Collect Trajectory Seed: {seed}] Episode aborted at turn {turn+1} due to tokenization failure."
- )
- if seed in self.episodes:
- try:
- self.episodes[seed].env.close()
- except Exception as e:
- logger.warning(
- f"[Collect Trajectory Seed: {seed}] Exception closing env for aborted episode: {e}"
- )
- del self.episodes[seed]
- return self._ensure_trajectory_token_limit(ep.trajectory)
-
- expected_len = self.config.group_size
- if len(alt_tokens) != expected_len:
- alt_tokens.extend([[]] * (expected_len - len(alt_tokens)))
- if len(alt_masks) != expected_len:
- alt_masks.extend([[]] * (expected_len - len(alt_masks)))
- if len(alt_messages) != expected_len:
- alt_messages.extend(
- [
- [
- {
- "role": "system",
- "content": "Missing due to prior success but unexpected count",
- }
- ]
- ]
- * (expected_len - len(alt_messages))
- )
-
- env_action = 0 if best_action == -1 else best_action
- if best_action == -1:
- logger.warning(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Selected action was invalid format (-1). "
- f"Stepping env with 'stick' (0)."
- )
-
- try:
- obs, reward, term, trunc, info = ep.env.step(env_action)
- logger.info(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Stepped main env with action {env_action}. "
- f"Reward: {reward}, Term: {term}, Trunc: {trunc}"
- )
- except Exception as env_step_error:
- logger.exception(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Error stepping main environment with action {env_action}: {env_step_error}"
- )
- term = True
- reward = -1.0
- obs = None
-
- ep.actions.append(env_action)
- ep.step_rewards.append(reward)
-
- ep.total_env_reward += reward
-
- ep.num_total_actions += 1
- if best_action != -1:
- ep.num_correct_actions += 1
-
- logger.info(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
- f"Step Rewards: Env={reward:.2f}. "
- f"Running Totals: Env={ep.total_env_reward:.2f}."
- )
-
- ep.trajectory.append(
- BlackjackScoredDataGroup(
- overrides=[],
- seed=seed,
- tokens=alt_tokens,
- masks=alt_masks,
- scores=scores,
- messages=alt_messages,
- parsed_action=best_action,
- )
- )
-
- if term or trunc:
- logger.info(
- f"[Collect Trajectory Seed: {seed}] "
- f"Episode ended. Term={term}, Trunc={trunc}. "
- f"Final Reward: {reward}"
- )
- ep.message_history.append({"role": "agent", "content": best_response})
-
- if obs is not None:
- final_formatted_obs = self._format_observation(obs)
- logger.debug(
- f"[Collect Trajectory Seed: {seed}] "
- f"Final State: {final_formatted_obs} (Reward: {reward})"
- )
- else:
- logger.debug(
- f"[Collect Trajectory Seed: {seed}] "
- f"Episode terminated with error. (Reward: {reward})"
- )
-
- break
- else:
- response_for_history = self._truncate_thinking_for_history(
- best_response, self.config.max_think_chars_history
- )
- ep.message_history.append(
- {"role": "agent", "content": response_for_history}
- )
- formatted_obs = self._format_observation(obs)
- ep.message_history.append(
- {"role": "environment", "content": formatted_obs}
- )
- logger.debug(
- f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] New Observation: {formatted_obs}"
- )
-
- logger.info(
- f"[Collect Trajectory Seed: {seed}] "
- f"Finished episode after {len(ep.actions)} steps."
- )
- logger.info(
- f"[Collect Trajectory Seed: {seed}] "
- f"Final Totals: Env Reward={ep.total_env_reward:.2f}."
- )
- logger.info(
- f"[Collect Trajectory Seed: {seed}] "
- f"Action Accuracy: {ep.num_correct_actions}/{max(1, ep.num_total_actions)} "
- f"({ep.num_correct_actions/max(1, ep.num_total_actions):.2%})"
- )
-
- final_env_reward_for_outcome = 0
- if ep.step_rewards:
- final_env_reward_for_outcome = ep.step_rewards[-1]
- game_outcome = 0
- if final_env_reward_for_outcome > 0:
- game_outcome = 1
- elif final_env_reward_for_outcome < 0:
- game_outcome = -1
-
- episode_summary_metrics = {
- "seed": seed,
- "total_reward": ep.total_env_reward,
- "num_correct_actions": ep.num_correct_actions,
- "num_total_actions": ep.num_total_actions,
- "game_outcome": game_outcome,
- "num_steps": len(ep.actions),
- }
- self.completed_episode_metrics_buffer.append(episode_summary_metrics)
-
- if seed in self.episodes:
- try:
- self.episodes[seed].env.close()
- except Exception as e:
- logger.warning(
- f"[Collect Trajectory Seed: {seed}] "
- f"Exception closing env for episode: {e}"
- )
- del self.episodes[seed]
- logger.debug(
- f"[Collect Trajectory Seed: {seed}] "
- f"Cleared episode state from self.episodes."
- )
-
- return self._ensure_trajectory_token_limit(ep.trajectory)
-
- async def collect_trajectories(
- self, item: Tuple[int, int]
- ) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]:
- seed, _ = item
- traj = await self.collect_trajectory(seed)
- if traj:
- traj = self._ensure_trajectory_token_limit(traj)
-
- if not traj:
- logger.warning(
- f"[collect_trajectories] "
- f"All steps for seed {seed} were filtered out due to token limit "
- f"constraints. Returning empty trajectory."
- )
-
- return traj, []
-
- async def score(
- self,
- rollout_group_data: List[BlackjackScoredDataGroup],
- ) -> List[Optional[BlackjackScoredDataGroup]]:
- """
- Applies final scoring adjustments to a completed trajectory.
- If the game was a win (determined by replaying chosen actions), a bonus
- is applied to the best alternative at each step.
- Tie-breaking based on token length is applied subsequently.
-
- Args:
- rollout_group_data: The list of ScoredDataGroups representing the trajectory.
-
- Returns:
- The list of ScoredDataGroups with potentially adjusted scores.
- Returns a list containing None elements if input steps are invalid.
- """
- logger.info(
- f"score: Received rollout_group_data with {len(rollout_group_data)} "
- f"groups for scoring."
- )
-
- if not rollout_group_data:
- logger.warning(
- "score: Received empty rollout_group_data. Returning empty list."
- )
- return []
-
- if not all(
- isinstance(rgd, dict) for rgd in rollout_group_data if rgd is not None
- ):
- logger.error(
- "score: rollout_group_data contains non-dictionary elements. "
- "Cannot proceed."
- )
- return [None] * len(rollout_group_data)
-
- final_env_reward_for_outcome = 0.0
- is_win = False
- seed_for_outcome = None
- first_valid_step_for_seed = next(
- (rgd for rgd in rollout_group_data if rgd is not None and "seed" in rgd),
- None,
- )
-
- if not first_valid_step_for_seed:
- logger.warning(
- "score: Cannot determine game outcome, no valid step with seed found "
- "in rollout_group_data."
- )
- else:
- seed_for_outcome = first_valid_step_for_seed["seed"]
- logger.info(
- f"score [Seed: {seed_for_outcome}]: Starting game outcome replay."
- )
- try:
- temp_env_outcome = gymnasium.make(self.config.env_name)
- temp_obs_outcome, _ = temp_env_outcome.reset(seed=seed_for_outcome)
-
- for step_idx_outcome, step_group_for_outcome in enumerate(
- rollout_group_data
- ):
- if step_group_for_outcome is None:
- logger.warning(
- f"score [Seed: {seed_for_outcome}]: "
- f"Encountered None step_group at index {step_idx_outcome} "
- f"during outcome replay. Assuming non-win."
- )
- final_env_reward_for_outcome = 0.0
- break
-
- action_for_outcome_step = step_group_for_outcome.get(
- "parsed_action"
- )
-
- if action_for_outcome_step is None or action_for_outcome_step == -1:
- logger.warning(
- f"score [Seed: {seed_for_outcome}]: "
- f"Invalid action ({action_for_outcome_step}) found at step {step_idx_outcome} "
- f"during game outcome replay. Assuming non-win outcome for scoring."
- )
- final_env_reward_for_outcome = 0.0
- break
-
- logger.debug(
- f"score [Seed: {seed_for_outcome} Replay]: "
- f"Step {step_idx_outcome}, Action: {action_for_outcome_step}"
- )
- (
- temp_obs_outcome,
- step_reward_outcome,
- term_outcome,
- trunc_outcome,
- _,
- ) = temp_env_outcome.step(action_for_outcome_step)
- final_env_reward_for_outcome = step_reward_outcome
-
- if term_outcome or trunc_outcome:
- logger.info(
- f"score [Seed: {seed_for_outcome}]: "
- f"Game outcome replay ended at step {step_idx_outcome} "
- f"(action: {action_for_outcome_step}). "
- f"Final env reward for outcome: {final_env_reward_for_outcome}"
- )
- break
- else:
- logger.info(
- f"score [Seed: {seed_for_outcome}]: "
- f"Game outcome replay completed all steps. "
- f"Final env reward for outcome: {final_env_reward_for_outcome}"
- )
-
- temp_env_outcome.close()
- except Exception as e:
- logger.exception(
- f"score [Seed: {seed_for_outcome}]: "
- f"Error during game outcome replay: {e}. "
- f"Assuming non-win."
- )
- final_env_reward_for_outcome = 0.0
-
- if final_env_reward_for_outcome > 0:
- is_win = True
- logger.info(
- f"score [Seed: {seed_for_outcome}]: "
- f"Game outcome determined as WIN "
- f"(Final Env Reward: {final_env_reward_for_outcome}). "
- f"Win bonus (+1.0) will be applied to best alternative at each step."
- )
- else:
- logger.info(
- f"score [Seed: {seed_for_outcome}]: "
- f"Game outcome determined as NON-WIN "
- f"(Final Env Reward: {final_env_reward_for_outcome}). "
- f"No win bonus from game outcome will be applied."
- )
-
- processed_rollout_data: List[Optional[BlackjackScoredDataGroup]] = []
-
- for step_idx, original_step_group_untyped in enumerate(rollout_group_data):
- if original_step_group_untyped is None:
- logger.warning(f"score: Skipping None step_group at index {step_idx}.")
- processed_rollout_data.append(None)
- continue
-
- current_step_group: BlackjackScoredDataGroup = (
- original_step_group_untyped.copy()
- )
- step_seed = current_step_group.get("seed", "N/A")
-
- if current_step_group.get("scores") is None:
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Scores are missing. Cannot apply win bonus or tie-breaking."
- )
- processed_rollout_data.append(current_step_group)
- continue
-
- original_scores = current_step_group["scores"]
- if not isinstance(original_scores, list) or not all(
- isinstance(s, (int, float)) for s in original_scores
- ):
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"'scores' is not a list of numbers. "
- f"Skipping scoring for this step. Scores: {original_scores}"
- )
- processed_rollout_data.append(current_step_group)
- continue
-
- modified_scores = original_scores.copy()
-
- if is_win and modified_scores:
- try:
- max_score_in_step = -float("inf")
- valid_scores_for_max = [
- s for s in modified_scores if isinstance(s, (int, float))
- ]
- if not valid_scores_for_max:
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"No valid numeric scores found to determine best alternative for win bonus."
- )
- else:
- max_score_in_step = max(valid_scores_for_max)
- best_alternative_idx_this_step = -1
- for idx, score_val in enumerate(modified_scores):
- if score_val == max_score_in_step:
- best_alternative_idx_this_step = idx
- break
-
- if best_alternative_idx_this_step != -1:
- win_bonus_amount = 1.0
- modified_scores[
- best_alternative_idx_this_step
- ] += win_bonus_amount
- logger.info(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Applied WIN bonus ({win_bonus_amount}) to alternative "
- f"{best_alternative_idx_this_step} "
- f"(Original score: {original_scores[best_alternative_idx_this_step]:.4f}, "
- f"New: {modified_scores[best_alternative_idx_this_step]:.4f})."
- )
- else:
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Could not find index of max score {max_score_in_step} for win bonus. "
- f"This should not happen if scores exist."
- )
- except ValueError:
- score_msg = (
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Error finding max score for win bonus."
- )
- logger.warning(score_msg)
- logger.debug(f"Problematic scores: {modified_scores}")
- except Exception as e_bonus:
- logger.exception(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Unexpected error applying win bonus: {e_bonus}"
- )
-
- step_messages = current_step_group.get("messages")
- if not isinstance(step_messages, list) or len(modified_scores) != len(
- step_messages
- ):
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Mismatch between scores ({len(modified_scores)}) and messages "
- f"({len(step_messages) if isinstance(step_messages, list) else 'not a list'}) "
- f"lengths, or messages missing. Skipping tie-breaking for this step."
- )
- elif modified_scores:
- token_lengths = []
- valid_messages_for_tiebreak = True
- for alt_msg_list_idx, alt_msg_list in enumerate(step_messages):
- if (
- not isinstance(alt_msg_list, list)
- or not alt_msg_list
- or not isinstance(alt_msg_list[-1], dict)
- or "content" not in alt_msg_list[-1]
- ):
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Invalid message structure for alternative {alt_msg_list_idx} "
- f"during tie-breaking. Skipping tie-breaking for this step."
- )
- valid_messages_for_tiebreak = False
- break
- response_text = alt_msg_list[-1]["content"]
- try:
- token_lengths.append(len(self.tokenizer.encode(response_text)))
- except Exception as e_tok:
- logger.error(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Tokenization error for tie-breaking on alt {alt_msg_list_idx}: {e_tok}. "
- f"Defaulting token length to large value."
- )
- token_lengths.append(float("inf"))
-
- if valid_messages_for_tiebreak:
- score_groups = {}
- for idx, score_val in enumerate(modified_scores):
- if not isinstance(score_val, (int, float)):
- continue
- if score_val not in score_groups:
- score_groups[score_val] = []
- score_groups[score_val].append(idx)
-
- scores_after_tiebreak = modified_scores.copy()
- for score_val, indices_with_this_score in score_groups.items():
- if len(indices_with_this_score) > 1:
- if not all(
- idx < len(token_lengths)
- for idx in indices_with_this_score
- ):
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Token length data incomplete for tied score {score_val}. "
- f"Indices: {indices_with_this_score}, "
- f"Token lengths count: {len(token_lengths)}. "
- f"Skipping tie-break for this group."
- )
- continue
-
- try:
- sorted_tied_indices = sorted(
- indices_with_this_score,
- key=lambda i: token_lengths[i],
- )
-
- for rank, tied_idx in enumerate(
- sorted_tied_indices[1:], 1
- ):
- penalty = 0.0001 * rank
- scores_after_tiebreak[tied_idx] -= penalty
- logger.debug(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Applied tie-break penalty {-penalty:.5f} to alternative index {tied_idx} "
- f"(original tied score {score_val:.4f}, token length rank {rank})."
- )
- except IndexError:
- logger.warning(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"IndexError during tie-breaking for score {score_val}. "
- f"Indices: {indices_with_this_score}. "
- f"Skipping tie-break for this group."
- )
- except Exception as e_tiebreak:
- logger.exception(
- f"score [Seed: {step_seed}, Step: {step_idx}]: "
- f"Unexpected error during tie-breaking for score {score_val}: {e_tiebreak}"
- )
- modified_scores = scores_after_tiebreak
-
- current_step_group["scores"] = modified_scores
- processed_rollout_data.append(current_step_group)
-
- logger.info(
- f"score: Finished scoring. Processed {len(processed_rollout_data)} step groups."
- )
- return processed_rollout_data
-
- async def setup(self):
- pass
-
- async def get_next_item(self) -> Tuple[int, int]:
- import random
-
- return (random.randint(0, 1000000), 0)
-
- async def rollout_and_score_eval(self, seed: int) -> Dict[str, Any]:
- """
- Run a single episode for evaluation and return detailed metrics.
- Does not use the best-of-n sampling, but a single completion per step.
- Cleans up the episode state after completion.
- """
- ep = self._get_or_create_episode(seed)
- max_turns = self.config.max_turns if self.config.max_turns is not None else 5
- logger.info(
- f"[Eval Rollout Seed: {seed}] Starting episode. Max turns: {max_turns}"
- )
-
- episode_metrics = {
- "seed": seed,
- "total_reward": 0.0,
- "num_steps": 0,
- "num_correct_actions": 0,
- "num_invalid_actions": 0,
- "actions_chosen": [],
- "game_outcome": 0,
- }
-
- for turn in range(max_turns):
- episode_metrics["num_steps"] = turn + 1
- messages_for_prompt = ep.message_history.copy()
-
- if self.config.thinking_active:
- messages_for_prompt.append({"role": "agent", "content": "\n"})
- else:
- messages_for_prompt.append({"role": "agent", "content": ""})
-
- prompt = self.tokenizer.apply_chat_template(
- messages_for_prompt, tokenize=False
- )
-
- try:
- completions = await self.server.completion(
- prompt=prompt,
- n=1,
- max_tokens=self.config.max_token_length,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- split="eval",
- )
- except Exception as api_error:
- logger.exception(
- f"[Eval Rollout Seed: {seed} Turn: {turn+1}] API Error: {api_error}"
- )
- break
-
- if not completions or not completions.choices:
- logger.error(
- f"[Eval Rollout Seed: {seed} Turn: {turn+1}] API did not return any choices. Aborting episode."
- )
- break
-
- response_text = (
- completions.choices[0].text
- if hasattr(completions.choices[0], "text")
- else getattr(completions.choices[0].message, "content", "")
- )
- full_response = (
- ("\n" + response_text)
- if self.config.thinking_active
- else response_text
- )
-
- parsed_action = self._parse_tool_call(full_response)
- episode_metrics["actions_chosen"].append(parsed_action)
-
- if parsed_action == -1:
- episode_metrics["num_invalid_actions"] += 1
- env_action = 0
- logger.warning(
- f"[Eval Rollout Seed: {seed} Turn: {turn+1}] Invalid action parsed. Defaulting to 'stick'."
- )
- else:
- episode_metrics["num_correct_actions"] += 1
- env_action = parsed_action
-
- try:
- obs, reward, term, trunc, info = ep.env.step(env_action)
- except Exception as env_step_error:
- logger.exception(
- f"[Eval Rollout Seed: {seed} Turn: {turn+1}] Error stepping env: {env_step_error}"
- )
- term = True
- reward = -1.0
- obs = None
-
- episode_metrics["total_reward"] += reward
-
- if term or trunc:
- episode_metrics["game_outcome"] = int(reward)
- logger.info(
- f"[Eval Rollout Seed: {seed}] Episode ended. Outcome Reward: {reward}"
- )
-
- ep.message_history.append({"role": "agent", "content": full_response})
-
- if obs is not None:
- final_formatted_obs = self._format_observation(obs)
- logger.debug(
- f"[Eval Rollout Seed: {seed}] "
- f"Final State: {final_formatted_obs} (Reward: {reward})"
- )
- else:
- logger.debug(
- f"[Eval Rollout Seed: {seed}] "
- f"Episode terminated with error. (Reward: {reward})"
- )
-
- break
- else:
- ep.message_history.append({"role": "agent", "content": full_response})
- formatted_obs = self._format_observation(obs)
- ep.message_history.append(
- {"role": "environment", "content": formatted_obs}
- )
-
- logger.info(
- f"[Eval Rollout Seed: {seed}] Finished episode. Metrics: {episode_metrics}"
- )
-
- if seed in self.episodes:
- try:
- self.episodes[seed].env.close()
- except Exception as e:
- logger.warning(
- f"[Eval Rollout Seed: {seed}] Exception closing env for episode: {e}"
- )
- del self.episodes[seed]
-
- return episode_metrics
-
- async def evaluate(self, *args, **kwargs):
- """Run evaluation episodes and aggregate metrics for logging."""
- if not self.config.use_wandb:
- logger.info("Skipping evaluation as wandb is not enabled.")
- return
-
- num_eval_episodes = self.config.eval_episodes
- logger.info(f"Starting evaluation for {num_eval_episodes} episodes.")
-
- eval_tasks = []
-
- for i in range(num_eval_episodes):
- eval_seed = random.randint(1000001, 2000000)
- eval_tasks.append(self.rollout_and_score_eval(eval_seed))
-
- all_episode_metrics = await tqdm_asyncio.gather(*eval_tasks)
-
- if not all_episode_metrics:
- logger.warning("No metrics collected from evaluation episodes.")
- return
-
- valid_metrics = [m for m in all_episode_metrics if m is not None]
- if not valid_metrics:
- logger.warning("All evaluation episodes resulted in None metrics.")
- return
-
- num_completed_episodes = len(valid_metrics)
-
- avg_total_env_reward = (
- sum(m["total_reward"] for m in valid_metrics) / num_completed_episodes
- )
- avg_num_turns = (
- sum(m["num_steps"] for m in valid_metrics) / num_completed_episodes
- )
-
- total_correct_actions = sum(m["num_correct_actions"] for m in valid_metrics)
- total_invalid_actions = sum(m["num_invalid_actions"] for m in valid_metrics)
- total_actions_taken = total_correct_actions + total_invalid_actions
- action_accuracy = (
- total_correct_actions / total_actions_taken
- if total_actions_taken > 0
- else 0
- )
- invalid_action_rate = (
- total_invalid_actions / total_actions_taken
- if total_actions_taken > 0
- else 0
- )
-
- wins = sum(1 for m in valid_metrics if m["game_outcome"] == 1)
- losses = sum(1 for m in valid_metrics if m["game_outcome"] == -1)
- draws = sum(1 for m in valid_metrics if m["game_outcome"] == 0)
-
- win_rate = wins / num_completed_episodes if num_completed_episodes > 0 else 0
- loss_rate = losses / num_completed_episodes if num_completed_episodes > 0 else 0
- draw_rate = draws / num_completed_episodes if num_completed_episodes > 0 else 0
-
- all_chosen_actions = [
- action for m in valid_metrics for action in m["actions_chosen"]
- ]
- count_hit = sum(1 for act in all_chosen_actions if act == 1)
- count_stick = sum(1 for act in all_chosen_actions if act == 0)
- count_error_actions = sum(1 for act in all_chosen_actions if act == -1)
- total_parsed_actions_in_eval = len(all_chosen_actions)
-
- self.eval_metrics = [
- ("eval/avg_total_reward", avg_total_env_reward),
- ("eval/avg_num_steps", avg_num_turns),
- ("eval/action_accuracy", action_accuracy),
- ("eval/invalid_action_rate", invalid_action_rate),
- ("eval/win_rate", win_rate),
- ("eval/loss_rate", loss_rate),
- ("eval/draw_rate", draw_rate),
- ("eval/num_wins", wins),
- ("eval/num_losses", losses),
- ("eval/num_draws", draws),
- ("eval/num_completed_episodes", num_completed_episodes),
- (
- "eval/hit_chosen_rate",
- (
- count_hit / total_parsed_actions_in_eval
- if total_parsed_actions_in_eval > 0
- else 0
- ),
- ),
- (
- "eval/stick_chosen_rate",
- (
- count_stick / total_parsed_actions_in_eval
- if total_parsed_actions_in_eval > 0
- else 0
- ),
- ),
- (
- "eval/error_action_chosen_rate",
- (
- count_error_actions / total_parsed_actions_in_eval
- if total_parsed_actions_in_eval > 0
- else 0
- ),
- ),
- ]
-
- logger.info(f"Evaluation completed. Aggregated metrics: {self.eval_metrics}")
-
- async def wandb_log(self, wandb_metrics: Optional[Dict[str, Any]] = None):
- """
- Log aggregated metrics from completed training episodes and call super().wandb_log.
- """
- if wandb_metrics is None:
- wandb_metrics = {}
-
- if self.completed_episode_metrics_buffer:
- num_episodes_in_buffer = len(self.completed_episode_metrics_buffer)
-
- avg_ep_env_reward = (
- sum(m["total_reward"] for m in self.completed_episode_metrics_buffer)
- / num_episodes_in_buffer
- )
-
- total_ep_correct_actions = sum(
- m["num_correct_actions"] for m in self.completed_episode_metrics_buffer
- )
- total_ep_actions = sum(
- m["num_total_actions"] for m in self.completed_episode_metrics_buffer
- )
- avg_ep_action_accuracy = (
- total_ep_correct_actions / total_ep_actions
- if total_ep_actions > 0
- else 0
- )
-
- avg_ep_num_steps = (
- sum(m["num_steps"] for m in self.completed_episode_metrics_buffer)
- / num_episodes_in_buffer
- )
-
- ep_wins = sum(
- 1
- for m in self.completed_episode_metrics_buffer
- if m["game_outcome"] == 1
- )
- ep_losses = sum(
- 1
- for m in self.completed_episode_metrics_buffer
- if m["game_outcome"] == -1
- )
- ep_draws = sum(
- 1
- for m in self.completed_episode_metrics_buffer
- if m["game_outcome"] == 0
- )
-
- ep_win_rate = (
- ep_wins / num_episodes_in_buffer if num_episodes_in_buffer > 0 else 0
- )
- ep_loss_rate = (
- ep_losses / num_episodes_in_buffer if num_episodes_in_buffer > 0 else 0
- )
- ep_draw_rate = (
- ep_draws / num_episodes_in_buffer if num_episodes_in_buffer > 0 else 0
- )
-
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward"
- ] = avg_ep_env_reward
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_action_accuracy"
- ] = avg_ep_action_accuracy
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_num_steps"
- ] = avg_ep_num_steps
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/episode_win_rate"
- ] = ep_win_rate
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/episode_loss_rate"
- ] = ep_loss_rate
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/episode_draw_rate"
- ] = ep_draw_rate
- wandb_metrics[
- f"{self.wandb_prepend or 'blackjack'}_train/num_episodes_in_log_period"
- ] = num_episodes_in_buffer
-
- logger.info(
- f"Logging metrics for {num_episodes_in_buffer} completed training episodes."
- )
- self.completed_episode_metrics_buffer = []
- await super().wandb_log(wandb_metrics)
-
- @classmethod
- def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
- env_config = BlackjackEnvConfig(
- tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
- group_size=16,
- use_wandb=True,
- max_num_workers=128,
- rollout_server_url="http://localhost:8000",
- total_steps=2000,
- batch_size=1024,
- steps_per_eval=20,
- max_token_length=1024 * 16,
- inference_weight=1.0,
- wandb_name="fundamental_metric_prediction",
- data_path_to_save_groups=None,
- eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
- eval_limit_ratio=0.1,
- env_name="Blackjack-v1",
- temperature=0.7,
- top_p=0.9,
- max_turns=5,
- thinking_active=True,
- eval_episodes=100,
- max_think_chars_history=3000,
- max_trajectory_tokens=24576,
- debug_mode=False,
- )
- server_configs = [
- OpenaiConfig(
- model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
- base_url="http://localhost:9004/v1",
- api_key="x",
- num_requests_for_eval=256,
- )
- ]
- return env_config, server_configs
-
- @classmethod
- def cli(cls):
- super(BlackjackEnv, cls).cli()
-
- def _truncate_thinking_for_history(
- self, response_text: str, max_chars_fallback: int
- ) -> str:
- """Helper to truncate the block of a response for message history."""
- try:
- think_start_tag = ""
- think_end_tag = ""
-
- think_start_idx = response_text.find(think_start_tag)
- think_end_idx = response_text.find(think_end_tag)
-
- if (
- think_start_idx != -1
- and think_end_idx != -1
- and think_start_idx < think_end_idx
- ):
- part_before_content = response_text[
- : think_start_idx + len(think_start_tag)
- ]
- original_think_content = response_text[
- think_start_idx + len(think_start_tag) : think_end_idx
- ].strip()
- part_after_content = response_text[think_end_idx:]
-
- truncated_think_content = original_think_content
- is_truncated = False
-
- if not original_think_content:
- return response_text
-
- paragraphs = [
- p.strip() for p in original_think_content.split("\n\n") if p.strip()
- ]
- if len(paragraphs) > 0:
- last_paragraph = paragraphs[-1]
- if len(last_paragraph) < len(original_think_content):
- truncated_think_content = last_paragraph
- is_truncated = True
- elif len(original_think_content) > max_chars_fallback:
- truncated_think_content = original_think_content[
- -max_chars_fallback:
- ]
- is_truncated = True
- elif len(original_think_content) > max_chars_fallback:
- truncated_think_content = original_think_content[
- -max_chars_fallback:
- ]
- is_truncated = True
-
- if is_truncated and truncated_think_content:
- if not truncated_think_content.startswith("... "):
- truncated_think_content = (
- "... " + truncated_think_content.lstrip()
- )
-
- if (
- not truncated_think_content.strip()
- or truncated_think_content.strip() == "..."
- ):
- final_content_for_block = ""
- else:
- final_content_for_block = f"\n{truncated_think_content.strip()}\n"
-
- return f"{part_before_content.rstrip()}{final_content_for_block}{part_after_content.lstrip()}"
-
- return response_text
- except Exception as e:
- logger.error(
- f"Error in _truncate_thinking_for_history for text '{response_text[:200]}...': {e}",
- exc_info=True,
- )
- return response_text
-
- def _ensure_trajectory_token_limit(
- self, trajectory: List[BlackjackScoredDataGroup]
- ) -> List[BlackjackScoredDataGroup]:
- """
- Ensure token sequences in a trajectory don't exceed max_trajectory_tokens.
- Attempts to uniformly truncate older messages (preferably paired turns) from all alternatives within a step.
- The system prompt, last environment observation, and last agent response are preserved as a minimum.
- If a step still exceeds the limit after maximum possible truncation, it is discarded.
-
- Args:
- trajectory: List of BlackjackScoredDataGroup from an episode
-
- Returns:
- The trajectory with potentially truncated messages/tokens/masks or filtered steps
- """
- if not trajectory:
- return trajectory
-
- filtered_trajectory: List[BlackjackScoredDataGroup] = []
-
- for step_idx, original_step_data in enumerate(trajectory):
- if not (
- original_step_data.get("messages")
- and original_step_data.get("tokens")
- and original_step_data.get("masks")
- and original_step_data.get("seed") is not None
- ):
- logger.warning(
- f"[_ensure_trajectory_token_limit] Step {step_idx} "
- f"is missing critical data (messages, tokens, masks, or seed). Skipping."
- )
- continue
-
- max_initial_tokens = 0
- if original_step_data["tokens"]:
- max_initial_tokens = (
- max(
- len(alt_tokens)
- for alt_tokens in original_step_data["tokens"]
- if isinstance(alt_tokens, list)
- )
- if any(
- isinstance(alt_tokens, list)
- for alt_tokens in original_step_data["tokens"]
- )
- else 0
- )
-
- if max_initial_tokens <= self.config.max_trajectory_tokens:
- filtered_trajectory.append(original_step_data)
- logger.info(
- f"[_ensure_trajectory_token_limit] Step {step_idx} compliant. "
- f"Max tokens: {max_initial_tokens}"
- )
- continue
-
- logger.info(
- f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
- f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
- )
-
- working_messages = [
- msgs_list.copy() for msgs_list in original_step_data["messages"] or []
- ]
- working_tokens = [
- tkns_list.copy() for tkns_list in original_step_data["tokens"] or []
- ]
- working_masks = [
- msks_list.copy() for msks_list in original_step_data["masks"] or []
- ]
- max_current_tokens = max_initial_tokens
- num_alternatives = len(working_messages)
-
- if num_alternatives == 0:
- logger.warning(
- f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping."
- )
- continue
-
- retokenization_error_this_step = False
- while max_current_tokens > self.config.max_trajectory_tokens:
- target_pop_counts_per_alt = []
- for alt_idx in range(num_alternatives):
- alt_msg_list = working_messages[alt_idx]
-
- num_preserved_at_end = 0
- if (
- len(alt_msg_list) > 1
- and alt_msg_list[-1]["role"] in UNMASKED_ROLES
- ):
- num_preserved_at_end = 1
- if (
- len(alt_msg_list) > 2
- and alt_msg_list[-2]["role"] == "environment"
- ):
- num_preserved_at_end = 2
-
- available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
-
- if available_to_pop <= 0:
- target_pop_counts_per_alt.append(0)
- else:
- can_pop_pair = (
- available_to_pop >= 2
- and len(alt_msg_list) > 2
- and alt_msg_list[1]["role"] == "environment"
- and alt_msg_list[2]["role"] in UNMASKED_ROLES
- )
- if can_pop_pair:
- target_pop_counts_per_alt.append(2)
- else:
- target_pop_counts_per_alt.append(1)
-
- positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
- if not positive_pop_counts:
- break
-
- min_pop_this_round = min(positive_pop_counts)
-
- temp_new_alt_tokens = []
- temp_new_alt_masks = []
- max_tokens_after_this_trunc = 0
-
- for alt_idx in range(num_alternatives):
- for _ in range(min_pop_this_round):
- if len(working_messages[alt_idx]) > 1:
- working_messages[alt_idx].pop(1)
- else:
- logger.error(
- f"[_ensure_trajectory_token_limit] Critical error during pop for "
- f"alt {alt_idx}, step {step_idx}. List too short."
- )
- retokenization_error_this_step = True
- break
- if retokenization_error_this_step:
- break
-
- try:
- tokenized_alt = tokenize_for_trainer_multistep(
- self.tokenizer, working_messages[alt_idx]
- )
- temp_new_alt_tokens.append(tokenized_alt["tokens"])
- temp_new_alt_masks.append(tokenized_alt["masks"])
- max_tokens_after_this_trunc = max(
- max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
- )
- except Exception as e:
- logger.error(
- f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
- f"in step {step_idx} after truncation: {e}"
- )
- retokenization_error_this_step = True
- break
-
- if retokenization_error_this_step:
- break
-
- working_tokens = temp_new_alt_tokens
- working_masks = temp_new_alt_masks
- max_current_tokens = max_tokens_after_this_trunc
- logger.debug(
- f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, "
- f"max tokens: {max_current_tokens}"
- )
-
- if (
- not retokenization_error_this_step
- and max_current_tokens <= self.config.max_trajectory_tokens
- ):
- updated_step_data: BlackjackScoredDataGroup = {
- "seed": original_step_data["seed"],
- "messages": working_messages,
- "tokens": working_tokens,
- "masks": working_masks,
- "scores": original_step_data.get("scores"),
- "parsed_action": original_step_data.get("parsed_action"),
- }
- filtered_trajectory.append(updated_step_data)
- logger.info(
- f"[_ensure_trajectory_token_limit] Step {step_idx} successfully processed. "
- f"Final max tokens: {max_current_tokens}"
- )
- else:
- logger.warning(
- f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
- f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
- f"or retokenization error occurred ({retokenization_error_this_step})."
- )
-
- if len(filtered_trajectory) < len(trajectory):
- logger.warning(
- f"[_ensure_trajectory_token_limit] Filtered out "
- f"{len(trajectory) - len(filtered_trajectory)} steps "
- f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
- )
- return filtered_trajectory
-
-
-if __name__ == "__main__":
- BlackjackEnv.cli()