mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
858 lines
No EOL
48 KiB
Python
858 lines
No EOL
48 KiB
Python
# In your meteorology_forecast_env.py file:
|
|
|
|
import asyncio
|
|
import time
|
|
import json
|
|
import os
|
|
import random
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Union, Any
|
|
import logging
|
|
import traceback # Import traceback for more detailed error logging
|
|
|
|
import wandb
|
|
from pydantic import Field
|
|
import httpx
|
|
|
|
# Assuming APIServer and ServerManager are imported correctly from atroposlib
|
|
# For this standalone example, let's define dummy classes if not available
|
|
try:
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
EvalHandlingEnum,
|
|
Item,
|
|
ScoredDataGroup,
|
|
APIServer
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
except ImportError:
|
|
# Dummy classes (keep these if you were using them for standalone testing)
|
|
class BaseEnvConfig: pass
|
|
class APIServerConfig:
|
|
def __init__(self, model_name, base_url, api_key, timeout=1200, num_max_requests_at_once=512, num_requests_for_eval=64, rolling_buffer_length=1000, server_type='openai', n_kwarg_is_ignored=False, health_check=True):
|
|
self.model_name = model_name
|
|
self.base_url = base_url
|
|
self.api_key = api_key
|
|
self.timeout = timeout # etc.
|
|
class BaseEnv:
|
|
def __init__(self, config, server_configs, slurm, testing):
|
|
self.config = config
|
|
self.server = type('ServerManager', (), {'servers': [APIServer(sc) for sc in server_configs]})()
|
|
self.tokenizer = type('Tokenizer', (), {'apply_chat_template': lambda *args, **kwargs: ""})()
|
|
self.testing = testing
|
|
def save_checkpoint(self, step, data): pass
|
|
async def wandb_log(self, metrics): pass
|
|
@classmethod
|
|
def cli(cls): print("Dummy CLI called") # Add dummy cli
|
|
class Item: pass
|
|
class ScoredDataGroup(dict): pass
|
|
class EvalHandlingEnum: LIMIT_TRAIN = "limit_train"; STOP_TRAIN = "stop_train" # Added STOP_TRAIN
|
|
class APIServer:
|
|
def __init__(self, config: APIServerConfig): self.config = config # Ensure config is taken
|
|
async def chat_completion(self, **kwargs):
|
|
print(f"Dummy APIServer chat_completion called with model: {kwargs.get('model', self.config.model_name)}")
|
|
# Simulate a response structure
|
|
class DummyMessage:
|
|
def __init__(self, content): self.content = content
|
|
class DummyChoice:
|
|
def __init__(self, content): self.message = DummyMessage(content)
|
|
class DummyCompletionResponse:
|
|
def __init__(self, choices_content_list):
|
|
self.choices = [DummyChoice(c) for c in choices_content_list]
|
|
if self.config.model_name.startswith("google"): # Simulate judge
|
|
return DummyCompletionResponse(["REASONING_SCORE: 3\nTOOL_CALL_SCORE: 1\nFORECAST_SUMMARY_SCORE: 1\nTOTAL_SCORE: 5\nJUSTIFICATION: Dummy judge output"])
|
|
else: # Simulate agent
|
|
return DummyCompletionResponse(["<think>Dummy agent thinking</think>\nFORECAST_SUMMARY: Dummy forecast"])
|
|
async def completion(self, **kwargs): return type('CompletionResponse', (), {'choices': []})()
|
|
|
|
def tokenize_for_trainer(tokenizer, messages, max_length): return {"tokens": [1,2,3], "masks": [1,1,1]} # Ensure it returns non-empty
|
|
print("Warning: atroposlib not fully found, using dummy classes for some components.")
|
|
|
|
|
|
# --- Setup Module-Level Logger (consistent with BaseEnv) ---
|
|
logger = logging.getLogger(__name__)
|
|
# Ensure basicConfig is called if not configured elsewhere, e.g., by atroposlib
|
|
if not logging.getLogger().hasHandlers(): # Check if root logger has handlers
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
# --- Configuration (MetRLConfig) remains the same ---
|
|
class MetRLConfig(BaseEnvConfig):
|
|
tokenizer_name: str = Field(default="Qwen/Qwen3-8B")
|
|
group_size: int = Field(default=2)
|
|
use_wandb: bool = Field(default=True)
|
|
max_num_workers: int = Field(default=64)
|
|
rollout_server_url: str = Field(default="http://localhost:8000")
|
|
total_steps: int = Field(default=2000)
|
|
batch_size: int = Field(default=-1) # Default
|
|
steps_per_eval: int = Field(default=100) # Default f
|
|
max_token_length: int = Field(default=2048) # Default
|
|
inference_weight: float = Field(default=1.0)
|
|
wandb_name: Optional[str] = Field(default=None)
|
|
data_path_to_save_groups: Optional[str] = Field(default='data/MeteorologyForecastRL.jsonl') # Default
|
|
eval_handling: EvalHandlingEnum = Field(default=EvalHandlingEnum.STOP_TRAIN) # Default
|
|
eval_limit_ratio: float = Field(default=0.5) # Default
|
|
num_eval_samples: int = Field(default=20)
|
|
num_rollouts_to_log: int = Field(default=10)
|
|
min_items_sent_before_logging: int = Field(default=2) # Default
|
|
include_messages: bool = Field(default=True) # Default
|
|
num_rollouts_to_keep: int = Field(default=32) # Default
|
|
num_rollouts_per_group_for_logging: int = Field(default=1) # Default
|
|
ensure_scores_are_not_same: bool = Field(default=False) # Default
|
|
max_eval_workers: int = Field(default=16) # Default
|
|
max_num_workers_per_node: int = Field(default=8) # Default
|
|
max_batches_offpolicy: int = Field(default=3) # Default
|
|
|
|
|
|
sounding_data_root: str = Field(
|
|
default="/Users/dev/hackathon/atropos/environments/hack0/data/",
|
|
description="Root directory for all sounding and AFD data."
|
|
)
|
|
target_date: str = Field(
|
|
default="20250314",
|
|
description="The specific date to load data for (YYYYMMDD format)."
|
|
)
|
|
judge_model_name: str = Field(
|
|
default="google/gemini-2.5-flash-preview",
|
|
description="Identifier for the Judge model on OpenRouter."
|
|
)
|
|
judge_api_key_env_var: str = Field(
|
|
default="OPENROUTER_API_KEY",
|
|
description="Environment variable name for OpenRouter API key for the Judge."
|
|
)
|
|
judge_base_url: str = Field(
|
|
default="https://openrouter.ai/api/v1",
|
|
description="Base URL for the OpenRouter API (for Judge)."
|
|
)
|
|
nwp_models_to_use: List[str] = Field(
|
|
default=["RAP"],
|
|
description="List of NWP models to use (e.g., RAP, HRRR)."
|
|
)
|
|
forecast_hours_to_sample: List[int] = Field(
|
|
default=[6, 9, 12, 15, 18],
|
|
description="Which forecast hours (UTC) from the model run to provide to the LLM."
|
|
)
|
|
target_forecast_hour_offset: int = Field(
|
|
default=1,
|
|
description="Offset from the latest provided sounding hour to set the target forecast time."
|
|
)
|
|
max_afds_for_judge: int = Field(
|
|
default=3,
|
|
description="Maximum number of AFD files to provide to the judge model."
|
|
)
|
|
max_reasoning_tokens_llm: int = Field(
|
|
default=3000,
|
|
description="Max tokens for the agent LLM's generation."
|
|
)
|
|
max_tokens_judge: int = Field(
|
|
default=2000,
|
|
description="Max tokens for the judge model's generation."
|
|
)
|
|
|
|
# --- Prompts (AGENT_SYSTEM_PROMPT, AGENT_USER_PROMPT_TEMPLATE, etc.) remain the same ---
|
|
AGENT_SYSTEM_PROMPT = """You are a highly skilled AI meteorologist. Your task is to analyze numerical weather prediction (NWP) model sounding data for a specific location and time period.
|
|
Based on your analysis, you must:
|
|
1. Provide a detailed step-by-step reasoning process. This should include identifying trends, interpreting meteorological parameters, and connecting them to potential weather phenomena.
|
|
2. If you determine that additional real-time observational data is crucial for a more accurate assessment, specify the tools you would use. For each tool, output a line in the exact format: TOOL_CALL: {{"tool_name": "tool_name_here", "arguments": {{"param1": "value1", ...}}}}
|
|
Available conceptual tools: get_surface_observations, get_latest_radar_imagery, get_satellite_imagery, get_upper_air_sounding.
|
|
3. Conclude with a concise forecast summary for the specified target time. Start this summary with "FORECAST_SUMMARY: ".
|
|
|
|
Analyze the provided data thoroughly. Your reasoning should be comprehensive."""
|
|
|
|
AGENT_USER_PROMPT_TEMPLATE = """Please analyze the following NWP model sounding data for station {location_id}.
|
|
The soundings provided are from the {model_name} model, run on {run_date_full_z}, valid at the following UTC times: {sounding_times_str}.
|
|
Your goal is to make a preliminary forecast assessment focusing on severe weather potential for {location_id} around {target_forecast_time_utc}.
|
|
|
|
Sounding Data:
|
|
{soundings_json_blob}
|
|
|
|
Remember to include your reasoning, any TOOL_CALL: {{"tool_name": "tool_name_here", "arguments": {{"param1": "value1", ...}}}} lines, and a final FORECAST_SUMMARY: statement."""
|
|
|
|
JUDGE_SYSTEM_PROMPT = """You are an expert meteorologist acting as a judge. You will evaluate an AI assistant's analysis of model sounding data.
|
|
The AI was asked to provide reasoning, call tools if necessary, and give a forecast summary.
|
|
You will be given the AI's output and relevant Area Forecast Discussions (AFDs) from human forecasters for context.
|
|
|
|
Your evaluation should focus on:
|
|
1. **Meteorological Soundness of Reasoning (0-5 points):**
|
|
* Correct interpretation of sounding parameters and trends.
|
|
* Logical connections between data and potential weather.
|
|
* Avoidance of meteorological fallacies or hallucinations.
|
|
* Depth and detail of the thought process.
|
|
2. **Tool Call Relevance & Justification (0-3 points):**
|
|
* Were the tools called (if any) appropriate given the AI's reasoning and the model data?
|
|
* Would these tools genuinely help a meteorologist in this situation?
|
|
* Were critical tool calls missed?
|
|
3. **Forecast Summary Quality (0-2 points):**
|
|
* Clarity and conciseness.
|
|
* Alignment with the AI's own reasoning and the provided AFDs (or sensible deviation if model data strongly suggested it).
|
|
|
|
Provide a numerical score for each category and a total score (sum of the three, max 10.0). Also, provide a brief overall justification for your scores.
|
|
Your output MUST be in the following exact format:
|
|
REASONING_SCORE: {{{{0-5 score}}}}
|
|
TOOL_CALL_SCORE: {{{{0-3 score}}}}
|
|
FORECAST_SUMMARY_SCORE: {{{{0-2 score}}}}
|
|
TOTAL_SCORE: {{{{sum of scores, e.g., 7.5}}}}
|
|
JUSTIFICATION: {{{{Your brief textual justification here.}}}}"""
|
|
|
|
JUDGE_USER_PROMPT_TEMPLATE = """AI Assistant's Output:
|
|
---
|
|
{llm_full_output}
|
|
---
|
|
|
|
Contextual Area Forecast Discussions (AFDs):
|
|
---
|
|
{afds_blob}
|
|
---
|
|
|
|
Please evaluate the AI assistant's output based on the criteria and provide your scores and justification in the specified format."""
|
|
|
|
|
|
class MeteorologyForecastRLEnv(BaseEnv):
|
|
env_config_cls = MetRLConfig
|
|
name = "MeteorologyForecastRL"
|
|
|
|
def __init__(
|
|
self,
|
|
config: MetRLConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True, # Default based on your CLI help
|
|
testing=False, # Default based on your CLI help
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
|
|
# self.config: MetRLConfig = self.config # This is redundant if super().__init__ sets self.config
|
|
# Ensure self.config is correctly typed if BaseEnv makes it generic
|
|
if not isinstance(self.config, MetRLConfig): # Check type
|
|
logger.warning(f"self.config in __init__ is not MetRLConfig, type: {type(self.config)}. This might indicate an issue with BaseEnv or pydantic_cli setup.")
|
|
|
|
self.locations_data: List[Dict[str, Any]] = []
|
|
self.agent_llm_server: Optional[APIServer] = None
|
|
self.judge_server: Optional[APIServer] = None
|
|
|
|
if not hasattr(self.server, 'servers') or not self.server.servers: # Added check for empty servers list
|
|
logger.error("CRITICAL: ServerManager (self.server) does not have a 'servers' attribute or it's empty!")
|
|
elif len(self.server.servers) > 1:
|
|
self.agent_llm_server = self.server.servers[0]
|
|
self.judge_server = self.server.servers[1]
|
|
if self.agent_llm_server and hasattr(self.agent_llm_server, 'config') and self.agent_llm_server.config: # check config exists
|
|
logger.info(f"Agent server: {self.agent_llm_server.config.model_name} @ {self.agent_llm_server.config.base_url}")
|
|
else:
|
|
logger.warning("Agent LLM server or its config is not properly initialized for logging.")
|
|
|
|
if self.judge_server and hasattr(self.judge_server, 'config') and self.judge_server.config: # check config exists
|
|
logger.info(f"Judge server: {self.judge_server.config.model_name} @ {self.judge_server.config.base_url}")
|
|
else:
|
|
logger.warning("Judge server or its config is not properly initialized for logging.")
|
|
|
|
elif len(self.server.servers) == 1:
|
|
logger.warning(
|
|
"Only 1 API server configured in ServerManager. Agent and Judge will use the same server."
|
|
)
|
|
self.agent_llm_server = self.server.servers[0]
|
|
self.judge_server = self.server.servers[0]
|
|
if self.agent_llm_server and hasattr(self.agent_llm_server, 'config') and self.agent_llm_server.config: # check config exists
|
|
logger.info(f"Agent/Judge server: {self.agent_llm_server.config.model_name} @ {self.agent_llm_server.config.base_url}")
|
|
else:
|
|
logger.warning("Agent/Judge server or its config is not properly initialized for logging.")
|
|
|
|
self.current_item_index: int = 0
|
|
self.iter: int = 0
|
|
self.judge_total_scores_buffer: List[float] = []
|
|
self.judge_reasoning_scores_buffer: List[float] = []
|
|
self.judge_tool_scores_buffer: List[float] = []
|
|
self.judge_forecast_scores_buffer: List[float] = []
|
|
self.rollouts_for_wandb_custom: List[Tuple[str, str, str, str, float, str]] = []
|
|
self.eval_metrics_buffer: List[Dict[str, float]] = []
|
|
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[MetRLConfig, List[APIServerConfig]]:
|
|
# This method is usually called by pydantic_cli.
|
|
# The CLI arguments will override these defaults.
|
|
env_config = MetRLConfig() # Initialize with defaults, CLI will override
|
|
|
|
# Get API keys and base URLs from environment or use defaults from MetRLConfig
|
|
# Agent server config (server_configs[0])
|
|
# These might be overridden by CLI --openai.model_name, --openai.base_url etc. for the *first* server
|
|
# However, atroposlib's pydantic_cli might handle multiple server configs differently.
|
|
# The provided help suggests a single --openai.* block, which implies it might apply to all servers
|
|
# or only the first. We'll assume here it populates the first server, and the second uses MetRLConfig defaults.
|
|
|
|
agent_model_name = os.environ.get("AGENT_LLM_MODEL_NAME", env_config.tokenizer_name) # Prioritize env var
|
|
agent_api_key = os.environ.get("AGENT_LLM_API_KEY", "EMPTY_KEY_IF_LOCAL_VLLM")
|
|
agent_base_url = os.environ.get("AGENT_LLM_BASE_URL", "http://localhost:8080/v1") # Example vLLM
|
|
|
|
judge_api_key = os.environ.get(env_config.judge_api_key_env_var)
|
|
if not judge_api_key:
|
|
# This print is fine as it's at class method execution, not instance init
|
|
# logger is not available at class level directly here, so print is okay or use logging.getLogger
|
|
logging.warning(f"Environment variable {env_config.judge_api_key_env_var} not set for Judge API.")
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name=agent_model_name, # This should be the agent model
|
|
base_url=agent_base_url,
|
|
api_key=agent_api_key,
|
|
# num_requests_for_eval=64, # these are often part of APIServer not its config
|
|
),
|
|
APIServerConfig(
|
|
model_name=env_config.judge_model_name,
|
|
base_url=env_config.judge_base_url,
|
|
api_key=judge_api_key,
|
|
# num_requests_for_eval=64,
|
|
)
|
|
]
|
|
# logger.info(f"config_init: env_config={env_config}") # For debugging
|
|
# logger.info(f"config_init: server_configs={server_configs}") # For debugging
|
|
return env_config, server_configs
|
|
|
|
# --- setup, get_next_item, _parse_llm_output, _parse_judge_output remain the same ---
|
|
async def setup(self):
|
|
logger.info(f"Setting up {self.name or self.__class__.__name__}...")
|
|
data_root = Path(self.config.sounding_data_root)
|
|
date_path = data_root / self.config.target_date
|
|
|
|
if not date_path.is_dir():
|
|
logger.error(f"Target date directory not found: {date_path}")
|
|
return
|
|
|
|
available_locations = [loc.name for loc in date_path.iterdir() if loc.is_dir()]
|
|
logger.info(f"Found {len(available_locations)} locations for date {self.config.target_date}: {available_locations}")
|
|
|
|
for loc_id in available_locations:
|
|
loc_path = date_path / loc_id
|
|
soundings_for_item = []
|
|
sounding_times_for_item = []
|
|
|
|
if not self.config.nwp_models_to_use or not self.config.forecast_hours_to_sample:
|
|
logger.warning(f"NWP models or forecast hours to sample is empty in config. Skipping {loc_id}")
|
|
continue
|
|
selected_model = self.config.nwp_models_to_use[0]
|
|
|
|
for hour_z in self.config.forecast_hours_to_sample:
|
|
fname = f"{loc_id}_{selected_model}_{self.config.target_date}{hour_z:02d}Z.buf_default_llm_optimized.jsonl"
|
|
sounding_file_path = loc_path / fname
|
|
if sounding_file_path.exists():
|
|
try:
|
|
with open(sounding_file_path, 'r') as f:
|
|
line = f.readline()
|
|
if line:
|
|
soundings_for_item.append(json.loads(line))
|
|
sounding_times_for_item.append(f"{hour_z:02d}00Z")
|
|
except Exception as e:
|
|
logger.warning(f"Could not load or parse sounding file {sounding_file_path}: {e}")
|
|
else:
|
|
logger.debug(f"Sounding file not found: {sounding_file_path}")
|
|
|
|
if not soundings_for_item:
|
|
logger.debug(f"No valid soundings found for {loc_id} on {self.config.target_date}. Skipping.")
|
|
continue
|
|
|
|
afd_files = sorted([f for f in loc_path.glob("AFD_*.txt")])
|
|
selected_afd_texts = []
|
|
if afd_files:
|
|
if len(afd_files) <= self.config.max_afds_for_judge:
|
|
indices_to_take = list(range(len(afd_files)))
|
|
else:
|
|
indices_to_take = sorted(list(set([0, len(afd_files) // 2, len(afd_files) - 1])))
|
|
indices_to_take = indices_to_take[:self.config.max_afds_for_judge]
|
|
|
|
for i in indices_to_take:
|
|
try:
|
|
with open(afd_files[i], 'r', encoding='utf-8', errors='replace') as f: # Specify encoding and error handling
|
|
afd_text = f.read()
|
|
# Remove common control characters, especially ETX (\x03 or \u0003)
|
|
cleaned_afd_text = ''.join(c for c in afd_text if c.isprintable() or c.isspace())
|
|
# Or more specifically for \u0003:
|
|
# cleaned_afd_text = afd_text.replace('\u0003', '')
|
|
selected_afd_texts.append(cleaned_afd_text)
|
|
except Exception as e:
|
|
logger.warning(f"Could not read or clean AFD file {afd_files[i]}: {e}")
|
|
|
|
|
|
if not sounding_times_for_item:
|
|
logger.warning(f"No sounding times available for {loc_id}, cannot determine target forecast time. Skipping.")
|
|
continue
|
|
|
|
latest_sounding_hour_str = sounding_times_for_item[-1][:2]
|
|
if not latest_sounding_hour_str.isdigit():
|
|
logger.warning(f"Could not parse latest sounding hour from {sounding_times_for_item[-1]} for {loc_id}. Skipping.")
|
|
continue
|
|
latest_sounding_hour = int(latest_sounding_hour_str)
|
|
|
|
target_hour = latest_sounding_hour + self.config.target_forecast_hour_offset
|
|
target_forecast_time_utc = f"{target_hour:02d}00Z on {self.config.target_date[4:6]}/{self.config.target_date[6:8]}/{self.config.target_date[0:4]}"
|
|
|
|
run_time_str = "UnknownRunTime"
|
|
if soundings_for_item and 'tm' in soundings_for_item[0] and '/' in soundings_for_item[0]['tm']:
|
|
try:
|
|
run_time_str = soundings_for_item[0]['tm'].split('/')[1][:2] + "Z"
|
|
except IndexError:
|
|
logger.warning(f"Could not parse run time from 'tm' field: {soundings_for_item[0]['tm']} for {loc_id}")
|
|
run_date_full_z = f"{self.config.target_date} at {run_time_str}"
|
|
|
|
|
|
item_data = {
|
|
"case_id": f"{self.config.target_date}_{loc_id}",
|
|
"location_id": loc_id,
|
|
"model_name": selected_model,
|
|
"run_date_full_z": run_date_full_z,
|
|
"target_forecast_time_utc": target_forecast_time_utc,
|
|
"model_soundings_data": soundings_for_item,
|
|
"sounding_times_str": ", ".join(sounding_times_for_item),
|
|
"afd_texts": selected_afd_texts
|
|
}
|
|
self.locations_data.append(item_data)
|
|
|
|
if not self.locations_data:
|
|
logger.error("No data loaded. Environment cannot proceed.")
|
|
else:
|
|
logger.info(f"Successfully prepared {len(self.locations_data)} items for processing.")
|
|
if self.locations_data: # Ensure not empty before shuffling
|
|
random.shuffle(self.locations_data)
|
|
self.iter = 0
|
|
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
if data is None: data = {}
|
|
data["current_item_index"] = self.current_item_index
|
|
data["iter"] = self.iter
|
|
super().save_checkpoint(step, data)
|
|
|
|
async def get_next_item(self) -> Optional[Dict[str, Any]]:
|
|
if not self.locations_data:
|
|
logger.warning("No locations data available in get_next_item.")
|
|
return None
|
|
if self.current_item_index >= len(self.locations_data):
|
|
logger.info("Cycled through all available location data. Re-shuffling and resetting index.")
|
|
random.shuffle(self.locations_data)
|
|
self.current_item_index = 0
|
|
if not self.locations_data: return None
|
|
|
|
item_to_return = self.locations_data[self.current_item_index]
|
|
self.current_item_index += 1
|
|
self.iter +=1
|
|
return item_to_return
|
|
|
|
def _parse_llm_output(self, llm_text: str) -> Dict[str, Any]:
|
|
think_content = ""
|
|
tool_calls = []
|
|
forecast_summary = ""
|
|
|
|
think_match = re.search(r"<think>(.*?)</think>", llm_text, re.DOTALL | re.IGNORECASE)
|
|
if think_match:
|
|
think_content = think_match.group(1).strip()
|
|
|
|
for line in llm_text.splitlines():
|
|
line_upper = line.strip().upper()
|
|
if line_upper.startswith("TOOL_CALL:"):
|
|
try:
|
|
tool_json_str = line.strip()[len("TOOL_CALL:"):].strip()
|
|
if tool_json_str.startswith("{") and tool_json_str.endswith("}"):
|
|
tool_calls.append(json.loads(tool_json_str))
|
|
else:
|
|
logger.warning(f"Skipping malformed TOOL_CALL: {line}")
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Could not parse TOOL_CALL JSON: {line}")
|
|
elif line_upper.startswith("FORECAST_SUMMARY:"):
|
|
forecast_summary = line.strip()[len("FORECAST_SUMMARY:"):].strip()
|
|
|
|
return {
|
|
"think_content": think_content,
|
|
"tool_calls": tool_calls,
|
|
"forecast_summary": forecast_summary
|
|
}
|
|
|
|
|
|
def _parse_judge_output(self, judge_text: str) -> Dict[str, Any]:
|
|
scores = {
|
|
"reasoning": 0.0, "tool_call": 0.0, "forecast_summary": 0.0, "total": 0.0
|
|
}
|
|
justification = "No justification provided or parse error."
|
|
|
|
patterns = {
|
|
"reasoning": r"REASONING_SCORE:\s*([0-9.]+)",
|
|
"tool_call": r"TOOL_CALL_SCORE:\s*([0-9.]+)",
|
|
"forecast_summary": r"FORECAST_SUMMARY_SCORE:\s*([0-9.]+)",
|
|
"total": r"TOTAL_SCORE:\s*([0-9.]+)"
|
|
}
|
|
|
|
for key, pattern in patterns.items():
|
|
match = re.search(pattern, judge_text, re.IGNORECASE)
|
|
if match:
|
|
try:
|
|
scores[key] = float(match.group(1))
|
|
except ValueError:
|
|
logger.warning(f"Could not parse score for {key} from: {match.group(1)}")
|
|
|
|
just_match = re.search(r"JUSTIFICATION:\s*(.*)", judge_text, re.DOTALL | re.IGNORECASE)
|
|
if just_match:
|
|
justification = just_match.group(1).strip()
|
|
|
|
calculated_total = round(scores["reasoning"] + scores["tool_call"] + scores["forecast_summary"], 2)
|
|
if abs(scores["total"] - calculated_total) > 0.1 :
|
|
if scores["total"] != 0.0:
|
|
logger.warning(f"Parsed total score {scores['total']} differs from sum of components {calculated_total}. Using sum.")
|
|
scores["total"] = calculated_total
|
|
|
|
return {"scores": scores, "justification": justification}
|
|
|
|
|
|
async def collect_trajectories(self, item: Dict[str, Any]) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
|
case_id = item.get('case_id', 'Unknown')
|
|
logger.info(f"ITEM {case_id}: Starting collect_trajectories.")
|
|
|
|
if item is None:
|
|
logger.warning(f"ITEM {case_id}: Received None item in collect_trajectories.")
|
|
return None, []
|
|
|
|
soundings_blob = json.dumps(item.get("model_soundings_data", []), indent=2)
|
|
agent_user_prompt = AGENT_USER_PROMPT_TEMPLATE.format(
|
|
location_id=item.get("location_id", "N/A"),
|
|
model_name=item.get("model_name", "N/A"),
|
|
run_date_full_z=item.get("run_date_full_z", "N/A"),
|
|
sounding_times_str=item.get("sounding_times_str", "N/A"),
|
|
target_forecast_time_utc=item.get("target_forecast_time_utc", "N/A"),
|
|
soundings_json_blob=soundings_blob
|
|
)
|
|
agent_messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}, {"role": "user", "content": agent_user_prompt}]
|
|
|
|
if not self.agent_llm_server:
|
|
logger.error(f"ITEM {case_id}: Agent LLM server not available.")
|
|
return None, []
|
|
|
|
logger.info(f"ITEM {case_id}: About to call Agent LLM...")
|
|
agent_chat_completions_obj = None # Initialize
|
|
try:
|
|
start_time = time.time()
|
|
agent_chat_completions_obj = await self.agent_llm_server.chat_completion(
|
|
messages=agent_messages,
|
|
model=self.agent_llm_server.config.model_name if self.agent_llm_server.config else "unknown_agent_model",
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_reasoning_tokens_llm,
|
|
temperature=0.7,
|
|
stop=["<|im_end|>", "<|endoftext|>", "<|eot_id|>"]
|
|
)
|
|
choices_received = len(agent_chat_completions_obj.choices) if agent_chat_completions_obj and hasattr(agent_chat_completions_obj, 'choices') else 0
|
|
logger.info(f"ITEM {case_id}: Agent LLM call completed. Time: {time.time() - start_time:.2f}s. Choices received: {choices_received}")
|
|
except Exception as e:
|
|
logger.error(f"ITEM {case_id}: Agent LLM call failed: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return None, []
|
|
|
|
if not agent_chat_completions_obj or not hasattr(agent_chat_completions_obj, 'choices') or not agent_chat_completions_obj.choices:
|
|
logger.warning(f"ITEM {case_id}: No choices received from Agent LLM or malformed response. Response: {agent_chat_completions_obj}")
|
|
return None, []
|
|
|
|
scored_data_group = ScoredDataGroup(tokens=[], masks=[], scores=[], overrides=[])
|
|
afd_context_blob = "\n\n---\n\n".join(item.get("afd_texts", [])) if item.get("afd_texts") else "No AFDs provided for this case."
|
|
|
|
for agent_choice_idx, agent_choice in enumerate(agent_chat_completions_obj.choices):
|
|
choice_id = f"{case_id}_choice_{agent_choice_idx}"
|
|
logger.info(f"ITEM {choice_id}: Processing agent choice.")
|
|
|
|
llm_full_output_text = "" # Initialize
|
|
if agent_choice and hasattr(agent_choice, 'message') and agent_choice.message and hasattr(agent_choice.message, 'content'):
|
|
llm_full_output_text = agent_choice.message.content
|
|
logger.debug(f"ITEM {choice_id}: Agent output received (first 200 chars): {llm_full_output_text[:200]}...")
|
|
else:
|
|
logger.warning(f"ITEM {choice_id}: Agent choice did not contain expected message content. Skipping this choice. Details: {agent_choice}")
|
|
continue
|
|
|
|
parsed_llm_out = self._parse_llm_output(llm_full_output_text)
|
|
logger.info(f"ITEM {choice_id}: Parsed agent output.")
|
|
|
|
judge_user_prompt = JUDGE_USER_PROMPT_TEMPLATE.format(llm_full_output=llm_full_output_text, afds_blob=afd_context_blob)
|
|
judge_messages = [{"role": "system", "content": JUDGE_SYSTEM_PROMPT}, {"role": "user", "content": judge_user_prompt}]
|
|
|
|
final_score = 0.0
|
|
judge_justification_text = "Judge call not made or failed."
|
|
judge_parsed_scores = {}
|
|
|
|
if not self.judge_server:
|
|
logger.error(f"ITEM {choice_id}: Judge server not available. Assigning default score.")
|
|
else:
|
|
logger.info(f"ITEM {choice_id}: About to call Judge LLM...")
|
|
|
|
# Log the request payload for the judge
|
|
request_payload_for_judge = {
|
|
"messages": judge_messages,
|
|
"model": self.judge_server.config.model_name if self.judge_server.config else "unknown_judge_model",
|
|
"max_tokens": self.config.max_tokens_judge,
|
|
"temperature": 0.2,
|
|
"n": 1
|
|
# Add any other parameters being sent by APIServer implicitly or explicitly
|
|
}
|
|
logger.info(f"ITEM {choice_id}: Judge LLM Request Payload: {json.dumps(request_payload_for_judge, indent=2, ensure_ascii=False)}")
|
|
|
|
|
|
try:
|
|
judge_start_time = time.time()
|
|
judge_completion_obj = await self.judge_server.chat_completion(
|
|
messages=judge_messages,
|
|
max_tokens=self.config.max_tokens_judge,
|
|
temperature=0.2,
|
|
n=1,
|
|
model=self.judge_server.config.model_name if self.judge_server.config else "unknown_judge_model"
|
|
)
|
|
logger.info(f"ITEM {choice_id}: Judge LLM call completed. Time: {time.time() - judge_start_time:.2f}s")
|
|
|
|
if judge_completion_obj and hasattr(judge_completion_obj, 'choices') and judge_completion_obj.choices and \
|
|
judge_completion_obj.choices[0] and hasattr(judge_completion_obj.choices[0], 'message') and \
|
|
judge_completion_obj.choices[0].message and hasattr(judge_completion_obj.choices[0].message, 'content'):
|
|
judge_output_text = judge_completion_obj.choices[0].message.content
|
|
logger.info(f"ITEM {choice_id}: Judge output received (first 200 chars): {judge_output_text[:200]}...")
|
|
parsed_judge_out = self._parse_judge_output(judge_output_text)
|
|
final_score = parsed_judge_out["scores"]["total"]
|
|
judge_justification_text = parsed_judge_out["justification"]
|
|
judge_parsed_scores = parsed_judge_out["scores"]
|
|
logger.info(f"ITEM {choice_id}: Parsed judge output. Score: {final_score}")
|
|
self.judge_total_scores_buffer.append(final_score)
|
|
self.judge_reasoning_scores_buffer.append(judge_parsed_scores.get("reasoning", 0.0))
|
|
self.judge_tool_scores_buffer.append(judge_parsed_scores.get("tool_call", 0.0))
|
|
self.judge_forecast_scores_buffer.append(judge_parsed_scores.get("forecast_summary", 0.0))
|
|
else:
|
|
logger.error(f"ITEM {choice_id}: Judge LLM response was empty or malformed. Details: {judge_completion_obj}")
|
|
|
|
# ***** START OF MODIFIED/ADDED ERROR HANDLING BLOCK *****
|
|
except httpx.HTTPStatusError as http_err:
|
|
logger.error(f"ITEM {choice_id}: Judge LLM call failed with HTTPStatusError: {http_err.response.status_code}")
|
|
logger.error(f"Request URL: {http_err.request.url}")
|
|
|
|
# Log request body (already logged above, but useful for context here)
|
|
logger.error(f"Request Body (for context):\n{json.dumps(request_payload_for_judge, indent=2, ensure_ascii=False)}")
|
|
|
|
logger.error(f"Response Headers:\n{http_err.response.headers}")
|
|
try:
|
|
# Attempt to read response body text.
|
|
# For async responses, if not already read, it might require await response.aread()
|
|
# but .text should be available if the response was processed by httpx.
|
|
response_body_str = http_err.response.text
|
|
logger.error(f"Response Body:\n{response_body_str}")
|
|
except Exception as resp_e:
|
|
logger.error(f"Could not decode or access response body text from HTTPStatusError: {resp_e}")
|
|
logger.error(traceback.format_exc())
|
|
# ***** END OF MODIFIED/ADDED ERROR HANDLING BLOCK *****
|
|
|
|
except Exception as e: # General catch-all
|
|
logger.error(f"ITEM {choice_id}: Judge LLM call failed with general Exception: {e}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
logger.info(f"ITEM {choice_id}: About to tokenize full trajectory...")
|
|
full_trajectory_messages = agent_messages + [{"role": "assistant", "content": llm_full_output_text}]
|
|
|
|
# Ensure self.tokenizer and self.config.max_token_length are available
|
|
tokenized_output = None
|
|
if hasattr(self, 'tokenizer') and self.tokenizer and hasattr(self.config, 'max_token_length'):
|
|
tokenized_output = tokenize_for_trainer(self.tokenizer, full_trajectory_messages, self.config.max_token_length)
|
|
logger.info(f"ITEM {choice_id}: Tokenization complete. Tokens length: {len(tokenized_output.get('tokens', [])) if tokenized_output else 'N/A'}")
|
|
else:
|
|
logger.error(f"ITEM {choice_id}: Tokenizer or max_token_length not available. Skipping tokenization.")
|
|
|
|
|
|
if tokenized_output and tokenized_output.get("tokens") and len(tokenized_output["tokens"]) > 0:
|
|
# Ensure scored_data_group is a dict-like object that supports append or assignment
|
|
if not isinstance(scored_data_group, dict) and not hasattr(scored_data_group, 'append'):
|
|
logger.error(f"ITEM {choice_id}: scored_data_group is not a dict or list-like object. Type: {type(scored_data_group)}")
|
|
else:
|
|
if isinstance(scored_data_group.get("tokens"), list): scored_data_group["tokens"].append(tokenized_output["tokens"])
|
|
if isinstance(scored_data_group.get("masks"), list): scored_data_group["masks"].append(tokenized_output["masks"])
|
|
if isinstance(scored_data_group.get("scores"), list): scored_data_group["scores"].append(final_score)
|
|
|
|
item_overrides = {
|
|
"case_id": case_id, "llm_think": parsed_llm_out["think_content"],
|
|
"llm_tools": str(parsed_llm_out["tool_calls"]), "llm_summary": parsed_llm_out["forecast_summary"],
|
|
"judge_justification": judge_justification_text,
|
|
"judge_score_reasoning": judge_parsed_scores.get("reasoning", 0.0),
|
|
"judge_score_tool": judge_parsed_scores.get("tool_call", 0.0),
|
|
"judge_score_forecast": judge_parsed_scores.get("forecast_summary", 0.0),
|
|
}
|
|
if isinstance(scored_data_group.get("overrides"), list): scored_data_group["overrides"].append(item_overrides)
|
|
|
|
if hasattr(self, 'rollouts_for_wandb_custom') and isinstance(self.rollouts_for_wandb_custom, list):
|
|
self.rollouts_for_wandb_custom.append((
|
|
agent_user_prompt[:300]+"...", parsed_llm_out["think_content"][:500]+"...",
|
|
str(parsed_llm_out["tool_calls"])[:300]+"...", parsed_llm_out["forecast_summary"][:300]+"...",
|
|
final_score, judge_justification_text[:500]+"..." ))
|
|
logger.info(f"ITEM {choice_id}: Added to scored_data_group and rollouts_for_wandb_custom.")
|
|
else:
|
|
logger.warning(f"ITEM {choice_id}: Tokenization failed or produced no tokens. Agent output was: {llm_full_output_text[:200]}...")
|
|
|
|
if not scored_data_group.get("tokens"):
|
|
logger.warning(f"ITEM {case_id}: No valid trajectories collected for this item.")
|
|
return None, []
|
|
|
|
logger.info(f"ITEM {case_id}: Finished processing all choices. Returning scored data group with {len(scored_data_group['tokens'])} trajectories.")
|
|
return scored_data_group, []
|
|
|
|
|
|
|
|
# --- evaluate and wandb_log remain the same or with similar logging detail if needed ---
|
|
async def evaluate(self, *args, **kwargs):
|
|
logger.info(f"Starting evaluation for {self.name or self.__class__.__name__}...")
|
|
self.eval_metrics_buffer.clear()
|
|
|
|
if not self.locations_data:
|
|
logger.warning("No data available for evaluation.")
|
|
return
|
|
|
|
# Ensure config attributes exist before accessing
|
|
num_eval_samples = getattr(self.config, "num_eval_samples", 20)
|
|
eval_items_to_process = min(len(self.locations_data), num_eval_samples)
|
|
|
|
if eval_items_to_process == 0:
|
|
logger.warning("Not enough data or num_eval_samples is 0. Skipping evaluation.")
|
|
return
|
|
|
|
eval_list = []
|
|
if self.locations_data : # ensure locations_data is not empty
|
|
eval_list = random.sample(self.locations_data, k=eval_items_to_process)
|
|
|
|
|
|
for eval_idx, eval_item_data in enumerate(eval_list):
|
|
case_id = eval_item_data.get('case_id', f"UnknownEval_{eval_idx}")
|
|
logger.info(f"EVAL ITEM {case_id}: Starting evaluation.")
|
|
|
|
soundings_blob = json.dumps(eval_item_data["model_soundings_data"], indent=2)
|
|
agent_user_prompt = AGENT_USER_PROMPT_TEMPLATE.format(
|
|
location_id=eval_item_data["location_id"], model_name=eval_item_data["model_name"],
|
|
run_date_full_z=eval_item_data["run_date_full_z"], sounding_times_str=eval_item_data["sounding_times_str"],
|
|
target_forecast_time_utc=eval_item_data["target_forecast_time_utc"], soundings_json_blob=soundings_blob)
|
|
|
|
agent_messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}, {"role": "user", "content": agent_user_prompt}]
|
|
|
|
if not self.agent_llm_server:
|
|
logger.error(f"EVAL ITEM {case_id}: Agent LLM server not available.")
|
|
continue
|
|
|
|
llm_full_output_text = "" # Initialize
|
|
logger.info(f"EVAL ITEM {case_id}: About to call Agent LLM for evaluation.")
|
|
try:
|
|
agent_start_time = time.time()
|
|
agent_completion_obj = await self.agent_llm_server.chat_completion(
|
|
messages=agent_messages,
|
|
n=1,
|
|
max_tokens=self.config.max_reasoning_tokens_llm,
|
|
temperature=0.1, # Low temp for eval
|
|
stop=["<|eot_id|>", "<|im_end|>", "<|endoftext|>"],
|
|
model=self.agent_llm_server.config.model_name if self.agent_llm_server.config else "unknown_agent_model"
|
|
)
|
|
logger.info(f"EVAL ITEM {case_id}: Agent LLM call completed. Time: {time.time() - agent_start_time:.2f}s")
|
|
if agent_completion_obj.choices and agent_completion_obj.choices[0].message and agent_completion_obj.choices[0].message.content:
|
|
llm_full_output_text = agent_completion_obj.choices[0].message.content
|
|
logger.debug(f"EVAL ITEM {case_id}: Agent output received (first 200): {llm_full_output_text[:200]}")
|
|
else:
|
|
logger.error(f"EVAL ITEM {case_id}: Agent LLM response was empty or malformed. Details: {agent_completion_obj}")
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"EVAL ITEM {case_id}: Agent LLM call failed: {e}")
|
|
logger.error(traceback.format_exc())
|
|
continue
|
|
|
|
afd_context_blob = "\n\n---\n\n".join(eval_item_data["afd_texts"]) or "No AFDs."
|
|
judge_user_prompt = JUDGE_USER_PROMPT_TEMPLATE.format(llm_full_output=llm_full_output_text, afds_blob=afd_context_blob)
|
|
judge_messages = [{"role": "system", "content": JUDGE_SYSTEM_PROMPT}, {"role": "user", "content": judge_user_prompt}]
|
|
|
|
if not self.judge_server:
|
|
logger.warning(f"EVAL ITEM {case_id}: Judge server not available.")
|
|
continue
|
|
|
|
logger.info(f"EVAL ITEM {case_id}: About to call Judge LLM for evaluation.")
|
|
try:
|
|
judge_start_time = time.time()
|
|
judge_completion_obj = await self.judge_server.chat_completion(
|
|
messages=judge_messages,
|
|
max_tokens=self.config.max_tokens_judge,
|
|
temperature=0.1, # Low temp for eval
|
|
n=1,
|
|
model=self.judge_server.config.model_name if self.judge_server.config else "unknown_judge_model"
|
|
)
|
|
logger.info(f"EVAL ITEM {case_id}: Judge LLM call completed. Time: {time.time() - judge_start_time:.2f}s")
|
|
if judge_completion_obj.choices and judge_completion_obj.choices[0].message and judge_completion_obj.choices[0].message.content:
|
|
judge_output_text = judge_completion_obj.choices[0].message.content
|
|
logger.debug(f"EVAL ITEM {case_id}: Judge output received (first 200): {judge_output_text[:200]}")
|
|
parsed_judge_out = self._parse_judge_output(judge_output_text)
|
|
self.eval_metrics_buffer.append(parsed_judge_out["scores"])
|
|
logger.info(f"EVAL ITEM {case_id}: Judge score {parsed_judge_out['scores']['total']} added to buffer.")
|
|
else:
|
|
logger.error(f"EVAL ITEM {case_id}: Judge LLM response was empty or malformed. Details: {judge_completion_obj}")
|
|
except httpx.HTTPStatusError as http_err: # Added specific error handling for eval as well
|
|
logger.error(f"EVAL ITEM {case_id}: Judge LLM call failed with HTTPStatusError: {http_err.response.status_code}")
|
|
logger.error(f"Request URL: {http_err.request.url}")
|
|
# Consider logging request payload for eval too if needed for debugging eval failures
|
|
logger.error(f"Response Headers:\n{http_err.response.headers}")
|
|
try:
|
|
response_body_str = http_err.response.text
|
|
logger.error(f"Response Body:\n{response_body_str}")
|
|
except Exception as resp_e:
|
|
logger.error(f"Could not decode or access response body text from HTTPStatusError during eval: {resp_e}")
|
|
logger.error(traceback.format_exc())
|
|
except Exception as e:
|
|
logger.error(f"EVAL ITEM {case_id}: Judge LLM call failed: {e}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
logger.info(f"Evaluation completed for {len(self.eval_metrics_buffer)} items.")
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if not self.config.use_wandb or not wandb.run: # Check if wandb is active
|
|
logger.debug("WandB logging skipped (disabled or no active run).")
|
|
if hasattr(super(), 'wandb_log'): # Call super if it exists, even if not logging locally
|
|
await super().wandb_log(wandb_metrics if wandb_metrics else {})
|
|
return
|
|
|
|
if wandb_metrics is None: wandb_metrics = {}
|
|
logger.info("Preparing metrics for WandB log...")
|
|
|
|
if self.judge_total_scores_buffer:
|
|
wandb_metrics["train/avg_judge_total_score"] = sum(self.judge_total_scores_buffer) / len(self.judge_total_scores_buffer)
|
|
wandb_metrics["train/avg_judge_reasoning_score"] = sum(self.judge_reasoning_scores_buffer) / len(self.judge_reasoning_scores_buffer)
|
|
wandb_metrics["train/avg_judge_tool_score"] = sum(self.judge_tool_scores_buffer) / len(self.judge_tool_scores_buffer)
|
|
wandb_metrics["train/avg_judge_forecast_score"] = sum(self.judge_forecast_scores_buffer) / len(self.judge_forecast_scores_buffer)
|
|
logger.info(f"Train scores (total): {wandb_metrics['train/avg_judge_total_score']:.2f} from {len(self.judge_total_scores_buffer)} samples.")
|
|
self.judge_total_scores_buffer.clear(); self.judge_reasoning_scores_buffer.clear(); self.judge_tool_scores_buffer.clear(); self.judge_forecast_scores_buffer.clear()
|
|
|
|
if self.eval_metrics_buffer:
|
|
avg_eval_total = sum(s['total'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer)
|
|
avg_eval_reasoning = sum(s['reasoning'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer)
|
|
avg_eval_tool = sum(s['tool_call'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer)
|
|
avg_eval_forecast = sum(s['forecast_summary'] for s in self.eval_metrics_buffer) / len(self.eval_metrics_buffer)
|
|
wandb_metrics["eval/avg_judge_total_score"] = avg_eval_total
|
|
wandb_metrics["eval/avg_judge_reasoning_score"] = avg_eval_reasoning
|
|
wandb_metrics["eval/avg_judge_tool_score"] = avg_eval_tool
|
|
wandb_metrics["eval/avg_judge_forecast_score"] = avg_eval_forecast
|
|
logger.info(f"Eval scores (total): {avg_eval_total:.2f} from {len(self.eval_metrics_buffer)} samples.")
|
|
self.eval_metrics_buffer.clear()
|
|
|
|
if self.rollouts_for_wandb_custom:
|
|
if wandb.run: # Double check wandb is active
|
|
table = wandb.Table(columns=["Prompt Hint", "LLM Think", "LLM Tools", "LLM Summary", "Judge Score", "Judge Justification"])
|
|
num_to_log = min(len(self.rollouts_for_wandb_custom), getattr(self.config, "num_rollouts_to_log", 10)) # Use getattr for safety
|
|
|
|
sample_to_log = []
|
|
if num_to_log > 0 and self.rollouts_for_wandb_custom:
|
|
sample_to_log = random.sample(self.rollouts_for_wandb_custom, k=min(num_to_log, len(self.rollouts_for_wandb_custom)))
|
|
|
|
for P, T, O, S, Sc, J in sample_to_log: table.add_data(P, T, O, S, Sc, J)
|
|
if sample_to_log:
|
|
wandb_metrics["train/detailed_rollouts"] = table
|
|
logger.info(f"Logged {len(sample_to_log)} rollouts to WandB table.")
|
|
self.rollouts_for_wandb_custom.clear()
|
|
|
|
if wandb_metrics:
|
|
logger.info(f"Logging to WandB: {list(wandb_metrics.keys())}")
|
|
await super().wandb_log(wandb_metrics)
|
|
else:
|
|
logger.info("No new metrics to log to WandB in this step.")
|
|
if hasattr(super(), 'wandb_log'): # Call super even if no new metrics locally
|
|
await super().wandb_log({})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
MeteorologyForecastRLEnv.cli()
|
|
except Exception as e:
|
|
logger.critical(f"CRITICAL Error during CLI execution: {e}") # Use critical for top-level crash
|
|
logger.critical(traceback.format_exc()) |