#!/usr/bin/env python3 """ WikipediaArticleCreatorEnv: Environment for training an LLM to research and create Wikipedia-style articles This environment uses web search and content extraction tools to enable multi-step research and article generation on arbitrary topics. """ import json import logging import os import random import re from typing import Dict, List, Optional, Tuple, Union # Load environment variables from .env file if present try: from dotenv import load_dotenv load_dotenv() except ImportError: logger = logging.getLogger(__name__) logger.warning( "python-dotenv not installed, environment variables must be set manually" ) import wandb from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tool_call_parser import parse_tool_call from environments.community.wikipedia_research.tools.tavily_tools import ( TavilyExtractTool, TavilySearchTool, ) # Set up logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Suppress logging from tool_call_parser module logging.getLogger("atroposlib.utils.tool_call_parser").setLevel(logging.ERROR) # System prompt for the Wikipedia article creation task SYSTEM_PROMPT = """ You are a skilled researcher and writer who creates accurate, neutral, and comprehensive Wikipedia-style articles. Your task is to research the given topic using web search and content extraction tools, and then write a well-structured Wikipedia article based on your findings. Follow these guidelines when creating your article: 1. Research thoroughly using the tools provided 2. Maintain a Neutral Point of View (NPOV) - present all significant viewpoints fairly 3. Structure your article with a clear introduction, organized sections, and a conclusion if appropriate 4. Cite reliable sources for factual claims 5. Use formal, encyclopedic language 6. Format your article in Markdown 7. IMPORTANT: Do not try to visit Wikipedia pages directly - they are blocked. Instead, search for information from other reputable sources During your work, you may: 1. Think through your research strategy and article planning 2. Search for information using web_search 3. Extract content from specific webpages using visit_page 4. Organize and synthesize information from multiple sources 5. Create a final Wikipedia-style article when you have sufficient information You should enclose your thoughts and internal monologue inside tags, and then use tools or provide your final output. IMPORTANT: When you have completed your research and are ready to provide the final article, format it as follows: Final Step: ```markdown [Your complete article in markdown format] ``` For tool calls, you MUST use tags with valid JSON inside. Always format exactly as shown: For web search: {"name": "web_search", "arguments": {"query": "example search query", "num_results": 5}} For webpage visits: {"name": "visit_page", "arguments": {"url": "https://example.com/page"}} The JSON structure is critical - it must be valid JSON with double quotes around all keys and string values. Always enclose your tool calls between and tags, and make sure the JSON is correctly formatted. """ class WikipediaArticleCreatorConfig(BaseEnvConfig): """Configuration for the WikipediaArticleCreator environment""" max_steps: int = 10 # Maximum research steps per article temperature: float = 0.7 # Sampling temperature thinking_active: bool = True # Enable thinking tags eval_topics: int = 30 # Number of topics for evaluation tool_timeout: float = 15.0 # Timeout for tool execution (seconds) tavily_api_key: Optional[str] = None # API key for Tavily (falls back to env var) min_article_sections: int = 3 # Minimum number of sections in final article max_article_tokens: int = 2048 # Maximum tokens in final article topics_file: str = "topics.json" # File containing research topics logging_active: bool = True # Enable detailed logging class EpisodeState: """ Maintains state for a single episode (article creation task) """ def __init__(self, episode_id: int, topic: str): self.episode_id = episode_id self.topic = topic # The research topic for this episode self.message_history: List[Dict] = [] # Stores all interactions self.tool_calls: List[Dict] = [] # Records tool calls made self.tool_results: List[Dict] = [] # Records tool results returned self.steps_taken: int = 0 # Number of steps in this episode self.is_terminal: bool = False # Whether episode has terminated self.final_article: Optional[str] = None # Final Wikipedia article in markdown self.research_facts: List[str] = ( [] ) # Important facts discovered during research self.score: float = 0.0 # Score for this episode class WikipediaArticleCreatorEnv(BaseEnv): """ Environment for training an LLM to research and create Wikipedia-style articles This environment: - Presents the model with a topic to research - Allows multi-step interactions using web_search and visit_page tools - Tracks research process and article quality - Rewards comprehensive, well-structured, and accurate articles """ def __init__( self, config: WikipediaArticleCreatorConfig, server_configs: List[APIServerConfig], slurm: bool = True, testing: bool = False, ): super().__init__(config, server_configs, slurm, testing) # Initialize environment self.config = config self.episodes: Dict[int, EpisodeState] = {} self.eval_metrics: List[Tuple[str, float]] = [] self.rollouts_for_wandb = [] # Set up tools tavily_key = config.tavily_api_key or os.environ.get("TAVILY_API_KEY") if not tavily_key: logger.warning( "No Tavily API key provided - tools will not function properly" ) self.search_tool = TavilySearchTool(api_key=tavily_key) self.extract_tool = TavilyExtractTool(api_key=tavily_key) # Tool definitions for the LLM self.tools = [ { "type": "function", "function": { "name": "web_search", "description": "Search the web for information on a topic", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "The search query to perform.", }, "num_results": { "type": "integer", "description": "Number of results to return (default: 5, max: 10)", "default": 5, }, "filter_year": { "type": "string", "description": "Filter results to a specific year", "nullable": True, }, }, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "visit_page", "description": "Extract content from a specific webpage", "parameters": { "type": "object", "properties": { "url": { "type": "string", "description": "The URL of the webpage to visit", } }, "required": ["url"], }, }, }, ] # Load topics if file exists self.topics = self._load_topics() self.iter = 0 self.article_quality_metrics: List[Dict[str, float]] = [] def _load_topics(self) -> List[str]: """Load research topics from wikipedia_articles.json or use defaults if file doesn't exist""" try: articles_path = os.path.join( os.path.dirname(__file__), "wikipedia_articles.json" ) if os.path.exists(articles_path): # The file is large, so we'll read it in chunks and extract just the titles topics = [] with open(articles_path, "r") as f: # Read opening bracket char = f.read(1) if char != "[": raise ValueError("Expected JSON array to start with '['") # Process articles one by one count = 0 max_topics = 100 # Limit to 100 topics while count < max_topics: article_json = "" brace_count = 0 in_article = False # Read until we find a complete article JSON object while True: char = f.read(1) if not char: # End of file break if char == "{" and not in_article: in_article = True brace_count = 1 article_json = "{" elif in_article: article_json += char if char == "{": brace_count += 1 elif char == "}": brace_count -= 1 if brace_count == 0: # Found complete article break if not article_json: break try: article = json.loads(article_json) title = article.get("title", "") if title and len(title) < 100: # Skip very long titles topics.append(title) count += 1 except json.JSONDecodeError: continue if topics: logger.info( f"Loaded {len(topics)} topics from wikipedia_articles.json" ) return topics except Exception as e: logger.warning(f"Failed to load topics from wikipedia_articles.json: {e}") # Default topics if file doesn't exist or loading fails default_topics = [ "History of artificial intelligence", "Climate change in the Arctic", "The Great Barrier Reef ecosystem", "Quantum computing principles", "Anti-black racism in the Arab World", "History of cryptography", "Renewable energy in developing countries", "Space exploration in the 21st century", "Traditional medicine systems around the world", "The evolution of human language", ] logger.info(f"Using {len(default_topics)} default topics") return default_topics @classmethod def config_init(cls) -> Tuple[WikipediaArticleCreatorConfig, List[APIServerConfig]]: """Initialize default configuration""" # Read environment variables (with defaults if not present) model_name = os.environ.get("MODEL_NAME", "gpt-4o") max_steps = int(os.environ.get("MAX_STEPS", "10")) temperature = float(os.environ.get("TEMPERATURE", "0.7")) # Determine if we're using an OpenAI model or a local model is_openai_model = model_name.startswith(("gpt-", "text-")) # Always use a standard HuggingFace tokenizer that's available # gpt2 is a good option for estimating OpenAI tokens tokenizer_name = "gpt2" env_config = WikipediaArticleCreatorConfig( tokenizer_name=tokenizer_name, group_size=1, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=512, steps_per_eval=20, max_token_length=1024 * 16, inference_weight=1.0, wandb_name="wikipedia_article_creator", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, max_steps=max_steps, temperature=temperature, thinking_active=True, eval_topics=5, tool_timeout=15.0, tavily_api_key=os.environ.get("TAVILY_API_KEY"), # Load from environment min_article_sections=3, max_article_tokens=2048, topics_file="topics.json", logging_active=True, # Enable message history in the output for wandb logging include_messages=True, # Enable message history in the output for wandb logging num_rollouts_to_keep=32, # Keep enough conversations for good logging samples ) # Configure servers based on model type if is_openai_model: # OpenAI API configuration openai_api_key = os.environ.get("OPENAI_API_KEY") if not openai_api_key: logger.warning("OPENAI_API_KEY not found in environment variables.") server_configs = [ APIServerConfig( model_name=model_name, base_url=None, # Use default OpenAI base URL api_key=openai_api_key, num_max_requests_at_once=4, num_requests_for_eval=16, ), ] else: # Local model configuration server_configs = [ APIServerConfig( model_name=model_name, base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=8, num_requests_for_eval=64, ), ] return env_config, server_configs def _get_or_create_episode( self, episode_id: int, topic: Optional[str] = None ) -> EpisodeState: """Get an existing episode or create a new one""" if episode_id not in self.episodes: if topic is None: topic = random.choice(self.topics) ep = EpisodeState(episode_id, topic) # Initialize with system prompt ep.message_history = [{"role": "system", "content": SYSTEM_PROMPT}] # Add initial user prompt with the topic ep.message_history.append( { "role": "user", "content": f'Research and write a comprehensive Wikipedia-style article about: "{topic}"', } ) self.episodes[episode_id] = ep return self.episodes[episode_id] def _parse_tool_calls(self, response: str) -> List[Dict]: """Extract tool calls from model response""" tool_calls = [] logger.info("\n==== PARSING TOOL CALLS ====") # Try to find tool calls using regex first tool_call_pattern = r"\s*(.*?)\s*" raw_tool_calls = re.findall(tool_call_pattern, response, re.DOTALL) logger.info(f"Found {len(raw_tool_calls)} tool call tags in response") if raw_tool_calls: for i, raw_call in enumerate(raw_tool_calls): # Print with line numbers to see where newlines and other issues might be lines = raw_call.split("\n") logger.info(f"RAW TOOL CALL #{i+1} (multiline format):") for line_num, line in enumerate(lines): logger.info(f" Line {line_num+1}: {repr(line)}") # Also print the raw string representation logger.info(f"RAW TOOL CALL #{i+1} (repr): {repr(raw_call)}") try: # Clean up the raw call string - fix known issues from GPT-4 responses # 1. Remove extra closing braces that sometimes appear clean_call = re.sub(r"\}\s*\}", "}", raw_call) # 2. Try to extract just the valid JSON using regex if there are still issues json_pattern = r'(\{\s*"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{[^}]+\}\s*\})' json_match = re.search(json_pattern, clean_call) if json_match: clean_call = json_match.group(1) logger.info(f"Extracted cleaner JSON: {clean_call}") # Try to parse the cleaned JSON call_data = json.loads(clean_call) name = call_data.get("name") args = call_data.get("arguments", {}) # Validate that the tool exists if any(tool["function"]["name"] == name for tool in self.tools): logger.info(f"Parsed tool call: {name}, {args}") tool_calls.append({"name": name, "arguments": args}) else: logger.warning(f"Unknown tool name: {name}") except json.JSONDecodeError as e: # Only log this at INFO level to reduce verbosity in normal output if ( self.config.logging_active and hasattr(self, "process_mode") and not self.process_mode ): logger.warning(f"Failed to parse tool call JSON: {e}") else: logger.debug(f"Failed to parse tool call JSON: {e}") # Fallback to the library parser if no tool calls were found if not tool_calls: logger.info( "No tool calls found with regex, falling back to library parser" ) for tool in self.tools: name = tool["function"]["name"] parsed_name, parsed_args, is_error = parse_tool_call( response, [tool], ["tool_call"] ) if not is_error and parsed_name == name: # Only log detailed parsing in non-process mode to reduce verbosity if not hasattr(self, "process_mode") or not self.process_mode: logger.debug( f"Parsed tool call with library: {name}, {parsed_args}" ) tool_calls.append({"name": name, "arguments": parsed_args}) elif parsed_name and parsed_name != "-ERROR-": # Only log parsing failures for non-obvious errors in non-process mode if not hasattr(self, "process_mode") or not self.process_mode: logger.debug( f"Failed tool call parse: {parsed_name}, error: {is_error}" ) logger.info(f"Final parsed tool calls: {len(tool_calls)}") return tool_calls def _extract_final_article(self, response: str) -> Optional[str]: """Extract final Wikipedia article markdown if present""" # Regular expression to match content between Final Step: ```markdown and ``` tags pattern = r"Final Step:\s*```markdown\s*(.*?)```" match = re.search(pattern, response, re.DOTALL) if match: return match.group(1).strip() return None def _format_tool_results(self, tool_results: List[Dict]) -> str: """Format tool results as a user message""" if not tool_results: return "No results found." formatted_results = ["==== TOOL RESULTS ===="] for result in tool_results: tool_name = result.get("name", "unknown_tool") args = result.get("arguments", {}) data = result.get("data", []) if tool_name == "web_search": query = args.get("query", "") num_results = args.get("num_results", 5) formatted_results.append( f'[WEB SEARCH] query="{query}", num_results={num_results}\n' ) if isinstance(data, list): formatted_results.append(json.dumps(data, indent=2)) else: formatted_results.append("No results found.") elif tool_name == "visit_page": url = args.get("url", "") formatted_results.append(f'[PAGE EXTRACT] url="{url}"\n') if isinstance(data, dict): content = data.get("content", "") title = data.get("title", "") success = data.get("success", False) if success: formatted_results.append(f"Title: {title}") formatted_results.append(f"Content:\n{content[:2000]}...") if len(content) > 2000: formatted_results.append( "\n[Content truncated due to length]" ) else: error = data.get("error", "Unknown error") formatted_results.append(f"Error: {error}") else: formatted_results.append("Failed to retrieve page content.") formatted_results.append("==== END TOOL RESULTS ====") return "\n\n".join(formatted_results) def _extract_research_facts(self, tool_results: List[Dict], facts: List[str]): """Extract important facts from tool results for later evaluation""" for result in tool_results: tool_name = result.get("name", "") data = result.get("data", None) if tool_name == "web_search" and isinstance(data, list): for item in data: content = item.get("content", "") if content: # Simple sentence extraction - could be enhanced with NLP sentences = re.split(r"(?<=[.!?])\s+", content) for sentence in sentences: if len(sentence) > 30 and sentence not in facts: facts.append(sentence) elif tool_name == "visit_page" and isinstance(data, dict): content = data.get("content", "") if content: paragraphs = content.split("\n\n") for paragraph in paragraphs: if len(paragraph) > 50 and paragraph not in facts: facts.append(paragraph) async def _execute_tool_call(self, tool_call: Dict) -> Dict: """Execute a tool call and return the result""" tool_name = tool_call.get("name", "") arguments = tool_call.get("arguments", {}) result = {"name": tool_name, "arguments": arguments, "data": None} try: if tool_name == "web_search": query = arguments.get("query", "") num_results = min( arguments.get("num_results", 5), 10 ) # Limit to 10 max filter_year = arguments.get("filter_year", None) # If query is about Wikipedia, provide a helpful message if "wikipedia" in query.lower(): logger.info( "Query contains 'wikipedia' - providing guidance message" ) wikipedia_message = [ { "title": "Wikipedia Research Notice", "url": "https://example.com/wikipedia-notice", "content": ( "Instead of searching for Wikipedia articles directly, try searching for the " "actual topic or subject. The goal is to create a Wikipedia-style article " "using information from various reliable sources." ), "snippet": ( "Instead of searching for Wikipedia articles directly, try searching for " "the actual topic or subject." ), "date": None, } ] result["data"] = wikipedia_message else: search_results = self.search_tool.forward( query=query, num_results=num_results, filter_year=filter_year ) # Filter out Wikipedia URLs from search results filtered_results = [] for item in search_results: url = item.get("url", "").lower() if "wikipedia.org" not in url: filtered_results.append(item) else: logger.info( f"Filtered out Wikipedia URL from search results: {url}" ) # Add a notice if results were filtered if len(filtered_results) < len(search_results): logger.info( f"Filtered out {len(search_results) - len(filtered_results)} Wikipedia results" ) # Add a notice as the last result if we filtered anything if filtered_results: filtered_results.append( { "title": "Search Results Notice", "url": "https://example.com/search-notice", "content": ( "Some Wikipedia results were automatically filtered out. Please focus " "on using other reliable sources for your research." ), "snippet": ( "Wikipedia results were filtered. Use other reliable sources for " "your research." ), "date": None, } ) result["data"] = filtered_results elif tool_name == "visit_page": url = arguments.get("url", "") # Check if the URL is from Wikipedia and block it if "wikipedia.org" in url.lower(): logger.info(f"Blocking Wikipedia URL: {url}") result["data"] = { "url": url, "title": "Page Not Found", "content": ( "Wikipedia pages are not available in this environment. Please search for " "information from other sources." ), "success": False, "error": "Wikipedia pages are blocked in this environment.", } else: try: logger.info(f"Attempting to extract content from URL: {url}") page_data = self.extract_tool.forward(url=url) # Log success or partial success if page_data.get("success", False): content_length = len(page_data.get("content", "")) logger.info( f"Successfully extracted {content_length} characters from {url}" ) else: error_msg = page_data.get("error", "Unknown error") logger.error( f"Extraction reported failure: {error_msg} for URL: {url}" ) result["data"] = page_data except Exception as e: logger.error( f"Exception during content extraction from {url}: {str(e)}" ) import traceback logger.error( f"Extraction error traceback: {traceback.format_exc()}" ) result["data"] = { "url": url, "title": "Page Extraction Failed", "content": ( f"Failed to extract content from the page due to an error: {str(e)}" ), "success": False, "error": f"Exception during extraction: {str(e)}", } else: logger.warning(f"Unknown tool: {tool_name}") result["data"] = {"error": f"Unknown tool: {tool_name}"} except Exception as e: logger.error(f"Error executing tool {tool_name}: {e}") result["data"] = {"error": f"Tool execution failed: {str(e)}"} return result async def _get_model_response(self, messages: List[Dict]) -> str: """Get a response from the model for the current conversation state""" try: # Try to use chat_completion first (which works with OpenAI models) try: logger.info("Attempting to use chat_completion API") completion = await self.server.chat_completion( messages=messages, temperature=self.config.temperature, max_tokens=min( 4096, self.config.max_token_length ), # Ensure within OpenAI limits ) return completion.choices[0].message.content except (AttributeError, TypeError) as e: # If chat_completion fails, fall back to standard completion logger.info( f"Chat completion failed: {e}, falling back to standard completion" ) # For non-OpenAI models (local), use the standard completion API with tokenized prompt prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) completion = await self.server.completion( prompt=prompt, n=1, max_tokens=self.config.max_token_length, temperature=self.config.temperature, ) return completion.choices[0].text except Exception as e: logger.error(f"Error getting model response: {e}") import traceback logger.error(traceback.format_exc()) return "" async def _next_step(self, episode: EpisodeState) -> Tuple[bool, Dict]: """ Process one step of article research interaction Returns (is_terminal, step_data) """ # Get current conversation history messages = episode.message_history.copy() logger.info("\n==== REQUESTING MODEL RESPONSE ====") # Generate model response response = await self._get_model_response(messages) if not response: episode.is_terminal = True logger.info("No response received from model") return True, {"response": "", "tool_calls": [], "tool_results": []} logger.info("\n==== MODEL RESPONSE ====") # Print the raw response with repr to see exactly what's in it, including newlines and special chars print("\n\n==== RAW MODEL RESPONSE (repr) ====") print(repr(response)) print("==== END RAW MODEL RESPONSE ====\n\n") # Also log it normally logger.info(response) logger.info("==== END MODEL RESPONSE ====") # Check for final article final_article = self._extract_final_article(response) if final_article: logger.info("\n==== FINAL ARTICLE DETECTED ====") episode.is_terminal = True episode.final_article = final_article # Add response to history episode.message_history.append({"role": "assistant", "content": response}) return True, {"response": response, "tool_calls": [], "tool_results": []} # Extract tool calls for research tool_calls = self._parse_tool_calls(response) # Hide detailed tool call logging in process mode if tool_calls: if hasattr(self, "process_mode") and self.process_mode: # In process mode, just show a summary logger.info(f"Found {len(tool_calls)} tool calls") else: # In normal mode, show more detailed logging logger.info(f"\n==== EXECUTING {len(tool_calls)} TOOL CALLS ====") # Execute research tool calls tool_results = [] for i, tool_call in enumerate(tool_calls): tool_name = tool_call.get("name", "unknown") logger.info(f"Executing tool call {i+1}: {tool_name}") result = await self._execute_tool_call(tool_call) tool_results.append(result) # Add response and tool results to history episode.message_history.append({"role": "assistant", "content": response}) # Format tool results as a user message tool_results_message = self._format_tool_results(tool_results) episode.message_history.append( {"role": "user", "content": tool_results_message} ) # Update episode state episode.steps_taken += 1 episode.tool_calls.extend(tool_calls) episode.tool_results.extend(tool_results) # Extract and store research facts for later evaluation self._extract_research_facts(tool_results, episode.research_facts) # Check if max steps reached if episode.steps_taken >= self.config.max_steps: logger.info(f"\n==== MAX STEPS REACHED ({self.config.max_steps}) ====") episode.is_terminal = True return episode.is_terminal, { "response": response, "tool_calls": tool_calls, "tool_results": tool_results, } def _assess_article_quality( self, final_article: str, research_facts: List[str] ) -> Dict[str, float]: """ Evaluate the quality of the final article Returns a dictionary of quality metrics """ metrics = { "structure_score": 0.0, "comprehensiveness_score": 0.0, "fact_usage_score": 0.0, "overall_quality": 0.0, } # Basic structure analysis if not final_article: return metrics # Check for section headers sections = re.findall(r"^##?\s+.+$", final_article, re.MULTILINE) num_sections = len(sections) # Check for references references = re.findall( r"^##?\s*References", final_article, re.MULTILINE | re.IGNORECASE ) has_references = len(references) > 0 # Calculate structure score structure_score = 0.0 if num_sections >= self.config.min_article_sections: structure_score += 0.7 else: structure_score += 0.7 * (num_sections / self.config.min_article_sections) if has_references: structure_score += 0.3 metrics["structure_score"] = structure_score # Calculate comprehensiveness score based on length and section count article_length = len(final_article) comp_score = min(1.0, article_length / 3000) * 0.7 comp_score += min(1.0, num_sections / 5) * 0.3 metrics["comprehensiveness_score"] = comp_score # Calculate fact usage score # This is a simplistic approach - could be enhanced with NLP/semantic matching fact_usage = 0.0 if research_facts: facts_found = 0 for fact in research_facts: # Check if key phrases from the fact appear in the article key_phrases = [p for p in fact.split() if len(p) > 5] if key_phrases: for phrase in key_phrases[:5]: # Use up to 5 phrases per fact if phrase.lower() in final_article.lower(): facts_found += 1 break fact_usage = min(1.0, facts_found / len(research_facts)) metrics["fact_usage_score"] = fact_usage # Calculate overall quality overall = structure_score * 0.3 + comp_score * 0.4 + fact_usage * 0.3 metrics["overall_quality"] = overall return metrics async def collect_trajectories( self, item: Tuple[int, str] ) -> Tuple[List[ScoredDataGroup], List]: """ Manage full research trajectory collection Args: item: Tuple containing (episode_id, topic) Returns: Tuple of: - List of ScoredDataGroup objects: Scored data for training - List: Empty list (no backlog items) """ episode_id, topic = item # Get or create episode state episode = self._get_or_create_episode(episode_id, topic) # Detect if we're in process mode is_process_mode = hasattr(self, "process_mode") and self.process_mode trajectory_data: List[ScoredDataGroup] = [] # Run episode until terminal state while not episode.is_terminal: is_terminal, step_data = await self._next_step(episode) # Skip steps with no response or no tools used (unless terminal) response = step_data.get("response", "") if not response: continue # Create scored data for this step step_score = ScoredDataGroup() # Tokenize conversation up to this point tokenized = tokenize_for_trainer(self.tokenizer, episode.message_history) # Score based on tool usage (basic heuristic, improve in future) tool_calls = step_data.get("tool_calls", []) tool_results = step_data.get("tool_results", []) if is_terminal and episode.final_article: # Terminal step with article - score based on article quality quality_metrics = self._assess_article_quality( episode.final_article, episode.research_facts ) step_score["tokens"] = [tokenized["tokens"]] step_score["masks"] = [tokenized["masks"]] step_score["scores"] = [ quality_metrics["overall_quality"] * 2 - 1 ] # Scale to [-1, 1] # Record metrics for logging quality_metrics["topic"] = episode.topic quality_metrics["steps_taken"] = episode.steps_taken self.article_quality_metrics.append(quality_metrics) # If we're in process mode, perform factual accuracy evaluation if is_process_mode: try: # Import here to avoid circular imports import json from environments.hack0.wikipedia.article_evaluator import ( ArticleEvaluator, ) # Check if OpenAI API key is available openai_api_key = os.environ.get("OPENAI_API_KEY") if openai_api_key: # Initialize article evaluator evaluator = ArticleEvaluator(openai_api_key) # Load reference articles articles_path = os.path.join( os.path.dirname(__file__), "wikipedia_articles.json" ) if os.path.exists(articles_path): with open(articles_path, "r") as f: articles_data = json.load(f) topic = episode.topic generated_article = episode.final_article # Retrieve reference article content reference_content = evaluator.get_reference_article( articles_data, topic ) if reference_content: # Evaluate article factual accuracy - changed from async to sync evaluation_results = ( evaluator.evaluate_article_accuracy( reference_content=reference_content, generated_article=generated_article, ) ) # Calculate accuracy score accuracy_score = evaluator.calculate_accuracy_score( evaluation_results ) # Print statistics for this evaluation if ( evaluation_results and "statistics" in evaluation_results ): stats = evaluation_results["statistics"] print("\n" + "=" * 80) print( f"FACTUAL ACCURACY EVALUATION FOR: {topic}" ) print("=" * 80) print( f"CORRECT: {stats.get('correct_count', 0)} statements " f"({stats.get('pct_correct', 0):.1f}%)" ) print( f"INCORRECT: {stats.get('incorrect_count', 0)} statements " f"({stats.get('pct_incorrect', 0):.1f}%)" ) print( f"UNKNOWN: {stats.get('unknown_count', 0)} statements " f"({stats.get('pct_unknown', 0):.1f}%)" ) print( f"TOTAL: {stats.get('total_count', 0)} statements evaluated" ) print("-" * 80) # Remove duplicate raw scores since we're keeping everything in [-1, 1] now # Keep original scores in their native ranges # Original quality_score is in [0,1] range # Convert to [-1,1] range for consistency with accuracy score quality_score_scaled = ( quality_metrics["overall_quality"] * 2 - 1 ) # Accuracy score is already in [-1,1] range # Calculate combined score (simple average of the two scores) combined_score = ( quality_score_scaled + accuracy_score ) / 2 # This is already in [-1,1] range for ScoredDataGroup scaled_score = combined_score print( f"Original Quality Score: {quality_score_scaled:.4f} (range [-1, 1])" ) print( f"Factual Accuracy Score: {accuracy_score:.4f} (range [-1, 1])" ) print( f"Combined Final Score: {combined_score:.4f} (range [-1, 1])" ) print("=" * 80 + "\n") # Update the score in step_score step_score["scores"] = [scaled_score] # Add accuracy metrics to article_quality_metrics for wandb logging if ( evaluation_results and "statistics" in evaluation_results ): stats = evaluation_results["statistics"] accuracy_metrics = { "pct_correct": stats.get("pct_correct", 0), "pct_incorrect": stats.get( "pct_incorrect", 0 ), "pct_unknown": stats.get("pct_unknown", 0), "accuracy_score": accuracy_score, } # Update the last added quality metrics entry self.article_quality_metrics[-1].update( accuracy_metrics ) except Exception as e: print(f"Error evaluating article factual accuracy: {e}") import traceback print(traceback.format_exc()) elif tool_calls: # Non-terminal step with tool usage - score based on usefulness step_score["tokens"] = [tokenized["tokens"]] step_score["masks"] = [tokenized["masks"]] # Simple usefulness heuristic: # - Higher score for visiting pages than generic searches # - Higher score if results were found than if errors usefulness = 0.0 for result in tool_results: name = result.get("name", "") data = result.get("data", None) if name == "web_search" and isinstance(data, list) and data: usefulness = max(usefulness, 0.6) elif ( name == "visit_page" and isinstance(data, dict) and data.get("success", False) ): usefulness = max(usefulness, 0.8) step_score["scores"] = [usefulness * 2 - 1] # Scale to [-1, 1] else: # Step with no tool usage - low score step_score["tokens"] = [tokenized["tokens"]] step_score["masks"] = [tokenized["masks"]] step_score["scores"] = [-0.5] # Slight negative score # Add messages to the step_score to make them available for wandb logging and HTML rendering # We do this through the messages key which is supported in ScoredDataGroup and HTML rendering if self.config.include_messages: # For HTML rendering, we need to combine all messages into a single markdown string # This ensures the entire conversation appears as a single content item # First, create the complete conversation as one big markdown document # This is what will be shown in the HTML output complete_conversation = [] # Add the topic complete_conversation.append(f"# Wikipedia Article: {episode.topic}\n") # Include tool calls and research steps for i, msg in enumerate(episode.message_history): role = msg.get("role", "unknown") content = msg.get("content", "") # Skip system messages for cleaner output if role == "system": continue # Handle normal messages complete_conversation.append(f"## {role.upper()}") # Special handling for final article if role == "assistant" and "Final Step: ```markdown" in content: article_content = self._extract_final_article(content) if article_content: complete_conversation.append( content.split("Final Step: ```markdown")[0] ) # Add thinking/research complete_conversation.append("### FINAL ARTICLE") complete_conversation.append( f"```markdown\n{article_content}\n```" ) else: complete_conversation.append(content) else: complete_conversation.append(content) # If this is an assistant message that triggered tools, add the tool calls if role == "assistant" and i < len(episode.message_history) - 1: next_msg = episode.message_history[i + 1] if next_msg.get( "role" ) == "user" and "==== TOOL RESULTS ====" in next_msg.get( "content", "" ): # Extract tool name from the message tool_content = next_msg.get("content", "") if "[WEB SEARCH]" in tool_content: complete_conversation.append("### 🔍 SEARCH RESULTS") elif "[PAGE EXTRACT]" in tool_content: complete_conversation.append("### 📄 PAGE EXTRACT") complete_conversation.append( "```\n" + tool_content + "\n```" ) # Join everything into a single string with double newlines between sections full_conversation_markdown = "\n\n".join(complete_conversation) # Store the full conversation as a single message (for HTML rendering) step_score["messages"] = [full_conversation_markdown] # For process mode, we only want to keep the final state # This ensures we get a single group in the HTML output if is_process_mode: if is_terminal: # For terminal steps, keep only this step which has the full conversation trajectory_data = [step_score] else: # For intermediate steps in process mode, don't add to trajectory_data pass else: # Normal training mode - add all steps trajectory_data.append(step_score) # Don't delete the episode yet - we need it for wandb logging # Instead, mark it for deletion after wandb logging is complete # We'll actually clean it up after handle_send_to_api in handle_env return trajectory_data, [] async def score( self, rollout_group_data: List[ScoredDataGroup] ) -> List[ScoredDataGroup]: """ Enhanced scoring function that incorporates factual accuracy evaluation. Uses OpenAI models to evaluate the factual accuracy of the generated articles against reference articles from Wikipedia. """ try: # Import here to avoid circular imports import json from environments.hack0.wikipedia.article_evaluator import ArticleEvaluator # Check if OpenAI API key is available openai_api_key = os.environ.get("OPENAI_API_KEY") if not openai_api_key: logger.warning( "OPENAI_API_KEY not found. Skipping factual accuracy evaluation." ) return rollout_group_data # Initialize article evaluator evaluator = ArticleEvaluator(openai_api_key) # Load reference articles articles_path = os.path.join( os.path.dirname(__file__), "wikipedia_articles.json" ) if not os.path.exists(articles_path): logger.warning( f"Wikipedia articles file not found at {articles_path}. Skipping factual accuracy evaluation." ) return rollout_group_data with open(articles_path, "r") as f: articles_data = json.load(f) # Process each ScoredDataGroup for group in rollout_group_data: for i in range(len(group["tokens"])): # Check if this is a terminal step with a final article episode_id = i # Use index as a proxy for episode_id episode = self.episodes.get(episode_id) if episode and episode.is_terminal and episode.final_article: topic = episode.topic generated_article = episode.final_article # Retrieve reference article content reference_content = evaluator.get_reference_article( articles_data, topic ) if reference_content: # Evaluate article factual accuracy evaluation_results = ( await evaluator.evaluate_article_accuracy( reference_content=reference_content, generated_article=generated_article, ) ) # Calculate accuracy score accuracy_score = evaluator.calculate_accuracy_score( evaluation_results ) # Combine with existing quality metrics quality_metrics = self._assess_article_quality( final_article=generated_article, research_facts=episode.research_facts, ) # Adjust the overall quality to include factual accuracy # Original score is in [0,1], we'll combine it with accuracy_score [-1,1] combined_score = ( quality_metrics["overall_quality"] + (accuracy_score + 1) / 2 ) / 2 # Scale to [-1, 1] for compatibility with existing scoring scaled_score = combined_score * 2 - 1 # Update the score in the ScoredDataGroup group["scores"][i] = scaled_score # Print statistics for this evaluation if ( evaluation_results and "statistics" in evaluation_results ): stats = evaluation_results["statistics"] print("\n" + "=" * 80) print(f"FACTUAL ACCURACY EVALUATION FOR: {topic}") print("=" * 80) print( f"CORRECT: {stats.get('correct_count', 0)} statements " f"({stats.get('pct_correct', 0):.1f}%)" ) print( f"INCORRECT: {stats.get('incorrect_count', 0)} statements " f"({stats.get('pct_incorrect', 0):.1f}%)" ) print( f"UNKNOWN: {stats.get('unknown_count', 0)} statements " f"({stats.get('pct_unknown', 0):.1f}%)" ) print( f"TOTAL: {stats.get('total_count', 0)} statements evaluated" ) print("-" * 80) print( f"Factual Accuracy Score: {accuracy_score:.4f} (range [-1, 1])" ) print( f"Original Quality Score: {quality_metrics['overall_quality']:.4f} (range [0, 1])" ) print( f"Combined Final Score: {scaled_score:.4f} (range [-1, 1])" ) print("=" * 80 + "\n") # Add accuracy metrics to article_quality_metrics for wandb logging if ( evaluation_results and "statistics" in evaluation_results ): stats = evaluation_results["statistics"] accuracy_metrics = { "pct_correct": stats.get("pct_correct", 0), "pct_incorrect": stats.get("pct_incorrect", 0), "pct_unknown": stats.get("pct_unknown", 0), "accuracy_score": accuracy_score, } # Find the corresponding metrics entry and update it for metrics in self.article_quality_metrics: if metrics.get("topic") == topic: metrics.update(accuracy_metrics) break else: logger.warning( f"No reference article found for topic: {topic}" ) except Exception as e: logger.error(f"Error during factual accuracy evaluation: {e}") import traceback logger.error(traceback.format_exc()) return rollout_group_data async def setup(self): """Set up the environment - load topics, etc.""" pass async def get_next_item(self) -> Tuple[int, str]: """Get next episode ID and topic""" # Select a random topic topic = random.choice(self.topics) episode_id = self.iter self.iter += 1 return (episode_id, topic) async def evaluate(self, *args, **kwargs): """Run evaluation on test set of topics""" if not self.config.use_wandb: logger.info("Skipping evaluation as wandb is not enabled") return num_eval = min(self.config.eval_topics, len(self.topics)) eval_topics = random.sample(self.topics, num_eval) logger.info(f"Starting evaluation on {num_eval} topics") eval_metrics = { "avg_steps": 0.0, "avg_quality": 0.0, "avg_structure": 0.0, "avg_comprehensiveness": 0.0, "avg_fact_usage": 0.0, "completion_rate": 0.0, } completed_count = 0 total_steps = 0 quality_scores = { "overall": [], "structure": [], "comprehensiveness": [], "fact_usage": [], } # Run evaluation episodes for eval_idx, topic in enumerate(eval_topics): episode_id = 1000000 + eval_idx # High range for eval episodes episode = self._get_or_create_episode(episode_id, topic) # Run episode until terminal while not episode.is_terminal: is_terminal, _ = await self._next_step(episode) if is_terminal: break # Record metrics total_steps += episode.steps_taken if episode.final_article: completed_count += 1 quality_metrics = self._assess_article_quality( episode.final_article, episode.research_facts ) quality_scores["overall"].append(quality_metrics["overall_quality"]) quality_scores["structure"].append(quality_metrics["structure_score"]) quality_scores["comprehensiveness"].append( quality_metrics["comprehensiveness_score"] ) quality_scores["fact_usage"].append(quality_metrics["fact_usage_score"]) # Clean up episode if episode_id in self.episodes: del self.episodes[episode_id] # Calculate averages if num_eval > 0: eval_metrics["avg_steps"] = total_steps / num_eval eval_metrics["completion_rate"] = completed_count / num_eval if completed_count > 0: eval_metrics["avg_quality"] = ( sum(quality_scores["overall"]) / completed_count ) eval_metrics["avg_structure"] = ( sum(quality_scores["structure"]) / completed_count ) eval_metrics["avg_comprehensiveness"] = ( sum(quality_scores["comprehensiveness"]) / completed_count ) eval_metrics["avg_fact_usage"] = ( sum(quality_scores["fact_usage"]) / completed_count ) # Store metrics for wandb logging self.eval_metrics = [ ("eval/avg_steps", eval_metrics["avg_steps"]), ("eval/completion_rate", eval_metrics["completion_rate"]), ("eval/avg_quality", eval_metrics["avg_quality"]), ("eval/avg_structure", eval_metrics["avg_structure"]), ("eval/avg_comprehensiveness", eval_metrics["avg_comprehensiveness"]), ("eval/avg_fact_usage", eval_metrics["avg_fact_usage"]), ] async def add_rollouts_for_wandb( self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], item: Item = None, ): """ Save complete conversation histories to wandb This captures the full research and article creation process, including all tool calls and intermediate steps """ # Use the base implementation first for basic text and scores await super().add_rollouts_for_wandb(scored_data, item) # Now also save the complete conversation history if we have it if item is not None and isinstance(item, tuple) and len(item) > 0: episode_id = item[0] episode = self.episodes.get(episode_id) if episode and hasattr(episode, "message_history"): # Format the conversation with relevant metadata num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: num_keep = self.config.group_size # Add detailed conversation data to rollouts # We'll extract this in create_rollout_table to create a more detailed table for i in range(min(num_keep, len(scored_data["tokens"]))): # Add chat history to the most recent entry in rollouts_for_wandb if len(self.rollouts_for_wandb) > 0 and i < len( self.rollouts_for_wandb[-1] ): entry = list(self.rollouts_for_wandb[-1][i]) # Append the message history to the existing tuple entry.append( { "topic": episode.topic, "steps_taken": episode.steps_taken, "is_terminal": episode.is_terminal, "message_history": episode.message_history, "tool_calls": episode.tool_calls, "tool_results": episode.tool_results, } ) # Replace the tuple with our updated entry self.rollouts_for_wandb[-1][i] = tuple(entry) async def create_rollout_table(self, wandb_metrics): """ Create a detailed wandb table with complete conversation histories This expands on the base implementation to include full chat histories and research steps in a structured format """ if len(self.rollouts_for_wandb) > 0: # First create the basic table with text and scores basic_table = wandb.Table(columns=["text", "score"]) for group in self.rollouts_for_wandb: for item in group: # Check if this is a basic entry (just text and score) or has chat history if len(item) == 2: basic_table.add_data(item[0], item[1]) else: basic_table.add_data(item[0], item[1]) wandb_metrics["train/rollouts"] = basic_table # Create a detailed table with conversation histories # This will only include entries that have chat history detailed_table = wandb.Table( columns=[ "topic", "steps_taken", "score", "full_conversation", "tool_calls_count", "has_final_article", ] ) for group in self.rollouts_for_wandb: for item in group: # Check if this entry has chat history if len(item) > 2: conversation_data = item[2] # Extract conversation metadata topic = conversation_data.get("topic", "Unknown") steps_taken = conversation_data.get("steps_taken", 0) tool_calls = conversation_data.get("tool_calls", []) message_history = conversation_data.get("message_history", []) # Format full conversation as a string conversation_text = "\n\n".join( [ f"[{msg.get('role', 'unknown')}]\n{msg.get('content', '')}" for msg in message_history ] ) # Check if there's a final article has_final_article = ( "Final Step: ```markdown" in conversation_text ) detailed_table.add_data( topic, steps_taken, item[1], # Score conversation_text, len(tool_calls), has_final_article, ) wandb_metrics["train/detailed_conversations"] = detailed_table return wandb_metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None): """Log metrics to wandb""" if wandb_metrics is None: wandb_metrics = {} # Add eval metrics for metric_name, metric_value in self.eval_metrics: wandb_metrics[metric_name] = metric_value # Clear metrics for next round self.eval_metrics = [] # Add article quality metrics if available if self.article_quality_metrics: # Calculate averages avg_quality = sum( m["overall_quality"] for m in self.article_quality_metrics ) / len(self.article_quality_metrics) avg_steps = sum( m["steps_taken"] for m in self.article_quality_metrics ) / len(self.article_quality_metrics) wandb_metrics["train/avg_article_quality"] = avg_quality wandb_metrics["train/avg_steps_per_article"] = avg_steps # Convert to [-1, 1] range for consistency with other metrics wandb_metrics["train/article_quality_score"] = avg_quality * 2 - 1 # Add factual accuracy metrics if available if any("accuracy_score" in m for m in self.article_quality_metrics): # Calculate average accuracy metrics accuracy_metrics = [ m for m in self.article_quality_metrics if "accuracy_score" in m ] if accuracy_metrics: # Calculate raw accuracy statistics avg_accuracy_score = sum( m["accuracy_score"] for m in accuracy_metrics ) / len(accuracy_metrics) avg_pct_correct = sum( m.get("pct_correct", 0) for m in accuracy_metrics ) / len(accuracy_metrics) avg_pct_incorrect = sum( m.get("pct_incorrect", 0) for m in accuracy_metrics ) / len(accuracy_metrics) avg_pct_unknown = sum( m.get("pct_unknown", 0) for m in accuracy_metrics ) / len(accuracy_metrics) # Log raw factual accuracy metrics wandb_metrics["train/avg_factual_accuracy"] = avg_accuracy_score wandb_metrics["train/avg_pct_correct"] = avg_pct_correct wandb_metrics["train/avg_pct_incorrect"] = avg_pct_incorrect wandb_metrics["train/avg_pct_unknown"] = avg_pct_unknown # Calculate combined scores combined_scores = [] for m in accuracy_metrics: # Convert quality score from [0,1] to [-1,1] quality_score_scaled = m["overall_quality"] * 2 - 1 # Take average of quality and accuracy scores combined_score = ( quality_score_scaled + m["accuracy_score"] ) / 2 combined_scores.append(combined_score) # Calculate average combined score if combined_scores: avg_combined_score = sum(combined_scores) / len(combined_scores) wandb_metrics["train/avg_combined_score"] = avg_combined_score # Add a summary metric that includes both article quality and factual accuracy # This provides a comprehensive score for overall article quality including factual accuracy wandb_metrics["train/overall_article_score"] = ( avg_combined_score ) # Create a table of article metrics if wandb.run is not None: # Add factual accuracy columns if available columns = [ "topic", "steps", "overall_quality", "structure", "comprehensiveness", "fact_usage", ] # Check if we have factual accuracy metrics if any("accuracy_score" in m for m in self.article_quality_metrics): columns.extend( [ "factual_accuracy", "pct_correct", "pct_incorrect", "pct_unknown", "combined_score", ] ) table = wandb.Table(columns=columns) for metric in self.article_quality_metrics: row_data = [ metric["topic"], metric["steps_taken"], metric["overall_quality"], metric["structure_score"], metric["comprehensiveness_score"], metric["fact_usage_score"], ] # Add factual accuracy metrics if available if "accuracy_score" in metric: # Calculate combined score quality_score_scaled = metric["overall_quality"] * 2 - 1 combined_score = ( quality_score_scaled + metric["accuracy_score"] ) / 2 row_data.extend( [ metric.get("accuracy_score", 0), metric.get("pct_correct", 0), metric.get("pct_incorrect", 0), metric.get("pct_unknown", 0), combined_score, ] ) table.add_data(*row_data) wandb_metrics["train/article_quality"] = table # Clear for next round self.article_quality_metrics = [] await super().wandb_log(wandb_metrics) @classmethod def cli(cls): """Command-line interface entry point""" super().cli() if __name__ == "__main__": WikipediaArticleCreatorEnv.cli()