diff --git a/atroposlib/utils/tokenize_for_trainer.py b/atroposlib/utils/tokenize_for_trainer.py index 8d9ea3dc..c1187fe1 100644 --- a/atroposlib/utils/tokenize_for_trainer.py +++ b/atroposlib/utils/tokenize_for_trainer.py @@ -4,7 +4,7 @@ from transformers import PreTrainedTokenizer from atroposlib.type_definitions import Message # Roles that should be masked in the loss calculation (not used for training) -UNMASKED_ROLES = ["assistant"] +UNMASKED_ROLES = ["assistant", "agent"] def tokenize_for_trainer( diff --git a/atroposlib/utils/tool_call_parser.py b/atroposlib/utils/tool_call_parser.py new file mode 100644 index 00000000..070acdc1 --- /dev/null +++ b/atroposlib/utils/tool_call_parser.py @@ -0,0 +1,105 @@ +""" +Tool call parser helper for extracting and validating tool calls from LLM responses. +""" + +import json +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def extract_tool_call(text: str, preferred_tags: List[str] = None) -> Optional[str]: + """ + Extract the content within tool call tags. + + Args: + text: The text to extract tool call from + preferred_tags: The tag names to look for (default: ['tool_call']) + + Returns: + The extracted content or None if no tool call found + """ + preferred_tags = preferred_tags or ["tool_call"] + for tag in preferred_tags: + pattern = f"<{tag}>(.*?)" + matches = re.findall(pattern, text, re.DOTALL) + if matches: + return matches[0].strip() + return None + + +def parse_tool_call( + response: str, available_tools: List[Dict] = None, preferred_tags: List[str] = None +) -> Tuple[str, Dict[str, Any], bool]: + """ + Parse a tool call from an LLM response. + + Args: + response: The LLM response text to parse + available_tools: Optional list of available tools for validation + preferred_tags: The tag names to look for (default: ['tool_call']) + + Returns: + Tuple of (tool_name, arguments, is_error) + - tool_name: Name of the called tool or "-ERROR-" if invalid + - arguments: Dictionary of arguments provided to the tool + - is_error: Boolean indicating if there was an error parsing + """ + # Extract content from tags + tool_call_content = extract_tool_call(response, preferred_tags) + + if not tool_call_content: + logger.warning(f"No tool call found in response: {response}...") + return "-ERROR-", {}, True + + # Parse JSON + try: + # Handle potential single quotes + tool_call_content = tool_call_content.replace("'", '"') + + tool_call = json.loads(tool_call_content, strict=False) + + # Extract tool name and arguments + tool_name = tool_call.get("name", "") + arguments = tool_call.get("arguments", {}) + + # Validate tool existence if tools are provided + if available_tools: + valid_tool_names = set() + for tool in available_tools: + if isinstance(tool, dict): + if "name" in tool: + valid_tool_names.add(tool["name"]) + elif "function" in tool and "name" in tool["function"]: + valid_tool_names.add(tool["function"]["name"]) + + if not tool_name or tool_name not in valid_tool_names: + return "-ERROR-", arguments, True + + logger.warning(f"Parsed tool call: {tool_name}, {arguments}") + return tool_name, arguments, False + + except (json.JSONDecodeError, Exception) as json_error: + logger.error(f"Failed to parse tool call: {json_error}", exc_info=True) + return "-ERROR-", {}, True + + +def format_tool_call_for_hangman(tool_name: str, arguments: Dict[str, Any]) -> str: + """ + Format a parsed tool call into the format expected by Hangman. + + Args: + tool_name: The name of the tool to call + arguments: Dictionary of arguments + + Returns: + The formatted tool call string (e.g., "[letter]" or "[word]") + """ + if tool_name == "guess_letter" and "letter" in arguments: + return f"[{arguments['letter']}]" + elif tool_name == "guess_word" and "word" in arguments: + return f"[{arguments['word']}]" + else: + return "-ERROR-" diff --git a/environments/game_environments/gymnasium/blackjack_env.py b/environments/game_environments/gymnasium/blackjack_env.py index 1472a393..efe857d1 100644 --- a/environments/game_environments/gymnasium/blackjack_env.py +++ b/environments/game_environments/gymnasium/blackjack_env.py @@ -1,17 +1,13 @@ import json import logging -import os import random from typing import Any, Dict, List, Optional, Tuple, Union import gymnasium import numpy as np -import yaml from tqdm.asyncio import tqdm_asyncio -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup -from atroposlib.envs.reward_fns import registry -from atroposlib.envs.reward_fns.combined_reward import CombinedReward +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup, EvalHandlingEnum from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tool_call_parser import parse_tool_call @@ -31,8 +27,6 @@ class BlackjackEnvConfig(BaseEnvConfig): debug_mode: bool = False group_size: int = 16 mc_samples: int = 3 # lowish K for MC value estimation - reward_functions: List[Union[str, Dict[str, Any]]] = [] - environment_reward_weight: float = 0.5 class BlackjackScoredDataGroup(ScoredDataGroup): @@ -88,8 +82,6 @@ class BlackjackEnv(BaseEnv): } ] - self.reward_function = self._initialize_reward_function() - tools_json = json.dumps(self.tools) self.system_prompt = ( "You are an AI agent playing Blackjack who uses extreme long chains of thought " @@ -108,54 +100,6 @@ class BlackjackEnv(BaseEnv): "Remember to carefully consider the probabilities and optimal strategy for Blackjack." ) - def _initialize_reward_function(self): - """Initialize the reward function for scoring based on self.config.reward_functions.""" - if hasattr(self.config, "reward_functions") and self.config.reward_functions: - reward_configs = self.config.reward_functions - logger.info( - f"[_initialize_reward_function] Initializing with reward_functions " - f"from config: {reward_configs}" - ) - - if not reward_configs: - logger.warning( - "[_initialize_reward_function] reward_functions list is empty " - "after access. No reward function will be active." - ) - return None - - if len(reward_configs) == 1: - try: - logger.debug( - f"[_initialize_reward_function] Creating single reward function from: {reward_configs[0]}" - ) - return registry.create(reward_configs[0]) - except Exception as e: - logger.error( - f"[_initialize_reward_function] Failed to create single reward function from config " - f"{reward_configs[0]}: {e}", - exc_info=True, - ) - return None - elif len(reward_configs) > 1: - try: - logger.debug( - f"[_initialize_reward_function] Creating CombinedReward function from: {reward_configs}" - ) - return CombinedReward(rewards=reward_configs) - except Exception as e: - logger.error( - f"[_initialize_reward_function] Failed to create CombinedReward function from configs: {e}", - exc_info=True, - ) - return None - else: - logger.info( - "[_initialize_reward_function] No 'reward_functions' key in config or it's empty. " - "No specific reward function (like format/tool_call) will be active." - ) - return None - def _get_or_create_episode(self, seed: int) -> EpisodeState: if seed not in self.episodes: env = gymnasium.make(self.config.env_name) @@ -180,60 +124,32 @@ class BlackjackEnv(BaseEnv): episode_seed: int, ) -> float: """ - Calculates a combined score for a single agent response based on environment and format rewards. + Calculates a score for a single agent response based purely on environment reward + and a penalty for invalid action format. """ - format_or_tool_call_reward_component = 0.0 current_env_reward = env_reward if parsed_action == -1: - current_env_reward -= 0.5 + current_env_reward -= 0.5 # Penalty for invalid action format logger.debug( f"[_score_response Seed: {episode_seed}] Penalty applied to env_reward for " - f"invalid action format (-0.5). Current env_reward: {current_env_reward}" + f"invalid action format (-0.5). Current env_reward: {current_env_reward:.4f}" ) - if self.reward_function: - messages_for_reward_func: List[List[Dict[str, str]]] = [ - [{"role": "agent", "content": response_text}] - ] - try: - reward_func_output_list = self.reward_function(messages_for_reward_func) - if reward_func_output_list and len(reward_func_output_list) > 0: - format_or_tool_call_reward_component = reward_func_output_list[0] - logger.debug( - f"[_score_response Seed: {episode_seed}] Output from self.reward_function " - f"(e.g., format/tool_call): {format_or_tool_call_reward_component:.4f}" - ) - else: - logger.warning( - f"[_score_response Seed: {episode_seed}] self.reward_function returned " - f"empty or invalid result: {reward_func_output_list}" - ) - except Exception as e: - logger.error( - f"[_score_response Seed: {episode_seed}] Error calculating reward via " - f"self.reward_function: {e}", - exc_info=True, - ) - else: - logger.debug( - f"[_score_response Seed: {episode_seed}] No self.reward_function active, " - f"format_or_tool_call_reward_component is 0." - ) + # env_w = self.config.environment_reward_weight # Removed, env reward is 100% - env_w = self.config.environment_reward_weight - - combined_score = ( - env_w * current_env_reward - ) + format_or_tool_call_reward_component + # combined_score = ( # Simplified + # env_w * current_env_reward + # ) + format_or_tool_call_reward_component + final_score = current_env_reward logger.debug( f"[_score_response Seed: {episode_seed}] Score Calculation: " - f"EnvReward(raw): {env_reward:.4f}, EnvReward(adj): {current_env_reward:.4f} (w:{env_w:.2f}), " - f"OutputFromRewardFunctions (already weighted): {format_or_tool_call_reward_component:.4f}, " - f"==> CombinedScore: {combined_score:.4f}" + f"EnvReward(raw): {env_reward:.4f}, EnvReward(adj for invalid): {current_env_reward:.4f} " + # f"OutputFromRewardFunctions (already weighted): {format_or_tool_call_reward_component:.4f}, " # Removed + f"==> Final Score (from env): {final_score:.4f}" ) - return combined_score + return final_score def _parse_tool_call(self, response: str) -> int: if not response: @@ -923,79 +839,46 @@ class BlackjackEnv(BaseEnv): await super().wandb_log(wandb_metrics) @classmethod - def config_init( - cls, config_name_or_path: Optional[str] = None - ) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]: - current_dir = os.path.dirname(os.path.abspath(__file__)) - default_config_filename = "blackjack_default.yaml" - if config_name_or_path is None: - cfg_path = os.path.join(current_dir, "configs", default_config_filename) - elif os.path.isabs(config_name_or_path): - cfg_path = config_name_or_path - else: - config_filename = config_name_or_path + ( - ".yaml" if not config_name_or_path.endswith(".yaml") else "" + def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]: + env_config = BlackjackEnvConfig( + # Fields from fundamental_prediction_environment.py's BaseEnvConfig init: + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, # Matches BlackjackEnvConfig default as well + use_wandb=True, + max_num_workers=128, + rollout_server_url="http://localhost:8000", + total_steps=2000, + batch_size=1024, + steps_per_eval=20, + max_token_length=1024 * 16, + inference_weight=1.0, + wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + + # BlackjackEnvConfig specific fields (those NOT in BaseEnvConfig from fundamental_prediction) + # using their defined defaults from BlackjackEnvConfig: + env_name="Blackjack-v1", # Default from BlackjackEnvConfig + temperature=0.7, # Default from BlackjackEnvConfig + top_p=0.9, # Default from BlackjackEnvConfig + max_turns=5, # Default from BlackjackEnvConfig + thinking_active=True, # Default from BlackjackEnvConfig + eval_episodes=100, # Default from BlackjackEnvConfig + max_think_chars_history=3000, # Default from BlackjackEnvConfig + max_trajectory_tokens=24576, # Default from BlackjackEnvConfig + debug_mode=False, # Default from BlackjackEnvConfig + mc_samples=3, # Default from BlackjackEnvConfig + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=256, # From fundamental_prediction_environment.py ) - cfg_path = os.path.join(current_dir, "configs", config_filename) - - try: - if os.path.exists(cfg_path): - with open(cfg_path) as f: - raw_yaml_data = yaml.safe_load(f) or {} - logger.info(f"Loaded config from {cfg_path}") - else: - logger.warning( - f"Config not found at {cfg_path}. Using default settings." - ) - raw_yaml_data = {} - - env_conf_data = raw_yaml_data.copy() - server_configs_list = env_conf_data.pop("server_configs", []) - if "blackjack" in env_conf_data: - env_conf_data.update(env_conf_data.pop("blackjack")) - env_conf = BlackjackEnvConfig(**env_conf_data) - - server_confs = [] - for sc_data in server_configs_list: - if not isinstance(sc_data, dict): - continue - params = sc_data.copy() - params["api_key"] = params.get( - "api_key", os.getenv("OPENAI_API_KEY", "x") - ) - params["model_name"] = params.get( - "model_name", - os.getenv( - "OPENAI_MODEL", "NousResearch/DeepHermes-3-Llama-3-8B-Preview" - ), - ) - params["base_url"] = params.get( - "base_url", os.getenv("OPENAI_API_BASE", "http://localhost:9004/v1") - ) - server_confs.append(OpenaiConfig(**params)) - if not server_confs: - server_confs = [ - OpenaiConfig( - model_name=os.getenv( - "OPENAI_MODEL", - "NousResearch/DeepHermes-3-Llama-3-8B-Preview", - ), - base_url=os.getenv( - "OPENAI_API_BASE", "http://localhost:9004/v1" - ), - api_key=os.getenv("OPENAI_API_KEY", "x"), - ) - ] - return env_conf, server_confs - except Exception as e: - logger.error(f"Error loading config from {cfg_path}: {e}") - return BlackjackEnvConfig(), [ - OpenaiConfig( - model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - base_url="http://localhost:9004/v1", - api_key="x", - ) - ] + ] + return env_config, server_configs def _truncate_thinking_for_history(self, response_text: str, max_chars: int) -> str: """Helper to truncate the block of a response for message history.""" diff --git a/environments/game_environments/gymnasium/blackjack_no_mc_env.py b/environments/game_environments/gymnasium/blackjack_no_mc_env.py index decea942..e6dca5c1 100644 --- a/environments/game_environments/gymnasium/blackjack_no_mc_env.py +++ b/environments/game_environments/gymnasium/blackjack_no_mc_env.py @@ -12,17 +12,13 @@ but may not be as effective at learning correct strategy (it's effectively a ser import json import logging -import os import random from typing import Any, Dict, List, Optional, Tuple, Union import gymnasium -import yaml from tqdm.asyncio import tqdm_asyncio -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup -from atroposlib.envs.reward_fns import registry -from atroposlib.envs.reward_fns.combined_reward import CombinedReward +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup, EvalHandlingEnum from atroposlib.type_definitions import Message from atroposlib.utils.tokenize_for_trainer import UNMASKED_ROLES, tokenize_for_trainer from atroposlib.utils.tool_call_parser import parse_tool_call @@ -44,10 +40,6 @@ class BlackjackEnvConfig(BaseEnvConfig): thinking_active: bool = True eval_episodes: int = 100 - reward_functions: List[Union[str, Dict[str, Any]]] = [] - format_reward_weight: float = 0.2 - environment_reward_weight: float = 0.8 - batch_size: int = 1024 max_think_chars_history: int = 3000 # Should be higher than the max tokens to allow for multiple turns @@ -81,8 +73,6 @@ class EpisodeState: self.step_rewards: List[float] = [] self.trajectory: List[BlackjackScoredDataGroup] = [] self.total_env_reward: float = 0.0 - self.total_format_reward: float = 0.0 - self.total_combined_reward: float = 0.0 self.num_correct_actions: int = 0 self.num_total_actions: int = 0 @@ -123,8 +113,6 @@ class BlackjackEnv(BaseEnv): } ] - self.reward_function = self._initialize_reward_function() - tools_json = json.dumps(self.tools) self.system_prompt = ( "You are an AI agent playing Blackjack who uses extreme long chains of thought " @@ -145,43 +133,6 @@ class BlackjackEnv(BaseEnv): "Remember to carefully consider the probabilities and optimal strategy for Blackjack." ) - def _initialize_reward_function(self): - """Initialize the combined reward function for scoring.""" - if hasattr(self.config, "reward_functions") and self.config.reward_functions: - reward_configs = [] - for reward_func in self.config.reward_functions: - if isinstance(reward_func, str): - if reward_func == "format": - format_config = { - "type": "format", - "weight": self.config.format_reward_weight, - "params": { - "preferred_tags": ["think", "tool_call"], - }, - } - reward_configs.append(format_config) - elif reward_func == "tool_calling": - tool_calling_config = { - "type": "tool_calling", - "weight": self.config.format_reward_weight, - "params": { - "tools": self.tools, - "preferred_tags": ["tool_call"], - "check_arguments": True, - }, - } - reward_configs.append(tool_calling_config) - else: - reward_configs.append(reward_func) - else: - reward_configs.append(reward_func) - - if len(reward_configs) == 1: - return registry.create(reward_configs[0]) - elif len(reward_configs) > 1: - return CombinedReward(rewards=reward_configs, normalization="none") - return None - def _get_or_create_episode(self, seed: int) -> EpisodeState: """Retrieve existing or create a new episode state keyed by seed.""" if seed not in self.episodes: @@ -232,22 +183,11 @@ class BlackjackEnv(BaseEnv): response_text: str, parsed_action: int, episode_seed: int, - update_episode_totals: bool = False, ) -> float: """ - Calculates a combined score for a single agent response based on environment and format rewards. - - Args: - env_reward: The raw reward obtained from simulating this action in the environment. - response_text: The full text response generated by the agent (including ). - parsed_action: The action parsed from the response (0=stick, 1=hit, -1=error). - episode_seed: The seed of the current episode. - update_episode_totals: Flag (currently unused here) to indicate if episode totals should be updated. - - Returns: - The combined score. + Calculates a score for a single agent response based purely on environment reward + and a penalty for invalid action format. """ - format_reward = 0.0 current_env_reward = env_reward if parsed_action == -1: @@ -256,34 +196,15 @@ class BlackjackEnv(BaseEnv): f"[_score_response Seed: {episode_seed}] Penalty applied for invalid action format (-0.5)." ) - if self.reward_function: - format_completions = [[{"role": "agent", "content": response_text}]] - try: - format_rewards = self.reward_function(format_completions) - if format_rewards and len(format_rewards) > 0: - format_reward = format_rewards[0] - logger.debug( - f"[_score_response Seed: {episode_seed}] Format reward calculated: {format_reward:.4f}" - ) - except Exception as e: - logger.error( - f"[_score_response Seed: {episode_seed}] Error calculating format reward: {e}" - ) - - env_weight = self.config.environment_reward_weight - format_weight = self.config.format_reward_weight - combined_reward = (env_weight * current_env_reward) + ( - format_weight * format_reward - ) + final_score = current_env_reward logger.debug( f"[_score_response Seed: {episode_seed}] Final Score Calculation: " f"Env Reward (raw): {env_reward:.4f}, " - f"Env Reward (adjusted): {current_env_reward:.4f}, " - f"Format Reward: {format_reward:.4f}, " - f"Combined Reward: {combined_reward:.4f}" + f"Env Reward (adjusted for invalid): {current_env_reward:.4f}, " + f"==> Final Score (from env): {final_score:.4f}" ) - return combined_reward + return final_score async def _select_best_action( self, episode: EpisodeState, actions: List[int], responses: List[str] @@ -299,7 +220,7 @@ class BlackjackEnv(BaseEnv): Returns: A tuple containing: - The best action selected (0, 1, or -1). - - A list of combined scores for each action/response. + - A list of scores for each action/response. """ if len(actions) != len(responses): logger.error( @@ -348,7 +269,6 @@ class BlackjackEnv(BaseEnv): response_text=response_text, parsed_action=action, episode_seed=episode.seed, - update_episode_totals=False, ) scores[idx] = combined_score token_lengths[idx] = len(self.tokenizer.encode(response_text)) @@ -640,25 +560,7 @@ class BlackjackEnv(BaseEnv): ep.actions.append(env_action) ep.step_rewards.append(reward) - format_reward_chosen = 0.0 - if self.reward_function: - format_completions = [[{"role": "agent", "content": best_response}]] - try: - format_rewards = self.reward_function(format_completions) - if format_rewards and len(format_rewards) > 0: - format_reward_chosen = format_rewards[0] - except Exception as e: - logger.error( - f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] " - f"Error re-calculating format reward for chosen action: {e}" - ) - ep.total_env_reward += reward - ep.total_format_reward += format_reward_chosen - combined_reward_step = (self.config.environment_reward_weight * reward) + ( - self.config.format_reward_weight * format_reward_chosen - ) - ep.total_combined_reward += combined_reward_step ep.num_total_actions += 1 if best_action != -1: @@ -666,10 +568,8 @@ class BlackjackEnv(BaseEnv): logger.info( f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] " - f"Step Rewards: Env={reward:.2f}, Format={format_reward_chosen:.2f}, " - f"Combined={combined_reward_step:.2f}. " - f"Running Totals: Env={ep.total_env_reward:.2f}, " - f"Format={ep.total_format_reward:.2f}, Combined={ep.total_combined_reward:.2f}" + f"Step Rewards: Env={reward:.2f}. " + f"Running Totals: Env={ep.total_env_reward:.2f}." ) ep.trajectory.append( @@ -726,9 +626,7 @@ class BlackjackEnv(BaseEnv): ) logger.info( f"[Collect Trajectory Seed: {seed}] " - f"Final Totals: Env Reward={ep.total_env_reward:.2f}, " - f"Format Reward={ep.total_format_reward:.2f}, " - f"Combined Reward={ep.total_combined_reward:.2f}" + f"Final Totals: Env Reward={ep.total_env_reward:.2f}." ) logger.info( f"[Collect Trajectory Seed: {seed}] " @@ -748,8 +646,6 @@ class BlackjackEnv(BaseEnv): episode_summary_metrics = { "seed": seed, "total_env_reward": ep.total_env_reward, - "total_format_reward": ep.total_format_reward, - "total_combined_reward": ep.total_combined_reward, "num_correct_actions": ep.num_correct_actions, "num_total_actions": ep.num_total_actions, "game_outcome": game_outcome, @@ -1159,8 +1055,6 @@ class BlackjackEnv(BaseEnv): episode_metrics = { "seed": seed, "total_env_reward": 0.0, - "total_format_reward": 0.0, - "total_combined_reward": 0.0, "num_turns": 0, "num_correct_actions": 0, "num_invalid_actions": 0, @@ -1236,24 +1130,10 @@ class BlackjackEnv(BaseEnv): reward = -1.0 obs = None - format_reward_step = 0.0 - if self.reward_function: - format_completions = [[{"role": "agent", "content": full_response}]] - try: - format_rewards = self.reward_function(format_completions) - if format_rewards and len(format_rewards) > 0: - format_reward_step = format_rewards[0] - except Exception as e: - logger.error( - f"[Eval Rollout Seed: {seed} Turn: {turn+1}] Error calculating format reward: {e}" - ) + ep.actions.append(env_action) + ep.step_rewards.append(reward) - episode_metrics["total_env_reward"] += reward - episode_metrics["total_format_reward"] += format_reward_step - combined_reward_step = (self.config.environment_reward_weight * reward) + ( - self.config.format_reward_weight * format_reward_step - ) - episode_metrics["total_combined_reward"] += combined_reward_step + ep.total_env_reward += reward if term or trunc: episode_metrics["game_outcome"] = int(reward) @@ -1329,14 +1209,6 @@ class BlackjackEnv(BaseEnv): avg_total_env_reward = ( sum(m["total_env_reward"] for m in valid_metrics) / num_completed_episodes ) - avg_total_format_reward = ( - sum(m["total_format_reward"] for m in valid_metrics) - / num_completed_episodes - ) - avg_total_combined_reward = ( - sum(m["total_combined_reward"] for m in valid_metrics) - / num_completed_episodes - ) avg_num_turns = ( sum(m["num_turns"] for m in valid_metrics) / num_completed_episodes ) @@ -1373,8 +1245,6 @@ class BlackjackEnv(BaseEnv): self.eval_metrics = [ ("eval/avg_total_env_reward", avg_total_env_reward), - ("eval/avg_total_format_reward", avg_total_format_reward), - ("eval/avg_total_combined_reward", avg_total_combined_reward), ("eval/avg_num_turns", avg_num_turns), ("eval/action_accuracy", action_accuracy), ("eval/invalid_action_rate", invalid_action_rate), @@ -1429,20 +1299,6 @@ class BlackjackEnv(BaseEnv): ) / num_episodes_in_buffer ) - avg_ep_format_reward = ( - sum( - m["total_format_reward"] - for m in self.completed_episode_metrics_buffer - ) - / num_episodes_in_buffer - ) - avg_ep_combined_reward = ( - sum( - m["total_combined_reward"] - for m in self.completed_episode_metrics_buffer - ) - / num_episodes_in_buffer - ) total_ep_correct_actions = sum( m["num_correct_actions"] for m in self.completed_episode_metrics_buffer @@ -1493,12 +1349,6 @@ class BlackjackEnv(BaseEnv): wandb_metrics[ f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_env_reward" ] = avg_ep_env_reward - wandb_metrics[ - f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_format_reward" - ] = avg_ep_format_reward - wandb_metrics[ - f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_combined_reward" - ] = avg_ep_combined_reward wandb_metrics[ f"{self.wandb_prepend or 'blackjack'}_train/avg_episode_action_accuracy" ] = avg_ep_action_accuracy @@ -1525,128 +1375,45 @@ class BlackjackEnv(BaseEnv): await super().wandb_log(wandb_metrics) @classmethod - def config_init( - cls, config_name_or_path: Optional[str] = None - ) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]: - """Load settings from the local configs directory or an absolute path.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - default_config_filename = "blackjack_default.yaml" + def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]: + env_config = BlackjackEnvConfig( + # Fields from fundamental_prediction_environment.py's BaseEnvConfig init: + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, # From Base, as not in BJ no_mc config's direct definition + use_wandb=True, + max_num_workers=128, + rollout_server_url="http://localhost:8000", + total_steps=2000, + batch_size=1024, # Matches BlackjackEnvConfig (no_mc) default as well + steps_per_eval=20, + max_token_length=1024 * 16, + inference_weight=1.0, + wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, - if config_name_or_path is None: - cfg_path = os.path.join(current_dir, "configs", default_config_filename) - logger.info(f"No config specified, using default: {cfg_path}") - elif os.path.isabs(config_name_or_path): - cfg_path = config_name_or_path - logger.info(f"Absolute config path provided: {cfg_path}") - if not os.path.splitext(cfg_path)[1]: - logger.warning( - f"Absolute config path {cfg_path} seems to be missing a file extension." - ) - else: - config_filename = config_name_or_path - if not config_name_or_path.endswith(".yaml"): - config_filename += ".yaml" - cfg_path = os.path.join(current_dir, "configs", config_filename) - logger.info( - f"Relative config name '{config_name_or_path}' provided, resolving to: {cfg_path}" + # BlackjackEnvConfig (no_mc version) specific fields (those NOT in BaseEnvConfig from fundamental_prediction) + # using their defined defaults from BlackjackEnvConfig (no_mc): + env_name="Blackjack-v1", # Default from BlackjackEnvConfig (no_mc) + temperature=0.7, # Default from BlackjackEnvConfig (no_mc) + top_p=0.9, # Default from BlackjackEnvConfig (no_mc) + max_turns=5, # Default from BlackjackEnvConfig (no_mc) + thinking_active=True, # Default from BlackjackEnvConfig (no_mc) + eval_episodes=100, # Default from BlackjackEnvConfig (no_mc) + max_think_chars_history=3000, # Default from BlackjackEnvConfig (no_mc) + max_trajectory_tokens=24576,# Default from BlackjackEnvConfig (no_mc) + debug_mode=False, # Default from BlackjackEnvConfig (no_mc) + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=256, # From fundamental_prediction_environment.py ) - - logger.debug(f"Final config path to check for existence: {cfg_path}") - - raw_yaml_data = {} - try: - if os.path.exists(cfg_path): - with open(cfg_path) as f: - raw_yaml_data = yaml.safe_load(f) or {} - logger.info(f"Loaded config from {cfg_path}") - else: - logger.warning( - f"Config file not found at {cfg_path}, " - "using default BlackjackEnvConfig settings and default server config." - ) - - env_conf_data = raw_yaml_data.copy() - server_configs_list_from_yaml = env_conf_data.pop("server_configs", []) - - if "blackjack" in env_conf_data: - blackjack_overrides = env_conf_data.pop("blackjack") - if isinstance(blackjack_overrides, dict): - env_conf_data.update(blackjack_overrides) - else: - logger.warning( - f"'blackjack' section in config YAML is not a dictionary " - f"(type: {type(blackjack_overrides)}), ignoring." - ) - - env_conf = BlackjackEnvConfig(**env_conf_data) - logger.debug(f"Initialized BlackjackEnvConfig: {env_conf}") - - server_confs = [] - if isinstance(server_configs_list_from_yaml, list): - for sc_data in server_configs_list_from_yaml: - if not isinstance(sc_data, dict): - logger.warning( - f"Skipping non-dictionary item in server_configs: {sc_data}" - ) - continue - - current_params = sc_data.copy() - - resolved_api_key = sc_data.get("api_key") - if resolved_api_key is None or resolved_api_key == "": - resolved_api_key = os.getenv("OPENAI_API_KEY") - if resolved_api_key is None or resolved_api_key == "": - resolved_api_key = "x" - current_params["api_key"] = resolved_api_key - - if "model_name" not in current_params: - current_params["model_name"] = os.getenv( - "OPENAI_MODEL", - "NousResearch/DeepHermes-3-Llama-3-8B-Preview", - ) - if "base_url" not in current_params: - current_params["base_url"] = os.getenv( - "OPENAI_API_BASE", "http://localhost:9004/v1" - ) - - server_confs.append(OpenaiConfig(**current_params)) - elif "server_configs" not in raw_yaml_data: - logger.warning( - "No 'server_configs' key found in YAML, creating default Blackjack server config." - ) - server_confs = [ - OpenaiConfig( - model_name=os.getenv( - "OPENAI_MODEL", - "NousResearch/DeepHermes-3-Llama-3-8B-Preview", - ), - base_url=os.getenv( - "OPENAI_API_BASE", "http://localhost:9004/v1" - ), - api_key=os.getenv("OPENAI_API_KEY", "x"), - ) - ] - - return env_conf, server_confs - - except Exception as e: - cfg_path_for_log = cfg_path if "cfg_path" in locals() else "unknown path" - logger.exception( - f"Error loading or parsing config from {cfg_path_for_log}: {e}" - ) - logger.warning( - "Falling back to default Blackjack configurations due to error." - ) - return BlackjackEnvConfig(), [ - OpenaiConfig( - model_name=os.getenv( - "OPENAI_MODEL", "NousResearch/DeepHermes-3-Llama-3-8B-Preview" - ), - base_url=os.getenv("OPENAI_API_BASE", "http://localhost:9004/v1"), - api_key=os.getenv("OPENAI_API_KEY", "x"), - num_requests_for_eval=1, - ) - ] + ] + return env_config, server_configs @classmethod def cli(cls):