mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Removed old code, added comments
This commit is contained in:
parent
ba604d44f9
commit
0248cc1227
4 changed files with 3268 additions and 99 deletions
|
|
@ -1,3 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
BlackjackEnv: Trainer environment for Gymnasium Blackjack
|
||||
|
||||
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
|
||||
using function-call style actions. Extends BaseEnv.
|
||||
|
||||
Uses Monte Carlo sampling to estimate the value of the current state, similar to VinePPO
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
|
@ -26,7 +36,7 @@ class BlackjackEnvConfig(BaseEnvConfig):
|
|||
max_trajectory_tokens: int = 24576
|
||||
debug_mode: bool = False
|
||||
group_size: int = 16
|
||||
mc_samples: int = 3 # lowish K for MC value estimation
|
||||
mc_samples: int = 3
|
||||
|
||||
|
||||
class BlackjackScoredDataGroup(ScoredDataGroup):
|
||||
|
|
@ -130,23 +140,17 @@ class BlackjackEnv(BaseEnv):
|
|||
current_env_reward = env_reward
|
||||
|
||||
if parsed_action == -1:
|
||||
current_env_reward -= 0.5 # Penalty for invalid action format
|
||||
current_env_reward -= 0.5
|
||||
logger.debug(
|
||||
f"[_score_response Seed: {episode_seed}] Penalty applied to env_reward for "
|
||||
f"invalid action format (-0.5). Current env_reward: {current_env_reward:.4f}"
|
||||
)
|
||||
|
||||
# env_w = self.config.environment_reward_weight # Removed, env reward is 100%
|
||||
|
||||
# combined_score = ( # Simplified
|
||||
# env_w * current_env_reward
|
||||
# ) + format_or_tool_call_reward_component
|
||||
final_score = current_env_reward
|
||||
|
||||
logger.debug(
|
||||
f"[_score_response Seed: {episode_seed}] Score Calculation: "
|
||||
f"EnvReward(raw): {env_reward:.4f}, EnvReward(adj for invalid): {current_env_reward:.4f} "
|
||||
# f"OutputFromRewardFunctions (already weighted): {format_or_tool_call_reward_component:.4f}, " # Removed
|
||||
f"==> Final Score (from env): {final_score:.4f}"
|
||||
)
|
||||
return final_score
|
||||
|
|
@ -235,7 +239,6 @@ class BlackjackEnv(BaseEnv):
|
|||
f"State s was already terminal. Value is 0."
|
||||
)
|
||||
all_rollout_returns.append(0.0)
|
||||
# This means V(s_terminal) = 0, which is correct.
|
||||
break
|
||||
else:
|
||||
rollout_reward_for_this_sample = 0.0
|
||||
|
|
@ -484,10 +487,9 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
alt_value_next.append(0.0)
|
||||
else:
|
||||
alt_value_next.append(0.0) # V(terminal) = 0
|
||||
alt_value_next.append(0.0)
|
||||
|
||||
for i in range(G):
|
||||
# Advantage = R_combined + gamma * V_raw(s') - V_raw(s) (gamma=1)
|
||||
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
|
||||
alt_advantages.append(advantage_i)
|
||||
logger.debug(
|
||||
|
|
@ -841,9 +843,8 @@ class BlackjackEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
|
||||
env_config = BlackjackEnvConfig(
|
||||
# Fields from fundamental_prediction_environment.py's BaseEnvConfig init:
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16, # Matches BlackjackEnvConfig default as well
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
max_num_workers=128,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
|
|
@ -852,30 +853,28 @@ class BlackjackEnv(BaseEnv):
|
|||
steps_per_eval=20,
|
||||
max_token_length=1024 * 16,
|
||||
inference_weight=1.0,
|
||||
wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction
|
||||
wandb_name="fundamental_metric_prediction",
|
||||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
|
||||
# BlackjackEnvConfig specific fields (those NOT in BaseEnvConfig from fundamental_prediction)
|
||||
# using their defined defaults from BlackjackEnvConfig:
|
||||
env_name="Blackjack-v1", # Default from BlackjackEnvConfig
|
||||
temperature=0.7, # Default from BlackjackEnvConfig
|
||||
top_p=0.9, # Default from BlackjackEnvConfig
|
||||
max_turns=5, # Default from BlackjackEnvConfig
|
||||
thinking_active=True, # Default from BlackjackEnvConfig
|
||||
eval_episodes=100, # Default from BlackjackEnvConfig
|
||||
max_think_chars_history=3000, # Default from BlackjackEnvConfig
|
||||
max_trajectory_tokens=24576, # Default from BlackjackEnvConfig
|
||||
debug_mode=False, # Default from BlackjackEnvConfig
|
||||
mc_samples=3, # Default from BlackjackEnvConfig
|
||||
env_name="Blackjack-v1",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_turns=5,
|
||||
thinking_active=True,
|
||||
eval_episodes=100,
|
||||
max_think_chars_history=3000,
|
||||
max_trajectory_tokens=24576,
|
||||
debug_mode=False,
|
||||
mc_samples=3,
|
||||
)
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256, # From fundamental_prediction_environment.py
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
|
@ -973,7 +972,7 @@ class BlackjackEnv(BaseEnv):
|
|||
and original_step_data.get("tokens")
|
||||
and original_step_data.get("masks")
|
||||
and original_step_data.get("seed") is not None
|
||||
and original_step_data.get("parsed_actions") is not None # Specific to MC version
|
||||
and original_step_data.get("parsed_actions") is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env "
|
||||
|
|
@ -981,7 +980,6 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
continue
|
||||
|
||||
# Initial token calculation from original data
|
||||
max_initial_tokens = 0
|
||||
if original_step_data["tokens"]:
|
||||
max_initial_tokens = max(
|
||||
|
|
@ -1017,7 +1015,7 @@ class BlackjackEnv(BaseEnv):
|
|||
for alt_idx in range(num_alternatives):
|
||||
alt_msg_list = working_messages[alt_idx]
|
||||
num_preserved_at_end = 0
|
||||
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"] + UNMASKED_ROLES:
|
||||
if len(alt_msg_list) > 1 and alt_msg_list[-1]["role"] in ["agent", "assistant"]:
|
||||
num_preserved_at_end = 1
|
||||
if len(alt_msg_list) > 2 and alt_msg_list[-2]["role"] == "environment":
|
||||
num_preserved_at_end = 2
|
||||
|
|
@ -1031,7 +1029,7 @@ class BlackjackEnv(BaseEnv):
|
|||
available_to_pop >= 2 and
|
||||
len(alt_msg_list) > 2 and
|
||||
alt_msg_list[1]["role"] == "environment" and
|
||||
alt_msg_list[2]["role"] in ["agent", "assistant"] + UNMASKED_ROLES
|
||||
alt_msg_list[2]["role"] in ["agent", "assistant"]
|
||||
)
|
||||
if can_pop_pair:
|
||||
target_pop_counts_per_alt.append(2)
|
||||
|
|
@ -1088,7 +1086,7 @@ class BlackjackEnv(BaseEnv):
|
|||
"tokens": working_tokens,
|
||||
"masks": working_masks,
|
||||
"scores": original_step_data.get("scores"),
|
||||
"parsed_actions": original_step_data.get("parsed_actions") # MC version specific
|
||||
"parsed_actions": original_step_data.get("parsed_actions")
|
||||
}
|
||||
filtered_trajectory.append(updated_step_data)
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -14,88 +13,63 @@ logging.basicConfig(level=logging.INFO)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# def parse_arguments(): # Removed
|
||||
# parser = argparse.ArgumentParser(description="Blackjack environment local server")
|
||||
# parser.add_argument(
|
||||
# "--config",
|
||||
# type=str,
|
||||
# default="blackjack_local",
|
||||
# help="Configuration file name (without .yaml extension, relative to "
|
||||
# "envs/gymnasium/configs), or full path to a YAML file.",
|
||||
# )
|
||||
# return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting Blackjack environment local debug runner")
|
||||
|
||||
# args = parse_arguments() # Removed
|
||||
|
||||
# Removed logic for config_name_or_path and BlackjackEnv.config_init
|
||||
|
||||
# Create hardcoded configurations for local debugging
|
||||
env_config = BlackjackEnvConfig(
|
||||
# BaseEnvConfig fields, tailored for debug
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=1, # Debug single generation path
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
wandb_name="blackjack_local_debug", # Explicitly set for debug
|
||||
wandb_name="blackjack_local_debug",
|
||||
max_num_workers=1,
|
||||
rollout_server_url="http://localhost:8000", # Standard default
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1, # Consistent with 1 step, 1 worker, group_size 1
|
||||
steps_per_eval=0, # No eval steps needed
|
||||
max_token_length=1024 * 4, # Reduced for faster local debugging if necessary
|
||||
batch_size=1,
|
||||
steps_per_eval=0,
|
||||
max_token_length=1024 * 4,
|
||||
inference_weight=1.0,
|
||||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.NONE, # No evaluation in this script
|
||||
eval_handling=EvalHandlingEnum.NONE,
|
||||
eval_limit_ratio=0.0,
|
||||
|
||||
# BlackjackEnvConfig specific fields (from blackjack_env.py's definition or defaults)
|
||||
env_name="Blackjack-v1",
|
||||
temperature=0.2, # Lower temperature for more deterministic debug output
|
||||
top_p=0.9, # Standard default
|
||||
max_turns=5, # Standard default
|
||||
temperature=0.2,
|
||||
top_p=0.9,
|
||||
max_turns=5,
|
||||
thinking_active=True,
|
||||
eval_episodes=0, # No evaluation episodes
|
||||
eval_episodes=0,
|
||||
max_think_chars_history=3000,
|
||||
max_trajectory_tokens=24576,
|
||||
debug_mode=True, # Enable debug logging from the environment
|
||||
mc_samples=1, # With group_size=1, this means 1 MC rollout for V(s)
|
||||
debug_mode=True,
|
||||
mc_samples=1,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4.1-mini", # Ensure this is locally available if not mocked
|
||||
base_url="https://api.openai.com/v1", # Explicitly set OpenAI base URL
|
||||
api_key=os.getenv("OPENAI_API_KEY"), # Use env var or default
|
||||
num_requests_for_eval=0, # No eval requests
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=0,
|
||||
)
|
||||
]
|
||||
logger.info("Using hardcoded debug configuration.")
|
||||
logger.debug(f"Env Config: {env_config}")
|
||||
logger.debug(f"Server Configs: {server_configs}")
|
||||
|
||||
# Create and set up the environment using the loaded configs
|
||||
try:
|
||||
env = BlackjackEnv(
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
slurm=False, # Explicitly false for local testing
|
||||
slurm=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to initialize BlackjackEnv: {e}")
|
||||
return
|
||||
|
||||
# Run a single trajectory directly
|
||||
logger.info("Running a single trajectory directly")
|
||||
try:
|
||||
await env.setup() # Setup the server connection etc.
|
||||
await env.setup()
|
||||
seed = random.randint(0, 1000000)
|
||||
logger.info(f"Using seed: {seed}")
|
||||
|
||||
# Make sure the episode exists before collecting
|
||||
# This also initializes the message history correctly
|
||||
_ = env._get_or_create_episode(seed)
|
||||
|
||||
result_trajectory = await env.collect_trajectory(seed)
|
||||
|
|
@ -105,13 +79,9 @@ async def main():
|
|||
|
||||
episode_summary = None
|
||||
if env.completed_episode_metrics_buffer:
|
||||
# Assume the last entry in the buffer corresponds to the trajectory just run
|
||||
episode_summary = env.completed_episode_metrics_buffer[-1]
|
||||
# Optionally, clear the buffer if this script is only for single runs
|
||||
# env.completed_episode_metrics_buffer.clear()
|
||||
|
||||
if episode_summary and episode_summary.get("seed") == seed:
|
||||
# Print a final summary
|
||||
logger.info("\n========== Episode Summary ==========")
|
||||
logger.info(f"Seed: {episode_summary['seed']}")
|
||||
logger.info(f"Total steps taken: {episode_summary['num_steps']}")
|
||||
|
|
@ -129,7 +99,6 @@ async def main():
|
|||
f"Game Outcome: {outcome_str} (Reward: {episode_summary['total_reward']:.0f})"
|
||||
)
|
||||
|
||||
# Calculate and log action accuracy based on EpisodeState fields
|
||||
if episode_summary["num_total_actions"] > 0:
|
||||
accuracy = episode_summary["num_correct_actions"] / max(
|
||||
1, episode_summary["num_total_actions"]
|
||||
|
|
|
|||
|
|
@ -645,11 +645,11 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
episode_summary_metrics = {
|
||||
"seed": seed,
|
||||
"total_env_reward": ep.total_env_reward,
|
||||
"total_reward": ep.total_env_reward,
|
||||
"num_correct_actions": ep.num_correct_actions,
|
||||
"num_total_actions": ep.num_total_actions,
|
||||
"game_outcome": game_outcome,
|
||||
"num_steps_in_episode": len(ep.actions),
|
||||
"num_steps": len(ep.actions),
|
||||
}
|
||||
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
|
||||
|
||||
|
|
@ -1054,8 +1054,8 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
episode_metrics = {
|
||||
"seed": seed,
|
||||
"total_env_reward": 0.0,
|
||||
"num_turns": 0,
|
||||
"total_reward": 0.0,
|
||||
"num_steps": 0,
|
||||
"num_correct_actions": 0,
|
||||
"num_invalid_actions": 0,
|
||||
"actions_chosen": [],
|
||||
|
|
@ -1063,7 +1063,7 @@ class BlackjackEnv(BaseEnv):
|
|||
}
|
||||
|
||||
for turn in range(max_turns):
|
||||
episode_metrics["num_turns"] = turn + 1
|
||||
episode_metrics["num_steps"] = turn + 1
|
||||
messages_for_prompt = ep.message_history.copy()
|
||||
|
||||
if self.config.thinking_active:
|
||||
|
|
@ -1130,10 +1130,7 @@ class BlackjackEnv(BaseEnv):
|
|||
reward = -1.0
|
||||
obs = None
|
||||
|
||||
ep.actions.append(env_action)
|
||||
ep.step_rewards.append(reward)
|
||||
|
||||
ep.total_env_reward += reward
|
||||
episode_metrics["total_reward"] += reward
|
||||
|
||||
if term or trunc:
|
||||
episode_metrics["game_outcome"] = int(reward)
|
||||
|
|
@ -1207,10 +1204,10 @@ class BlackjackEnv(BaseEnv):
|
|||
num_completed_episodes = len(valid_metrics)
|
||||
|
||||
avg_total_env_reward = (
|
||||
sum(m["total_env_reward"] for m in valid_metrics) / num_completed_episodes
|
||||
sum(m["total_reward"] for m in valid_metrics) / num_completed_episodes
|
||||
)
|
||||
avg_num_turns = (
|
||||
sum(m["num_turns"] for m in valid_metrics) / num_completed_episodes
|
||||
sum(m["num_steps"] for m in valid_metrics) / num_completed_episodes
|
||||
)
|
||||
|
||||
total_correct_actions = sum(m["num_correct_actions"] for m in valid_metrics)
|
||||
|
|
@ -1244,8 +1241,8 @@ class BlackjackEnv(BaseEnv):
|
|||
total_parsed_actions_in_eval = len(all_chosen_actions)
|
||||
|
||||
self.eval_metrics = [
|
||||
("eval/avg_total_env_reward", avg_total_env_reward),
|
||||
("eval/avg_num_turns", avg_num_turns),
|
||||
("eval/avg_total_reward", avg_total_env_reward),
|
||||
("eval/avg_num_steps", avg_num_turns),
|
||||
("eval/action_accuracy", action_accuracy),
|
||||
("eval/invalid_action_rate", invalid_action_rate),
|
||||
("eval/win_rate", win_rate),
|
||||
|
|
@ -1295,7 +1292,7 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
avg_ep_env_reward = (
|
||||
sum(
|
||||
m["total_env_reward"] for m in self.completed_episode_metrics_buffer
|
||||
m["total_reward"] for m in self.completed_episode_metrics_buffer
|
||||
)
|
||||
/ num_episodes_in_buffer
|
||||
)
|
||||
|
|
@ -1314,8 +1311,7 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
avg_ep_num_steps = (
|
||||
sum(
|
||||
m["num_steps_in_episode"]
|
||||
for m in self.completed_episode_metrics_buffer
|
||||
m["num_steps"] for m in self.completed_episode_metrics_buffer
|
||||
)
|
||||
/ num_episodes_in_buffer
|
||||
)
|
||||
|
|
@ -1347,7 +1343,7 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
|
||||
wandb_metrics[
|
||||
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_env_reward"
|
||||
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_reward"
|
||||
] = avg_ep_env_reward
|
||||
wandb_metrics[
|
||||
f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_action_accuracy"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue