mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
893 lines
35 KiB
Python
893 lines
35 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BlackjackEnv: Trainer environment for Gymnasium Blackjack
|
|
|
|
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
|
|
using function-call style actions. Extends BaseEnv.
|
|
|
|
Sort of inspired by VinePPO, but uses a recursive exact value calculation
|
|
instead of Monte Carlo sampling (because the action space is so small and
|
|
the environment is deterministic, plus short episode lengths).
|
|
"""
|
|
|
|
import copy
|
|
import json
|
|
import logging
|
|
import random
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import gymnasium
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
EvalHandlingEnum,
|
|
OpenaiConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
from atroposlib.utils.message_history_utils import truncate_thinking
|
|
from atroposlib.utils.tool_call_parser import parse_tool_call
|
|
from atroposlib.utils.best_of_n_selection import select_best_index
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BlackjackEnvConfig(BaseEnvConfig):
|
|
env_name: str = "Blackjack-v1"
|
|
temperature: float = 0.7
|
|
top_p: float = 0.9
|
|
max_turns: Optional[int] = 5
|
|
wandb_name: str = "blackjack"
|
|
thinking_active: bool = True
|
|
eval_episodes: int = 100
|
|
max_think_chars_history: int = 3000
|
|
max_trajectory_tokens: int = 24576 #seq_len of RL trainer
|
|
debug_mode: bool = False
|
|
group_size: int = 16
|
|
tiebreak_token_factor: float = 0.01
|
|
|
|
|
|
class BlackjackScoredDataGroup(ScoredDataGroup):
|
|
seed: int
|
|
tokens: Optional[List[List[int]]] = None
|
|
masks: Optional[List[List[int]]] = None
|
|
scores: Optional[List[float]] = None
|
|
messages: Optional[List[List[Dict]]] = None
|
|
parsed_actions: Optional[List[int]] = None
|
|
|
|
|
|
class EpisodeState:
|
|
def __init__(self, seed: int, env: gymnasium.Env):
|
|
self.seed = seed
|
|
self.env = env
|
|
self.message_history: List[Dict] = []
|
|
self.actions: List[int] = []
|
|
self.step_rewards: List[float] = []
|
|
self.total_reward: float = 0.0
|
|
self.num_steps: int = 0
|
|
self.num_correct_actions: int = 0
|
|
self.num_total_actions: int = 0
|
|
|
|
|
|
class BlackjackEnv(BaseEnv):
|
|
def __init__(
|
|
self,
|
|
config: BlackjackEnvConfig,
|
|
server_configs: List[OpenaiConfig],
|
|
slurm: bool = True,
|
|
testing: bool = False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.episodes: Dict[int, EpisodeState] = {}
|
|
self.debug_mode = config.debug_mode
|
|
self.completed_episode_metrics_buffer: List[Dict[str, float]] = []
|
|
if self.debug_mode:
|
|
logger.setLevel(logging.DEBUG)
|
|
else:
|
|
if logger.level == logging.NOTSET or logger.level > logging.WARNING:
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
self.tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "take_action",
|
|
"description": "Choose to 'hit' or 'stick' in Blackjack.",
|
|
"parameters": {
|
|
"action": {"type": "string", "enum": ["hit", "stick"]}
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
tools_json = json.dumps(self.tools)
|
|
self.system_prompt = (
|
|
"You are an AI agent playing Blackjack who uses extreme long chains of thought "
|
|
"to carefully consider the probabilities and optimal strategy. "
|
|
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
|
|
"You should enclose your thoughts and internal monologue inside <think> </think> tags, and then "
|
|
"provide your decision using the take_action function call. You may use extremely long chains "
|
|
"of thought to carefully consider the probabilities and optimal strategy.\n\n"
|
|
f"<tools>\n{tools_json}\n</tools>\n\n"
|
|
"For your function call, return a JSON object with function name and arguments "
|
|
"within <tool_call> </tool_call> tags with the following schema:\n"
|
|
'<tool_call>\n{"arguments": {"action": "hit"}, "name": "take_action"}\n</tool_call>\n\n'
|
|
"Your full answer format should be:\n"
|
|
"<think>\n[Your detailed reasoning process about whether to hit or stick]\n</think>\n\n"
|
|
'<tool_call>\n{"arguments": {"action": "stick"}, "name": "take_action"}\n</tool_call>\n\n'
|
|
"Remember to carefully consider the probabilities and optimal strategy for Blackjack."
|
|
)
|
|
|
|
def _get_or_create_episode(self, seed: int) -> EpisodeState:
|
|
if seed not in self.episodes:
|
|
env = gymnasium.make(self.config.env_name)
|
|
obs, _ = env.reset(seed=seed)
|
|
ep = EpisodeState(seed, env)
|
|
ep.message_history = [{"role": "system", "content": self.system_prompt}]
|
|
ep.message_history.append(
|
|
{"role": "environment", "content": self._format_observation(obs)}
|
|
)
|
|
self.episodes[seed] = ep
|
|
return self.episodes[seed]
|
|
|
|
def _format_observation(self, obs: Tuple[int, int, int]) -> str:
|
|
player_sum, dealer_card, usable_ace = obs
|
|
return f"Your hand sum is {player_sum}. Dealer showing: {dealer_card}. You have a usable ace: {usable_ace}."
|
|
|
|
def _score_response(
|
|
self,
|
|
env_reward: float,
|
|
response_text: str,
|
|
parsed_action: int,
|
|
episode_seed: int,
|
|
) -> float:
|
|
"""
|
|
Calculates a score for a single agent response based purely on environment reward
|
|
and a penalty for invalid action format.
|
|
"""
|
|
current_env_reward = env_reward * 1.0
|
|
# Action is good?
|
|
if parsed_action == -1:
|
|
current_env_reward -= 0.2
|
|
else:
|
|
current_env_reward += 0.2
|
|
|
|
# Check the thinking tags exist
|
|
match = re.search(r"<think>(.*?)</think>", response_text)
|
|
if match:
|
|
thinking_content = match.group(1)
|
|
if thinking_content:
|
|
current_env_reward += 0.2
|
|
# Check there's actually valid content (not just whitespace)
|
|
if not thinking_content.strip():
|
|
current_env_reward -= 0.2
|
|
else:
|
|
current_env_reward -= 0.2
|
|
|
|
# Calculate the number of tokens in the agent's response
|
|
if response_text:
|
|
num_tokens = len(self.tokenizer.encode(response_text))
|
|
else:
|
|
num_tokens = 0
|
|
|
|
# tiebreak & small length penalty
|
|
if self.config.max_token_length > 0:
|
|
token_ratio = min(1.0, num_tokens / self.config.max_token_length)
|
|
tiebreak_bonus = self.config.tiebreak_token_factor * (1.0 - token_ratio)
|
|
current_env_reward += tiebreak_bonus
|
|
return current_env_reward
|
|
|
|
def _parse_tool_call(self, response: str) -> int:
|
|
if not response:
|
|
logger.warning(
|
|
"Attempted to parse an empty response string. Returning invalid action (-1)."
|
|
)
|
|
return -1
|
|
|
|
parsed_name, parsed_args, is_error = parse_tool_call(
|
|
response, self.tools, ["tool_call"]
|
|
)
|
|
if is_error:
|
|
error_detail = (
|
|
parsed_name
|
|
if isinstance(parsed_name, str) and parsed_name
|
|
else "Parser indicated error, but no specific message was returned in the typical error slot."
|
|
)
|
|
logger.warning(
|
|
f"Failed to parse tool call. Full response: '{response}'. Error detail: {error_detail}"
|
|
)
|
|
return -1
|
|
|
|
action = parsed_args.get("action", "").lower()
|
|
if action == "hit":
|
|
return 1
|
|
elif action == "stick":
|
|
return 0
|
|
else:
|
|
logger.warning(
|
|
f"Successfully parsed tool call, but action is invalid. Action: '{action}'. "
|
|
f"Full response: '{response}'. Parsed args: {parsed_args}"
|
|
)
|
|
return -1
|
|
|
|
async def _sample_response(self, messages: List[Dict], n: int = 1) -> List[str]:
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
try:
|
|
completions = await self.server.completion(
|
|
prompt=prompt,
|
|
n=n,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
)
|
|
return [choice.text for choice in completions.choices]
|
|
except Exception as e:
|
|
logger.error(f"API error during completion: {e}")
|
|
return []
|
|
|
|
async def _estimate_value(
|
|
self,
|
|
episode_seed_for_sim: int,
|
|
env_actions_to_replay: List[int],
|
|
) -> float:
|
|
"""Calculate exact state value V*(s)
|
|
|
|
Args:
|
|
episode_seed_for_sim: The seed of the original episode for deterministic env creation.
|
|
env_actions_to_replay: List of environment actions (0 or 1) taken to reach current state s.
|
|
"""
|
|
v_star_cache: Dict[Tuple[int, int, int], float] = {}
|
|
|
|
def _get_v_star_recursive(
|
|
obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env
|
|
) -> float:
|
|
player_sum, dealer_card, usable_ace = obs_tuple
|
|
|
|
# Base Case 1: Bust
|
|
if player_sum > 21:
|
|
return -1.0
|
|
|
|
# Base Case 2: Check memoization cache
|
|
if obs_tuple in v_star_cache:
|
|
return v_star_cache[obs_tuple]
|
|
|
|
env_for_stick = copy.deepcopy(current_env)
|
|
_, reward_stick, _, _, _ = env_for_stick.step(0)
|
|
# stick is terminal, so reward is final outcome
|
|
q_star_stick = reward_stick
|
|
|
|
# Q-value for HIT (action 1)
|
|
env_for_hit = copy.deepcopy(current_env)
|
|
obs_hit, reward_hit, term_hit, trunc_hit, _ = env_for_hit.step(1)
|
|
|
|
if term_hit or trunc_hit: # Game ended after hitting
|
|
q_star_hit = reward_hit
|
|
else: # Game continues, recursively find V* of next state
|
|
# reward_hit is typically 0 if the game didn't end
|
|
q_star_hit = reward_hit + _get_v_star_recursive(obs_hit, env_for_hit)
|
|
|
|
v_star = max(q_star_stick, q_star_hit)
|
|
v_star_cache[obs_tuple] = v_star
|
|
return v_star
|
|
|
|
sim_env = None
|
|
try:
|
|
sim_env = gymnasium.make(self.config.env_name)
|
|
current_obs, _ = sim_env.reset(seed=episode_seed_for_sim)
|
|
|
|
# Replay actions to reach the current state s_t
|
|
is_terminal_after_replay = False
|
|
for action_idx, prev_action in enumerate(env_actions_to_replay):
|
|
current_obs, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
|
|
if term_replay or trunc_replay:
|
|
logger.debug(
|
|
f"[_estimate_value] State became terminal during action replay "
|
|
f"(action {action_idx+1}/{len(env_actions_to_replay)} of prev_actions). Value is 0."
|
|
)
|
|
is_terminal_after_replay = True
|
|
break
|
|
|
|
if is_terminal_after_replay:
|
|
return 0.0
|
|
final_v_star = _get_v_star_recursive(current_obs, sim_env)
|
|
return final_v_star
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"[_estimate_value] Error during exact value"
|
|
f" calculation for seed {episode_seed_for_sim}, "
|
|
f"actions {env_actions_to_replay}: {e}",
|
|
exc_info=True,
|
|
)
|
|
return 0.0
|
|
finally:
|
|
if sim_env is not None:
|
|
sim_env.close()
|
|
|
|
async def _collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
|
|
"""Collect data for ONE trajectory, evaluating G alternatives per step using MC advantages."""
|
|
G = self.config.group_size
|
|
max_turns = self.config.max_turns or 5
|
|
|
|
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
|
|
episode_summary_metrics: Optional[Dict[str, Any]] = None
|
|
|
|
logger.info(
|
|
f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}."
|
|
)
|
|
|
|
try:
|
|
ep = self._get_or_create_episode(seed)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed}] Failed to create/get episode: {e}",
|
|
exc_info=True,
|
|
)
|
|
return []
|
|
|
|
for turn in range(max_turns):
|
|
current_state_messages = ep.message_history.copy()
|
|
logger.debug(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}/{max_turns}] "
|
|
f"Current state history length: {len(current_state_messages)}"
|
|
)
|
|
|
|
try:
|
|
value_t = await self._estimate_value(
|
|
episode_seed_for_sim=ep.seed, env_actions_to_replay=ep.actions
|
|
)
|
|
logger.debug(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Estimated V(s_t) = {value_t:.4f}"
|
|
)
|
|
except Exception as e_vt:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Error estimating V(s_t): {e_vt}",
|
|
exc_info=True,
|
|
)
|
|
break
|
|
|
|
messages_for_llm = current_state_messages.copy()
|
|
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
|
|
messages_for_llm.append({"role": "agent", "content": agent_prompt_content})
|
|
|
|
try:
|
|
responses = await self._sample_response(messages_for_llm, n=G)
|
|
if len(responses) != G:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
|
f"Expected {G} responses, got {len(responses)}. "
|
|
f"Aborting trajectory."
|
|
)
|
|
break
|
|
except Exception as e_sample:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Error sampling responses: {e_sample}",
|
|
exc_info=True,
|
|
)
|
|
break
|
|
|
|
alt_full_responses: List[str] = []
|
|
alt_parsed_actions: List[int] = []
|
|
alt_env_actions: List[int] = []
|
|
alt_raw_rewards: List[float] = []
|
|
alt_combined_rewards: List[float] = []
|
|
alt_next_state_msgs: List[List[Dict]] = []
|
|
alt_is_terminal: List[bool] = []
|
|
alt_tokens: List[List[int]] = []
|
|
alt_masks: List[List[int]] = []
|
|
alt_value_next: List[float] = []
|
|
alt_advantages: List[float] = []
|
|
|
|
for i in range(G):
|
|
llm_output_only = responses[i]
|
|
full_agent_response = agent_prompt_content + llm_output_only
|
|
alt_full_responses.append(full_agent_response)
|
|
|
|
parsed_action = self._parse_tool_call(full_agent_response)
|
|
alt_parsed_actions.append(parsed_action)
|
|
|
|
env_action = parsed_action if parsed_action != -1 else 0
|
|
alt_env_actions.append(env_action)
|
|
|
|
sim_env = None
|
|
raw_env_reward_i = 0.0
|
|
term_i, trunc_i = False, False
|
|
next_state_msgs_i = []
|
|
try:
|
|
sim_env = gymnasium.make(self.config.env_name)
|
|
_, _ = sim_env.reset(seed=ep.seed)
|
|
for prev_action in ep.actions:
|
|
_, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
|
|
if term_replay or trunc_replay:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
|
f"Sim env terminated during replay. State mismatch?"
|
|
)
|
|
term_i, trunc_i = True, True
|
|
raw_env_reward_i = 0.0
|
|
break
|
|
|
|
if not (term_i or trunc_i):
|
|
sim_obs_next, raw_env_reward_i, term_i, trunc_i, _ = (
|
|
sim_env.step(env_action)
|
|
)
|
|
|
|
alt_raw_rewards.append(raw_env_reward_i)
|
|
alt_is_terminal.append(term_i or trunc_i)
|
|
|
|
combined_reward_i = self._score_response(
|
|
raw_env_reward_i, full_agent_response, parsed_action, ep.seed
|
|
)
|
|
alt_combined_rewards.append(combined_reward_i)
|
|
|
|
current_state_plus_response = current_state_messages + [
|
|
{"role": "agent", "content": full_agent_response}
|
|
]
|
|
if sim_obs_next is not None:
|
|
next_state_msgs_i = current_state_plus_response + [
|
|
{
|
|
"role": "environment",
|
|
"content": self._format_observation(sim_obs_next),
|
|
}
|
|
]
|
|
else:
|
|
next_state_msgs_i = current_state_plus_response
|
|
alt_next_state_msgs.append(next_state_msgs_i)
|
|
|
|
tokenized_i = tokenize_for_trainer(
|
|
self.tokenizer, next_state_msgs_i
|
|
)
|
|
alt_tokens.append(tokenized_i["tokens"])
|
|
alt_masks.append(tokenized_i["masks"])
|
|
|
|
except Exception as e_sim:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
|
f"Error simulating action {env_action}: {e_sim}",
|
|
exc_info=True,
|
|
)
|
|
alt_raw_rewards.append(0.0)
|
|
alt_combined_rewards.append(-1.0)
|
|
alt_next_state_msgs.append(
|
|
current_state_messages
|
|
+ [{"role": "agent", "content": full_agent_response}]
|
|
)
|
|
alt_is_terminal.append(True)
|
|
alt_tokens.append([])
|
|
alt_masks.append([])
|
|
finally:
|
|
if sim_env:
|
|
sim_env.close()
|
|
|
|
alt_value_next: List[float] = []
|
|
for i in range(G):
|
|
if not alt_is_terminal[i]:
|
|
try:
|
|
actions_to_reach_s_prime = ep.actions + [alt_env_actions[i]]
|
|
value_next_i = await self._estimate_value(
|
|
episode_seed_for_sim=ep.seed,
|
|
env_actions_to_replay=actions_to_reach_s_prime,
|
|
)
|
|
alt_value_next.append(value_next_i)
|
|
except Exception as e_vn:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
|
f"Error estimating V(s'): {e_vn}",
|
|
exc_info=True,
|
|
)
|
|
alt_value_next.append(0.0)
|
|
else:
|
|
alt_value_next.append(0.0)
|
|
|
|
for i in range(G):
|
|
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
|
|
# If we pass this then instead of raw scores, implicitly, we're
|
|
# doing some credit assignment. Could maybe do bonus on a win too
|
|
# and/or apply with a discount factor to alts in winning trajectories
|
|
alt_advantages.append(advantage_i)
|
|
logger.debug(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
|
f"CombinedR={alt_combined_rewards[i]:.2f}, V_t={value_t:.2f}, "
|
|
f"V_t+1={alt_value_next[i]:.2f} => Advantage={advantage_i:.2f}"
|
|
)
|
|
|
|
if (
|
|
len(alt_tokens) != G
|
|
or len(alt_masks) != G
|
|
or len(alt_advantages) != G
|
|
or len(alt_next_state_msgs) != G
|
|
or len(alt_parsed_actions) != G
|
|
):
|
|
# sanity check
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
|
f"Mismatch in alternative list lengths after processing. "
|
|
f"Aborting trajectory."
|
|
)
|
|
break
|
|
|
|
trajectory_data_for_trainer.append(
|
|
BlackjackScoredDataGroup(
|
|
seed=ep.seed,
|
|
tokens=alt_tokens,
|
|
masks=alt_masks,
|
|
scores=alt_advantages,
|
|
messages=alt_next_state_msgs,
|
|
parsed_actions=alt_parsed_actions,
|
|
)
|
|
)
|
|
|
|
# token lengths for tie-breaking during selection
|
|
alt_token_lengths = [len(tkns) for tkns in alt_tokens]
|
|
|
|
best_advantage_idx = select_best_index(
|
|
primary_scores=alt_advantages,
|
|
secondary_scores=alt_token_lengths,
|
|
primary_higher_is_better=True,
|
|
secondary_lower_is_better=True
|
|
)
|
|
|
|
chosen_advantage_for_log = alt_advantages[best_advantage_idx]
|
|
chosen_token_length_for_log = alt_token_lengths[best_advantage_idx]
|
|
logger.debug(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
|
f"Selected Alt {best_advantage_idx} "
|
|
f"(Adv: {chosen_advantage_for_log:.2f}, "
|
|
f"Tokens: {chosen_token_length_for_log}) "
|
|
f"from {G} alternatives using select_best_index."
|
|
)
|
|
|
|
chosen_env_action = alt_env_actions[best_advantage_idx]
|
|
chosen_full_response = alt_full_responses[best_advantage_idx]
|
|
chosen_raw_env_reward = alt_raw_rewards[best_advantage_idx]
|
|
chosen_is_terminal = alt_is_terminal[best_advantage_idx]
|
|
chosen_parsed_action = alt_parsed_actions[best_advantage_idx]
|
|
|
|
logger.info(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Chosen action to step env: "
|
|
f"{chosen_env_action} (from Alt {best_advantage_idx} with "
|
|
f"Adv {chosen_advantage_for_log:.2f})"
|
|
)
|
|
|
|
ep.num_total_actions += 1
|
|
if chosen_parsed_action != -1:
|
|
ep.num_correct_actions += 1
|
|
|
|
ep.message_history = current_state_messages
|
|
|
|
response_for_history = truncate_thinking(
|
|
chosen_full_response, self.tokenizer, self.config.max_think_chars_history
|
|
)
|
|
ep.message_history.append(
|
|
{"role": "agent", "content": response_for_history}
|
|
)
|
|
|
|
try:
|
|
main_obs, main_reward, main_term, main_trunc, _ = ep.env.step(
|
|
chosen_env_action
|
|
)
|
|
if abs(main_reward - chosen_raw_env_reward) > 1e-6:
|
|
logger.warning(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
|
f"Mismatch between simulated reward ({chosen_raw_env_reward}) and "
|
|
f"main env step reward ({main_reward}) for chosen action {chosen_env_action}."
|
|
)
|
|
if (main_term or main_trunc) != chosen_is_terminal:
|
|
logger.warning(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
|
f"Mismatch between simulated terminal state ({chosen_is_terminal}) and "
|
|
f"main env step terminal state ({(main_term or main_trunc)}) "
|
|
f"for chosen action {chosen_env_action}."
|
|
)
|
|
|
|
term = main_term
|
|
trunc = main_trunc
|
|
obs = main_obs
|
|
ep.actions.append(chosen_env_action)
|
|
ep.step_rewards.append(main_reward)
|
|
ep.num_steps += 1
|
|
|
|
if obs:
|
|
ep.message_history.append(
|
|
{
|
|
"role": "environment",
|
|
"content": self._format_observation(obs),
|
|
}
|
|
)
|
|
except Exception as e_main_step:
|
|
logger.error(
|
|
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
|
f"Error stepping MAIN environment with chosen action {chosen_env_action}: {e_main_step}",
|
|
exc_info=True,
|
|
)
|
|
term, trunc = True, True
|
|
|
|
if term or trunc:
|
|
ep.total_reward = sum(ep.step_rewards)
|
|
logger.info(
|
|
f"[Collect Trajectory Seed: {seed}] Trajectory ended. "
|
|
f"Term={term}, Trunc={trunc}. Total raw env reward: {ep.total_reward}"
|
|
)
|
|
break
|
|
|
|
final_raw_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0
|
|
logger.info(
|
|
f"[Collect Trajectory Seed: {seed}] Finished collecting trajectory. "
|
|
f"Steps collected: {len(trajectory_data_for_trainer)}, "
|
|
f"Final raw reward: {final_raw_reward:.2f}"
|
|
)
|
|
|
|
if ep:
|
|
game_outcome = 0
|
|
if final_raw_reward > 0:
|
|
game_outcome = 1
|
|
elif final_raw_reward < 0:
|
|
game_outcome = -1
|
|
# debugging
|
|
episode_summary_metrics = {
|
|
"seed": ep.seed,
|
|
"total_reward": final_raw_reward,
|
|
"num_steps": ep.num_steps,
|
|
"num_correct_actions": ep.num_correct_actions,
|
|
"num_total_actions": ep.num_total_actions,
|
|
"game_outcome": game_outcome,
|
|
}
|
|
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
|
|
logger.debug(
|
|
f"[Collect Trajectory Seed: {seed}] Added episode summary to buffer: {episode_summary_metrics}"
|
|
)
|
|
|
|
if seed in self.episodes:
|
|
try:
|
|
self.episodes[seed].env.close()
|
|
except Exception as e_close:
|
|
logger.warning(
|
|
f"[Collect Trajectory Seed: {seed}] Exception closing final env: {e_close}"
|
|
)
|
|
del self.episodes[seed]
|
|
|
|
return ensure_trajectory_token_limit(
|
|
trajectory_data_for_trainer,
|
|
self.tokenizer,
|
|
self.config.max_trajectory_tokens,
|
|
)
|
|
|
|
async def score(
|
|
self, rollout_group_data: List[BlackjackScoredDataGroup]
|
|
) -> List[Optional[BlackjackScoredDataGroup]]:
|
|
"""Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
|
|
already contains the A*(s,a) advantages from the collection phase.
|
|
|
|
If you wanted to play around with additional scoring metrics, you could do so here.
|
|
Eg, bonuses for the specific winning action trajectory
|
|
|
|
Args:
|
|
rollout_group_data: List of BlackjackScoredDataGroup objects containing the collected rollout data.
|
|
|
|
Returns:
|
|
List of BlackjackScoredDataGroup objects with the scores field updated.
|
|
"""
|
|
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
|
|
return rollout_group_data
|
|
|
|
async def collect_trajectories(
|
|
self, item: Tuple[int, int]
|
|
) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]:
|
|
"""Collect trajectories for training.
|
|
|
|
Args:
|
|
item: Tuple containing the seed and the group index.
|
|
|
|
Returns:
|
|
Tuple of two lists:
|
|
- List of BlackjackScoredDataGroup objects containing the collected rollout data.
|
|
- List of Tuple[int, int] objects for the backlog
|
|
"""
|
|
seed, _ = item
|
|
traj = await self._collect_trajectory(seed)
|
|
if not traj:
|
|
logger.warning(f"[Collect Trajectories] Empty trajectory for seed {seed}.")
|
|
return traj, []
|
|
|
|
async def setup(self):
|
|
pass
|
|
|
|
async def get_next_item(self) -> Tuple[int, int]:
|
|
return (random.randint(0, 1000000), 0)
|
|
|
|
async def rollout_and_score_eval(self, seed: int) -> Dict[str, float]:
|
|
"""Run a single episode for evaluation with a single response per step."""
|
|
ep = self._get_or_create_episode(seed)
|
|
max_turns = self.config.max_turns or 5
|
|
metrics = {
|
|
"seed": seed,
|
|
"total_reward": 0.0,
|
|
"num_turns": 0,
|
|
"num_correct_actions": 0,
|
|
"num_invalid_actions": 0,
|
|
"game_outcome": 0,
|
|
}
|
|
|
|
for turn in range(max_turns):
|
|
messages = ep.message_history.copy()
|
|
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
|
|
messages.append({"role": "agent", "content": agent_prompt_content})
|
|
|
|
responses = await self._sample_response(messages, n=1)
|
|
if not responses:
|
|
logger.error(
|
|
f"[Eval Seed: {seed}, Turn: {turn+1}] No response. Aborting."
|
|
)
|
|
break
|
|
|
|
llm_output_only = responses[0]
|
|
full_agent_response = agent_prompt_content + llm_output_only
|
|
|
|
action = self._parse_tool_call(full_agent_response)
|
|
if action == -1:
|
|
metrics["num_invalid_actions"] += 1
|
|
action = 0
|
|
else:
|
|
metrics["num_correct_actions"] += 1
|
|
|
|
try:
|
|
obs, reward, term, trunc, _ = ep.env.step(action)
|
|
except Exception as e:
|
|
logger.error(f"[Eval Seed: {seed}, Turn: {turn+1}] Env error: {e}")
|
|
term = True
|
|
reward = -1.0
|
|
obs = None
|
|
|
|
metrics["total_reward"] += reward
|
|
metrics["num_turns"] = turn + 1
|
|
|
|
response_for_history = truncate_thinking(
|
|
full_agent_response, self.tokenizer, self.config.max_think_chars_history
|
|
)
|
|
|
|
ep.message_history.append(
|
|
{"role": "agent", "content": response_for_history}
|
|
)
|
|
|
|
if obs:
|
|
ep.message_history.append(
|
|
{"role": "environment", "content": self._format_observation(obs)}
|
|
)
|
|
|
|
if term or trunc:
|
|
metrics["game_outcome"] = int(reward)
|
|
logger.info(f"[Eval Seed: {seed}] Episode ended. Outcome: {reward}")
|
|
break
|
|
|
|
ep.env.close()
|
|
del self.episodes[seed]
|
|
return metrics
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
if not self.config.use_wandb:
|
|
logger.info("Skipping evaluation as wandb is not enabled.")
|
|
return
|
|
num_eval_episodes = self.config.eval_episodes
|
|
logger.info(f"Starting evaluation for {num_eval_episodes} episodes.")
|
|
eval_tasks = [
|
|
self.rollout_and_score_eval(random.randint(1000001, 2000000))
|
|
for _ in range(num_eval_episodes)
|
|
]
|
|
all_metrics = await tqdm_asyncio.gather(*eval_tasks)
|
|
valid_metrics = [m for m in all_metrics if m]
|
|
if not valid_metrics:
|
|
logger.warning("No valid evaluation metrics.")
|
|
return
|
|
|
|
num_completed = len(valid_metrics)
|
|
avg_total_reward = sum(m["total_reward"] for m in valid_metrics) / num_completed
|
|
avg_num_turns = sum(m["num_turns"] for m in valid_metrics) / num_completed
|
|
total_correct = sum(m["num_correct_actions"] for m in valid_metrics)
|
|
total_invalid = sum(m["num_invalid_actions"] for m in valid_metrics)
|
|
total_actions = total_correct + total_invalid
|
|
action_accuracy = total_correct / total_actions if total_actions > 0 else 0
|
|
win_rate = (
|
|
sum(1 for m in valid_metrics if m["game_outcome"] == 1) / num_completed
|
|
)
|
|
loss_rate = (
|
|
sum(1 for m in valid_metrics if m["game_outcome"] == -1) / num_completed
|
|
)
|
|
draw_rate = (
|
|
sum(1 for m in valid_metrics if m["game_outcome"] == 0) / num_completed
|
|
)
|
|
|
|
self.eval_metrics = [
|
|
("eval/avg_total_reward", avg_total_reward),
|
|
("eval/avg_num_turns", avg_num_turns),
|
|
("eval/action_accuracy", action_accuracy),
|
|
("eval/win_rate", win_rate),
|
|
("eval/loss_rate", loss_rate),
|
|
("eval/draw_rate", draw_rate),
|
|
("eval/num_completed_episodes", num_completed),
|
|
]
|
|
logger.info(f"Evaluation completed. Metrics: {self.eval_metrics}")
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
if self.completed_episode_metrics_buffer:
|
|
num_episodes = len(self.completed_episode_metrics_buffer)
|
|
avg_reward = (
|
|
sum(m["total_reward"] for m in self.completed_episode_metrics_buffer)
|
|
/ num_episodes
|
|
)
|
|
avg_steps = (
|
|
sum(m["num_steps"] for m in self.completed_episode_metrics_buffer)
|
|
/ num_episodes
|
|
)
|
|
win_rate = (
|
|
sum(
|
|
1
|
|
for m in self.completed_episode_metrics_buffer
|
|
if m["game_outcome"] == 1
|
|
)
|
|
/ num_episodes
|
|
)
|
|
wandb_metrics[
|
|
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward"
|
|
] = avg_reward
|
|
wandb_metrics[
|
|
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_steps"
|
|
] = avg_steps
|
|
wandb_metrics[
|
|
f"{self.wandb_prepend or 'blackjack'}_train/episode_win_rate"
|
|
] = win_rate
|
|
wandb_metrics[f"{self.wandb_prepend or 'blackjack'}_train/num_episodes"] = (
|
|
num_episodes
|
|
)
|
|
self.completed_episode_metrics_buffer = []
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
|
|
env_config = BlackjackEnvConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=16,
|
|
use_wandb=True,
|
|
max_num_workers=128,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=2000,
|
|
batch_size=1024,
|
|
steps_per_eval=20,
|
|
max_token_length=1024 * 16,
|
|
inference_weight=1.0,
|
|
wandb_name="fundamental_metric_prediction",
|
|
data_path_to_save_groups=None,
|
|
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
|
eval_limit_ratio=0.1,
|
|
env_name="Blackjack-v1",
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
max_turns=5,
|
|
thinking_active=True,
|
|
eval_episodes=100,
|
|
max_think_chars_history=3000,
|
|
max_trajectory_tokens=24576,
|
|
debug_mode=False,
|
|
tiebreak_token_factor=0.01,
|
|
)
|
|
server_configs = [
|
|
OpenaiConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
)
|
|
]
|
|
return env_config, server_configs
|
|
|
|
@classmethod
|
|
def cli(cls):
|
|
super().cli()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
BlackjackEnv.cli()
|