atropos/environments/game_environments/gymnasium/blackjack_no_mc_env.py
2025-05-10 09:24:55 +10:00

1665 lines
68 KiB
Python

"""
BlackjackEnv: Trainer environment for Gymnasium Blackjack
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
using function-call style actions. Extends BaseEnv.
Alternative formulation of BlackjackEnv that uses a best-of-n approach to select actions
and no Monte Carlo sampling (direct bonus for winning trajectory). Much faster to train,
but may not be as effective at learning correct strategy (it's effectively a series of bandits).
"""
import json
import logging
import random
from typing import Any, Dict, List, Optional, Tuple
import gymnasium
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Message
from atroposlib.utils.tokenize_for_trainer import (
UNMASKED_ROLES,
tokenize_for_trainer_multistep,
)
from atroposlib.utils.tool_call_parser import parse_tool_call
logger = logging.getLogger(__name__)
class BlackjackEnvConfig(BaseEnvConfig):
"""
Configuration for the Blackjack environment trainer.
"""
env_name: str = "Blackjack-v1"
temperature: float = 0.7
top_p: float = 0.9
max_turns: Optional[int] = 5
wandb_name: str = "blackjack"
thinking_active: bool = True
eval_episodes: int = 100
batch_size: int = 1024
max_think_chars_history: int = 3000
max_trajectory_tokens: int = 24576
debug_mode: bool = False
class BlackjackScoredDataGroup(ScoredDataGroup):
"""
Represents the scored data for a single step in a Blackjack trajectory, potentially including multiple alternatives.
"""
seed: int
tokens: Optional[List[List[int]]] = None
masks: Optional[List[List[int]]] = None
scores: Optional[List[float]] = None
messages: Optional[List[List[Message]]] = None
parsed_action: Optional[int] = None
class EpisodeState:
"""
Stores per-episode state: gym env, history, actions, rewards, trajectory.
"""
def __init__(self, seed: int, env: gymnasium.Env):
self.seed: int = seed
self.env: gymnasium.Env = env
self.message_history: List[Message] = []
self.actions: List[int] = []
self.step_rewards: List[float] = []
self.trajectory: List[BlackjackScoredDataGroup] = []
self.total_env_reward: float = 0.0
self.num_correct_actions: int = 0
self.num_total_actions: int = 0
class BlackjackEnv(BaseEnv):
"""
Trainer environment for Gymnasium Blackjack using a best-of-n approach with function-call style actions.
"""
def __init__(
self,
config: BlackjackEnvConfig,
server_configs: List[OpenaiConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.episodes: Dict[int, EpisodeState] = {}
self.debug_mode = config.debug_mode
self.completed_episode_metrics_buffer: List[Dict[str, Any]] = []
if self.debug_mode:
logger.setLevel(logging.DEBUG)
else:
if logger.level == logging.NOTSET or logger.level > logging.WARNING:
logger.setLevel(logging.WARNING)
self.tools = [
{
"type": "function",
"function": {
"name": "take_action",
"description": "Choose to 'hit' or 'stick' in Blackjack.",
"parameters": {
"action": {"type": "string", "enum": ["hit", "stick"]}
},
},
}
]
tools_json = json.dumps(self.tools)
self.system_prompt = (
"You are an AI agent playing Blackjack who uses extreme long chains of thought "
"to carefully consider the probabilities and optimal strategy. "
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
"You should enclose your thoughts and internal monologue inside <think> </think> tags, and then "
"provide your decision using the take_action function call. You may use extremely long chains "
"of thought to carefully consider the probabilities and optimal strategy.\n\n"
f"<tools>\n{tools_json}\n</tools>\n\n"
"For your function call, return a JSON object with function name and arguments "
"within <tool_call> </tool_call> tags with the following schema:\n"
'<tool_call>\n{"arguments": {"action": "hit"}, "name": "take_action"}\n</tool_call>\n\n'
"Your answer format should be:\n"
"<think>\n"
"[Your detailed reasoning process about whether to hit or stick]\n"
"</think>\n\n"
'<tool_call>\n{"arguments": {"action": "stick"}, "name": "take_action"}\n</tool_call>\n\n'
"Remember to carefully consider the probabilities and optimal strategy for Blackjack."
)
def _get_or_create_episode(self, seed: int) -> EpisodeState:
"""Retrieve existing or create a new episode state keyed by seed."""
if seed not in self.episodes:
env = gymnasium.make(self.config.env_name)
obs, _ = env.reset(seed=seed)
ep = EpisodeState(seed, env)
ep.message_history = [{"role": "system", "content": self.system_prompt}]
formatted = self._format_observation(obs)
ep.message_history.append({"role": "environment", "content": formatted})
self.episodes[seed] = ep
return self.episodes[seed]
def _format_observation(self, obs: Tuple[int, int, int]) -> str:
"""Convert Blackjack observation to text for LLM."""
player_sum, dealer_card, usable_ace = obs
return (
f"Your hand sum is {player_sum}. "
f"Dealer showing: {dealer_card}. "
f"You have a usable ace: {usable_ace}."
)
def _parse_tool_call(self, response: str) -> int:
"""Extract 'hit'/'stick' and map to action 1/0."""
tool_name, arguments, is_error = parse_tool_call(
response, self.tools, ["tool_call"]
)
logger.warning(
f"Parsed tool call: name={tool_name}, args={arguments}, error={is_error}"
)
if is_error:
logger.warning(f"Failed to parse tool call from response: {response}")
return -1
action = arguments.get("action", "").lower()
if action == "hit":
return 1
elif action == "stick":
return 0
else:
logger.warning(f"Invalid action value: {action}")
return -1
def _score_response(
self,
env_reward: float,
response_text: str,
parsed_action: int,
episode_seed: int,
) -> float:
"""
Calculates a score for a single agent response based purely on environment reward
and a penalty for invalid action format.
"""
current_env_reward = env_reward
if parsed_action == -1:
current_env_reward -= 0.5
logger.debug(
f"[_score_response Seed: {episode_seed}] Penalty applied for invalid action format (-0.5)."
)
final_score = current_env_reward
logger.debug(
f"[_score_response Seed: {episode_seed}] Final Score Calculation: "
f"Env Reward (raw): {env_reward:.4f}, "
f"Env Reward (adjusted for invalid): {current_env_reward:.4f}, "
f"==> Final Score (from env): {final_score:.4f}"
)
return final_score
async def _select_best_action(
self, episode: EpisodeState, actions: List[int], responses: List[str]
) -> Tuple[int, List[float]]:
"""
Simulates and scores multiple candidate actions to select the best one.
Args:
episode: The current episode state.
actions: A list of parsed actions (0, 1, or -1) corresponding to the responses.
responses: A list of full agent responses (<think>...</think><tool_call>...</tool_call>).
Returns:
A tuple containing:
- The best action selected (0, 1, or -1).
- A list of scores for each action/response.
"""
if len(actions) != len(responses):
logger.error(
f"[_select_best_action Seed: {episode.seed}] "
f"Mismatch between actions ({len(actions)}) and responses ({len(responses)}) count."
)
default_action = next((a for a in actions if a != -1), -1)
return default_action, [-10.0] * len(actions)
scores = [0.0] * len(actions)
token_lengths = [0] * len(actions)
try:
for idx, (action, response_text) in enumerate(zip(actions, responses)):
sim_env = gymnasium.make(self.config.env_name)
sim_obs, sim_info = sim_env.reset(seed=episode.seed)
valid_sim = True
for past_action in episode.actions:
sim_obs, _, term, trunc, sim_info = sim_env.step(past_action)
if term or trunc:
logger.warning(
f"[_select_best_action Seed: {episode.seed}] "
f"Episode terminated during history replay before simulating action {idx}. "
f"Assigning low score."
)
valid_sim = False
break
if not valid_sim:
scores[idx] = -10.0
continue
if action == -1:
env_reward_sim = 0.0
else:
_obs_sim, env_reward_sim, term_sim, trunc_sim, _info_sim = (
sim_env.step(action)
)
logger.debug(
f"[_select_best_action Seed: {episode.seed}] Sim Action {idx} "
f"(val:{action}) -> Reward:{env_reward_sim}, Term:{term_sim}"
)
combined_score = self._score_response(
env_reward=env_reward_sim,
response_text=response_text,
parsed_action=action,
episode_seed=episode.seed,
)
scores[idx] = combined_score
token_lengths[idx] = len(self.tokenizer.encode(response_text))
except Exception as e:
logger.exception(
f"[_select_best_action Seed: {episode.seed}] "
f"Error during action simulation/scoring: {e}"
)
default_action = next((a for a in actions if a != -1), -1)
return default_action, [-10.0] * len(actions)
best_score = float("-inf")
best_action = -1
best_action_idx = -1
if scores:
best_score = max(scores)
potential_best_indices = [
i for i, score in enumerate(scores) if score == best_score
]
valid_indices = [i for i in potential_best_indices if actions[i] != -1]
if valid_indices:
if len(valid_indices) > 1:
try:
best_action_idx = min(
valid_indices, key=lambda i: token_lengths[i]
)
logger.debug(
f"[_select_best_action Seed: {episode.seed}] "
f"Tie-breaking valid actions based on token length. Chosen index: {best_action_idx}"
)
except IndexError:
logger.warning(
f"[_select_best_action Seed: {episode.seed}] "
f"IndexError during token length tie-breaking. Defaulting to first valid index."
)
best_action_idx = valid_indices[0]
else:
best_action_idx = valid_indices[0]
elif potential_best_indices:
best_action_idx = potential_best_indices[0]
logger.debug(
f"[_select_best_action Seed: {episode.seed}] "
f"All best scores correspond to invalid actions. Choosing first: index {best_action_idx}"
)
else:
logger.error(
f"[_select_best_action Seed: {episode.seed}] "
f"No potential best indices found despite scores existing. Returning default action -1."
)
best_action_idx = -1
if best_action_idx != -1:
best_action = actions[best_action_idx]
else:
best_action = -1
logger.info(
f"[_select_best_action Seed: {episode.seed}] Selected action: {best_action} "
f"(Index: {best_action_idx}, "
f"Score: {scores[best_action_idx] if best_action_idx != -1 else 'N/A'}) "
f"from scores: {['{:.4f}'.format(s) for s in scores]}"
)
else:
logger.error(
f"[_select_best_action Seed: {episode.seed}] No scores calculated. Returning default action -1."
)
return best_action, scores
async def collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
"""
Run a single episode from the given seed, using a best-of-n approach each step.
Refactored to use _select_best_action.
Returns a list of BlackjackScoredDataGroup, one per time step.
"""
ep = self._get_or_create_episode(seed)
max_turns = self.config.max_turns if self.config.max_turns is not None else 5
logger.info(
f"[Collect Trajectory Seed: {seed}] Starting episode. Max turns: {max_turns}"
)
for turn in range(max_turns):
logger.debug(
f"[Collect Trajectory Seed: {seed}] Starting Turn {turn + 1}/{max_turns}"
)
messages_for_prompt = ep.message_history.copy()
if self.config.thinking_active:
messages_for_prompt.append({"role": "agent", "content": "<think>\n"})
else:
messages_for_prompt.append({"role": "agent", "content": ""})
prompt = self.tokenizer.apply_chat_template(
messages_for_prompt, tokenize=False
)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Prompting LLM..."
)
try:
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
except Exception as api_error:
logger.exception(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"API Error during self.server.completion: {api_error}"
)
return self._ensure_trajectory_token_limit(ep.trajectory)
if (
not completions
or not completions.choices
or len(completions.choices) != self.config.group_size
):
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"API did not return the expected number of choices "
f"({self.config.group_size} vs {len(completions.choices) if completions else 0}). "
f"Aborting episode."
)
return self._ensure_trajectory_token_limit(ep.trajectory)
alt_actions: List[int] = []
alt_responses: List[str] = []
for choice_idx, choice in enumerate(completions.choices):
response_text = (
choice.text
if hasattr(choice, "text")
else getattr(choice.message, "content", "")
)
full_response = (
("<think>\n" + response_text)
if self.config.thinking_active
else response_text
)
alt_responses.append(full_response)
parsed_act = self._parse_tool_call(full_response)
alt_actions.append(parsed_act)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Choice {choice_idx}: Parsed Action={parsed_act}, Response Length={len(full_response)}"
)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Selecting best action..."
)
best_action, scores = await self._select_best_action(
ep, alt_actions, alt_responses
)
best_action_idx = -1
try:
best_score_val = max(scores)
possible_indices = [
i
for i, (act, score) in enumerate(zip(alt_actions, scores))
if act == best_action and score == best_score_val
]
if possible_indices:
best_action_idx = possible_indices[0]
logger.info(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Best action selected: {best_action} "
f"(Index: {best_action_idx}), "
f"Score: {scores[best_action_idx]:.4f}"
)
else:
logger.warning(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Could not find index for best action {best_action} with score {best_score_val}. "
f"Trying first occurrence of action."
)
best_action_idx = alt_actions.index(best_action)
logger.info(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Fallback - Best action selected: {best_action} (Index: {best_action_idx}), "
f"Score: {scores[best_action_idx]:.4f}"
)
best_response = alt_responses[best_action_idx]
except (ValueError, IndexError) as e:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Error finding index for best action {best_action}: {e}. "
f"Cannot proceed with episode."
)
if seed in self.episodes:
try:
self.episodes[seed].env.close()
except Exception as close_exc:
logger.warning(
f"[Collect Trajectory Seed: {seed}] "
f"Exception closing env for aborted episode on "
f"best_action index error: {close_exc}"
)
del self.episodes[seed]
return self._ensure_trajectory_token_limit(ep.trajectory)
alt_tokens: List[List[int]] = []
alt_masks: List[List[int]] = []
alt_messages: List[List[Message]] = []
tokenization_failed_for_step = False
for response in alt_responses:
step_msgs: List[Message] = [
{"role": m["role"], "content": m["content"]}
for m in ep.message_history
]
step_msgs.append({"role": "agent", "content": response})
try:
out = tokenize_for_trainer_multistep(self.tokenizer, step_msgs)
alt_tokens.append(out["tokens"])
alt_masks.append(out["masks"])
alt_messages.append(step_msgs)
except Exception as tokenization_error:
logger.exception(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Critical tokenization error for response: {response[:100]}... "
f"Error: {tokenization_error}. Aborting episode."
)
tokenization_failed_for_step = True
break
if tokenization_failed_for_step:
logger.warning(
f"[Collect Trajectory Seed: {seed}] Episode aborted at turn {turn+1} due to tokenization failure."
)
if seed in self.episodes:
try:
self.episodes[seed].env.close()
except Exception as e:
logger.warning(
f"[Collect Trajectory Seed: {seed}] Exception closing env for aborted episode: {e}"
)
del self.episodes[seed]
return self._ensure_trajectory_token_limit(ep.trajectory)
expected_len = self.config.group_size
if len(alt_tokens) != expected_len:
alt_tokens.extend([[]] * (expected_len - len(alt_tokens)))
if len(alt_masks) != expected_len:
alt_masks.extend([[]] * (expected_len - len(alt_masks)))
if len(alt_messages) != expected_len:
alt_messages.extend(
[
[
{
"role": "system",
"content": "Missing due to prior success but unexpected count",
}
]
]
* (expected_len - len(alt_messages))
)
env_action = 0 if best_action == -1 else best_action
if best_action == -1:
logger.warning(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Selected action was invalid format (-1). "
f"Stepping env with 'stick' (0)."
)
try:
obs, reward, term, trunc, info = ep.env.step(env_action)
logger.info(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Stepped main env with action {env_action}. "
f"Reward: {reward}, Term: {term}, Trunc: {trunc}"
)
except Exception as env_step_error:
logger.exception(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Error stepping main environment with action {env_action}: {env_step_error}"
)
term = True
reward = -1.0
obs = None
ep.actions.append(env_action)
ep.step_rewards.append(reward)
ep.total_env_reward += reward
ep.num_total_actions += 1
if best_action != -1:
ep.num_correct_actions += 1
logger.info(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Step Rewards: Env={reward:.2f}. "
f"Running Totals: Env={ep.total_env_reward:.2f}."
)
ep.trajectory.append(
BlackjackScoredDataGroup(
overrides=[],
seed=seed,
tokens=alt_tokens,
masks=alt_masks,
scores=scores,
messages=alt_messages,
parsed_action=best_action,
)
)
if term or trunc:
logger.info(
f"[Collect Trajectory Seed: {seed}] "
f"Episode ended. Term={term}, Trunc={trunc}. "
f"Final Reward: {reward}"
)
ep.message_history.append({"role": "agent", "content": best_response})
if obs is not None:
final_formatted_obs = self._format_observation(obs)
logger.debug(
f"[Collect Trajectory Seed: {seed}] "
f"Final State: {final_formatted_obs} (Reward: {reward})"
)
else:
logger.debug(
f"[Collect Trajectory Seed: {seed}] "
f"Episode terminated with error. (Reward: {reward})"
)
break
else:
response_for_history = self._truncate_thinking_for_history(
best_response, self.config.max_think_chars_history
)
ep.message_history.append(
{"role": "agent", "content": response_for_history}
)
formatted_obs = self._format_observation(obs)
ep.message_history.append(
{"role": "environment", "content": formatted_obs}
)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] New Observation: {formatted_obs}"
)
logger.info(
f"[Collect Trajectory Seed: {seed}] "
f"Finished episode after {len(ep.actions)} steps."
)
logger.info(
f"[Collect Trajectory Seed: {seed}] "
f"Final Totals: Env Reward={ep.total_env_reward:.2f}."
)
logger.info(
f"[Collect Trajectory Seed: {seed}] "
f"Action Accuracy: {ep.num_correct_actions}/{max(1, ep.num_total_actions)} "
f"({ep.num_correct_actions/max(1, ep.num_total_actions):.2%})"
)
final_env_reward_for_outcome = 0
if ep.step_rewards:
final_env_reward_for_outcome = ep.step_rewards[-1]
game_outcome = 0
if final_env_reward_for_outcome > 0:
game_outcome = 1
elif final_env_reward_for_outcome < 0:
game_outcome = -1
episode_summary_metrics = {
"seed": seed,
"total_reward": ep.total_env_reward,
"num_correct_actions": ep.num_correct_actions,
"num_total_actions": ep.num_total_actions,
"game_outcome": game_outcome,
"num_steps": len(ep.actions),
}
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
if seed in self.episodes:
try:
self.episodes[seed].env.close()
except Exception as e:
logger.warning(
f"[Collect Trajectory Seed: {seed}] "
f"Exception closing env for episode: {e}"
)
del self.episodes[seed]
logger.debug(
f"[Collect Trajectory Seed: {seed}] "
f"Cleared episode state from self.episodes."
)
return self._ensure_trajectory_token_limit(ep.trajectory)
async def collect_trajectories(
self, item: Tuple[int, int]
) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]:
seed, _ = item
traj = await self.collect_trajectory(seed)
if traj:
traj = self._ensure_trajectory_token_limit(traj)
if not traj:
logger.warning(
f"[collect_trajectories] "
f"All steps for seed {seed} were filtered out due to token limit "
f"constraints. Returning empty trajectory."
)
return traj, []
async def score(
self,
rollout_group_data: List[BlackjackScoredDataGroup],
) -> List[Optional[BlackjackScoredDataGroup]]:
"""
Applies final scoring adjustments to a completed trajectory.
If the game was a win (determined by replaying chosen actions), a bonus
is applied to the best alternative at each step.
Tie-breaking based on token length is applied subsequently.
Args:
rollout_group_data: The list of ScoredDataGroups representing the trajectory.
Returns:
The list of ScoredDataGroups with potentially adjusted scores.
Returns a list containing None elements if input steps are invalid.
"""
logger.info(
f"score: Received rollout_group_data with {len(rollout_group_data)} "
f"groups for scoring."
)
if not rollout_group_data:
logger.warning(
"score: Received empty rollout_group_data. Returning empty list."
)
return []
if not all(
isinstance(rgd, dict) for rgd in rollout_group_data if rgd is not None
):
logger.error(
"score: rollout_group_data contains non-dictionary elements. "
"Cannot proceed."
)
return [None] * len(rollout_group_data)
final_env_reward_for_outcome = 0.0
is_win = False
seed_for_outcome = None
first_valid_step_for_seed = next(
(rgd for rgd in rollout_group_data if rgd is not None and "seed" in rgd),
None,
)
if not first_valid_step_for_seed:
logger.warning(
"score: Cannot determine game outcome, no valid step with seed found "
"in rollout_group_data."
)
else:
seed_for_outcome = first_valid_step_for_seed["seed"]
logger.info(
f"score [Seed: {seed_for_outcome}]: Starting game outcome replay."
)
try:
temp_env_outcome = gymnasium.make(self.config.env_name)
temp_obs_outcome, _ = temp_env_outcome.reset(seed=seed_for_outcome)
for step_idx_outcome, step_group_for_outcome in enumerate(
rollout_group_data
):
if step_group_for_outcome is None:
logger.warning(
f"score [Seed: {seed_for_outcome}]: "
f"Encountered None step_group at index {step_idx_outcome} "
f"during outcome replay. Assuming non-win."
)
final_env_reward_for_outcome = 0.0
break
action_for_outcome_step = step_group_for_outcome.get(
"parsed_action"
)
if action_for_outcome_step is None or action_for_outcome_step == -1:
logger.warning(
f"score [Seed: {seed_for_outcome}]: "
f"Invalid action ({action_for_outcome_step}) found at step {step_idx_outcome} "
f"during game outcome replay. Assuming non-win outcome for scoring."
)
final_env_reward_for_outcome = 0.0
break
logger.debug(
f"score [Seed: {seed_for_outcome} Replay]: "
f"Step {step_idx_outcome}, Action: {action_for_outcome_step}"
)
(
temp_obs_outcome,
step_reward_outcome,
term_outcome,
trunc_outcome,
_,
) = temp_env_outcome.step(action_for_outcome_step)
final_env_reward_for_outcome = step_reward_outcome
if term_outcome or trunc_outcome:
logger.info(
f"score [Seed: {seed_for_outcome}]: "
f"Game outcome replay ended at step {step_idx_outcome} "
f"(action: {action_for_outcome_step}). "
f"Final env reward for outcome: {final_env_reward_for_outcome}"
)
break
else:
logger.info(
f"score [Seed: {seed_for_outcome}]: "
f"Game outcome replay completed all steps. "
f"Final env reward for outcome: {final_env_reward_for_outcome}"
)
temp_env_outcome.close()
except Exception as e:
logger.exception(
f"score [Seed: {seed_for_outcome}]: "
f"Error during game outcome replay: {e}. "
f"Assuming non-win."
)
final_env_reward_for_outcome = 0.0
if final_env_reward_for_outcome > 0:
is_win = True
logger.info(
f"score [Seed: {seed_for_outcome}]: "
f"Game outcome determined as WIN "
f"(Final Env Reward: {final_env_reward_for_outcome}). "
f"Win bonus (+1.0) will be applied to best alternative at each step."
)
else:
logger.info(
f"score [Seed: {seed_for_outcome}]: "
f"Game outcome determined as NON-WIN "
f"(Final Env Reward: {final_env_reward_for_outcome}). "
f"No win bonus from game outcome will be applied."
)
processed_rollout_data: List[Optional[BlackjackScoredDataGroup]] = []
for step_idx, original_step_group_untyped in enumerate(rollout_group_data):
if original_step_group_untyped is None:
logger.warning(f"score: Skipping None step_group at index {step_idx}.")
processed_rollout_data.append(None)
continue
current_step_group: BlackjackScoredDataGroup = (
original_step_group_untyped.copy()
)
step_seed = current_step_group.get("seed", "N/A")
if current_step_group.get("scores") is None:
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Scores are missing. Cannot apply win bonus or tie-breaking."
)
processed_rollout_data.append(current_step_group)
continue
original_scores = current_step_group["scores"]
if not isinstance(original_scores, list) or not all(
isinstance(s, (int, float)) for s in original_scores
):
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"'scores' is not a list of numbers. "
f"Skipping scoring for this step. Scores: {original_scores}"
)
processed_rollout_data.append(current_step_group)
continue
modified_scores = original_scores.copy()
if is_win and modified_scores:
try:
max_score_in_step = -float("inf")
valid_scores_for_max = [
s for s in modified_scores if isinstance(s, (int, float))
]
if not valid_scores_for_max:
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"No valid numeric scores found to determine best alternative for win bonus."
)
else:
max_score_in_step = max(valid_scores_for_max)
best_alternative_idx_this_step = -1
for idx, score_val in enumerate(modified_scores):
if score_val == max_score_in_step:
best_alternative_idx_this_step = idx
break
if best_alternative_idx_this_step != -1:
win_bonus_amount = 1.0
modified_scores[
best_alternative_idx_this_step
] += win_bonus_amount
logger.info(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Applied WIN bonus ({win_bonus_amount}) to alternative "
f"{best_alternative_idx_this_step} "
f"(Original score: {original_scores[best_alternative_idx_this_step]:.4f}, "
f"New: {modified_scores[best_alternative_idx_this_step]:.4f})."
)
else:
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Could not find index of max score {max_score_in_step} for win bonus. "
f"This should not happen if scores exist."
)
except ValueError:
score_msg = (
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Error finding max score for win bonus."
)
logger.warning(score_msg)
logger.debug(f"Problematic scores: {modified_scores}")
except Exception as e_bonus:
logger.exception(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Unexpected error applying win bonus: {e_bonus}"
)
step_messages = current_step_group.get("messages")
if not isinstance(step_messages, list) or len(modified_scores) != len(
step_messages
):
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Mismatch between scores ({len(modified_scores)}) and messages "
f"({len(step_messages) if isinstance(step_messages, list) else 'not a list'}) "
f"lengths, or messages missing. Skipping tie-breaking for this step."
)
elif modified_scores:
token_lengths = []
valid_messages_for_tiebreak = True
for alt_msg_list_idx, alt_msg_list in enumerate(step_messages):
if (
not isinstance(alt_msg_list, list)
or not alt_msg_list
or not isinstance(alt_msg_list[-1], dict)
or "content" not in alt_msg_list[-1]
):
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Invalid message structure for alternative {alt_msg_list_idx} "
f"during tie-breaking. Skipping tie-breaking for this step."
)
valid_messages_for_tiebreak = False
break
response_text = alt_msg_list[-1]["content"]
try:
token_lengths.append(len(self.tokenizer.encode(response_text)))
except Exception as e_tok:
logger.error(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Tokenization error for tie-breaking on alt {alt_msg_list_idx}: {e_tok}. "
f"Defaulting token length to large value."
)
token_lengths.append(float("inf"))
if valid_messages_for_tiebreak:
score_groups = {}
for idx, score_val in enumerate(modified_scores):
if not isinstance(score_val, (int, float)):
continue
if score_val not in score_groups:
score_groups[score_val] = []
score_groups[score_val].append(idx)
scores_after_tiebreak = modified_scores.copy()
for score_val, indices_with_this_score in score_groups.items():
if len(indices_with_this_score) > 1:
if not all(
idx < len(token_lengths)
for idx in indices_with_this_score
):
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Token length data incomplete for tied score {score_val}. "
f"Indices: {indices_with_this_score}, "
f"Token lengths count: {len(token_lengths)}. "
f"Skipping tie-break for this group."
)
continue
try:
sorted_tied_indices = sorted(
indices_with_this_score,
key=lambda i: token_lengths[i],
)
for rank, tied_idx in enumerate(
sorted_tied_indices[1:], 1
):
penalty = 0.0001 * rank
scores_after_tiebreak[tied_idx] -= penalty
logger.debug(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Applied tie-break penalty {-penalty:.5f} to alternative index {tied_idx} "
f"(original tied score {score_val:.4f}, token length rank {rank})."
)
except IndexError:
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"IndexError during tie-breaking for score {score_val}. "
f"Indices: {indices_with_this_score}. "
f"Skipping tie-break for this group."
)
except Exception as e_tiebreak:
logger.exception(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Unexpected error during tie-breaking for score {score_val}: {e_tiebreak}"
)
modified_scores = scores_after_tiebreak
current_step_group["scores"] = modified_scores
processed_rollout_data.append(current_step_group)
logger.info(
f"score: Finished scoring. Processed {len(processed_rollout_data)} step groups."
)
return processed_rollout_data
async def setup(self):
pass
async def get_next_item(self) -> Tuple[int, int]:
import random
return (random.randint(0, 1000000), 0)
async def rollout_and_score_eval(self, seed: int) -> Dict[str, Any]:
"""
Run a single episode for evaluation and return detailed metrics.
Does not use the best-of-n sampling, but a single completion per step.
Cleans up the episode state after completion.
"""
ep = self._get_or_create_episode(seed)
max_turns = self.config.max_turns if self.config.max_turns is not None else 5
logger.info(
f"[Eval Rollout Seed: {seed}] Starting episode. Max turns: {max_turns}"
)
episode_metrics = {
"seed": seed,
"total_reward": 0.0,
"num_steps": 0,
"num_correct_actions": 0,
"num_invalid_actions": 0,
"actions_chosen": [],
"game_outcome": 0,
}
for turn in range(max_turns):
episode_metrics["num_steps"] = turn + 1
messages_for_prompt = ep.message_history.copy()
if self.config.thinking_active:
messages_for_prompt.append({"role": "agent", "content": "<think>\n"})
else:
messages_for_prompt.append({"role": "agent", "content": ""})
prompt = self.tokenizer.apply_chat_template(
messages_for_prompt, tokenize=False
)
try:
completions = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
split="eval",
)
except Exception as api_error:
logger.exception(
f"[Eval Rollout Seed: {seed} Turn: {turn+1}] API Error: {api_error}"
)
break
if not completions or not completions.choices:
logger.error(
f"[Eval Rollout Seed: {seed} Turn: {turn+1}] API did not return any choices. Aborting episode."
)
break
response_text = (
completions.choices[0].text
if hasattr(completions.choices[0], "text")
else getattr(completions.choices[0].message, "content", "")
)
full_response = (
("<think>\n" + response_text)
if self.config.thinking_active
else response_text
)
parsed_action = self._parse_tool_call(full_response)
episode_metrics["actions_chosen"].append(parsed_action)
if parsed_action == -1:
episode_metrics["num_invalid_actions"] += 1
env_action = 0
logger.warning(
f"[Eval Rollout Seed: {seed} Turn: {turn+1}] Invalid action parsed. Defaulting to 'stick'."
)
else:
episode_metrics["num_correct_actions"] += 1
env_action = parsed_action
try:
obs, reward, term, trunc, info = ep.env.step(env_action)
except Exception as env_step_error:
logger.exception(
f"[Eval Rollout Seed: {seed} Turn: {turn+1}] Error stepping env: {env_step_error}"
)
term = True
reward = -1.0
obs = None
episode_metrics["total_reward"] += reward
if term or trunc:
episode_metrics["game_outcome"] = int(reward)
logger.info(
f"[Eval Rollout Seed: {seed}] Episode ended. Outcome Reward: {reward}"
)
ep.message_history.append({"role": "agent", "content": full_response})
if obs is not None:
final_formatted_obs = self._format_observation(obs)
logger.debug(
f"[Eval Rollout Seed: {seed}] "
f"Final State: {final_formatted_obs} (Reward: {reward})"
)
else:
logger.debug(
f"[Eval Rollout Seed: {seed}] "
f"Episode terminated with error. (Reward: {reward})"
)
break
else:
ep.message_history.append({"role": "agent", "content": full_response})
formatted_obs = self._format_observation(obs)
ep.message_history.append(
{"role": "environment", "content": formatted_obs}
)
logger.info(
f"[Eval Rollout Seed: {seed}] Finished episode. Metrics: {episode_metrics}"
)
if seed in self.episodes:
try:
self.episodes[seed].env.close()
except Exception as e:
logger.warning(
f"[Eval Rollout Seed: {seed}] Exception closing env for episode: {e}"
)
del self.episodes[seed]
return episode_metrics
async def evaluate(self, *args, **kwargs):
"""Run evaluation episodes and aggregate metrics for logging."""
if not self.config.use_wandb:
logger.info("Skipping evaluation as wandb is not enabled.")
return
num_eval_episodes = self.config.eval_episodes
logger.info(f"Starting evaluation for {num_eval_episodes} episodes.")
eval_tasks = []
for i in range(num_eval_episodes):
eval_seed = random.randint(1000001, 2000000)
eval_tasks.append(self.rollout_and_score_eval(eval_seed))
all_episode_metrics = await tqdm_asyncio.gather(*eval_tasks)
if not all_episode_metrics:
logger.warning("No metrics collected from evaluation episodes.")
return
valid_metrics = [m for m in all_episode_metrics if m is not None]
if not valid_metrics:
logger.warning("All evaluation episodes resulted in None metrics.")
return
num_completed_episodes = len(valid_metrics)
avg_total_env_reward = (
sum(m["total_reward"] for m in valid_metrics) / num_completed_episodes
)
avg_num_turns = (
sum(m["num_steps"] for m in valid_metrics) / num_completed_episodes
)
total_correct_actions = sum(m["num_correct_actions"] for m in valid_metrics)
total_invalid_actions = sum(m["num_invalid_actions"] for m in valid_metrics)
total_actions_taken = total_correct_actions + total_invalid_actions
action_accuracy = (
total_correct_actions / total_actions_taken
if total_actions_taken > 0
else 0
)
invalid_action_rate = (
total_invalid_actions / total_actions_taken
if total_actions_taken > 0
else 0
)
wins = sum(1 for m in valid_metrics if m["game_outcome"] == 1)
losses = sum(1 for m in valid_metrics if m["game_outcome"] == -1)
draws = sum(1 for m in valid_metrics if m["game_outcome"] == 0)
win_rate = wins / num_completed_episodes if num_completed_episodes > 0 else 0
loss_rate = losses / num_completed_episodes if num_completed_episodes > 0 else 0
draw_rate = draws / num_completed_episodes if num_completed_episodes > 0 else 0
all_chosen_actions = [
action for m in valid_metrics for action in m["actions_chosen"]
]
count_hit = sum(1 for act in all_chosen_actions if act == 1)
count_stick = sum(1 for act in all_chosen_actions if act == 0)
count_error_actions = sum(1 for act in all_chosen_actions if act == -1)
total_parsed_actions_in_eval = len(all_chosen_actions)
self.eval_metrics = [
("eval/avg_total_reward", avg_total_env_reward),
("eval/avg_num_steps", avg_num_turns),
("eval/action_accuracy", action_accuracy),
("eval/invalid_action_rate", invalid_action_rate),
("eval/win_rate", win_rate),
("eval/loss_rate", loss_rate),
("eval/draw_rate", draw_rate),
("eval/num_wins", wins),
("eval/num_losses", losses),
("eval/num_draws", draws),
("eval/num_completed_episodes", num_completed_episodes),
(
"eval/hit_chosen_rate",
(
count_hit / total_parsed_actions_in_eval
if total_parsed_actions_in_eval > 0
else 0
),
),
(
"eval/stick_chosen_rate",
(
count_stick / total_parsed_actions_in_eval
if total_parsed_actions_in_eval > 0
else 0
),
),
(
"eval/error_action_chosen_rate",
(
count_error_actions / total_parsed_actions_in_eval
if total_parsed_actions_in_eval > 0
else 0
),
),
]
logger.info(f"Evaluation completed. Aggregated metrics: {self.eval_metrics}")
async def wandb_log(self, wandb_metrics: Optional[Dict[str, Any]] = None):
"""
Log aggregated metrics from completed training episodes and call super().wandb_log.
"""
if wandb_metrics is None:
wandb_metrics = {}
if self.completed_episode_metrics_buffer:
num_episodes_in_buffer = len(self.completed_episode_metrics_buffer)
avg_ep_env_reward = (
sum(m["total_reward"] for m in self.completed_episode_metrics_buffer)
/ num_episodes_in_buffer
)
total_ep_correct_actions = sum(
m["num_correct_actions"] for m in self.completed_episode_metrics_buffer
)
total_ep_actions = sum(
m["num_total_actions"] for m in self.completed_episode_metrics_buffer
)
avg_ep_action_accuracy = (
total_ep_correct_actions / total_ep_actions
if total_ep_actions > 0
else 0
)
avg_ep_num_steps = (
sum(m["num_steps"] for m in self.completed_episode_metrics_buffer)
/ num_episodes_in_buffer
)
ep_wins = sum(
1
for m in self.completed_episode_metrics_buffer
if m["game_outcome"] == 1
)
ep_losses = sum(
1
for m in self.completed_episode_metrics_buffer
if m["game_outcome"] == -1
)
ep_draws = sum(
1
for m in self.completed_episode_metrics_buffer
if m["game_outcome"] == 0
)
ep_win_rate = (
ep_wins / num_episodes_in_buffer if num_episodes_in_buffer > 0 else 0
)
ep_loss_rate = (
ep_losses / num_episodes_in_buffer if num_episodes_in_buffer > 0 else 0
)
ep_draw_rate = (
ep_draws / num_episodes_in_buffer if num_episodes_in_buffer > 0 else 0
)
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward"
] = avg_ep_env_reward
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_action_accuracy"
] = avg_ep_action_accuracy
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_num_steps"
] = avg_ep_num_steps
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/episode_win_rate"
] = ep_win_rate
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/episode_loss_rate"
] = ep_loss_rate
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/episode_draw_rate"
] = ep_draw_rate
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/num_episodes_in_log_period"
] = num_episodes_in_buffer
logger.info(
f"Logging metrics for {num_episodes_in_buffer} completed training episodes."
)
self.completed_episode_metrics_buffer = []
await super().wandb_log(wandb_metrics)
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
env_config = BlackjackEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="fundamental_metric_prediction",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
env_name="Blackjack-v1",
temperature=0.7,
top_p=0.9,
max_turns=5,
thinking_active=True,
eval_episodes=100,
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=False,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256,
)
]
return env_config, server_configs
@classmethod
def cli(cls):
super(BlackjackEnv, cls).cli()
def _truncate_thinking_for_history(
self, response_text: str, max_chars_fallback: int
) -> str:
"""Helper to truncate the <think> block of a response for message history."""
try:
think_start_tag = "<think>"
think_end_tag = "</think>"
think_start_idx = response_text.find(think_start_tag)
think_end_idx = response_text.find(think_end_tag)
if (
think_start_idx != -1
and think_end_idx != -1
and think_start_idx < think_end_idx
):
part_before_content = response_text[
: think_start_idx + len(think_start_tag)
]
original_think_content = response_text[
think_start_idx + len(think_start_tag) : think_end_idx
].strip()
part_after_content = response_text[think_end_idx:]
truncated_think_content = original_think_content
is_truncated = False
if not original_think_content:
return response_text
paragraphs = [
p.strip() for p in original_think_content.split("\n\n") if p.strip()
]
if len(paragraphs) > 0:
last_paragraph = paragraphs[-1]
if len(last_paragraph) < len(original_think_content):
truncated_think_content = last_paragraph
is_truncated = True
elif len(original_think_content) > max_chars_fallback:
truncated_think_content = original_think_content[
-max_chars_fallback:
]
is_truncated = True
elif len(original_think_content) > max_chars_fallback:
truncated_think_content = original_think_content[
-max_chars_fallback:
]
is_truncated = True
if is_truncated and truncated_think_content:
if not truncated_think_content.startswith("... "):
truncated_think_content = (
"... " + truncated_think_content.lstrip()
)
if (
not truncated_think_content.strip()
or truncated_think_content.strip() == "..."
):
final_content_for_block = ""
else:
final_content_for_block = f"\n{truncated_think_content.strip()}\n"
return f"{part_before_content.rstrip()}{final_content_for_block}{part_after_content.lstrip()}"
return response_text
except Exception as e:
logger.error(
f"Error in _truncate_thinking_for_history for text '{response_text[:200]}...': {e}",
exc_info=True,
)
return response_text
def _ensure_trajectory_token_limit(
self, trajectory: List[BlackjackScoredDataGroup]
) -> List[BlackjackScoredDataGroup]:
"""
Ensure token sequences in a trajectory don't exceed max_trajectory_tokens.
Attempts to uniformly truncate older messages (preferably paired turns) from all alternatives within a step.
The system prompt, last environment observation, and last agent response are preserved as a minimum.
If a step still exceeds the limit after maximum possible truncation, it is discarded.
Args:
trajectory: List of BlackjackScoredDataGroup from an episode
Returns:
The trajectory with potentially truncated messages/tokens/masks or filtered steps
"""
if not trajectory:
return trajectory
filtered_trajectory: List[BlackjackScoredDataGroup] = []
for step_idx, original_step_data in enumerate(trajectory):
if not (
original_step_data.get("messages")
and original_step_data.get("tokens")
and original_step_data.get("masks")
and original_step_data.get("seed") is not None
):
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} "
f"is missing critical data (messages, tokens, masks, or seed). Skipping."
)
continue
max_initial_tokens = 0
if original_step_data["tokens"]:
max_initial_tokens = (
max(
len(alt_tokens)
for alt_tokens in original_step_data["tokens"]
if isinstance(alt_tokens, list)
)
if any(
isinstance(alt_tokens, list)
for alt_tokens in original_step_data["tokens"]
)
else 0
)
if max_initial_tokens <= self.config.max_trajectory_tokens:
filtered_trajectory.append(original_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} compliant. "
f"Max tokens: {max_initial_tokens}"
)
continue
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} (max tokens: {max_initial_tokens}) "
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
)
working_messages = [
msgs_list.copy() for msgs_list in original_step_data["messages"] or []
]
working_tokens = [
tkns_list.copy() for tkns_list in original_step_data["tokens"] or []
]
working_masks = [
msks_list.copy() for msks_list in original_step_data["masks"] or []
]
max_current_tokens = max_initial_tokens
num_alternatives = len(working_messages)
if num_alternatives == 0:
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping."
)
continue
retokenization_error_this_step = False
while max_current_tokens > self.config.max_trajectory_tokens:
target_pop_counts_per_alt = []
for alt_idx in range(num_alternatives):
alt_msg_list = working_messages[alt_idx]
num_preserved_at_end = 0
if (
len(alt_msg_list) > 1
and alt_msg_list[-1]["role"] in UNMASKED_ROLES
):
num_preserved_at_end = 1
if (
len(alt_msg_list) > 2
and alt_msg_list[-2]["role"] == "environment"
):
num_preserved_at_end = 2
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
if available_to_pop <= 0:
target_pop_counts_per_alt.append(0)
else:
can_pop_pair = (
available_to_pop >= 2
and len(alt_msg_list) > 2
and alt_msg_list[1]["role"] == "environment"
and alt_msg_list[2]["role"] in UNMASKED_ROLES
)
if can_pop_pair:
target_pop_counts_per_alt.append(2)
else:
target_pop_counts_per_alt.append(1)
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
if not positive_pop_counts:
break
min_pop_this_round = min(positive_pop_counts)
temp_new_alt_tokens = []
temp_new_alt_masks = []
max_tokens_after_this_trunc = 0
for alt_idx in range(num_alternatives):
for _ in range(min_pop_this_round):
if len(working_messages[alt_idx]) > 1:
working_messages[alt_idx].pop(1)
else:
logger.error(
f"[_ensure_trajectory_token_limit] Critical error during pop for "
f"alt {alt_idx}, step {step_idx}. List too short."
)
retokenization_error_this_step = True
break
if retokenization_error_this_step:
break
try:
tokenized_alt = tokenize_for_trainer_multistep(
self.tokenizer, working_messages[alt_idx]
)
temp_new_alt_tokens.append(tokenized_alt["tokens"])
temp_new_alt_masks.append(tokenized_alt["masks"])
max_tokens_after_this_trunc = max(
max_tokens_after_this_trunc, len(tokenized_alt["tokens"])
)
except Exception as e:
logger.error(
f"[_ensure_trajectory_token_limit] Error re-tokenizing alt {alt_idx} "
f"in step {step_idx} after truncation: {e}"
)
retokenization_error_this_step = True
break
if retokenization_error_this_step:
break
working_tokens = temp_new_alt_tokens
working_masks = temp_new_alt_masks
max_current_tokens = max_tokens_after_this_trunc
logger.debug(
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, "
f"max tokens: {max_current_tokens}"
)
if (
not retokenization_error_this_step
and max_current_tokens <= self.config.max_trajectory_tokens
):
updated_step_data: BlackjackScoredDataGroup = {
"seed": original_step_data["seed"],
"messages": working_messages,
"tokens": working_tokens,
"masks": working_masks,
"scores": original_step_data.get("scores"),
"parsed_action": original_step_data.get("parsed_action"),
}
filtered_trajectory.append(updated_step_data)
logger.info(
f"[_ensure_trajectory_token_limit] Step {step_idx} successfully processed. "
f"Final max tokens: {max_current_tokens}"
)
else:
logger.warning(
f"[_ensure_trajectory_token_limit] Discarding step {step_idx}. "
f"Max tokens ({max_current_tokens}) still exceed limit ({self.config.max_trajectory_tokens}) "
f"or retokenization error occurred ({retokenization_error_this_step})."
)
if len(filtered_trajectory) < len(trajectory):
logger.warning(
f"[_ensure_trajectory_token_limit] Filtered out "
f"{len(trajectory) - len(filtered_trajectory)} steps "
f"due to token limit constraints. Original: {len(trajectory)}, Filtered: {len(filtered_trajectory)}"
)
return filtered_trajectory
if __name__ == "__main__":
BlackjackEnv.cli()