# In your meteorology_forecast_env.py file: import asyncio import time import json import os import random import re from pathlib import Path from typing import Dict, List, Optional, Tuple, Union, Any import logging import traceback # Import traceback for more detailed error logging import wandb from pydantic import Field import httpx # Assuming APIServer and ServerManager are imported correctly from atroposlib # For this standalone example, let's define dummy classes if not available try: from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, APIServer ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer except ImportError: # Dummy classes (keep these if you were using them for standalone testing) class BaseEnvConfig: pass class APIServerConfig: def __init__(self, model_name, base_url, api_key, timeout=1200, num_max_requests_at_once=512, num_requests_for_eval=64, rolling_buffer_length=1000, server_type='openai', n_kwarg_is_ignored=False, health_check=True): self.model_name = model_name self.base_url = base_url self.api_key = api_key self.timeout = timeout # etc. class BaseEnv: def __init__(self, config, server_configs, slurm, testing): self.config = config self.server = type('ServerManager', (), {'servers': [APIServer(sc) for sc in server_configs]})() self.tokenizer = type('Tokenizer', (), {'apply_chat_template': lambda *args, **kwargs: ""})() self.testing = testing def save_checkpoint(self, step, data): pass async def wandb_log(self, metrics): pass @classmethod def cli(cls): print("Dummy CLI called") # Add dummy cli class Item: pass class ScoredDataGroup(dict): pass class EvalHandlingEnum: LIMIT_TRAIN = "limit_train"; STOP_TRAIN = "stop_train" # Added STOP_TRAIN class APIServer: def __init__(self, config: APIServerConfig): self.config = config # Ensure config is taken async def chat_completion(self, **kwargs): print(f"Dummy APIServer chat_completion called with model: {kwargs.get('model', self.config.model_name)}") # Simulate a response structure class DummyMessage: def __init__(self, content): self.content = content class DummyChoice: def __init__(self, content): self.message = DummyMessage(content) class DummyCompletionResponse: def __init__(self, choices_content_list): self.choices = [DummyChoice(c) for c in choices_content_list] if self.config.model_name.startswith("google"): # Simulate judge return DummyCompletionResponse(["REASONING_SCORE: 3\nTOOL_CALL_SCORE: 1\nFORECAST_SUMMARY_SCORE: 1\nTOTAL_SCORE: 5\nJUSTIFICATION: Dummy judge output"]) else: # Simulate agent return DummyCompletionResponse(["Dummy agent thinking\nFORECAST_SUMMARY: Dummy forecast"]) async def completion(self, **kwargs): return type('CompletionResponse', (), {'choices': []})() def tokenize_for_trainer(tokenizer, messages, max_length): return {"tokens": [1,2,3], "masks": [1,1,1]} # Ensure it returns non-empty print("Warning: atroposlib not fully found, using dummy classes for some components.") # --- Setup Module-Level Logger (consistent with BaseEnv) --- logger = logging.getLogger(__name__) # Ensure basicConfig is called if not configured elsewhere, e.g., by atroposlib if not logging.getLogger().hasHandlers(): # Check if root logger has handlers logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # --- Configuration (MetRLConfig) remains the same --- class MetRLConfig(BaseEnvConfig): tokenizer_name: str = Field(default="Qwen/Qwen3-8B") group_size: int = Field(default=2) use_wandb: bool = Field(default=True) max_num_workers: int = Field(default=64) rollout_server_url: str = Field(default="http://localhost:8000") total_steps: int = Field(default=2000) batch_size: int = Field(default=-1) # Default steps_per_eval: int = Field(default=100) # Default f max_token_length: int = Field(default=2048) # Default inference_weight: float = Field(default=1.0) wandb_name: Optional[str] = Field(default=None) data_path_to_save_groups: Optional[str] = Field(default='data/MeteorologyForecastRL.jsonl') # Default eval_handling: EvalHandlingEnum = Field(default=EvalHandlingEnum.STOP_TRAIN) # Default eval_limit_ratio: float = Field(default=0.5) # Default num_eval_samples: int = Field(default=20) num_rollouts_to_log: int = Field(default=10) min_items_sent_before_logging: int = Field(default=2) # Default include_messages: bool = Field(default=True) # Default num_rollouts_to_keep: int = Field(default=32) # Default num_rollouts_per_group_for_logging: int = Field(default=1) # Default ensure_scores_are_not_same: bool = Field(default=False) # Default max_eval_workers: int = Field(default=16) # Default max_num_workers_per_node: int = Field(default=8) # Default max_batches_offpolicy: int = Field(default=3) # Default sounding_data_root: str = Field( default="/Users/dev/hackathon/atropos/environments/hack0/data/", description="Root directory for all sounding and AFD data." ) target_date: str = Field( default="20250314", description="The specific date to load data for (YYYYMMDD format)." ) judge_model_name: str = Field( default="google/gemini-2.5-flash-preview", description="Identifier for the Judge model on OpenRouter." ) judge_api_key_env_var: str = Field( default="OPENROUTER_API_KEY", description="Environment variable name for OpenRouter API key for the Judge." ) judge_base_url: str = Field( default="https://openrouter.ai/api/v1", description="Base URL for the OpenRouter API (for Judge)." ) nwp_models_to_use: List[str] = Field( default=["RAP"], description="List of NWP models to use (e.g., RAP, HRRR)." ) forecast_hours_to_sample: List[int] = Field( default=[6, 9, 12, 15, 18], description="Which forecast hours (UTC) from the model run to provide to the LLM." ) target_forecast_hour_offset: int = Field( default=1, description="Offset from the latest provided sounding hour to set the target forecast time." ) max_afds_for_judge: int = Field( default=3, description="Maximum number of AFD files to provide to the judge model." ) max_reasoning_tokens_llm: int = Field( default=3000, description="Max tokens for the agent LLM's generation." ) max_tokens_judge: int = Field( default=2000, description="Max tokens for the judge model's generation." ) # --- Prompts (AGENT_SYSTEM_PROMPT, AGENT_USER_PROMPT_TEMPLATE, etc.) remain the same --- AGENT_SYSTEM_PROMPT = """You are a highly skilled AI meteorologist. Your task is to analyze numerical weather prediction (NWP) model sounding data for a specific location and time period. Based on your analysis, you must: 1. Provide a detailed step-by-step reasoning process. This should include identifying trends, interpreting meteorological parameters, and connecting them to potential weather phenomena. 2. If you determine that additional real-time observational data is crucial for a more accurate assessment, specify the tools you would use. For each tool, output a line in the exact format: TOOL_CALL: {{"tool_name": "tool_name_here", "arguments": {{"param1": "value1", ...}}}} Available conceptual tools: get_surface_observations, get_latest_radar_imagery, get_satellite_imagery, get_upper_air_sounding. 3. Conclude with a concise forecast summary for the specified target time. Start this summary with "FORECAST_SUMMARY: ". Analyze the provided data thoroughly. Your reasoning should be comprehensive.""" AGENT_USER_PROMPT_TEMPLATE = """Please analyze the following NWP model sounding data for station {location_id}. The soundings provided are from the {model_name} model, run on {run_date_full_z}, valid at the following UTC times: {sounding_times_str}. Your goal is to make a preliminary forecast assessment focusing on severe weather potential for {location_id} around {target_forecast_time_utc}. Sounding Data: {soundings_json_blob} Remember to include your reasoning, any TOOL_CALL: {{"tool_name": "tool_name_here", "arguments": {{"param1": "value1", ...}}}} lines, and a final FORECAST_SUMMARY: statement.""" JUDGE_SYSTEM_PROMPT = """You are an expert meteorologist acting as a judge. You will evaluate an AI assistant's analysis of model sounding data. The AI was asked to provide reasoning, call tools if necessary, and give a forecast summary. You will be given the AI's output and relevant Area Forecast Discussions (AFDs) from human forecasters for context. Your evaluation should focus on: 1. **Meteorological Soundness of Reasoning (0-5 points):** * Correct interpretation of sounding parameters and trends. * Logical connections between data and potential weather. * Avoidance of meteorological fallacies or hallucinations. * Depth and detail of the thought process. 2. **Tool Call Relevance & Justification (0-3 points):** * Were the tools called (if any) appropriate given the AI's reasoning and the model data? * Would these tools genuinely help a meteorologist in this situation? * Were critical tool calls missed? 3. **Forecast Summary Quality (0-2 points):** * Clarity and conciseness. * Alignment with the AI's own reasoning and the provided AFDs (or sensible deviation if model data strongly suggested it). Provide a numerical score for each category and a total score (sum of the three, max 10.0). Also, provide a brief overall justification for your scores. Your output MUST be in the following exact format: REASONING_SCORE: {{{{0-5 score}}}} TOOL_CALL_SCORE: {{{{0-3 score}}}} FORECAST_SUMMARY_SCORE: {{{{0-2 score}}}} TOTAL_SCORE: {{{{sum of scores, e.g., 7.5}}}} JUSTIFICATION: {{{{Your brief textual justification here.}}}}""" JUDGE_USER_PROMPT_TEMPLATE = """AI Assistant's Output: --- {llm_full_output} --- Contextual Area Forecast Discussions (AFDs): --- {afds_blob} --- Please evaluate the AI assistant's output based on the criteria and provide your scores and justification in the specified format.""" class MeteorologyForecastRLEnv(BaseEnv): env_config_cls = MetRLConfig name = "MeteorologyForecastRL" def __init__( self, config: MetRLConfig, server_configs: List[APIServerConfig], slurm=True, # Default based on your CLI help testing=False, # Default based on your CLI help ): super().__init__(config, server_configs, slurm, testing) # self.config: MetRLConfig = self.config # This is redundant if super().__init__ sets self.config # Ensure self.config is correctly typed if BaseEnv makes it generic if not isinstance(self.config, MetRLConfig): # Check type logger.warning(f"self.config in __init__ is not MetRLConfig, type: {type(self.config)}. This might indicate an issue with BaseEnv or pydantic_cli setup.") self.locations_data: List[Dict[str, Any]] = [] self.agent_llm_server: Optional[APIServer] = None self.judge_server: Optional[APIServer] = None if not hasattr(self.server, 'servers') or not self.server.servers: # Added check for empty servers list logger.error("CRITICAL: ServerManager (self.server) does not have a 'servers' attribute or it's empty!") elif len(self.server.servers) > 1: self.agent_llm_server = self.server.servers[0] self.judge_server = self.server.servers[1] if self.agent_llm_server and hasattr(self.agent_llm_server, 'config') and self.agent_llm_server.config: # check config exists logger.info(f"Agent server: {self.agent_llm_server.config.model_name} @ {self.agent_llm_server.config.base_url}") else: logger.warning("Agent LLM server or its config is not properly initialized for logging.") if self.judge_server and hasattr(self.judge_server, 'config') and self.judge_server.config: # check config exists logger.info(f"Judge server: {self.judge_server.config.model_name} @ {self.judge_server.config.base_url}") else: logger.warning("Judge server or its config is not properly initialized for logging.") elif len(self.server.servers) == 1: logger.warning( "Only 1 API server configured in ServerManager. Agent and Judge will use the same server." ) self.agent_llm_server = self.server.servers[0] self.judge_server = self.server.servers[0] if self.agent_llm_server and hasattr(self.agent_llm_server, 'config') and self.agent_llm_server.config: # check config exists logger.info(f"Agent/Judge server: {self.agent_llm_server.config.model_name} @ {self.agent_llm_server.config.base_url}") else: logger.warning("Agent/Judge server or its config is not properly initialized for logging.") self.current_item_index: int = 0 self.iter: int = 0 self.judge_total_scores_buffer: List[float] = [] self.judge_reasoning_scores_buffer: List[float] = [] self.judge_tool_scores_buffer: List[float] = [] self.judge_forecast_scores_buffer: List[float] = [] self.rollouts_for_wandb_custom: List[Tuple[str, str, str, str, float, str]] = [] self.eval_metrics_buffer: List[Dict[str, float]] = [] @classmethod def config_init(cls) -> Tuple[MetRLConfig, List[APIServerConfig]]: # This method is usually called by pydantic_cli. # The CLI arguments will override these defaults. env_config = MetRLConfig() # Initialize with defaults, CLI will override # Get API keys and base URLs from environment or use defaults from MetRLConfig # Agent server config (server_configs[0]) # These might be overridden by CLI --openai.model_name, --openai.base_url etc. for the *first* server # However, atroposlib's pydantic_cli might handle multiple server configs differently. # The provided help suggests a single --openai.* block, which implies it might apply to all servers # or only the first. We'll assume here it populates the first server, and the second uses MetRLConfig defaults. agent_model_name = os.environ.get("AGENT_LLM_MODEL_NAME", env_config.tokenizer_name) # Prioritize env var agent_api_key = os.environ.get("AGENT_LLM_API_KEY", "EMPTY_KEY_IF_LOCAL_VLLM") agent_base_url = os.environ.get("AGENT_LLM_BASE_URL", "http://localhost:8080/v1") # Example vLLM judge_api_key = os.environ.get(env_config.judge_api_key_env_var) if not judge_api_key: # This print is fine as it's at class method execution, not instance init # logger is not available at class level directly here, so print is okay or use logging.getLogger logging.warning(f"Environment variable {env_config.judge_api_key_env_var} not set for Judge API.") server_configs = [ APIServerConfig( model_name=agent_model_name, # This should be the agent model base_url=agent_base_url, api_key=agent_api_key, # num_requests_for_eval=64, # these are often part of APIServer not its config ), APIServerConfig( model_name=env_config.judge_model_name, base_url=env_config.judge_base_url, api_key=judge_api_key, # num_requests_for_eval=64, ) ] # logger.info(f"config_init: env_config={env_config}") # For debugging # logger.info(f"config_init: server_configs={server_configs}") # For debugging return env_config, server_configs # --- setup, get_next_item, _parse_llm_output, _parse_judge_output remain the same --- async def setup(self): logger.info(f"Setting up {self.name or self.__class__.__name__}...") data_root = Path(self.config.sounding_data_root) date_path = data_root / self.config.target_date if not date_path.is_dir(): logger.error(f"Target date directory not found: {date_path}") return available_locations = [loc.name for loc in date_path.iterdir() if loc.is_dir()] logger.info(f"Found {len(available_locations)} locations for date {self.config.target_date}: {available_locations}") for loc_id in available_locations: loc_path = date_path / loc_id soundings_for_item = [] sounding_times_for_item = [] if not self.config.nwp_models_to_use or not self.config.forecast_hours_to_sample: logger.warning(f"NWP models or forecast hours to sample is empty in config. Skipping {loc_id}") continue selected_model = self.config.nwp_models_to_use[0] for hour_z in self.config.forecast_hours_to_sample: fname = f"{loc_id}_{selected_model}_{self.config.target_date}{hour_z:02d}Z.buf_default_llm_optimized.jsonl" sounding_file_path = loc_path / fname if sounding_file_path.exists(): try: with open(sounding_file_path, 'r') as f: line = f.readline() if line: soundings_for_item.append(json.loads(line)) sounding_times_for_item.append(f"{hour_z:02d}00Z") except Exception as e: logger.warning(f"Could not load or parse sounding file {sounding_file_path}: {e}") else: logger.debug(f"Sounding file not found: {sounding_file_path}") if not soundings_for_item: logger.debug(f"No valid soundings found for {loc_id} on {self.config.target_date}. Skipping.") continue afd_files = sorted([f for f in loc_path.glob("AFD_*.txt")]) selected_afd_texts = [] if afd_files: if len(afd_files) <= self.config.max_afds_for_judge: indices_to_take = list(range(len(afd_files))) else: indices_to_take = sorted(list(set([0, len(afd_files) // 2, len(afd_files) - 1]))) indices_to_take = indices_to_take[:self.config.max_afds_for_judge] for i in indices_to_take: try: with open(afd_files[i], 'r', encoding='utf-8', errors='replace') as f: # Specify encoding and error handling afd_text = f.read() # Remove common control characters, especially ETX (\x03 or \u0003) cleaned_afd_text = ''.join(c for c in afd_text if c.isprintable() or c.isspace()) # Or more specifically for \u0003: # cleaned_afd_text = afd_text.replace('\u0003', '') selected_afd_texts.append(cleaned_afd_text) except Exception as e: logger.warning(f"Could not read or clean AFD file {afd_files[i]}: {e}") if not sounding_times_for_item: logger.warning(f"No sounding times available for {loc_id}, cannot determine target forecast time. Skipping.") continue latest_sounding_hour_str = sounding_times_for_item[-1][:2] if not latest_sounding_hour_str.isdigit(): logger.warning(f"Could not parse latest sounding hour from {sounding_times_for_item[-1]} for {loc_id}. Skipping.") continue latest_sounding_hour = int(latest_sounding_hour_str) target_hour = latest_sounding_hour + self.config.target_forecast_hour_offset target_forecast_time_utc = f"{target_hour:02d}00Z on {self.config.target_date[4:6]}/{self.config.target_date[6:8]}/{self.config.target_date[0:4]}" run_time_str = "UnknownRunTime" if soundings_for_item and 'tm' in soundings_for_item[0] and '/' in soundings_for_item[0]['tm']: try: run_time_str = soundings_for_item[0]['tm'].split('/')[1][:2] + "Z" except IndexError: logger.warning(f"Could not parse run time from 'tm' field: {soundings_for_item[0]['tm']} for {loc_id}") run_date_full_z = f"{self.config.target_date} at {run_time_str}" item_data = { "case_id": f"{self.config.target_date}_{loc_id}", "location_id": loc_id, "model_name": selected_model, "run_date_full_z": run_date_full_z, "target_forecast_time_utc": target_forecast_time_utc, "model_soundings_data": soundings_for_item, "sounding_times_str": ", ".join(sounding_times_for_item), "afd_texts": selected_afd_texts } self.locations_data.append(item_data) if not self.locations_data: logger.error("No data loaded. Environment cannot proceed.") else: logger.info(f"Successfully prepared {len(self.locations_data)} items for processing.") if self.locations_data: # Ensure not empty before shuffling random.shuffle(self.locations_data) self.iter = 0 def save_checkpoint(self, step, data=None): if data is None: data = {} data["current_item_index"] = self.current_item_index data["iter"] = self.iter super().save_checkpoint(step, data) async def get_next_item(self) -> Optional[Dict[str, Any]]: if not self.locations_data: logger.warning("No locations data available in get_next_item.") return None if self.current_item_index >= len(self.locations_data): logger.info("Cycled through all available location data. Re-shuffling and resetting index.") random.shuffle(self.locations_data) self.current_item_index = 0 if not self.locations_data: return None item_to_return = self.locations_data[self.current_item_index] self.current_item_index += 1 self.iter +=1 return item_to_return def _parse_llm_output(self, llm_text: str) -> Dict[str, Any]: think_content = "" tool_calls = [] forecast_summary = "" think_match = re.search(r"(.*?)", llm_text, re.DOTALL | re.IGNORECASE) if think_match: think_content = think_match.group(1).strip() for line in llm_text.splitlines(): line_upper = line.strip().upper() if line_upper.startswith("TOOL_CALL:"): try: tool_json_str = line.strip()[len("TOOL_CALL:"):].strip() if tool_json_str.startswith("{") and tool_json_str.endswith("}"): tool_calls.append(json.loads(tool_json_str)) else: logger.warning(f"Skipping malformed TOOL_CALL: {line}") except json.JSONDecodeError: logger.warning(f"Could not parse TOOL_CALL JSON: {line}") elif line_upper.startswith("FORECAST_SUMMARY:"): forecast_summary = line.strip()[len("FORECAST_SUMMARY:"):].strip() return { "think_content": think_content, "tool_calls": tool_calls, "forecast_summary": forecast_summary } def _parse_judge_output(self, judge_text: str) -> Dict[str, Any]: scores = { "reasoning": 0.0, "tool_call": 0.0, "forecast_summary": 0.0, "total": 0.0 } justification = "No justification provided or parse error." patterns = { "reasoning": r"REASONING_SCORE:\s*([0-9.]+)", "tool_call": r"TOOL_CALL_SCORE:\s*([0-9.]+)", "forecast_summary": r"FORECAST_SUMMARY_SCORE:\s*([0-9.]+)", "total": r"TOTAL_SCORE:\s*([0-9.]+)" } for key, pattern in patterns.items(): match = re.search(pattern, judge_text, re.IGNORECASE) if match: try: scores[key] = float(match.group(1)) except ValueError: logger.warning(f"Could not parse score for {key} from: {match.group(1)}") just_match = re.search(r"JUSTIFICATION:\s*(.*)", judge_text, re.DOTALL | re.IGNORECASE) if just_match: justification = just_match.group(1).strip() calculated_total = round(scores["reasoning"] + scores["tool_call"] + scores["forecast_summary"], 2) if abs(scores["total"] - calculated_total) > 0.1 : if scores["total"] != 0.0: logger.warning(f"Parsed total score {scores['total']} differs from sum of components {calculated_total}. Using sum.") scores["total"] = calculated_total return {"scores": scores, "justification": justification} async def collect_trajectories(self, item: Dict[str, Any]) -> Tuple[Optional[ScoredDataGroup], List[Item]]: case_id = item.get('case_id', 'Unknown') logger.info(f"ITEM {case_id}: Starting collect_trajectories.") if item is None: logger.warning(f"ITEM {case_id}: Received None item in collect_trajectories.") return None, [] soundings_blob = json.dumps(item.get("model_soundings_data", []), indent=2) agent_user_prompt = AGENT_USER_PROMPT_TEMPLATE.format( location_id=item.get("location_id", "N/A"), model_name=item.get("model_name", "N/A"), run_date_full_z=item.get("run_date_full_z", "N/A"), sounding_times_str=item.get("sounding_times_str", "N/A"), target_forecast_time_utc=item.get("target_forecast_time_utc", "N/A"), soundings_json_blob=soundings_blob ) agent_messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}, {"role": "user", "content": agent_user_prompt}] if not self.agent_llm_server: logger.error(f"ITEM {case_id}: Agent LLM server not available.") return None, [] logger.info(f"ITEM {case_id}: About to call Agent LLM...") agent_chat_completions_obj = None # Initialize try: start_time = time.time() agent_chat_completions_obj = await self.agent_llm_server.chat_completion( messages=agent_messages, model=self.agent_llm_server.config.model_name if self.agent_llm_server.config else "unknown_agent_model", n=self.config.group_size, max_tokens=self.config.max_reasoning_tokens_llm, temperature=0.7, stop=["<|im_end|>", "<|endoftext|>", "<|eot_id|>"] ) choices_received = len(agent_chat_completions_obj.choices) if agent_chat_completions_obj and hasattr(agent_chat_completions_obj, 'choices') else 0 logger.info(f"ITEM {case_id}: Agent LLM call completed. Time: {time.time() - start_time:.2f}s. Choices received: {choices_received}") except Exception as e: logger.error(f"ITEM {case_id}: Agent LLM call failed: {e}") logger.error(traceback.format_exc()) return None, [] if not agent_chat_completions_obj or not hasattr(agent_chat_completions_obj, 'choices') or not agent_chat_completions_obj.choices: logger.warning(f"ITEM {case_id}: No choices received from Agent LLM or malformed response. Response: {agent_chat_completions_obj}") return None, [] scored_data_group = ScoredDataGroup(tokens=[], masks=[], scores=[], overrides=[]) afd_context_blob = "\n\n---\n\n".join(item.get("afd_texts", [])) if item.get("afd_texts") else "No AFDs provided for this case." for agent_choice_idx, agent_choice in enumerate(agent_chat_completions_obj.choices): choice_id = f"{case_id}_choice_{agent_choice_idx}" logger.info(f"ITEM {choice_id}: Processing agent choice.") llm_full_output_text = "" # Initialize if agent_choice and hasattr(agent_choice, 'message') and agent_choice.message and hasattr(agent_choice.message, 'content'): llm_full_output_text = agent_choice.message.content logger.debug(f"ITEM {choice_id}: Agent output received (first 200 chars): {llm_full_output_text[:200]}...") else: logger.warning(f"ITEM {choice_id}: Agent choice did not contain expected message content. Skipping this choice. Details: {agent_choice}") continue parsed_llm_out = self._parse_llm_output(llm_full_output_text) logger.info(f"ITEM {choice_id}: Parsed agent output.") judge_user_prompt = JUDGE_USER_PROMPT_TEMPLATE.format(llm_full_output=llm_full_output_text, afds_blob=afd_context_blob) judge_messages = [{"role": "system", "content": JUDGE_SYSTEM_PROMPT}, {"role": "user", "content": judge_user_prompt}] final_score = 0.0 judge_justification_text = "Judge call not made or failed." judge_parsed_scores = {} if not self.judge_server: logger.error(f"ITEM {choice_id}: Judge server not available. Assigning default score.") else: logger.info(f"ITEM {choice_id}: About to call Judge LLM...") # Log the request payload for the judge request_payload_for_judge = { "messages": judge_messages, "model": self.judge_server.config.model_name if self.judge_server.config else "unknown_judge_model", "max_tokens": self.config.max_tokens_judge, "temperature": 0.2, "n": 1 # Add any other parameters being sent by APIServer implicitly or explicitly } logger.info(f"ITEM {choice_id}: Judge LLM Request Payload: {json.dumps(request_payload_for_judge, indent=2, ensure_ascii=False)}") try: judge_start_time = time.time() judge_completion_obj = await self.judge_server.chat_completion( messages=judge_messages, max_tokens=self.config.max_tokens_judge, temperature=0.2, n=1, model=self.judge_server.config.model_name if self.judge_server.config else "unknown_judge_model" ) logger.info(f"ITEM {choice_id}: Judge LLM call completed. Time: {time.time() - judge_start_time:.2f}s") if judge_completion_obj and hasattr(judge_completion_obj, 'choices') and judge_completion_obj.choices and \ judge_completion_obj.choices[0] and hasattr(judge_completion_obj.choices[0], 'message') and \ judge_completion_obj.choices[0].message and hasattr(judge_completion_obj.choices[0].message, 'content'): judge_output_text = judge_completion_obj.choices[0].message.content logger.info(f"ITEM {choice_id}: Judge output received (first 200 chars): {judge_output_text[:200]}...") parsed_judge_out = self._parse_judge_output(judge_output_text) final_score = parsed_judge_out["scores"]["total"] judge_justification_text = parsed_judge_out["justification"] judge_parsed_scores = parsed_judge_out["scores"] logger.info(f"ITEM {choice_id}: Parsed judge output. Score: {final_score}") self.judge_total_scores_buffer.append(final_score) self.judge_reasoning_scores_buffer.append(judge_parsed_scores.get("reasoning", 0.0)) self.judge_tool_scores_buffer.append(judge_parsed_scores.get("tool_call", 0.0)) self.judge_forecast_scores_buffer.append(judge_parsed_scores.get("forecast_summary", 0.0)) else: logger.error(f"ITEM {choice_id}: Judge LLM response was empty or malformed. Details: {judge_completion_obj}") # ***** START OF MODIFIED/ADDED ERROR HANDLING BLOCK ***** except httpx.HTTPStatusError as http_err: logger.error(f"ITEM {choice_id}: Judge LLM call failed with HTTPStatusError: {http_err.response.status_code}") logger.error(f"Request URL: {http_err.request.url}") # Log request body (already logged above, but useful for context here) logger.error(f"Request Body (for context):\n{json.dumps(request_payload_for_judge, indent=2, ensure_ascii=False)}") logger.error(f"Response Headers:\n{http_err.response.headers}") try: # Attempt to read response body text. # For async responses, if not already read, it might require await response.aread() # but .text should be available if the response was processed by httpx. response_body_str = http_err.response.text logger.error(f"Response Body:\n{response_body_str}") except Exception as resp_e: logger.error(f"Could not decode or access response body text from HTTPStatusError: {resp_e}") logger.error(traceback.format_exc()) # ***** END OF MODIFIED/ADDED ERROR HANDLING BLOCK ***** except Exception as e: # General catch-all logger.error(f"ITEM {choice_id}: Judge LLM call failed with general Exception: {e}") logger.error(traceback.format_exc()) logger.info(f"ITEM {choice_id}: About to tokenize full trajectory...") full_trajectory_messages = agent_messages + [{"role": "assistant", "content": llm_full_output_text}] # Ensure self.tokenizer and self.config.max_token_length are available tokenized_output = None if hasattr(self, 'tokenizer') and self.tokenizer and hasattr(self.config, 'max_token_length'): tokenized_output = tokenize_for_trainer(self.tokenizer, full_trajectory_messages, self.config.max_token_length) logger.info(f"ITEM {choice_id}: Tokenization complete. Tokens length: {len(tokenized_output.get('tokens', [])) if tokenized_output else 'N/A'}") else: logger.error(f"ITEM {choice_id}: Tokenizer or max_token_length not available. Skipping tokenization.") if tokenized_output and tokenized_output.get("tokens") and len(tokenized_output["tokens"]) > 0: # Ensure scored_data_group is a dict-like object that supports append or assignment if not isinstance(scored_data_group, dict) and not hasattr(scored_data_group, 'append'): logger.error(f"ITEM {choice_id}: scored_data_group is not a dict or list-like object. Type: {type(scored_data_group)}") else: if isinstance(scored_data_group.get("tokens"), list): scored_data_group["tokens"].append(tokenized_output["tokens"]) if isinstance(scored_data_group.get("masks"), list): scored_data_group["masks"].append(tokenized_output["masks"]) if isinstance(scored_data_group.get("scores"), list): scored_data_group["scores"].append(final_score) item_overrides = { "case_id": case_id, "llm_think": parsed_llm_out["think_content"], "llm_tools": str(parsed_llm_out["tool_calls"]), "llm_summary": parsed_llm_out["forecast_summary"], "judge_justification": judge_justification_text, "judge_score_reasoning": judge_parsed_scores.get("reasoning", 0.0), "judge_score_tool": judge_parsed_scores.get("tool_call", 0.0), "judge_score_forecast": judge_parsed_scores.get("forecast_summary", 0.0), } if isinstance(scored_data_group.get("overrides"), list): scored_data_group["overrides"].append(item_overrides) if hasattr(self, 'rollouts_for_wandb_custom') and isinstance(self.rollouts_for_wandb_custom, list): self.rollouts_for_wandb_custom.append(( agent_user_prompt[:300]+"...", parsed_llm_out["think_content"][:500]+"...", str(parsed_llm_out["tool_calls"])[:300]+"...", parsed_llm_out["forecast_summary"][:300]+"...", final_score, judge_justification_text[:500]+"..." )) logger.info(f"ITEM {choice_id}: Added to scored_data_group and rollouts_for_wandb_custom.") else: logger.warning(f"ITEM {choice_id}: Tokenization failed or produced no tokens. Agent output was: {llm_full_output_text[:200]}...") if not scored_data_group.get("tokens"): logger.warning(f"ITEM {case_id}: No valid trajectories collected for this item.") return None, [] logger.info(f"ITEM {case_id}: Finished processing all choices. Returning scored data group with {len(scored_data_group['tokens'])} trajectories.") return scored_data_group, [] # --- evaluate and wandb_log remain the same or with similar logging detail if needed --- async def evaluate(self, *args, **kwargs): logger.info(f"Starting evaluation for {self.name or self.__class__.__name__}...") self.eval_metrics_buffer.clear() if not self.locations_data: logger.warning("No data available for evaluation.") return # Ensure config attributes exist before accessing num_eval_samples = getattr(self.config, "num_eval_samples", 20) eval_items_to_process = min(len(self.locations_data), num_eval_samples) if eval_items_to_process == 0: logger.warning("Not enough data or num_eval_samples is 0. Skipping evaluation.") return eval_list = [] if self.locations_data : # ensure locations_data is not empty eval_list = random.sample(self.locations_data, k=eval_items_to_process) for eval_idx, eval_item_data in enumerate(eval_list): case_id = eval_item_data.get('case_id', f"UnknownEval_{eval_idx}") logger.info(f"EVAL ITEM {case_id}: Starting evaluation.") soundings_blob = json.dumps(eval_item_data["model_soundings_data"], indent=2) agent_user_prompt = AGENT_USER_PROMPT_TEMPLATE.format( location_id=eval_item_data["location_id"], model_name=eval_item_data["model_name"], run_date_full_z=eval_item_data["run_date_full_z"], sounding_times_str=eval_item_data["sounding_times_str"], target_forecast_time_utc=eval_item_data["target_forecast_time_utc"], soundings_json_blob=soundings_blob) agent_messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}, {"role": "user", "content": agent_user_prompt}] if not self.agent_llm_server: logger.error(f"EVAL ITEM {case_id}: Agent LLM server not available.") continue llm_full_output_text = "" # Initialize logger.info(f"EVAL ITEM {case_id}: About to call Agent LLM for evaluation.") try: agent_start_time = time.time() agent_completion_obj = await self.agent_llm_server.chat_completion( messages=agent_messages, n=1, max_tokens=self.config.max_reasoning_tokens_llm, temperature=0.1, # Low temp for eval stop=["<|eot_id|>", "<|im_end|>", "<|endoftext|>"], model=self.agent_llm_server.config.model_name if self.agent_llm_server.config else "unknown_agent_model" ) logger.info(f"EVAL ITEM {case_id}: Agent LLM call completed. Time: {time.time() - agent_start_time:.2f}s") if agent_completion_obj.choices and agent_completion_obj.choices[0].message and agent_completion_obj.choices[0].message.content: llm_full_output_text = agent_completion_obj.choices[0].message.content logger.debug(f"EVAL ITEM {case_id}: Agent output received (first 200): {llm_full_output_text[:200]}") else: logger.error(f"EVAL ITEM {case_id}: Agent LLM response was empty or malformed. Details: {agent_completion_obj}") continue except Exception as e: logger.error(f"EVAL ITEM {case_id}: Agent LLM call failed: {e}") logger.error(traceback.format_exc()) continue afd_context_blob = "\n\n---\n\n".join(eval_item_data["afd_texts"]) or "No AFDs." judge_user_prompt = JUDGE_USER_PROMPT_TEMPLATE.format(llm_full_output=llm_full_output_text, afds_blob=afd_context_blob) judge_messages = [{"role": "system", "content": JUDGE_SYSTEM_PROMPT}, {"role": "user", "content": judge_user_prompt}] if not self.judge_server: logger.warning(f"EVAL ITEM {case_id}: Judge server not available.") continue logger.info(f"EVAL ITEM {case_id}: About to call Judge LLM for evaluation.") try: judge_start_time = time.time() judge_completion_obj = await self.judge_server.chat_completion( messages=judge_messages, max_tokens=self.config.max_tokens_judge, temperature=0.1, # Low temp for eval n=1, model=self.judge_server.config.model_name if self.judge_server.config else "unknown_judge_model" ) logger.info(f"EVAL ITEM {case_id}: Judge LLM call completed. Time: {time.time() - judge_start_time:.2f}s") if judge_completion_obj.choices and judge_completion_obj.choices[0].message and judge_completion_obj.choices[0].message.content: judge_output_text = judge_completion_obj.choices[0].message.content logger.debug(f"EVAL ITEM {case_id}: Judge output received (first 200): {judge_output_text[:200]}") parsed_judge_out = self._parse_judge_output(judge_output_text) self.eval_metrics_buffer.append(parsed_judge_out["scores"]) logger.info(f"EVAL ITEM {case_id}: Judge score {parsed_judge_out['scores']['total']} added to buffer.") else: logger.error(f"EVAL ITEM {case_id}: Judge LLM response was empty or malformed. Details: {judge_completion_obj}") except httpx.HTTPStatusError as http_err: # Added specific error handling for eval as well logger.error(f"EVAL ITEM {case_id}: Judge LLM call failed with HTTPStatusError: {http_err.response.status_code}") logger.error(f"Request URL: {http_err.request.url}") # Consider logging request payload for eval too if needed for debugging eval failures logger.error(f"Response Headers:\n{http_err.response.headers}") try: response_body_str = http_err.response.text logger.error(f"Response Body:\n{response_body_str}") except Exception as resp_e: logger.error(f"Could not decode or access response body text from HTTPStatusError during eval: {resp_e}") logger.error(traceback.format_exc()) except Exception as e: logger.error(f"EVAL ITEM {case_id}: Judge LLM call failed: {e}") logger.error(traceback.format_exc()) logger.info(f"Evaluation completed for {len(self.eval_metrics_buffer)} items.") async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if not self.config.use_wandb or not wandb.run: # Check if wandb is active logger.debug("WandB logging skipped (disabled or no active run).") if hasattr(super(), 'wandb_log'): # Call super if it exists, even if not logging locally await super().wandb_log(wandb_metrics if wandb_metrics else {}) return if wandb_metrics is None: wandb_metrics = {} logger.info("Preparing metrics for WandB log...") if self.judge_total_scores_buffer: wandb_metrics["train/avg_judge_total_score"] = sum(self.judge_total_scores_buffer) / len(self.judge_total_scores_buffer) wandb_metrics["train/avg_judge_reasoning_score"] = sum(self.judge_reasoning_scores_buffer) / len(self.judge_reasoning_scores_buffer) wandb_metrics["train/avg_judge_tool_score"] = sum(self.judge_tool_scores_buffer) / len(self.judge_tool_scores_buffer) wandb_metrics["train/avg_judge_forecast_score"] = sum(self.judge_forecast_scores_buffer) / len(self.judge_forecast_scores_buffer) logger.info(f"Train scores (total): {wandb_metrics['train/avg_judge_total_score']:.2f} from {len(self.judge_total_scores_buffer)} samples.") self.judge_total_scores_buffer.clear(); self.judge_reasoning_scores_buffer.clear(); self.judge_tool_scores_buffer.clear(); self.judge_forecast_scores_buffer.clear() if self.eval_metrics_buffer: avg_eval_total = sum(s['total'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer) avg_eval_reasoning = sum(s['reasoning'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer) avg_eval_tool = sum(s['tool_call'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer) avg_eval_forecast = sum(s['forecast_summary'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer) wandb_metrics["eval/avg_judge_total_score"] = avg_eval_total wandb_metrics["eval/avg_judge_reasoning_score"] = avg_eval_reasoning wandb_metrics["eval/avg_judge_tool_score"] = avg_eval_tool wandb_metrics["eval/avg_judge_forecast_score"] = avg_eval_forecast logger.info(f"Eval scores (total): {avg_eval_total:.2f} from {len(self.eval_metrics_buffer)} samples.") self.eval_metrics_buffer.clear() if self.rollouts_for_wandb_custom: if wandb.run: # Double check wandb is active table = wandb.Table(columns=["Prompt Hint", "LLM Think", "LLM Tools", "LLM Summary", "Judge Score", "Judge Justification"]) num_to_log = min(len(self.rollouts_for_wandb_custom), getattr(self.config, "num_rollouts_to_log", 10)) # Use getattr for safety sample_to_log = [] if num_to_log > 0 and self.rollouts_for_wandb_custom: sample_to_log = random.sample(self.rollouts_for_wandb_custom, k=min(num_to_log, len(self.rollouts_for_wandb_custom))) for P, T, O, S, Sc, J in sample_to_log: table.add_data(P, T, O, S, Sc, J) if sample_to_log: wandb_metrics["train/detailed_rollouts"] = table logger.info(f"Logged {len(sample_to_log)} rollouts to WandB table.") self.rollouts_for_wandb_custom.clear() if wandb_metrics: logger.info(f"Logging to WandB: {list(wandb_metrics.keys())}") await super().wandb_log(wandb_metrics) else: logger.info("No new metrics to log to WandB in this step.") if hasattr(super(), 'wandb_log'): # Call super even if no new metrics locally await super().wandb_log({}) if __name__ == "__main__": try: MeteorologyForecastRLEnv.cli() except Exception as e: logger.critical(f"CRITICAL Error during CLI execution: {e}") # Use critical for top-level crash logger.critical(traceback.format_exc())