Removed old code, added comments

This commit is contained in:
Shannon Sands 2025-05-10 08:39:52 +10:00
parent ba604d44f9
commit 0248cc1227
4 changed files with 3268 additions and 99 deletions

View file

@ -1,3 +1,13 @@
#!/usr/bin/env python3
"""
BlackjackEnv: Trainer environment for Gymnasium Blackjack
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
using function-call style actions. Extends BaseEnv.
Uses Monte Carlo sampling to estimate the value of the current state, similar to VinePPO
"""
import json
import logging
import random
@ -26,7 +36,7 @@ class BlackjackEnvConfig(BaseEnvConfig):
max_trajectory_tokens: int = 24576
debug_mode: bool = False
group_size: int = 16
mc_samples: int = 3 # lowish K for MC value estimation
mc_samples: int = 3
class BlackjackScoredDataGroup(ScoredDataGroup):
@ -130,23 +140,17 @@ class BlackjackEnv(BaseEnv):
current_env_reward = env_reward
if parsed_action == -1:
current_env_reward -= 0.5 # Penalty for invalid action format
current_env_reward -= 0.5
logger.debug(
f"[_score_response Seed: {episode_seed}] Penalty applied to env_reward for "
f"invalid action format (-0.5). Current env_reward: {current_env_reward:.4f}"
)
# env_w = self.config.environment_reward_weight # Removed, env reward is 100%
# combined_score = ( # Simplified
# env_w * current_env_reward
# ) + format_or_tool_call_reward_component
final_score = current_env_reward
logger.debug(
f"[_score_response Seed: {episode_seed}] Score Calculation: "
f"EnvReward(raw): {env_reward:.4f}, EnvReward(adj for invalid): {current_env_reward:.4f} "
# f"OutputFromRewardFunctions (already weighted): {format_or_tool_call_reward_component:.4f}, " # Removed
f"==> Final Score (from env): {final_score:.4f}"
)
return final_score
@ -235,7 +239,6 @@ class BlackjackEnv(BaseEnv):
f"State s was already terminal. Value is 0."
)
all_rollout_returns.append(0.0)
# This means V(s_terminal) = 0, which is correct.
break
else:
rollout_reward_for_this_sample = 0.0
@ -484,10 +487,9 @@ class BlackjackEnv(BaseEnv):
)
alt_value_next.append(0.0)
else:
alt_value_next.append(0.0) # V(terminal) = 0
alt_value_next.append(0.0)
for i in range(G):
# Advantage = R_combined + gamma * V_raw(s') - V_raw(s) (gamma=1)
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
alt_advantages.append(advantage_i)
logger.debug(
@ -841,9 +843,8 @@ class BlackjackEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
env_config = BlackjackEnvConfig(
# Fields from fundamental_prediction_environment.py's BaseEnvConfig init:
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16, # Matches BlackjackEnvConfig default as well
group_size=16,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
@ -852,30 +853,28 @@ class BlackjackEnv(BaseEnv):
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction
wandb_name="fundamental_metric_prediction",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
# BlackjackEnvConfig specific fields (those NOT in BaseEnvConfig from fundamental_prediction)
# using their defined defaults from BlackjackEnvConfig:
env_name="Blackjack-v1", # Default from BlackjackEnvConfig
temperature=0.7, # Default from BlackjackEnvConfig
top_p=0.9, # Default from BlackjackEnvConfig
max_turns=5, # Default from BlackjackEnvConfig
thinking_active=True, # Default from BlackjackEnvConfig
eval_episodes=100, # Default from BlackjackEnvConfig
max_think_chars_history=3000, # Default from BlackjackEnvConfig
max_trajectory_tokens=24576, # Default from BlackjackEnvConfig
debug_mode=False, # Default from BlackjackEnvConfig
mc_samples=3, # Default from BlackjackEnvConfig
env_name="Blackjack-v1",
temperature=0.7,
top_p=0.9,
max_turns=5,
thinking_active=True,
eval_episodes=100,
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=False,
mc_samples=3,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256, # From fundamental_prediction_environment.py
num_requests_for_eval=256,
)
]
return env_config, server_configs
@ -973,7 +972,7 @@ class BlackjackEnv(BaseEnv):
and original_step_data.get("tokens")
and original_step_data.get("masks")
and original_step_data.get("seed") is not None
and original_step_data.get("parsed_actions") is not None # Specific to MC version
and original_step_data.get("parsed_actions") is not None
):
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env "
@ -981,7 +980,6 @@ class BlackjackEnv(BaseEnv):
)
continue
# Initial token calculation from original data
max_initial_tokens = 0
if original_step_data["tokens"]:
max_initial_tokens = max(
@ -1017,7 +1015,7 @@ class BlackjackEnv(BaseEnv):
for alt_idx in range(num_alternatives):
alt_msg_list = working_messages[alt_idx]
num_preserved_at_end = 0
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES:
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"]:
num_preserved_at_end = 1
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
num_preserved_at_end = 2
@ -1031,7 +1029,7 @@ class BlackjackEnv(BaseEnv):
available_to_pop >= 2 and
len(alt_msg_list) > 2 and
alt_msg_list[1]["role"] == "environment" and
alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES
alt_msg_list[2]["role"] in ["agent", "assistant"]
)
if can_pop_pair:
target_pop_counts_per_alt.append(2)
@ -1088,7 +1086,7 @@ class BlackjackEnv(BaseEnv):
"tokens": working_tokens,
"masks": working_masks,
"scores": original_step_data.get("scores"),
"parsed_actions": original_step_data.get("parsed_actions") # MC version specific
"parsed_actions": original_step_data.get("parsed_actions")
}
filtered_trajectory.append(updated_step_data)
logger.info(

View file

@ -1,4 +1,3 @@
import argparse
import asyncio
import logging
import os
@ -14,88 +13,63 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# def parse_arguments(): # Removed
# parser = argparse.ArgumentParser(description="Blackjack environment local server")
# parser.add_argument(
# "--config",
# type=str,
# default="blackjack_local",
# help="Configuration file name (without .yaml extension, relative to "
# "envs/gymnasium/configs), or full path to a YAML file.",
# )
# return parser.parse_args()
async def main():
logger.info("Starting Blackjack environment local debug runner")
# args = parse_arguments() # Removed
# Removed logic for config_name_or_path and BlackjackEnv.config_init
# Create hardcoded configurations for local debugging
env_config = BlackjackEnvConfig(
# BaseEnvConfig fields, tailored for debug
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=1, # Debug single generation path
group_size=1,
use_wandb=False,
wandb_name="blackjack_local_debug", # Explicitly set for debug
wandb_name="blackjack_local_debug",
max_num_workers=1,
rollout_server_url="http://localhost:8000", # Standard default
rollout_server_url="http://localhost:8000",
total_steps=1,
batch_size=1, # Consistent with 1 step, 1 worker, group_size 1
steps_per_eval=0, # No eval steps needed
max_token_length=1024 * 4, # Reduced for faster local debugging if necessary
batch_size=1,
steps_per_eval=0,
max_token_length=1024 * 4,
inference_weight=1.0,
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.NONE, # No evaluation in this script
eval_handling=EvalHandlingEnum.NONE,
eval_limit_ratio=0.0,
# BlackjackEnvConfig specific fields (from blackjack_env.py's definition or defaults)
env_name="Blackjack-v1",
temperature=0.2, # Lower temperature for more deterministic debug output
top_p=0.9, # Standard default
max_turns=5, # Standard default
temperature=0.2,
top_p=0.9,
max_turns=5,
thinking_active=True,
eval_episodes=0, # No evaluation episodes
eval_episodes=0,
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=True, # Enable debug logging from the environment
mc_samples=1, # With group_size=1, this means 1 MC rollout for V(s)
debug_mode=True,
mc_samples=1,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4.1-mini", # Ensure this is locally available if not mocked
base_url="https://api.openai.com/v1", # Explicitly set OpenAI base URL
api_key=os.getenv("OPENAI_API_KEY"), # Use env var or default
num_requests_for_eval=0, # No eval requests
model_name="gpt-4.1-nano",
base_url="https://api.openai.com/v1",
api_key=os.getenv("OPENAI_API_KEY"),
num_requests_for_eval=0,
)
]
logger.info("Using hardcoded debug configuration.")
logger.debug(f"Env Config: {env_config}")
logger.debug(f"Server Configs: {server_configs}")
# Create and set up the environment using the loaded configs
try:
env = BlackjackEnv(
config=env_config,
server_configs=server_configs,
slurm=False, # Explicitly false for local testing
slurm=False,
)
except Exception as e:
logger.exception(f"Failed to initialize BlackjackEnv: {e}")
return
# Run a single trajectory directly
logger.info("Running a single trajectory directly")
try:
await env.setup() # Setup the server connection etc.
await env.setup()
seed = random.randint(0, 1000000)
logger.info(f"Using seed: {seed}")
# Make sure the episode exists before collecting
# This also initializes the message history correctly
_ = env._get_or_create_episode(seed)
result_trajectory = await env.collect_trajectory(seed)
@ -105,13 +79,9 @@ async def main():
episode_summary = None
if env.completed_episode_metrics_buffer:
# Assume the last entry in the buffer corresponds to the trajectory just run
episode_summary = env.completed_episode_metrics_buffer[-1]
# Optionally, clear the buffer if this script is only for single runs
# env.completed_episode_metrics_buffer.clear()
if episode_summary and episode_summary.get("seed") == seed:
# Print a final summary
logger.info("\n========== Episode Summary ==========")
logger.info(f"Seed: {episode_summary['seed']}")
logger.info(f"Total steps taken: {episode_summary['num_steps']}")
@ -129,7 +99,6 @@ async def main():
f"Game Outcome: {outcome_str} (Reward: {episode_summary['total_reward']:.0f})"
)
# Calculate and log action accuracy based on EpisodeState fields
if episode_summary["num_total_actions"] > 0:
accuracy = episode_summary["num_correct_actions"] / max(
1, episode_summary["num_total_actions"]

View file

@ -645,11 +645,11 @@ class BlackjackEnv(BaseEnv):
episode_summary_metrics = {
"seed": seed,
"total_env_reward": ep.total_env_reward,
"total_reward": ep.total_env_reward,
"num_correct_actions": ep.num_correct_actions,
"num_total_actions": ep.num_total_actions,
"game_outcome": game_outcome,
"num_steps_in_episode": len(ep.actions),
"num_steps": len(ep.actions),
}
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
@ -1054,8 +1054,8 @@ class BlackjackEnv(BaseEnv):
episode_metrics = {
"seed": seed,
"total_env_reward": 0.0,
"num_turns": 0,
"total_reward": 0.0,
"num_steps": 0,
"num_correct_actions": 0,
"num_invalid_actions": 0,
"actions_chosen": [],
@ -1063,7 +1063,7 @@ class BlackjackEnv(BaseEnv):
}
for turn in range(max_turns):
episode_metrics["num_turns"] = turn + 1
episode_metrics["num_steps"] = turn + 1
messages_for_prompt = ep.message_history.copy()
if self.config.thinking_active:
@ -1130,10 +1130,7 @@ class BlackjackEnv(BaseEnv):
reward = -1.0
obs = None
ep.actions.append(env_action)
ep.step_rewards.append(reward)
ep.total_env_reward += reward
episode_metrics["total_reward"] += reward
if term or trunc:
episode_metrics["game_outcome"] = int(reward)
@ -1207,10 +1204,10 @@ class BlackjackEnv(BaseEnv):
num_completed_episodes = len(valid_metrics)
avg_total_env_reward = (
sum(m["total_env_reward"] for m in valid_metrics) / num_completed_episodes
sum(m["total_reward"] for m in valid_metrics) / num_completed_episodes
)
avg_num_turns = (
sum(m["num_turns"] for m in valid_metrics) / num_completed_episodes
sum(m["num_steps"] for m in valid_metrics) / num_completed_episodes
)
total_correct_actions = sum(m["num_correct_actions"] for m in valid_metrics)
@ -1244,8 +1241,8 @@ class BlackjackEnv(BaseEnv):
total_parsed_actions_in_eval = len(all_chosen_actions)
self.eval_metrics = [
("eval/avg_total_env_reward", avg_total_env_reward),
("eval/avg_num_turns", avg_num_turns),
("eval/avg_total_reward", avg_total_env_reward),
("eval/avg_num_steps", avg_num_turns),
("eval/action_accuracy", action_accuracy),
("eval/invalid_action_rate", invalid_action_rate),
("eval/win_rate", win_rate),
@ -1295,7 +1292,7 @@ class BlackjackEnv(BaseEnv):
avg_ep_env_reward = (
sum(
m["total_env_reward"] for m in self.completed_episode_metrics_buffer
m["total_reward"] for m in self.completed_episode_metrics_buffer
)
/ num_episodes_in_buffer
)
@ -1314,8 +1311,7 @@ class BlackjackEnv(BaseEnv):
avg_ep_num_steps = (
sum(
m["num_steps_in_episode"]
for m in self.completed_episode_metrics_buffer
m["num_steps"] for m in self.completed_episode_metrics_buffer
)
/ num_episodes_in_buffer
)
@ -1347,7 +1343,7 @@ class BlackjackEnv(BaseEnv):
)
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_env_reward"
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward"
] = avg_ep_env_reward
wandb_metrics[
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_action_accuracy"

3206
uv.lock generated Normal file

File diff suppressed because it is too large Load diff