diff --git a/environments/answer_format_environment/README.md b/environments/answer_format_environment/README.md new file mode 100644 index 00000000..b592adfb --- /dev/null +++ b/environments/answer_format_environment/README.md @@ -0,0 +1,473 @@ +# Answer Format Environment + +A comprehensive training environment for teaching language models to generate responses in specific structured formats. This environment focuses on **format adherence** rather than answer correctness, using randomized format requirements and corresponding parsers to train models on structured response generation. + +## 🎯 Overview + +The Answer Format Environment trains models to: +- Generate responses in 150+ different structured formats +- Follow strict thinking tag discipline (``) +- Parse and validate format compliance +- Handle multiple dataset types with appropriate format selection +- Maintain equivalent training ratios across formats (optional) + +**Key Philosophy**: This environment scores based on format compliance (1.0 for correct format, 0.0 for incorrect), not answer accuracy. It teaches models to follow formatting instructions precisely. + +## ✨ Key Features + +### 🔄 **Randomized Format Selection** +- 150+ supported answer formats across multiple categories +- Weighted format selection (70% simple, 30% complex) +- Dataset type-aware format filtering (generic, math_only, code_only) +- Dynamic compositor system for complex structured responses + +### 🧠 **Thinking Tag Validation** +- Enforces exactly one `` section per response +- All reasoning must be contained within thinking tags +- Strict validation prevents multiple thinking sections +- Answer must appear after `` in specified format + +### 📊 **Comprehensive Data Management** +- Multi-dataset support with automatic shuffling +- Configurable train/eval splits +- Extensive data dumping with group-level statistics +- Failed rollout tracking and analysis +- WandB integration with detailed metrics + +### ⚖️ **Equivalent Ratio Enforcement** +- Optional system to ensure balanced training across formats +- Pauses formats after reaching success threshold +- Prevents format bias in training data +- Comprehensive monitoring and status reporting + +### 🔍 **Advanced Monitoring** +- Format success rate tracking +- Group-level performance statistics +- Failure case analysis and logging +- Real-time format balance monitoring + +## 📋 Supported Format Categories + +### **Basic Structured Data** +```json +{"answer": "content"} // JSON +answer: content // YAML +answer = "content" // TOML +``` + +### **XML/HTML Tags** +```xml +content // XML +Final Answer: content // XML with prefix +content // Output tags +content // Result tags +``` + +### **LaTeX Formats** +```latex +\boxed{content} // Text-friendly boxed +$\boxed{expression}$ // Math-only boxed +\begin{align} expression \end{align} // Math alignment +$\text{answer}$ // Text in math mode +``` + +### **Natural Language** +``` +The answer is: content +Final answer: content +In conclusion: content +Therefore: content +``` + +### **Programming Formats** +```python +print("answer") // Python print +console.log("answer") // JavaScript console +# answer // Python comment +return "answer" // Return statement +``` + +### **Complex Multi-Tag Formats** +```xml +... +... +... +... +``` + +### **Dynamic Compositor Formats** +Randomly combines 3-6 components in XML, JSON, YAML, or TOML: +- Analysis components (problem_analysis, requirements_analysis, etc.) +- Reasoning components (logical_reasoning, step_by_step, etc.) +- Planning components (approach, methodology, etc.) +- Technical components (implementation, code_structure, etc.) +- Evaluation components (validation, testing, etc.) +- Output components (final_answer, conclusion, etc.) + +## 🚀 Quick Start + +### Basic Configuration + +```python +from atropos.environments.answer_format_environment import AnswerFormatEnv, AnswerFormatEnvConfig + +# Simple configuration +config = AnswerFormatEnvConfig( + dataset_configs=[ + { + "name": "your_dataset", + "split": "train", + "sample_size": 1000, + "prompt_field": "question", + "answer_field": "answer", + "dataset_type": "generic" + } + ], + debug_logging=True, + dump_rollouts=True, + eval_set_percentage=0.1 +) + +# Initialize environment +env = AnswerFormatEnv(config, server_configs) +``` + +### Multi-Dataset Configuration + +```python +config = AnswerFormatEnvConfig( + dataset_configs=[ + { + "name": "teknium/OpenHermes-2.5", + "split": "train", + "sample_size": 5000, + "prompt_field": "conversations", + "answer_field": "conversations", + "metadata_fields": ["source"], + "dataset_type": "generic" + }, + { + "name": "gsm8k", + "split": "train", + "sample_size": 2000, + "prompt_field": "question", + "answer_field": "answer", + "dataset_type": "math_only" + }, + { + "name": "NousResearch/AcademicMCQA", + "split": "train", + "sample_size": 5000, + "prompt_field": "prompt", + "answer_field": "ground_truth", + "metadata_fields": ["answer", "options"], + "dataset_type": "generic" + } + ], + ensure_equivalent_ratios=True, + format_group_threshold=50, + dump_failed_rollouts=True +) +``` + +## ⚙️ Configuration Options + +### **Dataset Configuration** +```python +dataset_configs: List[Dict[str, Any]] = [ + { + "name": str, # Dataset name or HuggingFace path + "split": str, # Dataset split ("train", "test", etc.) + "sample_size": int, # Number of samples to use + "prompt_field": str, # Field containing prompts/questions + "answer_field": str, # Field containing answers + "metadata_fields": List[str], # Additional fields to preserve + "dataset_type": str # "generic", "math_only", or "code_only" + } +] +``` + +### **Core Settings** +```python +debug_logging: bool = True # Enable detailed logging +dump_rollouts: bool = True # Save rollouts to JSONL +dump_failed_rollouts: bool = True # Save failed rollouts separately +rollout_save_score_threshold: float = 0.0 # Minimum score to save rollouts +eval_set_percentage: float = 0.1 # Evaluation set percentage +``` + +### **Format Control** +```python +supported_formats: Optional[List[AnswerFormat]] = None # Filter to specific formats +ensure_equivalent_ratios: bool = False # Enable ratio enforcement +format_group_threshold: int = 50 # Success threshold per format +``` + +## 📊 Dataset Types & Format Selection + +### **Generic Datasets** (`dataset_type: "generic"`) +- **Available Formats**: All basic formats (JSON, XML, natural language, brackets, etc.) +- **Use Case**: General conversation, QA, instruction following, MCQA +- **Examples**: OpenHermes, Alpaca, AcademicMCQA, general chat datasets + +### **Math-Only Datasets** (`dataset_type: "math_only"`) +- **Available Formats**: Generic formats + LaTeX math expressions +- **Additional Formats**: `$\boxed{}$`, `\begin{align}`, `$\text{}$`, etc. +- **Use Case**: Mathematical problem solving +- **Examples**: GSM8K, MATH, mathematical reasoning datasets + +### **Code-Only Datasets** (`dataset_type: "code_only"`) +- **Available Formats**: Generic formats + programming-specific formats +- **Additional Formats**: `print()`, `console.log()`, comments, return statements +- **Use Case**: Code generation, programming problems +- **Examples**: HumanEval, MBPP, coding datasets + +## 🎲 Dynamic Compositor System + +The dynamic compositor creates complex structured responses by randomly combining components: + +### **Component Categories** +- **Analysis**: `problem_analysis`, `requirements_analysis`, `context_analysis` +- **Reasoning**: `logical_reasoning`, `step_by_step`, `causal_reasoning` +- **Planning**: `approach`, `methodology`, `strategy` +- **Technical**: `implementation`, `code_structure`, `algorithm` +- **Evaluation**: `validation`, `testing`, `verification` +- **Output**: `final_answer`, `conclusion`, `summary` + +### **Output Formats** +- **XML**: `content` +- **JSON**: `{"component_name": "content"}` +- **YAML**: `component_name: content` +- **TOML**: `component_name = "content"` + +### **Example Dynamic Format** +```xml +Understanding the requirements... +Step by step analysis... +My methodology will be... +Here's the solution... +42 +``` + +## 📈 Monitoring & Analytics + +### **WandB Metrics** +- `train/percent_correct`: Overall format compliance rate +- `train/format_success_rate_{format}`: Per-format success rates +- `train/format_usage_count_{format}`: Usage frequency per format +- `train/equivalent_ratio_paused_formats`: Number of paused formats +- `train/group_success_rate`: Percentage of successful groups +- `train/failed_groups_count`: Number of completely failed groups + +### **Group-Level Statistics** +``` +Format: json_confidence | Group average score: 0.7500 | 12/16 correct (75.0%) +Format: latex_boxed_math | Group average score: 0.0000 | 0/16 correct (0.0%) (All failures!) +Format: natural_language_answer | Group average score: 1.0000 | 16/16 correct (100.0%) (Perfect group!) +``` + +### **Data Dumps** +- **Regular Rollouts**: `answer_format_rollouts_{uuid}_{batch}.jsonl` +- **Failed Rollouts**: `answer_format_failed_rollouts_{uuid}_{batch}.jsonl` +- **Metadata**: Format type, scores, conversation history, timestamps +- **Batch Size**: 100 groups per file + +## ⚖️ Equivalent Ratio Enforcement + +Optional system to ensure balanced training across all formats: + +### **How It Works** +1. Tracks successful groups per format +2. Pauses formats that reach threshold (default: 50 successful groups) +3. Continues training other formats until they catch up +4. Resumes paused formats when balance is restored + +### **Configuration** +```python +ensure_equivalent_ratios=True, # Enable the system +format_group_threshold=50, # Success threshold per format +``` + +### **Monitoring** +```python +status = env.get_equivalent_ratio_status() +print(f"Paused formats: {status['paused_formats']}") +print(f"Active formats: {status['active_formats']}") +print(f"Progress: {status['format_progress']}") +``` + +## 🔧 Advanced Usage + +### **Custom Format Filtering** +```python +from atropos.environments.answer_format_environment import AnswerFormat + +config = AnswerFormatEnvConfig( + supported_formats=[ + AnswerFormat.JSON, + AnswerFormat.XML, + AnswerFormat.LATEX_BOXED, + AnswerFormat.NATURAL_LANGUAGE_ANSWER + ] +) +``` + +### **Evaluation Mode** +```python +# Run evaluation on held-out set +eval_score = await env.evaluate() +print(f"Evaluation format compliance: {eval_score}") +``` + +### **Supported Dataset Formats** + +#### **OpenHermes Conversations** +```python +{ + "name": "teknium/OpenHermes-2.5", + "prompt_field": "conversations", + "answer_field": "conversations" +} +``` + +#### **GSM8K Math Problems** +```python +{ + "name": "gsm8k", + "prompt_field": "question", + "answer_field": "answer" +} +``` + +#### **AcademicMCQA Multiple Choice** +```python +{ + "name": "NousResearch/AcademicMCQA", + "prompt_field": "prompt", + "answer_field": "ground_truth", + "metadata_fields": ["answer", "options"] +} +``` + +## 📝 Response Format Requirements + +### **Thinking Tags** +- **Required**: Exactly one `` opening and one `` closing tag +- **Content**: All reasoning must be inside thinking tags +- **Placement**: Answer must appear after `` in specified format +- **Validation**: No additional thinking tags allowed after first `` + +### **Example Valid Response** +``` + +Let me analyze this problem step by step. +First, I need to understand what's being asked... +The solution involves calculating... + + +{"answer": "42"} +``` + +### **Example Invalid Response** +``` +Some reasoning +{"answer": "42"} +More reasoning // ❌ Additional thinking tags not allowed +``` + +## 🚨 Common Issues & Solutions + +### **Format Validation Failures** +- **Issue**: Response doesn't match expected format +- **Solution**: Check regex patterns and ensure exact format compliance +- **Debug**: Enable `debug_logging=True` for detailed validation info + +### **Thinking Tag Violations** +- **Issue**: Multiple thinking sections or missing tags +- **Solution**: Ensure exactly one `` section per response +- **Debug**: Check thinking tag validation logs + +### **Dataset Loading Errors** +- **Issue**: Dataset not found or field missing +- **Solution**: Verify dataset name, split, and field names +- **Debug**: Check dataset configuration and field mappings + +### **Memory Issues with Large Datasets** +- **Issue**: Out of memory with large sample sizes +- **Solution**: Reduce `sample_size` or use streaming datasets +- **Debug**: Monitor memory usage and adjust batch sizes + +## 🔍 Debugging & Development + +### **Enable Debug Logging** +```python +config = AnswerFormatEnvConfig(debug_logging=True) +``` + +### **Check Format Status** +```python +# Get equivalent ratio status +status = env.get_equivalent_ratio_status() + +# Check format success rates +for format_name, success_rate in env.format_success_rates.items(): + print(f"{format_name}: {success_rate:.2%}") +``` + +### **Analyze Failed Rollouts** +```python +# Failed rollouts are automatically saved when dump_failed_rollouts=True +# Check the datadumps directory for analysis files +``` + +## 📚 Technical Details + +### **Scoring System** +- **Format Compliance**: 1.0 for correct format, 0.0 for incorrect +- **No Answer Accuracy**: Content correctness is not evaluated +- **Group Scoring**: Average of all rollouts in a group +- **Success Threshold**: Groups with >0.0 average are considered successful + +### **Format Validation** +- **Regex Patterns**: Each format has specific validation patterns +- **Exact Matching**: Strict compliance required for scoring +- **Content Extraction**: Validated content is extracted for consistency + +### **Dynamic Format Generation** +- **Component Selection**: Random selection of 3-6 components +- **Format Templates**: XML, JSON, YAML, TOML output formats +- **Validation Storage**: Components stored for precise validation + +### **Special Dataset Handling** +- **GSM8K**: Extracts numerical answers from `####` separated format +- **AcademicMCQA**: Uses ground truth letters (A, B, C, D) as answers +- **OpenHermes**: Extracts from conversation format with role-based parsing + +## 🤝 Contributing + +### **Adding New Formats** +1. Add enum value to `AnswerFormat` +2. Add system prompt instruction +3. Add validation regex pattern +4. Add content extraction logic +5. Test with sample responses + +### **Adding New Dataset Types** +1. Define dataset type in configuration +2. Add format filtering logic +3. Update format selection methods +4. Test with representative datasets + +### **Adding New Datasets** +1. Add special handling in `setup()` method +2. Define field mappings in configuration +3. Add metadata extraction logic +4. Test dataset loading and processing + +## 📄 License + +This environment is part of the Atropos training framework. See the main repository for license information. + +--- + +**Need Help?** Check the debug logs, enable verbose logging, or review the comprehensive monitoring metrics for troubleshooting guidance. \ No newline at end of file diff --git a/environments/answer_format_environment/answer_format_environment.py b/environments/answer_format_environment/answer_format_environment.py new file mode 100644 index 00000000..e78503c3 --- /dev/null +++ b/environments/answer_format_environment/answer_format_environment.py @@ -0,0 +1,4258 @@ +""" +Answer Format Environment + +This environment trains models to generate responses in specific formats. +It focuses on format adherence rather than answer correctness, using randomized +format requirements and corresponding parsers. + +Key Features: +- Randomized answer format selection from 150+ supported formats +- Strict thinking tag validation (exactly one section) +- Format-specific parsers for validation +- Support for multiple input datasets that get shuffled together +- Dataset type-aware format selection (generic, math_only, code_only) +- Dynamic compositor system for complex structured responses +- Comprehensive data dumping and logging following environment conventions +- Format compliance scoring (1.0 for correct format, 0.0 for incorrect) +- Format success rate tracking and monitoring +- Weighted format selection for balanced training +- Optional equivalent ratio enforcement (stops generating formats after N successful groups) + +Supported Answer Formats: +- Basic structured data: JSON, YAML, TOML (with confidence scores) +- XML/HTML tags: , , , nested variants +- LaTeX: \boxed{}, math mode expressions, align blocks, matrices +- Markdown: code blocks, bold, italic, headers, quotes +- Bracket notation: [], [[]], {}, (), <> +- Natural language: "The answer is:", "Final answer:", 15+ variants +- Programming formats: print(), console.log(), comments, docstrings +- Custom delimiters: |answer|, #answer#, _answer_, ~answer~ +- Complex multi-tag formats: coding workflows, math derivations, research formats +- Dynamic compositor formats: randomly combined XML/JSON/YAML/TOML structures +- Special formats: TextArena [A], arrows =>, colons Answer: + +Dataset Type Support: +- generic: All basic formats (JSON, XML, natural language, etc.) +- math_only: Generic formats + LaTeX math expressions, math workflows +- code_only: Generic formats + programming-specific formats, code workflows + +Response Format Requirements: +- Must use exactly ONE opening tag and ONE closing tag +- All reasoning must be inside the thinking tags +- Answer must be in the specified format after +- No additional tags allowed after the first closing tag + +Example Usage: + env_config = AnswerFormatEnvConfig( + dataset_configs=[ + { + "name": "your_dataset", + "split": "train", + "sample_size": 1000, + "prompt_field": "question", + "answer_field": "answer", + "dataset_type": "generic" # or "math_only", "code_only" + } + ], + supported_formats=[AnswerFormat.JSON, AnswerFormat.XML], # Optional filter + eval_set_percentage=0.1, + debug_logging=True, + ensure_equivalent_ratios=True, # Enable equivalent ratio enforcement + format_group_threshold=50, # Stop each format after 50 successful groups + ) +""" + +import json +import logging +import os +import random +import re +import uuid +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset +from pydantic import Field +from tqdm.asyncio import tqdm_asyncio + +import wandb +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + Item, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class AnswerFormat(Enum): + """Enumeration of supported answer formats.""" + + # Basic structured data formats (answer only) + JSON = "json" # {"answer": "content"} + JSON_ARRAY = "json_array" # ["answer"] + JSON_SIMPLE = "json_simple" # "answer" + YAML = "yaml" # answer: content + YAML_LIST = "yaml_list" # - content + TOML = "toml" # answer = "content" + TOML_SECTION = "toml_section" # [response] \n answer = "content" + + # Structured data with confidence scores + JSON_CONFIDENCE = "json_confidence" # {"answer": "content", "confidence": 0.9} + YAML_CONFIDENCE = "yaml_confidence" # answer: content \n confidence: 0.9 + TOML_CONFIDENCE = "toml_confidence" # answer = "content" \n confidence = 0.9 + + # XML/HTML tag variations (XML now uses answer tags) + XML = "xml" # content + XML_FINAL_ANSWER = "xml_final_answer" # Final Answer: content + OUTPUT_TAGS = "output_tags" # + RESULT_TAGS = "result_tags" # + RESPONSE_TAGS = "response_tags" # + FINAL_ANSWER_TAGS = "final_answer_tags" # + SOLUTION_TAGS = "solution_tags" # + CONCLUSION_TAGS = "conclusion_tags" # + REPLY_TAGS = "reply_tags" # + NESTED_RESPONSE_ANSWER = "nested_response_answer" # explanation\ncontent + NESTED_SOLUTION_ANSWER = "nested_solution_answer" # explanation\ncontent + NESTED_OUTPUT_RESULT = ( + "nested_output_result" # explanation\ncontent + ) + NESTED_ANALYSIS_CONCLUSION = "nested_analysis_conclusion" # explanation\ncontent # noqa: E501 + NESTED_REASONING_ANSWER = "nested_reasoning_answer" # explanation\ncontent + + # LaTeX formats (text-friendly) + LATEX_BOXED = "latex_boxed" # \boxed{} - can contain text + LATEX_TEXTBF = "latex_textbf" # \textbf{} - bold text + LATEX_TEXTIT = "latex_textit" # \textit{} - italic text + LATEX_UNDERLINE = "latex_underline" # \underline{} - underlined text + + # LaTeX formats (math-only - for math datasets) + LATEX_BOXED_MATH = "latex_boxed_math" # $\boxed{}$ - math mode + LATEX_ALIGN = "latex_align" # \begin{align} \end{align} - math mode + LATEX_EQUATION = "latex_equation" # \begin{equation} \end{equation} - math mode + LATEX_DISPLAYMATH = "latex_displaymath" # \[ \] - math mode + LATEX_INLINE_MATH = "latex_inline_math" # $ $ - math mode + LATEX_TEXT_MATH = "latex_text_math" # $\text{answer}$ - text in math mode + LATEX_MATHRM = "latex_mathrm" # $\mathrm{answer}$ - roman text in math + LATEX_THEREFORE = "latex_therefore" # $\therefore answer$ - therefore symbol + LATEX_IMPLIES = "latex_implies" # $\implies answer$ - implies symbol + LATEX_EQUIV = "latex_equiv" # $answer \equiv value$ - equivalence + LATEX_MATRIX = "latex_matrix" # \begin{matrix} answer \end{matrix} + LATEX_PMATRIX = "latex_pmatrix" # \begin{pmatrix} answer \end{pmatrix} + + # Markdown formats + MARKDOWN_CODE = "markdown_code" # ``` + MARKDOWN_BOLD = "markdown_bold" # **text** + MARKDOWN_ITALIC = "markdown_italic" # *text* + MARKDOWN_HEADER = "markdown_header" # ## Answer + MARKDOWN_QUOTE = "markdown_quote" # > answer + + # Bracket and delimiter formats + SQUARE_BRACKETS = "square_brackets" # [answer] + DOUBLE_SQUARE_BRACKETS = "double_square_brackets" # [[answer]] + CURLY_BRACES = "curly_braces" # {answer} + PARENTHESES = "parentheses" # (answer) + ANGLE_BRACKETS = "angle_brackets" # + + # Natural language patterns + NATURAL_LANGUAGE_ANSWER = "natural_language_answer" # "The answer is:" + NATURAL_LANGUAGE_FINAL = "natural_language_final" # "Final answer:" + NATURAL_LANGUAGE_CONCLUSION = "natural_language_conclusion" # "In conclusion:" + NATURAL_LANGUAGE_THEREFORE = "natural_language_therefore" # "Therefore:" + NATURAL_LANGUAGE_RESULT = "natural_language_result" # "The result is:" + + # Additional natural language patterns + NATURAL_LANGUAGE_BEST = "natural_language_best" # "The best answer is:" + NATURAL_LANGUAGE_MY_FINAL = "natural_language_my_final" # "My final answer is:" + NATURAL_LANGUAGE_CORRECT = "natural_language_correct" # "The correct answer is:" + NATURAL_LANGUAGE_SOLUTION = "natural_language_solution" # "The solution is:" + NATURAL_LANGUAGE_RESPONSE = "natural_language_response" # "My response is:" + NATURAL_LANGUAGE_ULTIMATELY = "natural_language_ultimately" # "Ultimately:" + NATURAL_LANGUAGE_THUS = "natural_language_thus" # "Thus:" + NATURAL_LANGUAGE_HENCE = "natural_language_hence" # "Hence:" + NATURAL_LANGUAGE_CONSEQUENTLY = "natural_language_consequently" # "Consequently:" + NATURAL_LANGUAGE_TO_SUMMARIZE = "natural_language_to_summarize" # "To summarize:" + NATURAL_LANGUAGE_IN_SUMMARY = "natural_language_in_summary" # "In summary:" + NATURAL_LANGUAGE_OVERALL = "natural_language_overall" # "Overall:" + NATURAL_LANGUAGE_FINAL_VERDICT = ( + "natural_language_final_verdict" # "Final verdict:" + ) + NATURAL_LANGUAGE_BOTTOM_LINE = "natural_language_bottom_line" # "Bottom line:" + NATURAL_LANGUAGE_KEY_POINT = "natural_language_key_point" # "The key point is:" + + # Special formats + TEXTARENA_FORMAT = "textarena_format" # [A], [B], [C], [D] + COLON_FORMAT = "colon_format" # Answer: content + ARROW_FORMAT = "arrow_format" # => answer or -> answer + + # HTML formats + HTML_CODE = "html_code" # + HTML_PRE = "html_pre" #

+    HTML_SPAN = "html_span"  # 
+    HTML_DIV = "html_div"  # 
+ HTML_P = "html_p" #

+ + # Multiple structured tags + MULTIPLE_TAGS = ( + "multiple_tags" # + ) + + # Complex multi-tag formats for specific domains + COMPLEX_CODING_FORMAT = "complex_coding_format" + COMPLEX_CODING_SIMPLE = "complex_coding_simple" + COMPLEX_CODING_MINIMAL = "complex_coding_minimal" + COMPLEX_MATH_FORMAT = "complex_math_format" + COMPLEX_MATH_SIMPLE = "complex_math_simple" + COMPLEX_GENERAL_FORMAT = "complex_general_format" + COMPLEX_GENERAL_SIMPLE = ( + "complex_general_simple" + ) + COMPLEX_RESEARCH_FORMAT = "complex_research_format" + + # Advanced scratchpad and classification formats + COMPLEX_SCRATCHPAD_FULL = "complex_scratchpad_full" + COMPLEX_SCRATCHPAD_SIMPLE = "complex_scratchpad_simple" + COMPLEX_CLASSIFICATION_FORMAT = ( + "complex_classification_format" + ) + COMPLEX_ANALYSIS_WITH_ANSWER = "complex_analysis_with_answer" + COMPLEX_EVALUATION_FORMAT = "complex_evaluation_format" + + # Dynamic compositor formats - randomly combine components in different output formats + # COMMENTED OUT: These formats are proving too difficult to parse reliably + # DYNAMIC_SCRATCHPAD_XML = "dynamic_scratchpad_xml" # Random XML scratchpad components + # DYNAMIC_SCRATCHPAD_JSON = "dynamic_scratchpad_json" # Random JSON scratchpad components + # DYNAMIC_SCRATCHPAD_YAML = "dynamic_scratchpad_yaml" # Random YAML scratchpad components + # DYNAMIC_SCRATCHPAD_TOML = "dynamic_scratchpad_toml" # Random TOML scratchpad components + # DYNAMIC_ANALYSIS_XML = "dynamic_analysis_xml" # Random XML analysis components + # DYNAMIC_ANALYSIS_JSON = "dynamic_analysis_json" # Random JSON analysis components + # DYNAMIC_WORKFLOW_XML = "dynamic_workflow_xml" # Random XML workflow components + # DYNAMIC_WORKFLOW_JSON = "dynamic_workflow_json" # Random JSON workflow components + + # Custom delimiters + PIPE_DELIMITED = "pipe_delimited" # |answer| + HASH_DELIMITED = "hash_delimited" # #answer# + UNDERSCORE_DELIMITED = "underscore_delimited" # _answer_ + TILDE_DELIMITED = "tilde_delimited" # ~answer~ + + # Programming-style formats + FUNCTION_CALL = "function_call" # answer() + VARIABLE_ASSIGNMENT = "variable_assignment" # answer = "content" + RETURN_STATEMENT = "return_statement" # return "answer" + + # Additional easy-to-parse formats + EQUALS_FORMAT = "equals_format" # = answer + DASH_FORMAT = "dash_format" # - answer + PLUS_FORMAT = "plus_format" # + answer + STAR_FORMAT = "star_format" # * answer + PERCENT_FORMAT = "percent_format" # % answer + AMPERSAND_FORMAT = "ampersand_format" # & answer + AT_FORMAT = "at_format" # @ answer + EXCLAMATION_FORMAT = "exclamation_format" # ! answer + QUESTION_FORMAT = "question_format" # ? answer (for when answer is a question) + SEMICOLON_FORMAT = "semicolon_format" # ; answer + DOUBLE_COLON_FORMAT = "double_colon_format" # :: answer + TRIPLE_DASH_FORMAT = "triple_dash_format" # --- answer + DOUBLE_ARROW_FORMAT = "double_arrow_format" # >> answer + TRIPLE_ARROW_FORMAT = "triple_arrow_format" # >>> answer + BACKTICK_FORMAT = "backtick_format" # `answer` + DOUBLE_BACKTICK_FORMAT = "double_backtick_format" # ``answer`` + QUOTE_FORMAT = "quote_format" # "answer" + SINGLE_QUOTE_FORMAT = "single_quote_format" # 'answer' + TRIPLE_QUOTE_FORMAT = "triple_quote_format" # """answer""" + ANSWER_IS_FORMAT = "answer_is_format" # ANSWER IS: content + SOLUTION_IS_FORMAT = "solution_is_format" # SOLUTION IS: content + RESULT_IS_FORMAT = "result_is_format" # RESULT IS: content + OUTPUT_IS_FORMAT = "output_is_format" # OUTPUT IS: content + + # Code-specific formats (for code datasets) + PYTHON_PRINT = "python_print" # print("answer") + JAVASCRIPT_CONSOLE = "javascript_console" # console.log("answer") + PYTHON_COMMENT = "python_comment" # # answer + JAVASCRIPT_COMMENT = "javascript_comment" # // answer + C_COMMENT = "c_comment" # /* answer */ + SHELL_ECHO = "shell_echo" # echo "answer" + SHELL_OUTPUT = "shell_output" # $ answer + PYTHON_DOCSTRING = "python_docstring" # """answer""" + INI_FORMAT = "ini_format" # [section]\nanswer=value + ENV_FORMAT = "env_format" # ANSWER=value + + +class AnswerFormatEnvConfig(BaseEnvConfig): + """Custom config class for AnswerFormatEnv with additional parameters.""" + + dataset_configs: List[Dict[str, Any]] = Field( + default=[ + { + "name": "teknium/OpenHermes-2.5", + "split": "train", + "sample_size": 1000, + "prompt_field": "conversations", + "answer_field": "conversations", + "metadata_fields": ["source"], + "dataset_type": "generic", # Options: "generic", "math_only", "code_only" + }, + { + "name": "gsm8k", + "split": "train", + "sample_size": 2000, + "prompt_field": "question", + "answer_field": "answer", + "metadata_fields": [], + "dataset_type": "math_only", # GSM8K is a math dataset + }, + ], + description="List of dataset configurations to load and combine. Each can specify dataset_type: 'generic', 'math_only', or 'code_only'", # noqa: E501 + ) + + debug_logging: bool = Field( + default=True, description="Enable detailed debug logging" + ) + + suppress_base_env_logs: bool = Field( + default=True, + description="Suppress verbose base environment logs (like status dict updates)", + ) + + dump_rollouts: bool = Field( + default=False, + description="Whether to dump rollouts to JSONL files for analysis", + ) + + dump_failed_rollouts: bool = Field( + default=False, + description="Whether to dump failed rollouts (all 0 scores) to JSONL files for debugging", + ) + + rollout_save_score_threshold: float = Field( + default=0.0, + description="Minimum score threshold for saving rollouts (0.0 saves all)", + ge=0.0, + le=1.0, + ) + + eval_set_percentage: float = Field( + default=0.1, + description="Percentage of the dataset to use for evaluation (e.g., 0.1 for 10%)", + ge=0.0, + le=1.0, + ) + + supported_formats: Optional[List[AnswerFormat]] = Field( + default=None, + description="Optional list of AnswerFormat enums to use. If None, all formats are used.", + ) + + ensure_equivalent_ratios: bool = Field( + default=False, + description="Ensure equivalent ratios of successful groups across all formats by pausing formats that reach the threshold", # noqa: E501 + ) + + format_group_threshold: int = Field( + default=50, + description="Number of successful groups per format before pausing that format (only used when ensure_equivalent_ratios=True)", # noqa: E501 + ge=1, + le=1000, + ) + + seed: int = Field( + default_factory=lambda: random.randint(20, 9999), + description="Random seed used for dataset shuffling and other random operations", + ge=1, + le=99999, + ) + + +class AnswerFormatEnv(BaseEnv): + """Environment for training models on answer format adherence.""" + + env_config_cls = AnswerFormatEnvConfig + + def __init__( + self, + config: AnswerFormatEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + + # Set up debug logging + self.debug_logging = getattr(config, "debug_logging", True) + if self.debug_logging: + self.logger = logging.getLogger(f"{self.__class__.__name__}") + self.logger.setLevel(logging.DEBUG) + # Prevent propagation to avoid duplicate messages + self.logger.propagate = False + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.info("Debug logging enabled for AnswerFormatEnv") + else: + self.logger = logging.getLogger(f"{self.__class__.__name__}") + self.logger.addHandler(logging.NullHandler()) + self.logger.propagate = False + + # Suppress base environment logs if requested + if getattr(config, "suppress_base_env_logs", True): + # Set the base environment logger to WARNING level to suppress INFO logs + base_logger = logging.getLogger("atroposlib.envs.base") + base_logger.setLevel(logging.WARNING) + + self.percent_correct_buffer = list() + self.eval_metrics = list() + self.rollouts_for_wandb = [] + self.dataset_items: List[Dict[str, Any]] = [] + + # Track format success rates + self.format_success_counts = {} + self.format_total_counts = {} + + # Track successful groups per format for equivalent ratio enforcement + self.format_successful_groups = {} # format_name -> count of successful groups + self.format_group_threshold = config.format_group_threshold + self.ensure_equivalent_ratios = config.ensure_equivalent_ratios + + # For saving failed rollouts (all 0 scores) for debugging + self.failed_rollouts_to_save_buffer: List[Dict[str, Any]] = [] + self.failed_save_file_batch_num = 0 + + # Group-level statistics tracking + self.group_statistics = { + "total_groups": 0, + "successful_groups": 0, + "failed_groups": 0, + "average_scores": [], + "format_distribution": {}, + } + + # Define format categories by dataset type + self.math_only_formats = { + AnswerFormat.LATEX_BOXED_MATH, + AnswerFormat.LATEX_ALIGN, + AnswerFormat.LATEX_EQUATION, + AnswerFormat.LATEX_DISPLAYMATH, + AnswerFormat.LATEX_INLINE_MATH, + AnswerFormat.LATEX_TEXT_MATH, + AnswerFormat.LATEX_MATHRM, + AnswerFormat.LATEX_THEREFORE, + AnswerFormat.LATEX_IMPLIES, + AnswerFormat.LATEX_EQUIV, + AnswerFormat.LATEX_MATRIX, + AnswerFormat.LATEX_PMATRIX, + AnswerFormat.COMPLEX_MATH_FORMAT, + AnswerFormat.COMPLEX_MATH_SIMPLE, + } + + self.code_only_formats = { + AnswerFormat.SHELL_OUTPUT, + AnswerFormat.COMPLEX_CODING_FORMAT, + AnswerFormat.COMPLEX_CODING_SIMPLE, + AnswerFormat.COMPLEX_CODING_MINIMAL, + } + + # Generic formats work with any content type + all_formats = set(AnswerFormat) + self.generic_formats = ( + all_formats - self.math_only_formats - self.code_only_formats + ) + + # Validate configuration + self._validate_config(config) + + # Store seed for dataset shuffling and other random operations + self.seed = config.seed + + # Store base supported formats (will be filtered per item based on dataset type) + if config.supported_formats and len(config.supported_formats) > 0: + self.base_supported_formats = config.supported_formats + if self.debug_logging: + self.logger.info( + f"Using configured base formats: {[f.value for f in self.base_supported_formats]}" + ) + else: + self.base_supported_formats = list(AnswerFormat) + if self.debug_logging: + self.logger.info( + f"Using all formats as base: {len(self.base_supported_formats)} formats" + ) + + # Data dumping setup + self.run_uuid = str(uuid.uuid4()) + self.rollouts_to_save_buffer: List[Dict[str, Any]] = [] + self.processed_item_count = 0 + + # Dynamic format component storage + # Dynamic formats are commented out - no storage needed + + # Create datadumps directory + self.datadumps_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "datadumps" + ) + self.save_file_batch_num = 0 + + if self.debug_logging: + self.logger.info( + f"Data dumping {'enabled' if config.dump_rollouts else 'disabled'}" + ) + if config.dump_rollouts: + self.logger.info(f"Rollouts will be saved to: {self.datadumps_dir}") + + def _validate_config(self, config: AnswerFormatEnvConfig): + """Validate the configuration for common issues.""" + # Check dataset configurations + if not config.dataset_configs: + raise ValueError("At least one dataset configuration is required") + + for i, dataset_config in enumerate(config.dataset_configs): + if "name" not in dataset_config: + raise ValueError(f"Dataset config {i} missing 'name' field") + + dataset_type = dataset_config.get("dataset_type", "generic") + if dataset_type not in ["generic", "math_only", "code_only"]: + self.logger.warning( + f"Unknown dataset_type '{dataset_type}' in config {i}, using 'generic'" + ) + dataset_config["dataset_type"] = "generic" + + # Check eval percentage + if not (0.0 <= config.eval_set_percentage <= 1.0): + raise ValueError( + f"eval_set_percentage must be between 0.0 and 1.0, got {config.eval_set_percentage}" + ) + + # Check supported formats + if config.supported_formats: + invalid_formats = [ + f for f in config.supported_formats if not isinstance(f, AnswerFormat) + ] + if invalid_formats: + raise ValueError( + f"Invalid formats in supported_formats: {invalid_formats}" + ) + + # Check equivalent ratio parameters + if config.format_group_threshold < 1: + raise ValueError( + f"format_group_threshold must be at least 1, got {config.format_group_threshold}" + ) + + if self.debug_logging: + self.logger.info("Configuration validation passed") + if config.ensure_equivalent_ratios: + self.logger.info( + f"Equivalent ratio enforcement enabled with threshold of {config.format_group_threshold} successful groups per format" # noqa: E501 + ) + + def _should_exclude_format_for_balance(self, answer_format: AnswerFormat) -> bool: + """Check if a format should be excluded due to equivalent ratio enforcement.""" + if not self.ensure_equivalent_ratios: + return False + + format_name = answer_format.value + successful_groups = self.format_successful_groups.get(format_name, 0) + + # If this format has reached the threshold, check if others are still below + if successful_groups >= self.format_group_threshold: + # Check if there are any formats still below the threshold + min_successful_groups = ( + min(self.format_successful_groups.values()) + if self.format_successful_groups + else 0 + ) + if min_successful_groups < self.format_group_threshold: + if self.debug_logging: + self.logger.debug( + f"Excluding format {format_name} (has {successful_groups}, min is {min_successful_groups})" + ) + return True + + return False + + def _get_balanced_formats( + self, available_formats: List[AnswerFormat] + ) -> List[AnswerFormat]: + """Filter formats based on equivalent ratio requirements.""" + if not self.ensure_equivalent_ratios: + return available_formats + + # Filter out formats that have reached the threshold while others haven't + balanced_formats = [ + f + for f in available_formats + if not self._should_exclude_format_for_balance(f) + ] + + # If all formats are excluded (shouldn't happen in practice), allow all + if not balanced_formats: + if self.debug_logging: + self.logger.warning( + "All formats excluded by equivalent ratio enforcement, allowing all formats" + ) + return available_formats + + if self.debug_logging and len(balanced_formats) != len(available_formats): + excluded_count = len(available_formats) - len(balanced_formats) + self.logger.debug( + f"Equivalent ratio enforcement: excluded {excluded_count} formats that reached threshold" + ) + + return balanced_formats + + def get_equivalent_ratio_status(self) -> Dict[str, Any]: + """Get current status of equivalent ratio enforcement for debugging/monitoring.""" + if not self.ensure_equivalent_ratios: + return {"enabled": False} + + total_formats = len(self.base_supported_formats) + formats_at_threshold = sum( + 1 + for count in self.format_successful_groups.values() + if count >= self.format_group_threshold + ) + + status = { + "enabled": True, + "threshold": self.format_group_threshold, + "total_formats": total_formats, + "formats_with_data": len(self.format_successful_groups), + "formats_at_threshold": formats_at_threshold, + "completion_percentage": ( + (formats_at_threshold / total_formats) * 100 if total_formats > 0 else 0 + ), + } + + if self.format_successful_groups: + successful_counts = list(self.format_successful_groups.values()) + status.update( + { + "min_successful_groups": min(successful_counts), + "max_successful_groups": max(successful_counts), + "avg_successful_groups": sum(successful_counts) + / len(successful_counts), + } + ) + + # Top 5 formats by successful groups + sorted_formats = sorted( + self.format_successful_groups.items(), key=lambda x: x[1], reverse=True + ) + status["top_formats"] = sorted_formats[:5] + + # Bottom 5 formats by successful groups (that have any data) + status["bottom_formats"] = ( + sorted_formats[-5:] if len(sorted_formats) >= 5 else sorted_formats + ) + + return status + + def _generate_system_prompt(self, answer_format: AnswerFormat) -> str: + """Generate a system prompt for the specified answer format.""" + base_prompt = ( + "You are an AI assistant that provides helpful responses. " + "You may use extremely long chains of thought to deeply consider the " + "problem and deliberate with yourself via systematic reasoning processes " + "to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside tags. " + "CRITICAL FORMAT REQUIREMENT: After your thinking, you must provide your answer in the EXACT format specified below. " # noqa: E501 + "Use the specified format EXACTLY ONCE and ONLY ONCE. Do not use the format multiple times or in any other way." # noqa: E501 + ) + + format_instructions = { + # Basic structured data formats (answer only) + AnswerFormat.JSON: ( + "CRITICAL: After your thinking, provide your answer as a JSON object with an 'answer' field EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + 'Example: {"answer": "your response here"}' + ), + AnswerFormat.JSON_ARRAY: ( + "CRITICAL: After your thinking, provide your answer as a JSON array with one element EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + 'Example: ["your response here"]' + ), + AnswerFormat.JSON_SIMPLE: ( + "CRITICAL: After your thinking, provide your answer as a simple JSON string EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + 'Example: "your response here"' + ), + AnswerFormat.YAML: ( + "CRITICAL: After your thinking, provide your answer in YAML format with an 'answer' field EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + "Example:\nanswer: your response here" + ), + AnswerFormat.YAML_LIST: ( + "CRITICAL: After your thinking, provide your answer as a YAML list with one item EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + "Example:\n- your response here" + ), + AnswerFormat.TOML: ( + "CRITICAL: After your thinking, provide your answer in TOML format with an 'answer' field EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + 'Example:\nanswer = "your response here"' + ), + AnswerFormat.TOML_SECTION: ( + "CRITICAL: After your thinking, provide your answer in TOML format with a section EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + 'Example:\n[response]\nanswer = "your response here"' + ), + # Structured data with confidence scores + AnswerFormat.JSON_CONFIDENCE: ( + "CRITICAL: After your thinking, provide your answer as JSON with answer and confidence fields EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + 'Example: {"answer": "your response here", "confidence": 0.9}' + ), + AnswerFormat.YAML_CONFIDENCE: ( + "CRITICAL: After your thinking, provide your answer in YAML with answer and confidence fields EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example:\nanswer: your response here\nconfidence: 0.9" + ), + AnswerFormat.TOML_CONFIDENCE: ( + "CRITICAL: After your thinking, provide your answer in TOML with answer and confidence fields EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example:\nanswer = "your response here"\nconfidence = 0.9' + ), + # XML/HTML tag variations (XML now uses answer tags) + AnswerFormat.XML: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.XML_FINAL_ANSWER: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags with 'Final Answer:' prefix EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " + "Example: Final Answer: your response here" + ), + AnswerFormat.OUTPUT_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.RESULT_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.RESPONSE_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.FINAL_ANSWER_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.SOLUTION_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.CONCLUSION_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), + AnswerFormat.REPLY_TAGS: ( + "CRITICAL: After your thinking, provide your answer enclosed in tags. " # noqa: E501 + "Example: your response here" + ), # noqa: E501 + AnswerFormat.NESTED_RESPONSE_ANSWER: ( + "CRITICAL: After your thinking, provide explanation and answer in nested response tags EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: Brief explanation of your reasoning\nyour final answer here" # noqa: E501 + ), + AnswerFormat.NESTED_SOLUTION_ANSWER: ( + "CRITICAL: After your thinking, provide explanation and answer in nested solution tags EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: Brief explanation of the approach\nyour final answer here" # noqa: E501 + ), + AnswerFormat.NESTED_OUTPUT_RESULT: ( + "CRITICAL: After your thinking, provide explanation and result in nested output tags EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: Brief explanation of the process\nyour final result here" # noqa: E501 + ), + AnswerFormat.NESTED_ANALYSIS_CONCLUSION: ( + "CRITICAL: After your thinking, provide analysis and conclusion in nested tags EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: Brief analysis of the problem\nyour final conclusion here" # noqa: E501 + ), + AnswerFormat.NESTED_REASONING_ANSWER: ( + "CRITICAL: After your thinking, provide reasoning and answer in nested tags EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: Brief reasoning summary\nyour final answer here" # noqa: E501 + ), + # LaTeX formats (text-friendly) + AnswerFormat.LATEX_BOXED: ( + "CRITICAL: After your thinking, provide your answer using LaTeX boxed notation EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: \\boxed{your answer here}" + ), + AnswerFormat.LATEX_TEXTBF: ( + "CRITICAL: After your thinking, provide your answer using LaTeX bold text notation EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: \\textbf{your answer here}" + ), + AnswerFormat.LATEX_TEXTIT: ( + "CRITICAL: After your thinking, provide your answer using LaTeX italic text notation EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: \\textit{your answer here}" + ), + AnswerFormat.LATEX_UNDERLINE: ( + "CRITICAL: After your thinking, provide your answer using LaTeX underline notation EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: \\underline{your answer here}" + ), + # LaTeX formats (math-only - for mathematical content) + AnswerFormat.LATEX_BOXED_MATH: ( + "CRITICAL: After your thinking, provide your answer using LaTeX math boxed notation EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical expressions. " # noqa: E501 + "Example: $\\boxed{x = 42}$" + ), + AnswerFormat.LATEX_ALIGN: ( + "CRITICAL: After your thinking, provide your answer within LaTeX align blocks EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical expressions. " # noqa: E501 + "Example:\n\\begin{align}\nx &= 42 \\\\\ny &= x + 1\n\\end{align}" + ), + AnswerFormat.LATEX_EQUATION: ( + "CRITICAL: After your thinking, provide your answer within LaTeX equation blocks EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical expressions. " # noqa: E501 + "Example:\n\\begin{equation}\nx = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}\n\\end{equation}" # noqa: E501 + ), + AnswerFormat.LATEX_DISPLAYMATH: ( + "CRITICAL: After your thinking, provide your answer within LaTeX display math delimiters EXACTLY ONCE. " # noqa: E501 + "Example: \\[x = \\frac{a + b}{c}\\]" + ), + AnswerFormat.LATEX_INLINE_MATH: ( + "CRITICAL: After your thinking, provide your answer within LaTeX inline math delimiters EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical expressions. " # noqa: E501 + "Example: $x = 42$" + ), + AnswerFormat.LATEX_TEXT_MATH: ( + "CRITICAL: After your thinking, provide your answer using LaTeX text within math mode EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for text answers in mathematical context. " # noqa: E501 + "Example: $\\text{your answer here}$" + ), + AnswerFormat.LATEX_MATHRM: ( + "CRITICAL: After your thinking, provide your answer using LaTeX mathrm within math mode EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for roman text in mathematical context. " # noqa: E501 + "Example: $\\mathrm{your answer here}$" + ), + AnswerFormat.LATEX_THEREFORE: ( + "CRITICAL: After your thinking, provide your answer using LaTeX therefore symbol EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical conclusions. " # noqa: E501 + "Example: $\\therefore \\text{your answer here}$" + ), + AnswerFormat.LATEX_IMPLIES: ( + "CRITICAL: After your thinking, provide your answer using LaTeX implies symbol EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical implications. " # noqa: E501 + "Example: $\\implies \\text{your answer here}$" + ), + AnswerFormat.LATEX_EQUIV: ( + "CRITICAL: After your thinking, provide your answer using LaTeX equivalence symbol EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for mathematical equivalences. " # noqa: E501 + "Example: $\\text{answer} \\equiv \\text{your solution here}$" + ), + AnswerFormat.LATEX_MATRIX: ( + "CRITICAL: After your thinking, provide your answer within LaTeX matrix environment EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for matrix/array answers. " # noqa: E501 + "Example: $\\begin{matrix} \\text{your answer here} \\end{matrix}$" + ), + AnswerFormat.LATEX_PMATRIX: ( + "CRITICAL: After your thinking, provide your answer within LaTeX pmatrix environment EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for parenthesized matrix answers. " # noqa: E501 + "Example: $\\begin{pmatrix} \\text{your answer here} \\end{pmatrix}$" + ), + # Markdown formats + AnswerFormat.MARKDOWN_CODE: ( + "CRITICAL: After your thinking, provide your answer in a markdown code block without language specification. " # noqa: E501 + "Example:\n```\nyour answer here\n```" + ), + AnswerFormat.MARKDOWN_BOLD: ( + "CRITICAL: After your thinking, provide your answer in bold markdown formatting. " + "Example: **your answer here**" + ), + AnswerFormat.MARKDOWN_ITALIC: ( + "CRITICAL: After your thinking, provide your answer in italic markdown formatting. " + "Example: *your answer here*" + ), + AnswerFormat.MARKDOWN_HEADER: ( + "CRITICAL: After your thinking, provide your answer as a markdown header. " + "Example: ## your answer here" + ), + AnswerFormat.MARKDOWN_QUOTE: ( + "CRITICAL: After your thinking, provide your answer as a markdown quote. " + "Example: > your answer here" + ), + # Bracket and delimiter formats + AnswerFormat.SQUARE_BRACKETS: ( + "CRITICAL: After your thinking, provide your answer enclosed in square brackets. " + "Example: [your answer here]" + ), + AnswerFormat.DOUBLE_SQUARE_BRACKETS: ( + "CRITICAL: After your thinking, provide your answer enclosed in double square brackets. " + "Example: [[your answer here]]" + ), + AnswerFormat.CURLY_BRACES: ( + "CRITICAL: After your thinking, provide your answer enclosed in curly braces. " + "Example: {your answer here}" + ), + AnswerFormat.PARENTHESES: ( + "CRITICAL: After your thinking, provide your answer enclosed in parentheses. " + "Example: (your answer here)" + ), + AnswerFormat.ANGLE_BRACKETS: ( + "CRITICAL: After your thinking, provide your answer enclosed in angle brackets. " + "Example: " + ), + # Natural language patterns + AnswerFormat.NATURAL_LANGUAGE_ANSWER: ( + "CRITICAL: After your thinking, provide your answer using 'The answer is:' pattern. " + "Example: The answer is: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_FINAL: ( + "CRITICAL: After your thinking, provide your answer using 'Final answer:' pattern. " + "Example: Final answer: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_CONCLUSION: ( + "CRITICAL: After your thinking, provide your answer using 'In conclusion:' pattern. " + "Example: In conclusion: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_THEREFORE: ( + "CRITICAL: After your thinking, provide your answer using 'Therefore:' pattern. " + "Example: Therefore: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_RESULT: ( + "CRITICAL: After your thinking, provide your answer using 'The result is:' pattern. " + "Example: The result is: your response here" + ), + # Additional natural language patterns + AnswerFormat.NATURAL_LANGUAGE_BEST: ( + "CRITICAL: After your thinking, provide your answer using 'The best answer is:' pattern. " + "Example: The best answer is: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_MY_FINAL: ( + "CRITICAL: After your thinking, provide your answer using 'My final answer is:' pattern. " + "Example: My final answer is: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_CORRECT: ( + "CRITICAL: After your thinking, provide your answer using 'The correct answer is:' pattern. " + "Example: The correct answer is: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_SOLUTION: ( + "CRITICAL: After your thinking, provide your answer using 'The solution is:' pattern. " + "Example: The solution is: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_RESPONSE: ( + "CRITICAL: After your thinking, provide your answer using 'My response is:' pattern. " + "Example: My response is: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_ULTIMATELY: ( + "CRITICAL: After your thinking, provide your answer using 'Ultimately:' pattern. " + "Example: Ultimately: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_THUS: ( + "CRITICAL: After your thinking, provide your answer using 'Thus:' pattern. " + "Example: Thus: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_HENCE: ( + "CRITICAL: After your thinking, provide your answer using 'Hence:' pattern. " + "Example: Hence: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_CONSEQUENTLY: ( + "CRITICAL: After your thinking, provide your answer using 'Consequently:' pattern. " + "Example: Consequently: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_TO_SUMMARIZE: ( + "CRITICAL: After your thinking, provide your answer using 'To summarize:' pattern. " + "Example: To summarize: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_IN_SUMMARY: ( + "CRITICAL: After your thinking, provide your answer using 'In summary:' pattern. " + "Example: In summary: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_OVERALL: ( + "CRITICAL: After your thinking, provide your answer using 'Overall:' pattern. " + "Example: Overall: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_FINAL_VERDICT: ( + "CRITICAL: After your thinking, provide your answer using 'Final verdict:' pattern. " + "Example: Final verdict: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_BOTTOM_LINE: ( + "CRITICAL: After your thinking, provide your answer using 'Bottom line:' pattern. " + "Example: Bottom line: your response here" + ), + AnswerFormat.NATURAL_LANGUAGE_KEY_POINT: ( + "CRITICAL: After your thinking, provide your answer using 'The key point is:' pattern. " # noqa: E501 + "Example: The key point is: your response here" # noqa: E501 + ), + # Special formats + AnswerFormat.TEXTARENA_FORMAT: ( + "CRITICAL: After your thinking, provide your answer in TextArena format using square brackets with a letter. " # noqa: E501 + "Example: [A] or [B] or [C] or [D]" + ), + AnswerFormat.COLON_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with 'Answer:' prefix. " + "Example: Answer: your response here" + ), + AnswerFormat.ARROW_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with arrow notation. " + "Example: => your answer here" + ), + # HTML formats + AnswerFormat.HTML_CODE: ( + "CRITICAL: After your thinking, provide your answer within HTML code tags. " + "Example: your answer here" + ), + AnswerFormat.HTML_PRE: ( + "CRITICAL: After your thinking, provide your answer within HTML pre tags. " + "Example:
your answer here
" + ), + AnswerFormat.HTML_SPAN: ( + "CRITICAL: After your thinking, provide your answer within HTML span tags. " + "Example: your answer here" + ), + AnswerFormat.HTML_DIV: ( + "CRITICAL: After your thinking, provide your answer within HTML div tags. " + "Example:
your answer here
" + ), + AnswerFormat.HTML_P: ( + "CRITICAL: After your thinking, provide your answer within HTML p tags. " + "Example:

your answer here

" + ), + # Multiple structured tags + AnswerFormat.MULTIPLE_TAGS: ( + "CRITICAL: After your thinking, provide your response using multiple structured tags. " + "Example:\nyour reasoning\nyour final answer\nadditional context" # noqa: E501 + ), + # Complex multi-tag formats for specific domains + AnswerFormat.COMPLEX_CODING_FORMAT: ( + "CRITICAL: After your thinking, provide your response in the following structured format for coding problems. " # noqa: E501 + "Note: The and other sections should summarize your previous thinking, not repeat the detailed thought process. " # noqa: E501 + "Format:\n" # noqa: E501 + "Restate the problem clearly\n" # noqa: E501 + "\nFirst key insight\nSecond key insight\n\n" # noqa: E501 + "\nFirst step\nSecond step\n\n" # noqa: E501 + "\nFirst schema\nSecond schema\n\n" # noqa: E501 + "UML workflow diagram in natural language\n" # noqa: E501 + "Internal critique and validation\n" # noqa: E501 + "Your code solution\n" # noqa: E501 + "Explanation of the code\n" # noqa: E501 + "Unit test code" # noqa: E501 + ), + AnswerFormat.COMPLEX_CODING_SIMPLE: ( + "CRITICAL: After your thinking, provide your response in the following structured format for coding problems. " # noqa: E501 + "Note: The sections should summarize your previous thinking, not repeat detailed thought processes. " # noqa: E501 + "Format:\n" # noqa: E501 + "Restate the problem\n" # noqa: E501 + "\nKey insight 1\nKey insight 2\n\n" # noqa: E501 + "\nStep 1\nStep 2\n\n" # noqa: E501 + "Your code solution\n" # noqa: E501 + "Code explanation" # noqa: E501 + ), + AnswerFormat.COMPLEX_CODING_MINIMAL: ( + "CRITICAL: After your thinking, provide your response in the following structured format for coding problems. " # noqa: E501 + "Note: The sections should summarize your previous thinking. " # noqa: E501 + "Format:\n" # noqa: E501 + "Problem analysis summary\n" # noqa: E501 + "Solution approach\n" # noqa: E501 + "Your code solution\n" # noqa: E501 + "Test cases or validation" # noqa: E501 + ), + AnswerFormat.COMPLEX_MATH_FORMAT: ( + "CRITICAL: After your thinking, provide your response in the following structured format for math problems. " # noqa: E501 + "Note: The sections should summarize your previous thinking, not repeat detailed calculations. " # noqa: E501 + "Format:\n" # noqa: E501 + "Restate the mathematical problem\n" # noqa: E501 + "\nFirst mathematical insight\nSecond insight\n\n" # noqa: E501 + "Mathematical approach and strategy\n" # noqa: E501 + "Step-by-step mathematical derivation\n" # noqa: E501 + "Final mathematical solution\n" # noqa: E501 + "Verification of the solution" # noqa: E501 + ), + AnswerFormat.COMPLEX_MATH_SIMPLE: ( + "CRITICAL: After your thinking, provide your response in the following structured format for math problems. " # noqa: E501 + "Note: The sections should summarize your previous thinking. " # noqa: E501 + "Format:\n" # noqa: E501 + "Analysis of the mathematical problem\n" # noqa: E501 + "Step-by-step solution process\n" # noqa: E501 + "The final mathematical answer" # noqa: E501 + ), + AnswerFormat.COMPLEX_GENERAL_FORMAT: ( + "CRITICAL: After your thinking, provide your response in the following structured format for general problems. " # noqa: E501 + "Note: The sections should summarize your previous thinking, not repeat detailed analysis. " # noqa: E501 + "Format:\n" # noqa: E501 + "Restate the problem or question\n" # noqa: E501 + "\nFirst key aspect\nSecond key aspect\n\n" # noqa: E501 + "Logical reasoning process\n" # noqa: E501 + "Your conclusion or answer\n" # noqa: E501 + "Reflection on the solution quality" # noqa: E501 + ), + AnswerFormat.COMPLEX_GENERAL_SIMPLE: ( + "CRITICAL: After your thinking, provide your response in the following structured format. " # noqa: E501 + "Note: The sections should summarize your previous thinking. " # noqa: E501 + "Format:\n" # noqa: E501 + "Problem or question analysis\n" # noqa: E501 + "Key reasoning points\n" # noqa: E501 + "Your final conclusion or answer" # noqa: E501 + ), + AnswerFormat.COMPLEX_RESEARCH_FORMAT: ( + "CRITICAL: After your thinking, provide your response in the following research-style structured format. " # noqa: E501 + "Note: The sections should summarize your previous thinking and analysis. " # noqa: E501 + "Format:\n" # noqa: E501 + "Research hypothesis or question\n" # noqa: E501 + "Approach and methodology\n" # noqa: E501 + "Data analysis and findings\n" # noqa: E501 + "Key findings and results\n" # noqa: E501 + "Research conclusion" # noqa: E501 + ), + # Advanced scratchpad and classification formats + AnswerFormat.COMPLEX_SCRATCHPAD_FULL: ( + "CRITICAL: After your thinking, provide your response in the following comprehensive scratchpad format. " # noqa: E501 + "Note: The sections should summarize your previous thinking, not repeat detailed analysis. " # noqa: E501 + "Format:\n" # noqa: E501 + "\n" # noqa: E501 + "Restate the problem clearly\n" # noqa: E501 + "\nFirst insight\nSecond insight\n\n" # noqa: E501 + "\nFirst step\nCritique of step 1\nSecond step\nCritique of step 2\n\n" # noqa: E501 + "\nFirst citation with bibtext in markdown codeblock\n\n" # noqa: E501 + "\nFirst schema\n\n" # noqa: E501 + "UML workflow diagram in natural language\n" # noqa: E501 + "Harsh internal critique of the plan\n" # noqa: E501 + "\nRevised step 1\nRevised step 2\n\n" # noqa: E501 + "\n" # noqa: E501 + "Detailed solution\n" # noqa: E501 + "Solution explanation\n" # noqa: E501 + "Unit test code (if applicable)" # noqa: E501 + ), + AnswerFormat.COMPLEX_SCRATCHPAD_SIMPLE: ( + "CRITICAL: After your thinking, provide your response in the following simple scratchpad format. " # noqa: E501 + "Note: The sections should summarize your previous thinking. " # noqa: E501 + "Format:\n" # noqa: E501 + "\n" # noqa: E501 + "Restate the problem\n" # noqa: E501 + "\nKey insight 1\nKey insight 2\n\n" # noqa: E501 + "\nStep 1\nStep 2\n\n" # noqa: E501 + "Mermaid workflow diagram in natural language\n" # noqa: E501 + "\n" # noqa: E501 + "Then provide your solution following the plan step by step." # noqa: E501 + ), + AnswerFormat.COMPLEX_CLASSIFICATION_FORMAT: ( + "CRITICAL: After your thinking, provide your response in the following classification format. " # noqa: E501 + "Analyze the given text and provide binary classifications with explanations. " # noqa: E501 + "Format:\n" # noqa: E501 + "Explanation for completeness assessment\n" # noqa: E501 + "TRUE or FALSE\n" # noqa: E501 + "Explanation for grammar assessment\n" # noqa: E501 + "TRUE or FALSE\n" # noqa: E501 + "Explanation for scientific nature assessment\n" # noqa: E501 + "TRUE or FALSE\n" # noqa: E501 + "Explanation for programming nature assessment\n" # noqa: E501 + "TRUE or FALSE" # noqa: E501 + ), + AnswerFormat.COMPLEX_ANALYSIS_WITH_ANSWER: ( + "CRITICAL: After your thinking, provide your response in the following analysis format that ends with a specific answer. " # noqa: E501 + "Note: The sections should summarize your previous thinking. " # noqa: E501 + "Format:\n" # noqa: E501 + "Comprehensive analysis of the problem\n" # noqa: E501 + "Approach and methods used\n" # noqa: E501 + "Key findings and insights\n" # noqa: E501 + "Your final answer in the format: ANSWER IS: [your answer here]" # noqa: E501 + ), + AnswerFormat.COMPLEX_EVALUATION_FORMAT: ( + "CRITICAL: After your thinking, provide your response in the following evaluation format. " # noqa: E501 + "Note: The sections should summarize your previous thinking and analysis. " # noqa: E501 + "Format:\n" # noqa: E501 + "Evaluation criteria and standards\n" # noqa: E501 + "\nAssessment of first criterion\nScore for criterion 1 (1-10)\nAssessment of second criterion\nScore for criterion 2 (1-10)\n\n" # noqa: E501 + "Overall score (1-10)\n" # noqa: E501 + "Final judgment and recommendation" # noqa: E501 + ), + # Custom delimiters + AnswerFormat.PIPE_DELIMITED: ( + "CRITICAL: After your thinking, provide your answer enclosed in pipe delimiters. " # noqa: E501 + "Example: |your answer here|" # noqa: E501 + ), + AnswerFormat.HASH_DELIMITED: ( + "CRITICAL: After your thinking, provide your answer enclosed in hash delimiters. " # noqa: E501 + "Example: #your answer here#" # noqa: E501 + ), + AnswerFormat.UNDERSCORE_DELIMITED: ( + "CRITICAL: After your thinking, provide your answer enclosed in underscore delimiters. " # noqa: E501 + "Example: _your answer here_" # noqa: E501 + ), + AnswerFormat.TILDE_DELIMITED: ( + "CRITICAL: After your thinking, provide your answer enclosed in tilde delimiters. " # noqa: E501 + "Example: ~your answer here~" # noqa: E501 + ), + # Programming-style formats + AnswerFormat.FUNCTION_CALL: ( + "CRITICAL: After your thinking, provide your answer as a function call with quoted string EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: answer("your response here")' # noqa: E501 + ), + AnswerFormat.VARIABLE_ASSIGNMENT: ( + "CRITICAL: After your thinking, provide your answer as a variable assignment with quoted string EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: answer = "your response here"' # noqa: E501 + ), + AnswerFormat.RETURN_STATEMENT: ( + "CRITICAL: After your thinking, provide your answer as a return statement with quoted string EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: return "your response here"' # noqa: E501 + ), + # Additional easy-to-parse formats + AnswerFormat.EQUALS_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with an equals sign prefix. " + "Example: = your response here" + ), + AnswerFormat.DASH_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with a dash prefix. " + "Example: - your response here" + ), + AnswerFormat.PLUS_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with a plus sign prefix. " + "Example: + your response here" + ), + AnswerFormat.STAR_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with a star prefix. " + "Example: * your response here" + ), + AnswerFormat.PERCENT_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with a percent sign prefix. " + "Example: % your response here" + ), + AnswerFormat.AMPERSAND_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with an ampersand prefix. " + "Example: & your response here" + ), + AnswerFormat.AT_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with an at sign prefix. " + "Example: @ your response here" + ), + AnswerFormat.EXCLAMATION_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with an exclamation mark prefix. " + "Example: ! your response here" + ), + AnswerFormat.QUESTION_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with a question mark prefix. " + "Example: ? your response here" + ), + AnswerFormat.SEMICOLON_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with a semicolon prefix. " + "Example: ; your response here" + ), + AnswerFormat.DOUBLE_COLON_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with double colon prefix. " + "Example: :: your response here" + ), + AnswerFormat.TRIPLE_DASH_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with triple dash prefix. " + "Example: --- your response here" + ), + AnswerFormat.DOUBLE_ARROW_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with double arrow prefix. " + "Example: >> your response here" + ), + AnswerFormat.TRIPLE_ARROW_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with triple arrow prefix. " + "Example: >>> your response here" + ), + AnswerFormat.BACKTICK_FORMAT: ( + "CRITICAL: After your thinking, provide your answer enclosed in single backticks. " + "Example: `your response here`" + ), + AnswerFormat.DOUBLE_BACKTICK_FORMAT: ( + "CRITICAL: After your thinking, provide your answer enclosed in double backticks. " + "Example: ``your response here``" + ), + AnswerFormat.QUOTE_FORMAT: ( + "CRITICAL: After your thinking, provide your answer enclosed in double quotes. " + 'Example: "your response here"' + ), + AnswerFormat.SINGLE_QUOTE_FORMAT: ( + "CRITICAL: After your thinking, provide your answer enclosed in single quotes. " + "Example: 'your response here'" + ), + AnswerFormat.TRIPLE_QUOTE_FORMAT: ( + "CRITICAL: After your thinking, provide your answer enclosed in triple quotes. " + 'Example: """your response here"""' + ), + AnswerFormat.ANSWER_IS_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with 'ANSWER IS:' prefix in all caps. " # noqa: E501 + "Example: ANSWER IS: your response here" # noqa: E501 + ), + AnswerFormat.SOLUTION_IS_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with 'SOLUTION IS:' prefix in all caps. " # noqa: E501 + "Example: SOLUTION IS: your response here" # noqa: E501 + ), + AnswerFormat.RESULT_IS_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with 'RESULT IS:' prefix in all caps. " # noqa: E501 + "Example: RESULT IS: your response here" # noqa: E501 + ), + AnswerFormat.OUTPUT_IS_FORMAT: ( + "CRITICAL: After your thinking, provide your answer with 'OUTPUT IS:' prefix in all caps. " # noqa: E501 + "Example: OUTPUT IS: your response here" # noqa: E501 + ), + # Code-specific formats (for code datasets) + AnswerFormat.PYTHON_PRINT: ( + "CRITICAL: After your thinking, provide your answer as a Python print statement EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: print("your answer here")' # noqa: E501 + ), + AnswerFormat.JAVASCRIPT_CONSOLE: ( + "CRITICAL: After your thinking, provide your answer as a JavaScript console.log statement EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: console.log("your answer here")' # noqa: E501 + ), + AnswerFormat.PYTHON_COMMENT: ( + "CRITICAL: After your thinking, provide your answer as a Python comment EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: # your answer here" # noqa: E501 + ), + AnswerFormat.JAVASCRIPT_COMMENT: ( + "CRITICAL: After your thinking, provide your answer as a JavaScript comment EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: // your answer here" # noqa: E501 + ), + AnswerFormat.C_COMMENT: ( + "CRITICAL: After your thinking, provide your answer as a C-style block comment EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example: /* your answer here */" # noqa: E501 + ), + AnswerFormat.SHELL_ECHO: ( + "CRITICAL: After your thinking, provide your answer as a shell echo command EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: echo "your answer here"' # noqa: E501 + ), + AnswerFormat.SHELL_OUTPUT: ( + "CRITICAL: After your thinking, provide your answer as shell command output EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. This format is for code-related content. " # noqa: E501 + "Example: $ your answer here" # noqa: E501 + ), + AnswerFormat.PYTHON_DOCSTRING: ( + "CRITICAL: After your thinking, provide your answer as a Python docstring EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: """your answer here"""' # noqa: E501 + ), + AnswerFormat.INI_FORMAT: ( + "CRITICAL: After your thinking, provide your answer in INI configuration format EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + "Example:\n[section]\nanswer = your response here" # noqa: E501 + ), + AnswerFormat.ENV_FORMAT: ( + "CRITICAL: After your thinking, provide your answer as an environment variable EXACTLY ONCE. " # noqa: E501 + "Use this format only once in your entire response. " # noqa: E501 + 'Example: ANSWER="your response here"' # noqa: E501 + ), + } + + # Dynamic formats are commented out - all formats use static instructions + instruction = format_instructions.get( + answer_format, "Provide your answer after thinking." + ) + + return f"{base_prompt}\n\n{instruction}\n\nREMEMBER: Use the specified format EXACTLY ONCE and ONLY ONCE. Do not repeat the format or use it multiple times anywhere in your response." # noqa: E501 + + def _generate_dynamic_format_instruction( + self, answer_format: AnswerFormat + ) -> Tuple[str, List[str]]: + """Generate dynamic format instructions by randomly selecting components. + + Returns: + Tuple of (instruction_text, selected_components) + """ + + # Define component pools + analysis_components = [ + "restatement", + "problem_analysis", + "context_analysis", + "requirements_analysis", + ] + + reasoning_components = [ + "reasoning", + "insights", + "key_points", + "considerations", + "assumptions", + ] + + planning_components = ["plan", "approach", "methodology", "strategy", "steps"] + + technical_components = [ + "schemas", + "diagrams", + "models", + "architecture", + "specifications", + ] + + execution_components = [ + "implementation", + "solution", + "code", + "execution", + "results", + ] + + reflection_components = [ + "reflection", + "evaluation", + "verification", + "testing", + "validation", + ] + + # Select 3-5 random components based on format type + if "scratchpad" in answer_format.value: + # Scratchpad formats get a mix of analysis, reasoning, and planning + available_components = ( + analysis_components + reasoning_components + planning_components + ) + elif "analysis" in answer_format.value: + # Analysis formats focus on analysis and reasoning + available_components = ( + analysis_components + reasoning_components + reflection_components + ) + elif "workflow" in answer_format.value: + # Workflow formats get planning, execution, and reflection + available_components = ( + planning_components + execution_components + reflection_components + ) + else: + # Default: mix of all + available_components = ( + analysis_components + + reasoning_components + + planning_components + + technical_components + ) + + # Randomly select 3-5 components + num_components = random.randint(3, 5) + selected_components = random.sample( + available_components, min(num_components, len(available_components)) + ) + + # Generate format-specific instructions + if "xml" in answer_format.value: + instruction = self._generate_xml_format_instruction(selected_components) + elif "json" in answer_format.value: + instruction = self._generate_json_format_instruction(selected_components) + elif "yaml" in answer_format.value: + instruction = self._generate_yaml_format_instruction(selected_components) # noqa + elif "toml" in answer_format.value: + instruction = self._generate_toml_format_instruction(selected_components) # noqa + else: + instruction = f"Please structure your response using the following components: {', '.join(selected_components)}" # noqa + + # Add component info to instruction for debugging + component_list = ", ".join(selected_components) + + if self.debug_logging: + self.logger.debug( + f"Generated dynamic format {answer_format.value} with components: {component_list}" # noqa + ) + + return ( + f"{instruction}\n\nSelected components: {component_list}", + selected_components, + ) + + def _generate_xml_format_instruction(self, components: List[str]) -> str: + """Generate XML format instruction with selected components.""" + component_examples = [] + for comp in components: + tag_name = comp.upper() + if comp in ["reasoning", "insights"]: + component_examples.append( + f"<{tag_name}>\nFirst key point\nSecond key point\n" # noqa + ) + elif comp in ["plan", "steps"]: + component_examples.append( + f"<{tag_name}>\nFirst step\nSecond step\n" # noqa + ) + elif comp in ["schemas", "models"]: + component_examples.append( + f"<{tag_name}>\nFirst schema\nSecond schema\n" # noqa + ) + else: + component_examples.append( + f"<{tag_name}>Content for {comp}" + ) + + format_example = "\n".join(component_examples) + return ( + f"CRITICAL: After your thinking, provide your response using the following XML format with these specific components. " # noqa + f"Note: The sections should summarize your previous thinking. " # noqa + f"Format:\n{format_example}" + ) + + def _generate_json_format_instruction(self, components: List[str]) -> str: + """Generate JSON format instruction with selected components.""" + json_fields = [] + for comp in components: + if comp in ["reasoning", "insights"]: + json_fields.append( + f'"{comp}": {{"point_1": "First key point", "point_2": "Second key point"}}' # noqa + ) + elif comp in ["plan", "steps"]: + json_fields.append( + f'"{comp}": {{"step_1": "First step", "step_2": "Second step"}}' # noqa + ) + elif comp in ["schemas", "models"]: + json_fields.append( + f'"{comp}": {{"schema_1": "First schema", "schema_2": "Second schema"}}' # noqa + ) + else: + json_fields.append(f'"{comp}": "Content for {comp}"') + + format_example = "{\n " + ",\n ".join(json_fields) + "\n}" + return ( + f"CRITICAL: After your thinking, provide your response as a JSON object with these specific fields. " # noqa + f"Note: The sections should summarize your previous thinking. " + f"Format:\n{format_example}" + ) + + def _generate_yaml_format_instruction(self, components: List[str]) -> str: + """Generate YAML format instruction with selected components.""" + yaml_fields = [] + for comp in components: + if comp in ["reasoning", "insights"]: + yaml_fields.append( + f"{comp}:\n point_1: First key point\n point_2: Second key point" # noqa + ) + elif comp in ["plan", "steps"]: + yaml_fields.append( + f"{comp}:\n step_1: First step\n step_2: Second step" # noqa + ) + elif comp in ["schemas", "models"]: + yaml_fields.append( + f"{comp}:\n schema_1: First schema\n schema_2: Second schema" # noqa + ) + else: + yaml_fields.append(f"{comp}: Content for {comp}") + + format_example = "\n".join(yaml_fields) + return ( + f"CRITICAL: After your thinking, provide your response in YAML format with these specific fields. " # noqa + f"Note: The sections should summarize your previous thinking. " # noqa + f"Format:\n{format_example}" + ) + + def _generate_toml_format_instruction(self, components: List[str]) -> str: + """Generate TOML format instruction with selected components.""" + toml_fields = [] + for comp in components: + if comp in ["reasoning", "insights"]: + toml_fields.append( + f'[{comp}]\npoint_1 = "First key point"\npoint_2 = "Second key point"' # noqa + ) + elif comp in ["plan", "steps"]: + toml_fields.append( + f'[{comp}]\nstep_1 = "First step"\nstep_2 = "Second step"' # noqa + ) + elif comp in ["schemas", "models"]: + toml_fields.append( + f'[{comp}]\nschema_1 = "First schema"\nschema_2 = "Second schema"' # noqa + ) + else: + toml_fields.append(f'{comp} = "Content for {comp}"') + + format_example = "\n\n".join(toml_fields) + return ( + f"CRITICAL: After your thinking, provide your response in TOML format with these specific fields. " # noqa + f"Note: The sections should summarize your previous thinking. " # noqa + f"Format:\n{format_example}" + ) + + def _get_formats_for_dataset_type(self, dataset_type: str) -> List[AnswerFormat]: + """Get appropriate formats for a given dataset type.""" + if dataset_type == "math_only": + # Math datasets can use generic + math-only formats + available_formats = list(self.generic_formats | self.math_only_formats) + elif dataset_type == "code_only": + # Code datasets can use generic + code-only formats + available_formats = list(self.generic_formats | self.code_only_formats) + else: # "generic" or any other type + # Generic datasets use only generic formats + available_formats = list(self.generic_formats) + + # Filter by base supported formats + filtered_formats = [ + f for f in available_formats if f in self.base_supported_formats + ] + + # Apply equivalent ratio filter if enabled + final_formats = self._get_balanced_formats(filtered_formats) + + if self.debug_logging: + self.logger.debug( + f"Dataset type '{dataset_type}': {len(filtered_formats)} available formats, " + f"{len(final_formats)} after equivalent ratio filtering" + ) + + return final_formats + + def _get_dynamic_format_patterns( + self, answer_format: AnswerFormat, stored_components: List[str] + ) -> List[str]: + """Generate specific validation patterns for dynamic formats based on selected components.""" # noqa + if not stored_components: + return [] + + patterns = [] + + if "xml" in answer_format.value: + # Generate specific XML tag patterns for each component + for component in stored_components: + tag_name = component.upper() + patterns.append(f"<{tag_name}>.*?") + elif "json" in answer_format.value: + # For JSON, we'll validate that it's a valid JSON object with the expected fields + patterns.append(r"\{.*?\}") # Basic JSON object pattern + elif "yaml" in answer_format.value: + # Generate specific YAML field patterns for each component + for component in stored_components: + patterns.append(f"^{component}\\s*:.*$") + elif "toml" in answer_format.value: + # Generate specific TOML field patterns for each component + for component in stored_components: + patterns.append(f"^{component}\\s*=.*$") + + return patterns + + def _validate_format_appears_exactly_once( + self, text: str, answer_format: AnswerFormat + ) -> bool: + """Validate that the specified format appears exactly once in the response section.""" + # Define patterns for each format to count occurrences + format_patterns = { + # Basic structured data formats (answer only) + AnswerFormat.JSON: [r'\{\s*"answer"\s*:\s*"[^"]*"\s*\}'], + AnswerFormat.JSON_ARRAY: [r'\[\s*"[^"]*"\s*\]'], + AnswerFormat.JSON_SIMPLE: [r'"[^"]*"'], + AnswerFormat.YAML: [r"^answer\s*:\s*.+$"], + AnswerFormat.YAML_LIST: [r"^-\s*.+$"], + AnswerFormat.TOML: [r"^answer\s*=\s*.+$"], + AnswerFormat.TOML_SECTION: [r"^\[response\]", r"^answer\s*=\s*.+$"], + # Structured data with confidence scores + AnswerFormat.JSON_CONFIDENCE: [ + r'\{\s*"answer"\s*:\s*"[^"]*"\s*,\s*"confidence"\s*:\s*[0-9.]+\s*\}' + ], + AnswerFormat.YAML_CONFIDENCE: [ + r"^answer\s*:\s*.+$", + r"^confidence\s*:\s*[0-9.]+$", + ], + AnswerFormat.TOML_CONFIDENCE: [ + r"^answer\s*=\s*.+$", + r"^confidence\s*=\s*[0-9.]+$", + ], + # XML/HTML tag variations (XML now uses answer tags) + AnswerFormat.XML: [r".*?"], + AnswerFormat.XML_FINAL_ANSWER: [r"Final Answer:.*?"], + AnswerFormat.OUTPUT_TAGS: [r".*?"], + AnswerFormat.RESULT_TAGS: [r".*?"], + AnswerFormat.RESPONSE_TAGS: [r".*?"], + AnswerFormat.FINAL_ANSWER_TAGS: [r".*?"], + AnswerFormat.SOLUTION_TAGS: [r".*?"], + AnswerFormat.CONCLUSION_TAGS: [r".*?"], + AnswerFormat.REPLY_TAGS: [r".*?"], + AnswerFormat.NESTED_RESPONSE_ANSWER: [ + r".*?.*?" + ], + AnswerFormat.NESTED_SOLUTION_ANSWER: [ + r".*?.*?" + ], + AnswerFormat.NESTED_OUTPUT_RESULT: [ + r".*?.*?" + ], + AnswerFormat.NESTED_ANALYSIS_CONCLUSION: [ + r".*?.*?" + ], + AnswerFormat.NESTED_REASONING_ANSWER: [ + r".*?.*?" + ], + # LaTeX formats (text-friendly) + AnswerFormat.LATEX_BOXED: [r"\\boxed\{[^}]+\}"], + AnswerFormat.LATEX_TEXTBF: [r"\\textbf\{[^}]+\}"], + AnswerFormat.LATEX_TEXTIT: [r"\\textit\{[^}]+\}"], + AnswerFormat.LATEX_UNDERLINE: [r"\\underline\{[^}]+\}"], + # LaTeX formats (math-only) + AnswerFormat.LATEX_BOXED_MATH: [r"\$\\boxed\{[^}]+\}\$"], + AnswerFormat.LATEX_ALIGN: [r"\\begin\{align\}.*?\\end\{align\}"], + AnswerFormat.LATEX_EQUATION: [r"\\begin\{equation\}.*?\\end\{equation\}"], + AnswerFormat.LATEX_DISPLAYMATH: [r"\\\\?\[.*?\\\\?\]"], + AnswerFormat.LATEX_INLINE_MATH: [r"\$[^$]+\$"], + AnswerFormat.LATEX_TEXT_MATH: [r"\$\\text\{[^}]+\}\$"], + AnswerFormat.LATEX_MATHRM: [r"\$\\mathrm\{[^}]+\}\$"], + AnswerFormat.LATEX_THEREFORE: [r"\$\\therefore[^$]+\$"], + AnswerFormat.LATEX_IMPLIES: [r"\$\\implies[^$]+\$"], + AnswerFormat.LATEX_EQUIV: [r"\$[^$]*\\equiv[^$]*\$"], + AnswerFormat.LATEX_MATRIX: [r"\$\\begin\{matrix\}.*?\\end\{matrix\}\$"], + AnswerFormat.LATEX_PMATRIX: [r"\$\\begin\{pmatrix\}.*?\\end\{pmatrix\}\$"], + # Markdown formats + AnswerFormat.MARKDOWN_CODE: [r"```\s*\n.*?\n```"], + AnswerFormat.MARKDOWN_BOLD: [r"\*\*[^*]+\*\*"], + AnswerFormat.MARKDOWN_ITALIC: [r"\*[^*]+\*"], + AnswerFormat.MARKDOWN_HEADER: [r"^##\s*.+?(?:\n|$)"], + AnswerFormat.MARKDOWN_QUOTE: [r"^>\s*.+?(?:\n|$)"], + # Bracket and delimiter formats + AnswerFormat.SQUARE_BRACKETS: [r"\[[^\]]+\]"], + AnswerFormat.DOUBLE_SQUARE_BRACKETS: [r"\[\[[^\]]+\]\]"], + AnswerFormat.CURLY_BRACES: [r"\{[^}]+\}"], + AnswerFormat.PARENTHESES: [r"\([^)]+\)"], + AnswerFormat.ANGLE_BRACKETS: [r"<[^>]+>"], + # Natural language patterns + AnswerFormat.NATURAL_LANGUAGE_ANSWER: [r"The answer is:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_FINAL: [r"Final answer:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_CONCLUSION: [ + r"In conclusion:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_THEREFORE: [r"Therefore:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_RESULT: [r"The result is:?\s*.+?(?:\n|$)"], + # Additional natural language patterns + AnswerFormat.NATURAL_LANGUAGE_BEST: [r"The best answer is:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_MY_FINAL: [ + r"My final answer is:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_CORRECT: [ + r"The correct answer is:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_SOLUTION: [ + r"The solution is:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_RESPONSE: [r"My response is:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_ULTIMATELY: [r"Ultimately:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_THUS: [r"Thus:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_HENCE: [r"Hence:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_CONSEQUENTLY: [ + r"Consequently:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_TO_SUMMARIZE: [ + r"To summarize:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_IN_SUMMARY: [r"In summary:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_OVERALL: [r"Overall:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_FINAL_VERDICT: [ + r"Final verdict:?\s*.+?(?:\n|$)" + ], + AnswerFormat.NATURAL_LANGUAGE_BOTTOM_LINE: [r"Bottom line:?\s*.+?(?:\n|$)"], + AnswerFormat.NATURAL_LANGUAGE_KEY_POINT: [ + r"The key point is:?\s*.+?(?:\n|$)" + ], + # Special formats + AnswerFormat.TEXTARENA_FORMAT: [r"\[[A-Da-d]\]"], + AnswerFormat.COLON_FORMAT: [r"Answer:?\s*.+?(?:\n|$)"], + AnswerFormat.ARROW_FORMAT: [ + r"=>\s*.+?(?:\n|$)", + r"->\s*.+?(?:\n|$)", + r"→\s*.+?(?:\n|$)", + ], + # HTML formats + AnswerFormat.HTML_CODE: [r".*?"], + AnswerFormat.HTML_PRE: [r"
.*?
"], + AnswerFormat.HTML_SPAN: [r".*?"], + AnswerFormat.HTML_DIV: [r"
.*?
"], + AnswerFormat.HTML_P: [r"

.*?

"], + # Multiple structured tags - special case + AnswerFormat.MULTIPLE_TAGS: [ + r".*?", + r".*?", + r".*?", + ], + # Complex multi-tag formats + AnswerFormat.COMPLEX_CODING_FORMAT: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_CODING_SIMPLE: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_CODING_MINIMAL: [ + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_MATH_FORMAT: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_MATH_SIMPLE: [ + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_GENERAL_FORMAT: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_GENERAL_SIMPLE: [ + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_RESEARCH_FORMAT: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + # Advanced scratchpad and classification formats + AnswerFormat.COMPLEX_SCRATCHPAD_FULL: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_SCRATCHPAD_SIMPLE: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_CLASSIFICATION_FORMAT: [ + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_ANALYSIS_WITH_ANSWER: [ + r".*?", + r".*?", + r".*?", + r".*?", + ], + AnswerFormat.COMPLEX_EVALUATION_FORMAT: [ + r".*?", + r".*?", + r".*?", + r".*?", + ], + # Dynamic formats - COMMENTED OUT: These formats are proving too difficult to parse reliably + # AnswerFormat.DYNAMIC_SCRATCHPAD_XML: [r'<[A-Z_]+>.*?'], + # AnswerFormat.DYNAMIC_SCRATCHPAD_JSON: [r'\{.*?\}'], + # AnswerFormat.DYNAMIC_SCRATCHPAD_YAML: [r'^[a-z_]+:.*$'], + # AnswerFormat.DYNAMIC_SCRATCHPAD_TOML: [r'^[a-z_]+\s*=.*$', r'^\[[a-z_]+\]$'], + # AnswerFormat.DYNAMIC_ANALYSIS_XML: [r'<[A-Z_]+>.*?'], + # AnswerFormat.DYNAMIC_ANALYSIS_JSON: [r'\{.*?\}'], + # AnswerFormat.DYNAMIC_WORKFLOW_XML: [r'<[A-Z_]+>.*?'], + # AnswerFormat.DYNAMIC_WORKFLOW_JSON: [r'\{.*?\}'], + # Custom delimiters + AnswerFormat.PIPE_DELIMITED: [r"\|[^|]+\|"], + AnswerFormat.HASH_DELIMITED: [r"#[^#]+#"], + AnswerFormat.UNDERSCORE_DELIMITED: [r"_[^_]+_"], + AnswerFormat.TILDE_DELIMITED: [r"~[^~]+~"], + # Programming-style formats + AnswerFormat.FUNCTION_CALL: [r"answer\([^)]*\)"], + AnswerFormat.VARIABLE_ASSIGNMENT: [r"answer\s*=\s*[^;\n]+"], + AnswerFormat.RETURN_STATEMENT: [r"return\s+[^;\n]+"], + # Additional easy-to-parse formats + AnswerFormat.EQUALS_FORMAT: [r"^=\s*.+?(?:\n|$)"], + AnswerFormat.DASH_FORMAT: [r"^-\s*.+?(?:\n|$)"], + AnswerFormat.PLUS_FORMAT: [r"^\+\s*.+?(?:\n|$)"], + AnswerFormat.STAR_FORMAT: [r"^\*\s*.+?(?:\n|$)"], + AnswerFormat.PERCENT_FORMAT: [r"^%\s*.+?(?:\n|$)"], + AnswerFormat.AMPERSAND_FORMAT: [r"^&\s*.+?(?:\n|$)"], + AnswerFormat.AT_FORMAT: [r"^@\s*.+?(?:\n|$)"], + AnswerFormat.EXCLAMATION_FORMAT: [r"^!\s*.+?(?:\n|$)"], + AnswerFormat.QUESTION_FORMAT: [r"^\?\s*.+?(?:\n|$)"], + AnswerFormat.SEMICOLON_FORMAT: [r"^;\s*.+?(?:\n|$)"], + AnswerFormat.DOUBLE_COLON_FORMAT: [r"^::\s*.+?(?:\n|$)"], + AnswerFormat.TRIPLE_DASH_FORMAT: [r"^---\s*.+?(?:\n|$)"], + AnswerFormat.DOUBLE_ARROW_FORMAT: [r"^>>\s*.+?(?:\n|$)"], + AnswerFormat.TRIPLE_ARROW_FORMAT: [r"^>>>\s*.+?(?:\n|$)"], + AnswerFormat.BACKTICK_FORMAT: [r"`[^`]+`"], + AnswerFormat.DOUBLE_BACKTICK_FORMAT: [r"``[^`]+``"], + AnswerFormat.QUOTE_FORMAT: [r'"[^"]+"'], + AnswerFormat.SINGLE_QUOTE_FORMAT: [r"'[^']+'"], + AnswerFormat.TRIPLE_QUOTE_FORMAT: [r'"""[^"]+"""'], + AnswerFormat.ANSWER_IS_FORMAT: [r"ANSWER IS:?\s*.+?(?:\n|$)"], + AnswerFormat.SOLUTION_IS_FORMAT: [r"SOLUTION IS:?\s*.+?(?:\n|$)"], + AnswerFormat.RESULT_IS_FORMAT: [r"RESULT IS:?\s*.+?(?:\n|$)"], + AnswerFormat.OUTPUT_IS_FORMAT: [r"OUTPUT IS:?\s*.+?(?:\n|$)"], + # Code-specific formats + AnswerFormat.PYTHON_PRINT: [r'print\s*\(\s*"[^"]*"\s*\)'], + AnswerFormat.JAVASCRIPT_CONSOLE: [r'console\.log\s*\(\s*"[^"]*"\s*\)'], + AnswerFormat.PYTHON_COMMENT: [r"^#\s*.+?(?:\n|$)"], + AnswerFormat.JAVASCRIPT_COMMENT: [r"^//\s*.+?(?:\n|$)"], + AnswerFormat.C_COMMENT: [r"/\*.*?\*/"], + AnswerFormat.SHELL_ECHO: [r'echo\s+"[^"]*"'], + AnswerFormat.SHELL_OUTPUT: [r"^\$\s*.+?(?:\n|$)"], + AnswerFormat.PYTHON_DOCSTRING: [r'"""[^"]*"""'], + AnswerFormat.INI_FORMAT: [r"^\[.*?\]", r"^.*?\s*=\s*.+?(?:\n|$)"], + AnswerFormat.ENV_FORMAT: [r'^[A-Z_]+\s*=\s*"[^"]*"'], + } + + patterns = format_patterns.get(answer_format, []) + + # Dynamic formats are commented out - no special handling needed + + if not patterns: + return True # If no patterns defined, assume valid + + total_matches = 0 + for pattern in patterns: + matches = re.findall( + pattern, text, re.DOTALL | re.IGNORECASE | re.MULTILINE + ) + total_matches += len(matches) + + # Special cases for formats that require multiple elements + if answer_format == AnswerFormat.MULTIPLE_TAGS: + return 2 <= total_matches <= 5 + elif answer_format == AnswerFormat.COMPLEX_CODING_FORMAT: + return total_matches == 9 # All 9 required tags + elif answer_format == AnswerFormat.COMPLEX_CODING_SIMPLE: + return total_matches == 5 # All 5 required tags + elif answer_format == AnswerFormat.COMPLEX_CODING_MINIMAL: + return total_matches == 4 # All 4 required tags + elif answer_format == AnswerFormat.COMPLEX_MATH_FORMAT: + return total_matches == 6 # All 6 required tags + elif answer_format == AnswerFormat.COMPLEX_MATH_SIMPLE: + return total_matches == 3 # All 3 required tags + elif answer_format == AnswerFormat.COMPLEX_GENERAL_FORMAT: + return total_matches == 5 # All 5 required tags + elif answer_format == AnswerFormat.COMPLEX_GENERAL_SIMPLE: + return total_matches == 3 # All 3 required tags + elif answer_format == AnswerFormat.COMPLEX_RESEARCH_FORMAT: + return total_matches == 5 # All 5 required tags + elif answer_format == AnswerFormat.COMPLEX_SCRATCHPAD_FULL: + return total_matches == 11 # All 11 required tags + elif answer_format == AnswerFormat.COMPLEX_SCRATCHPAD_SIMPLE: + return total_matches == 5 # All 5 required tags + elif answer_format == AnswerFormat.COMPLEX_CLASSIFICATION_FORMAT: + return ( + total_matches == 8 + ) # All 8 required tags (4 explanations + 4 classifications) + elif answer_format == AnswerFormat.COMPLEX_ANALYSIS_WITH_ANSWER: + return total_matches == 4 # All 4 required tags + elif answer_format == AnswerFormat.COMPLEX_EVALUATION_FORMAT: + return total_matches == 4 # All 4 required tags + # Dynamic formats are commented out - no special handling needed + elif answer_format in [ + AnswerFormat.YAML_CONFIDENCE, + AnswerFormat.TOML_CONFIDENCE, + ]: + return total_matches == 2 # Need both answer and confidence fields + elif answer_format in [AnswerFormat.TOML_SECTION, AnswerFormat.INI_FORMAT]: + return total_matches == 2 # Need both section and key=value + elif answer_format in [ + AnswerFormat.NESTED_RESPONSE_ANSWER, + AnswerFormat.NESTED_SOLUTION_ANSWER, + AnswerFormat.NESTED_OUTPUT_RESULT, + AnswerFormat.NESTED_ANALYSIS_CONCLUSION, + AnswerFormat.NESTED_REASONING_ANSWER, + ]: + return ( + total_matches == 1 + ) # Each nested format should appear exactly once as a complete unit + + # For all other formats, we want exactly 1 occurrence + return total_matches == 1 + + def _extract_answer_content( + self, + text: str, + answer_format: AnswerFormat, + dataset_item: Dict[str, Any] = None, + ) -> Optional[str]: + """Extract answer content based on the specified format with strict thinking tag validation.""" + if self.debug_logging: + self.logger.debug( + f"Extracting {answer_format.value} from response (length: {len(text)})" + ) + + # Ensure text is a string + if not isinstance(text, str): + if self.debug_logging: + self.logger.warning( + f"Expected string but got {type(text)}, converting to string" + ) + text = str(text) + + # 1. Validate thinking tags + think_tags = re.findall(r"", text, re.IGNORECASE) + think_close_tags = re.findall(r"", text, re.IGNORECASE) + + if len(think_tags) != 1 or len(think_close_tags) != 1: + if self.debug_logging: + self.logger.warning( + f"Invalid thinking tag structure: {len(think_tags)} open tags, {len(think_close_tags)} close tags" + ) + return None + + # Split the text into thinking and response sections + match_think_block = re.match( + r"(.*?)(.*?)(.*)", text, re.DOTALL | re.IGNORECASE + ) + if not match_think_block: + if self.debug_logging: + self.logger.warning( + "Could not find a complete ... block" + ) + return None + + response_section = match_think_block.group(3) # Content after + + # Validate that there are no additional thinking tags in the response section + if "" in response_section.lower(): + if self.debug_logging: + self.logger.warning( + "Found tags in response section after " + ) + return None + + # 2. Validate that the format appears exactly once in the response section + if not self._validate_format_appears_exactly_once( + response_section, answer_format + ): + if self.debug_logging: + self.logger.warning( + f"Format {answer_format.value} does not appear exactly once in response section" + ) + return None + + # 3. Extract based on answer_format from response_section + extracted_content = None + + # Basic structured data formats (answer only) + if answer_format == AnswerFormat.JSON: + # Look for JSON object with "answer" field + match = re.search( + r'\{\s*"answer"\s*:\s*"([^"]*)"\s*\}', response_section, re.DOTALL + ) + if match: + try: + json.loads(match.group(0)) + extracted_content = match.group(1) # Extract just the answer value + except json.JSONDecodeError: + pass + + elif answer_format == AnswerFormat.JSON_ARRAY: + # Look for JSON array with one string element + match = re.search(r'\[\s*"([^"]*)"\s*\]', response_section, re.DOTALL) + if match: + try: + json.loads(match.group(0)) + extracted_content = match.group(1) # Extract the array element + except json.JSONDecodeError: + pass + + elif answer_format == AnswerFormat.JSON_SIMPLE: + # Look for simple JSON string + match = re.search(r'"([^"]*)"', response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.YAML: + # Look for YAML answer field + match = re.search( + r"^answer\s*:\s*(.+?)(?:\n|$)", response_section, re.MULTILINE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.YAML_LIST: + # Look for YAML list item + match = re.search(r"^-\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.TOML: + # Look for TOML answer field + match = re.search( + r'^answer\s*=\s*"([^"]*)"', response_section, re.MULTILINE + ) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.TOML_SECTION: + # Look for TOML section with answer field + section_match = re.search( + r'^\[response\].*?^answer\s*=\s*"([^"]*)"', + response_section, + re.MULTILINE | re.DOTALL, + ) + if section_match: + extracted_content = section_match.group(1) + + # Structured data with confidence scores + elif answer_format == AnswerFormat.JSON_CONFIDENCE: + # Look for JSON with answer and confidence fields + match = re.search( + r'\{\s*"answer"\s*:\s*"([^"]*)"\s*,\s*"confidence"\s*:\s*[0-9.]+\s*\}', + response_section, + re.DOTALL, + ) + if match: + try: + json.loads(match.group(0)) + extracted_content = match.group(1) # Extract just the answer value + except json.JSONDecodeError: + pass + + elif answer_format == AnswerFormat.YAML_CONFIDENCE: + # Look for YAML with answer and confidence fields + answer_match = re.search( + r"^answer\s*:\s*(.+?)(?:\n|$)", response_section, re.MULTILINE + ) + confidence_match = re.search( + r"^confidence\s*:\s*[0-9.]+", response_section, re.MULTILINE + ) + if answer_match and confidence_match: + extracted_content = answer_match.group(1).strip() + + elif answer_format == AnswerFormat.TOML_CONFIDENCE: + # Look for TOML with answer and confidence fields + answer_match = re.search( + r'^answer\s*=\s*"([^"]*)"', response_section, re.MULTILINE + ) + confidence_match = re.search( + r"^confidence\s*=\s*[0-9.]+", response_section, re.MULTILINE + ) + if answer_match and confidence_match: + extracted_content = answer_match.group(1) + + # XML/HTML tag variations (XML now uses answer tags) + elif answer_format == AnswerFormat.XML: + match = re.search( + r"(.*?)", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.XML_FINAL_ANSWER: + match = re.search( + r"Final Answer:\s*(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.OUTPUT_TAGS: + match = re.search( + r"(.*?)", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.RESULT_TAGS: + match = re.search( + r"(.*?)", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.RESPONSE_TAGS: + match = re.search( + r"(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.FINAL_ANSWER_TAGS: + match = re.search( + r"(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.SOLUTION_TAGS: + match = re.search( + r"(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.CONCLUSION_TAGS: + match = re.search( + r"(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.REPLY_TAGS: + match = re.search( + r"(.*?)", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NESTED_RESPONSE_ANSWER: + match = re.search( + r".*?(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NESTED_SOLUTION_ANSWER: + match = re.search( + r".*?(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NESTED_OUTPUT_RESULT: + match = re.search( + r".*?(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NESTED_ANALYSIS_CONCLUSION: + match = re.search( + r".*?(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NESTED_REASONING_ANSWER: + match = re.search( + r".*?(.*?)", + response_section, + re.DOTALL | re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + # LaTeX formats (text-friendly) + elif answer_format == AnswerFormat.LATEX_BOXED: + match = re.search(r"\\boxed\{([^}]+)\}", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.LATEX_TEXTBF: + match = re.search(r"\\textbf\{([^}]+)\}", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.LATEX_TEXTIT: + match = re.search(r"\\textit\{([^}]+)\}", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.LATEX_UNDERLINE: + match = re.search(r"\\underline\{([^}]+)\}", response_section) + if match: + extracted_content = match.group(1) + + # LaTeX formats (math-only) + elif answer_format == AnswerFormat.LATEX_BOXED_MATH: + match = re.search(r"\$\\boxed\{([^}]+)\}\$", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.LATEX_ALIGN: + match = re.search( + r"\\begin\{align\}(.*?)\\end\{align\}", response_section, re.DOTALL + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_EQUATION: + match = re.search( + r"\\begin\{equation\}(.*?)\\end\{equation\}", + response_section, + re.DOTALL, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_DISPLAYMATH: + match = re.search(r"\\\\?\[(.*?)\\\\?\]", response_section, re.DOTALL) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_INLINE_MATH: + match = re.search(r"\$([^$]+)\$", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_TEXT_MATH: + match = re.search(r"\$\\text\{([^}]+)\}\$", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_MATHRM: + match = re.search(r"\$\\mathrm\{([^}]+)\}\$", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_THEREFORE: + match = re.search(r"\$\\therefore\s*(.+?)\$", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_IMPLIES: + match = re.search(r"\$\\implies\s*(.+?)\$", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_EQUIV: + match = re.search(r"\$(.+?)\\equiv\s*(.+?)\$", response_section) + if match: + # Extract the part after the equiv symbol + extracted_content = match.group(2).strip() + + elif answer_format == AnswerFormat.LATEX_MATRIX: + match = re.search( + r"\$\\begin\{matrix\}(.*?)\\end\{matrix\}\$", + response_section, + re.DOTALL, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.LATEX_PMATRIX: + match = re.search( + r"\$\\begin\{pmatrix\}(.*?)\\end\{pmatrix\}\$", + response_section, + re.DOTALL, + ) + if match: + extracted_content = match.group(1).strip() + + # Markdown formats + elif answer_format == AnswerFormat.MARKDOWN_CODE: + match = re.search(r"```\s*\n(.*?)\n```", response_section, re.DOTALL) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.MARKDOWN_BOLD: + match = re.search(r"\*\*([^*]+)\*\*", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.MARKDOWN_ITALIC: + match = re.search(r"\*([^*]+)\*", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.MARKDOWN_HEADER: + match = re.search(r"^##\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.MARKDOWN_QUOTE: + match = re.search(r"^>\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + # Bracket and delimiter formats + elif answer_format == AnswerFormat.SQUARE_BRACKETS: + match = re.search(r"\[([^\]]+)\]", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.DOUBLE_SQUARE_BRACKETS: + match = re.search(r"\[\[([^\]]+)\]\]", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.CURLY_BRACES: + match = re.search(r"\{([^}]+)\}", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.PARENTHESES: + match = re.search(r"\(([^)]+)\)", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.ANGLE_BRACKETS: + match = re.search(r"<([^>]+)>", response_section) + if match: + extracted_content = match.group(1) + + # Natural language patterns + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_ANSWER: + match = re.search( + r"The answer is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_FINAL: + match = re.search( + r"Final answer:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_CONCLUSION: + match = re.search( + r"In conclusion:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_THEREFORE: + match = re.search( + r"Therefore:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_RESULT: + match = re.search( + r"The result is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + # Additional natural language patterns + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_BEST: + match = re.search( + r"The best answer is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_MY_FINAL: + match = re.search( + r"My final answer is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_CORRECT: + match = re.search( + r"The correct answer is:?\s*(.+?)(?:\n|$)", + response_section, + re.IGNORECASE, + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_SOLUTION: + match = re.search( + r"The solution is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_RESPONSE: + match = re.search( + r"My response is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_ULTIMATELY: + match = re.search( + r"Ultimately:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_THUS: + match = re.search( + r"Thus:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_HENCE: + match = re.search( + r"Hence:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_CONSEQUENTLY: + match = re.search( + r"Consequently:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_TO_SUMMARIZE: + match = re.search( + r"To summarize:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_IN_SUMMARY: + match = re.search( + r"In summary:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_OVERALL: + match = re.search( + r"Overall:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_FINAL_VERDICT: + match = re.search( + r"Final verdict:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_BOTTOM_LINE: + match = re.search( + r"Bottom line:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.NATURAL_LANGUAGE_KEY_POINT: + match = re.search( + r"The key point is:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + # Special formats + elif answer_format == AnswerFormat.TEXTARENA_FORMAT: + match = re.search(r"\[([A-Da-d])\]", response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.COLON_FORMAT: + match = re.search( + r"Answer:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.ARROW_FORMAT: + patterns = [ + r"=>\s*(.+?)(?:\n|$)", + r"->\s*(.+?)(?:\n|$)", + r"→\s*(.+?)(?:\n|$)", + ] + for pattern in patterns: + match = re.search(pattern, response_section) + if match: + extracted_content = match.group(1).strip() + break + + # HTML formats + elif answer_format == AnswerFormat.HTML_CODE: + match = re.search( + r"(.*?)", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.HTML_PRE: + match = re.search( + r"
(.*?)
", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.HTML_SPAN: + match = re.search( + r"(.*?)", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.HTML_DIV: + match = re.search( + r"
(.*?)
", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.HTML_P: + match = re.search( + r"

(.*?)

", response_section, re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + # Multiple structured tags + elif answer_format == AnswerFormat.MULTIPLE_TAGS: + tags_found = {} + tag_patterns = [ + ("theory", r"(.*?)"), + ("answer", r"(.*?)"), + ("explanation", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("conclusion", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) >= 2: # At least 2 tags found + extracted_content = json.dumps(tags_found) + + # Complex multi-tag formats + elif answer_format == AnswerFormat.COMPLEX_CODING_FORMAT: + tags_found = {} + tag_patterns = [ + ("restatement", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("plan", r"(.*?)"), + ("schemas", r"(.*?)"), + ("diagram", r"(.*?)"), + ("reflection", r"(.*?)"), + ("solution", r"(.*?)"), + ("explanation", r"(.*?)"), + ("unit_test", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 9: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_CODING_SIMPLE: + tags_found = {} + tag_patterns = [ + ("restatement", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("plan", r"(.*?)"), + ("solution", r"(.*?)"), + ("explanation", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 5: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_CODING_MINIMAL: + tags_found = {} + tag_patterns = [ + ("analysis", r"(.*?)"), + ("approach", r"(.*?)"), + ("solution", r"(.*?)"), + ("test", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 4: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_MATH_FORMAT: + tags_found = {} + tag_patterns = [ + ("restatement", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("approach", r"(.*?)"), + ("derivation", r"(.*?)"), + ("solution", r"(.*?)"), + ("verification", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 6: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_MATH_SIMPLE: + tags_found = {} + tag_patterns = [ + ("problem_analysis", r"(.*?)"), + ("solution_steps", r"(.*?)"), + ("final_answer", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 3: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_GENERAL_FORMAT: + tags_found = {} + tag_patterns = [ + ("restatement", r"(.*?)"), + ("analysis", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("conclusion", r"(.*?)"), + ("reflection", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 5: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_GENERAL_SIMPLE: + tags_found = {} + tag_patterns = [ + ("analysis", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("conclusion", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 3: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_RESEARCH_FORMAT: + tags_found = {} + tag_patterns = [ + ("hypothesis", r"(.*?)"), + ("methodology", r"(.*?)"), + ("analysis", r"(.*?)"), + ("findings", r"(.*?)"), + ("conclusion", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 5: # All required tags found + extracted_content = json.dumps(tags_found) + + # Advanced scratchpad and classification formats + elif answer_format == AnswerFormat.COMPLEX_SCRATCHPAD_FULL: + tags_found = {} + tag_patterns = [ + ("scratchpad", r"(.*?)"), + ("restatement", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("plan", r"(.*?)"), + ("citations", r"(.*?)"), + ("schemas", r"(.*?)"), + ("diagram", r"(.*?)"), + ("reflection", r"(.*?)"), + ("revised_plan", r"(.*?)"), + ("solution", r"(.*?)"), + ("explanation", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 11: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_SCRATCHPAD_SIMPLE: + tags_found = {} + tag_patterns = [ + ("scratchpad", r"(.*?)"), + ("restatement", r"(.*?)"), + ("reasoning", r"(.*?)"), + ("plan", r"(.*?)"), + ("diagram", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 5: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_CLASSIFICATION_FORMAT: + tags_found = {} + tag_patterns = [ + ( + "completeness_explanation", + r"(.*?)", + ), + ( + "completeness_classification", + r"(.*?)", + ), + ( + "grammar_explanation", + r"(.*?)", + ), + ( + "grammar_classification", + r"(.*?)", + ), + ( + "scientific_explanation", + r"(.*?)", + ), + ( + "scientific_classification", + r"(.*?)", + ), + ( + "programming_explanation", + r"(.*?)", + ), + ( + "programming_classification", + r"(.*?)", + ), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 8: # All required tags found + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_ANALYSIS_WITH_ANSWER: + tags_found = {} + tag_patterns = [ + ("analysis", r"(.*?)"), + ("methodology", r"(.*?)"), + ("findings", r"(.*?)"), + ("answer", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 4: # All required tags found + # Extract the actual answer from the ANSWER tag + answer_content = tags_found.get("answer", "") + answer_match = re.search( + r"ANSWER IS:\s*(.+)", answer_content, re.IGNORECASE + ) + if answer_match: + extracted_content = answer_match.group(1).strip() + else: + extracted_content = json.dumps(tags_found) + + elif answer_format == AnswerFormat.COMPLEX_EVALUATION_FORMAT: + tags_found = {} + tag_patterns = [ + ("criteria", r"(.*?)"), + ("assessment", r"(.*?)"), + ("overall_score", r"(.*?)"), + ("judgment", r"(.*?)"), + ] + for tag_name, pattern in tag_patterns: + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + tags_found[tag_name] = match.group(1).strip() + + if len(tags_found) == 4: # All required tags found + extracted_content = json.dumps(tags_found) + + # Dynamic formats are commented out - no extraction logic needed + + # Custom delimiters + elif answer_format == AnswerFormat.PIPE_DELIMITED: + match = re.search(r"\|([^|]+)\|", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.HASH_DELIMITED: + match = re.search(r"#([^#]+)#", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.UNDERSCORE_DELIMITED: + match = re.search(r"_([^_]+)_", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.TILDE_DELIMITED: + match = re.search(r"~([^~]+)~", response_section) + if match: + extracted_content = match.group(1).strip() + + # Programming-style formats + elif answer_format == AnswerFormat.FUNCTION_CALL: + patterns = [ + r'answer\("([^"]+)"\)', + r"answer\(\'([^\']+)\'\)", + r"answer\(([^)]+)\)", + ] + for pattern in patterns: + match = re.search(pattern, response_section) + if match: + extracted_content = match.group(1).strip() + break + + elif answer_format == AnswerFormat.VARIABLE_ASSIGNMENT: + patterns = [ + r'answer\s*=\s*"([^"]+)"', + r"answer\s*=\s*\'([^\']+)\'", + r"answer\s*=\s*([^;\n]+)", + ] + for pattern in patterns: + match = re.search(pattern, response_section) + if match: + extracted_content = match.group(1).strip() + break + + elif answer_format == AnswerFormat.RETURN_STATEMENT: + patterns = [ + r'return\s+"([^"]+)"', + r"return\s+\'([^\']+)\'", + r"return\s+([^;\n]+)", + ] + for pattern in patterns: + match = re.search(pattern, response_section) + if match: + extracted_content = match.group(1).strip() + break + + # Additional easy-to-parse formats + elif answer_format == AnswerFormat.EQUALS_FORMAT: + match = re.search(r"^=\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.DASH_FORMAT: + match = re.search(r"^-\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.PLUS_FORMAT: + match = re.search(r"^\+\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.STAR_FORMAT: + match = re.search(r"^\*\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.PERCENT_FORMAT: + match = re.search(r"^%\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.AMPERSAND_FORMAT: + match = re.search(r"^&\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.AT_FORMAT: + match = re.search(r"^@\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.EXCLAMATION_FORMAT: + match = re.search(r"^!\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.QUESTION_FORMAT: + match = re.search(r"^\?\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.SEMICOLON_FORMAT: + match = re.search(r"^;\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.DOUBLE_COLON_FORMAT: + match = re.search(r"^::\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.TRIPLE_DASH_FORMAT: + match = re.search(r"^---\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.DOUBLE_ARROW_FORMAT: + match = re.search(r"^>>\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.TRIPLE_ARROW_FORMAT: + match = re.search(r"^>>>\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.BACKTICK_FORMAT: + match = re.search(r"`([^`]+)`", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.DOUBLE_BACKTICK_FORMAT: + match = re.search(r"``([^`]+)``", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.QUOTE_FORMAT: + match = re.search(r'"([^"]+)"', response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.SINGLE_QUOTE_FORMAT: + match = re.search(r"'([^']+)'", response_section) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.TRIPLE_QUOTE_FORMAT: + match = re.search(r'"""([^"]+)"""', response_section, re.DOTALL) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.ANSWER_IS_FORMAT: + match = re.search( + r"ANSWER IS:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.SOLUTION_IS_FORMAT: + match = re.search( + r"SOLUTION IS:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.RESULT_IS_FORMAT: + match = re.search( + r"RESULT IS:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.OUTPUT_IS_FORMAT: + match = re.search( + r"OUTPUT IS:?\s*(.+?)(?:\n|$)", response_section, re.IGNORECASE + ) + if match: + extracted_content = match.group(1).strip() + + # Code-specific formats + elif answer_format == AnswerFormat.PYTHON_PRINT: + match = re.search(r'print\s*\(\s*"([^"]*)"\s*\)', response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.JAVASCRIPT_CONSOLE: + match = re.search(r'console\.log\s*\(\s*"([^"]*)"\s*\)', response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.PYTHON_COMMENT: + match = re.search(r"^#\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.JAVASCRIPT_COMMENT: + match = re.search(r"^//\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.C_COMMENT: + match = re.search(r"/\*(.*?)\*/", response_section, re.DOTALL) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.SHELL_ECHO: + match = re.search(r'echo\s+"([^"]*)"', response_section) + if match: + extracted_content = match.group(1) + + elif answer_format == AnswerFormat.SHELL_OUTPUT: + match = re.search(r"^\$\s*(.+?)(?:\n|$)", response_section, re.MULTILINE) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.PYTHON_DOCSTRING: + match = re.search(r'"""([^"]*)"""', response_section, re.DOTALL) + if match: + extracted_content = match.group(1).strip() + + elif answer_format == AnswerFormat.INI_FORMAT: + # Look for section and key=value pattern + section_match = re.search(r"^\[.*?\]", response_section, re.MULTILINE) + value_match = re.search( + r"^.*?\s*=\s*(.+?)(?:\n|$)", response_section, re.MULTILINE + ) + if section_match and value_match: + extracted_content = value_match.group(1).strip() + + elif answer_format == AnswerFormat.ENV_FORMAT: + match = re.search( + r'^[A-Z_]+\s*=\s*"([^"]*)"', response_section, re.MULTILINE + ) + if match: + extracted_content = match.group(1) + + if extracted_content is not None: + if self.debug_logging: + self.logger.debug( + f"Successfully extracted {answer_format.value} content: {extracted_content[:100]}{'...' if len(extracted_content) > 100 else ''}" # noqa + ) + return extracted_content + else: + if self.debug_logging: + self.logger.warning( + f"Failed to extract {answer_format.value} format from response" + ) + return None + + @classmethod + def config_init(cls) -> Tuple[AnswerFormatEnvConfig, List[APIServerConfig]]: + """Initialize configuration for the environment.""" + env_config = AnswerFormatEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=250, + batch_size=1024, + steps_per_eval=20, + max_token_length=1024 * 12, + inference_weight=1.0, + wandb_name="answer_format_adherence", + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + debug_logging=True, + dump_rollouts=True, + dump_failed_rollouts=True, + rollout_save_score_threshold=0.0, + eval_set_percentage=0.1, + ensure_equivalent_ratios=False, + format_group_threshold=10, + suppress_base_env_logs=True, + ensure_scores_are_not_same=True, + dataset_configs=[ + { + "name": "NousResearch/AcademicMCQA", + "split": "train", + "sample_size": 50000, + "prompt_field": "prompt", + "answer_field": "ground_truth", + "metadata_fields": ["answer", "options"], + "dataset_type": "generic", + }, + { + "name": "gsm8k", + "split": "train", + "sample_size": 20000, + "prompt_field": "question", + "answer_field": "answer", + "metadata_fields": [], + "dataset_type": "math_only", + }, + # Example of how to add math dataset: + # { + # "name": "hendrycks/competition_math", + # "split": "train", + # "sample_size": 200, + # "prompt_field": "problem", + # "answer_field": "solution", + # "metadata_fields": ["level", "type"], + # "dataset_type": "math_only" + # }, + # Example of how to add code dataset: + # { + # "name": "openai/humaneval", + # "split": "test", + # "sample_size": 100, + # "prompt_field": "prompt", + # "answer_field": "canonical_solution", + # "metadata_fields": ["task_id"], + # "dataset_type": "code_only" + # } + ], + ) + server_configs = [ + APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ), + ] + return env_config, server_configs + + def _extract_prompt_and_answer_from_conversations( + self, conversations: List[Dict[str, str]] + ) -> Tuple[Optional[str], Optional[str]]: + """Extract prompt and answer from conversation format.""" + if not conversations or not isinstance(conversations, list): + return None, None + + # Find human and assistant messages + human_msg = None + assistant_msg = None + + for msg in conversations: + if isinstance(msg, dict): + role = msg.get("from", msg.get("role", "")) + content = msg.get("value", msg.get("content", "")) + + if role in ["human", "user"] and not human_msg: + human_msg = content + elif role in ["gpt", "assistant"] and not assistant_msg: + assistant_msg = content + + return human_msg, assistant_msg + + async def setup(self): + """Load and combine datasets.""" + if self.debug_logging: + self.logger.info("Starting environment setup") + + all_datasets = [] + + for dataset_config in self.config.dataset_configs: + try: + dataset_name = dataset_config["name"] + split = dataset_config.get("split", "train") + sample_size = dataset_config.get("sample_size", None) + + if self.debug_logging: + self.logger.info(f"Loading dataset: {dataset_name}, split: {split}") + + # Load dataset with error handling + try: + if dataset_name == "gsm8k": + # Special handling for GSM8K dataset + dataset = load_dataset("gsm8k", "main", split=split) + else: + dataset = load_dataset(dataset_name, split=split) + except Exception as dataset_load_error: + if self.debug_logging: + self.logger.error( + f"Failed to load dataset {dataset_name}: {dataset_load_error}" + ) + continue # Skip this dataset and try the next one + + if sample_size and len(dataset) > sample_size: + dataset = dataset.shuffle(seed=self.seed).select(range(sample_size)) + + # Convert to our standard format + processed_items = [] + for item in dataset: + prompt = None + answer = None + metadata = {} + + # Extract prompt and answer based on field configuration + prompt_field = dataset_config.get("prompt_field", "prompt") + answer_field = dataset_config.get("answer_field", "answer") + + if ( + prompt_field == "conversations" + and answer_field == "conversations" + ): + # Handle conversation format + conversations = item.get("conversations", []) + prompt, answer = ( + self._extract_prompt_and_answer_from_conversations( + conversations + ) + ) + elif dataset_name == "gsm8k": + # Special handling for GSM8K dataset + prompt = item.get(prompt_field) + raw_answer = item.get(answer_field) + # Extract the final numerical answer from GSM8K format + if raw_answer and "####" in raw_answer: + # GSM8K uses #### to separate explanation from final answer + answer = ( + raw_answer.split("####")[-1].strip().replace(",", "") + ) + elif raw_answer and "#" in raw_answer: + # Fallback for other # patterns + answer = raw_answer.split("#")[-1].strip().replace(",", "") + else: + answer = raw_answer + elif dataset_name == "NousResearch/AcademicMCQA": + # Special handling for AcademicMCQA dataset + prompt = item.get(prompt_field) # "prompt" field + correct_answer_index = item.get( + "answer" + ) # Index of correct answer + ground_truth_letter = item.get( + "ground_truth" + ) # Letter (A, B, C, D) + options = item.get("options", []) # List of answer options + + # Use the ground truth letter as the answer for format training + # The format environment will train on generating the letter in various formats + answer = ground_truth_letter + + # Store additional metadata for MCQA + metadata["correct_answer_index"] = correct_answer_index + metadata["ground_truth_letter"] = ground_truth_letter + metadata["options"] = options + if ( + correct_answer_index is not None + and correct_answer_index < len(options) + ): + metadata["correct_answer_text"] = options[ + correct_answer_index + ] + else: + # Handle direct field access + prompt = item.get(prompt_field) + answer = item.get(answer_field) + + # Extract metadata + metadata_fields = dataset_config.get("metadata_fields", []) + for field in metadata_fields: + if field in item: + metadata[field] = item[field] + + metadata["dataset_name"] = dataset_name + + # Get dataset type (default to "generic" if not specified) + dataset_type = dataset_config.get("dataset_type", "generic") + + if prompt and answer: + processed_items.append( + { + "prompt": prompt, + "answer": answer, + "metadata": metadata, + "dataset_type": dataset_type, + } + ) + + if self.debug_logging: + self.logger.info( + f"Processed {len(processed_items)} items from {dataset_name}" + ) + + all_datasets.extend(processed_items) + + except Exception as e: + error_msg = f"Error loading dataset {dataset_config.get('name', 'unknown')}: {e}" + if self.debug_logging: + self.logger.error(error_msg) + print(error_msg) + + if not all_datasets: + raise ValueError("No datasets could be loaded successfully") + + # Shuffle all items together + random.shuffle(all_datasets) + self.dataset_items = all_datasets + + # Split into train and test + split_idx = int( + len(self.dataset_items) * (1.0 - self.config.eval_set_percentage) + ) + self.train_items = self.dataset_items[:split_idx] + self.test_items = self.dataset_items[split_idx:] + + self.iter = 0 + + if self.debug_logging: + # Log dataset type distribution + type_counts = {} + for item in self.dataset_items: + dataset_type = item.get("dataset_type", "generic") + type_counts[dataset_type] = type_counts.get(dataset_type, 0) + 1 + + self.logger.info( + f"Setup complete: {len(self.train_items)} training items, {len(self.test_items)} test items" + ) + self.logger.info(f"Dataset type distribution: {type_counts}") + + # Log format availability by type + for dataset_type in type_counts.keys(): + available_formats = self._get_formats_for_dataset_type(dataset_type) + self.logger.info( + f"Dataset type '{dataset_type}': {len(available_formats)} available formats" + ) + + print( + f"AnswerFormatEnv setup complete. {len(self.train_items)} training items, {len(self.test_items)} test items." # noqa + ) + + # Print dataset type distribution + type_counts = {} + for item in self.dataset_items: + dataset_type = item.get("dataset_type", "generic") + type_counts[dataset_type] = type_counts.get(dataset_type, 0) + 1 + print(f"Dataset types: {type_counts}") + + async def get_next_item(self) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]: + """Get the next training item with randomized format based on dataset type.""" + if not self.train_items: + raise ValueError("No training items available") + + # Dynamic formats are commented out - no state management needed + + # Get the next item cyclically + dataset_item = self.train_items[self.iter % len(self.train_items)] + self.iter += 1 + + # Get appropriate formats for this dataset type + dataset_type = dataset_item.get("dataset_type", "generic") + available_formats = self._get_formats_for_dataset_type(dataset_type) + + if not available_formats: + # Fallback to generic formats if no formats available + available_formats = list(self.generic_formats) + if self.debug_logging: + self.logger.warning( + f"No formats available for dataset type '{dataset_type}', using generic formats" + ) + + # Randomly select an answer format from available formats with optional weighting + # Simple formats get higher weight for better training balance + if len(available_formats) > 10: # Only apply weighting if we have many formats + simple_formats = [ + f + for f in available_formats + if not f.value.startswith(("complex_", "dynamic_")) + ] + complex_formats = [ + f + for f in available_formats + if f.value.startswith(("complex_", "dynamic_")) + ] + + # 70% chance for simple formats, 30% for complex formats + if random.random() < 0.7 and simple_formats: + selected_format = random.choice(simple_formats) + elif complex_formats: + selected_format = random.choice(complex_formats) + else: + selected_format = random.choice(available_formats) + else: + selected_format = random.choice(available_formats) + + if self.debug_logging: + self.logger.debug( + f"Item {self.iter-1}: dataset_type='{dataset_type}', " + f"selected_format='{selected_format.value}' from {len(available_formats)} available" + ) + + # Store the selected format in the item for scoring + dataset_item["selected_format"] = selected_format + + # Dynamic formats are commented out - no component generation needed + + # Generate system prompt for the selected format + system_prompt = self._generate_system_prompt(selected_format) + + # Create the message structure + prompt_messages = [ + frozenset({"role": "system", "content": system_prompt}.items()), + frozenset({"role": "user", "content": dataset_item["prompt"]}.items()), + ] + + return tuple(prompt_messages), dataset_item + + async def score( + self, + rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]], + ) -> Optional[ScoredDataGroup]: + """Score rollouts based on format adherence.""" + if self.debug_logging: + self.logger.debug(f"Scoring {len(rollout_group_data)} rollouts") + + scores_obj = ScoredDataGroup() + scores_obj["tokens"] = list() + scores_obj["masks"] = list() + scores_obj["scores"] = list() + + if not rollout_group_data: + return None + + dataset_item = rollout_group_data[0][1] + selected_format = dataset_item["selected_format"] + format_name = selected_format.value + + if self.debug_logging: + self.logger.debug(f"Scoring for format: {format_name}") + + random.shuffle(rollout_group_data) + failed_rollouts_this_group = [] + + for item_messages, _ in rollout_group_data: + messages_as_dicts = [dict(fs_message) for fs_message in item_messages] + model_response_text = messages_as_dicts[-1]["content"] + + # Extract content based on the selected format + extracted_content = self._extract_answer_content( + model_response_text, selected_format, dataset_item + ) + + # Score: 1.0 if format is correct, 0.0 if not + reward = 1.0 if extracted_content is not None else 0.0 + is_failed = reward == 0.0 + + # Track format success rates + self.format_total_counts[format_name] = ( + self.format_total_counts.get(format_name, 0) + 1 + ) + if reward == 1.0: + self.format_success_counts[format_name] = ( + self.format_success_counts.get(format_name, 0) + 1 + ) + + try: + # Validate message format for tokenization + for msg_idx, msg in enumerate(messages_as_dicts): + if not isinstance(msg, dict): + if self.debug_logging: + self.logger.error( + f"Message {msg_idx} is not a dict: {type(msg)}" + ) + continue + if "role" not in msg or "content" not in msg: + if self.debug_logging: + self.logger.error( + f"Message {msg_idx} missing required keys: {msg.keys()}" + ) + continue + if not isinstance(msg["content"], str): + msg["content"] = str(msg["content"]) + + out_dict = tokenize_for_trainer( + self.tokenizer, + messages_as_dicts, + include_messages=self.config.include_messages, + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + except Exception as e: + if self.debug_logging: + self.logger.error(f"Tokenization failed: {e}") + continue + + # Remove examples with insufficient context + if len([1 for m_val in masks if m_val != -100]) < 10: + continue + + scores_obj["tokens"].append(tokens) + scores_obj["masks"].append(masks) + scores_obj["scores"].append(reward) + + # Track failed rollouts for debugging + if is_failed and self.config.dump_failed_rollouts: + failed_rollouts_this_group.append( + { + "conversation": messages_as_dicts, + "score": reward, + "selected_format": format_name, + "metadata": dataset_item.get("metadata", {}), + "extracted_content": extracted_content, + "failure_reason": "format_parsing_failed", + } + ) + + self.percent_correct_buffer.append(reward) + + if len(scores_obj["tokens"]) >= self.config.group_size: + break + + if not scores_obj["tokens"]: + if self.debug_logging: + self.logger.warning("No valid examples processed in this batch") + return None + + # Calculate and log group-level statistics + group_scores = scores_obj["scores"] + average_score = sum(group_scores) / len(group_scores) if group_scores else 0.0 + + # Update group statistics + self.group_statistics["total_groups"] += 1 + self.group_statistics["average_scores"].append(average_score) + self.group_statistics["format_distribution"][format_name] = ( + self.group_statistics["format_distribution"].get(format_name, 0) + 1 + ) + + if average_score > 0.0: + self.group_statistics["successful_groups"] += 1 + else: + self.group_statistics["failed_groups"] += 1 + + # Log group performance with detailed information + if self.debug_logging: + correct_count = sum(1 for score in group_scores if score == 1.0) + incorrect_count = len(group_scores) - correct_count + correct_percentage = ( + (correct_count / len(group_scores)) * 100 if group_scores else 0.0 + ) + + if correct_count == 0: + self.logger.info( + f"Format: {format_name} | Group average score: {average_score:.4f} | {correct_count}/{len(group_scores)} correct ({correct_percentage:.1f}%) (All failures in this group!)" # noqa + ) + elif incorrect_count == 0: + self.logger.info( + f"Format: {format_name} | Group average score: {average_score:.4f} | {correct_count}/{len(group_scores)} correct ({correct_percentage:.1f}%) (Perfect group!)" # noqa + ) + else: + self.logger.info( + f"Format: {format_name} | Group average score: {average_score:.4f} | {correct_count}/{len(group_scores)} correct ({correct_percentage:.1f}%)" # noqa + ) + + # Handle failed rollouts dumping + if failed_rollouts_this_group and self.config.dump_failed_rollouts: + failed_item_data_to_save = { + "item_id": f"failed_group_{self.group_statistics['total_groups']}", + "format": format_name, + "group_average_score": average_score, + "rollouts": failed_rollouts_this_group, + "group_metadata": { + "total_rollouts": len(group_scores), + "failed_rollouts": len(failed_rollouts_this_group), + "dataset_type": dataset_item.get("dataset_type", "unknown"), + }, + } + self.failed_rollouts_to_save_buffer.append(failed_item_data_to_save) + + # Save failed rollouts every 50 groups + if ( + self.config.dump_failed_rollouts + and len(self.failed_rollouts_to_save_buffer) >= 50 + ): + if self.debug_logging: + self.logger.info( + f"Saving batch of {len(self.failed_rollouts_to_save_buffer)} failed rollout groups" + ) + await self._save_failed_rollouts_to_jsonl() + + # Check if all scores are the same (no learning signal) + if all(group_scores[0] == score for score in group_scores): + if self.debug_logging: + self.logger.debug( + "All scores are identical, returning None for learning signal" + ) + return None + + # Track successful groups for equivalent ratio enforcement + if self.ensure_equivalent_ratios: + # Count this as a successful group if we have any successful examples + if any(score > 0.0 for score in group_scores): + self.format_successful_groups[format_name] = ( + self.format_successful_groups.get(format_name, 0) + 1 + ) + + if self.debug_logging: + successful_count = self.format_successful_groups[format_name] + remaining_until_threshold = ( + self.format_group_threshold - successful_count + ) + + # Always log when a group is added, showing progress toward threshold + if remaining_until_threshold > 0: + self.logger.info( + f"Format {format_name} successful group #{successful_count} - {remaining_until_threshold} more needed to reach threshold ({self.format_group_threshold})" # noqa + ) + + # Log when a format reaches the threshold + if successful_count == self.format_group_threshold: + self.logger.info( + f"Format {format_name} reached threshold of {self.format_group_threshold} successful groups - will be excluded from future rounds until others catch up" # noqa + ) + + # Apply length penalty if all responses are correct + if all(s == 1.0 for s in group_scores): + avg_len = sum(len(t) for t in scores_obj["tokens"]) / len( + scores_obj["tokens"] + ) + if avg_len > self.config.max_token_length * 0.75: + scores_obj["scores"] = [s * 0.9 for s in scores_obj["scores"]] + if self.debug_logging: + self.logger.debug(f"Applied length penalty: avg_len={avg_len}") + + return scores_obj + + async def collect_trajectories( + self, item: Item + ) -> Tuple[Optional[ScoredDataGroup], List]: + """Collect trajectories for a given item.""" + prompt_messages_tuple, dataset_item = item + + if self.debug_logging: + self.logger.debug( + f"Collecting trajectories for format: {dataset_item['selected_format'].value}" + ) + + # Convert frozensets to dicts for the API call + messages_for_api = [dict(fs_message) for fs_message in prompt_messages_tuple] + + prompt_str = self.tokenizer.apply_chat_template( + messages_for_api, add_generation_prompt=True, tokenize=False + ) + + completions = await self.server.completion( + prompt=prompt_str, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=0.9, + ) + + to_score_list = [] + for choice in completions.choices: + # Create a full message list for this choice + current_trajectory_messages = list(prompt_messages_tuple) + current_trajectory_messages.append( + frozenset({"role": "assistant", "content": choice.text}.items()) + ) + to_score_list.append((tuple(current_trajectory_messages), dataset_item)) + + scored_data = await self.score(to_score_list) + + # Data dumping logic - following reasoning gym pattern exactly + if scored_data and self.config.dump_rollouts: + # Only save groups that have at least one rollout with score > threshold + group_scores = scored_data.get("scores", []) + threshold = self.config.rollout_save_score_threshold + if any(score > threshold for score in group_scores): + if self.debug_logging: + self.logger.debug( + f"Saving group with scores: {[f'{s:.3f}' for s in group_scores]} (has high-quality rollout, threshold: {threshold})" # noqa + ) + rollouts_with_scores_to_save = [] + + num_scored_rollouts = len(group_scores) + for i in range(num_scored_rollouts): + # Get conversation messages directly from to_score_list like reasoning gym does + conversation_messages = [ + dict(fs_msg) for fs_msg in to_score_list[i][0] + ] + score_for_rollout = group_scores[i] + + # Only save rollouts that meet the score threshold + if score_for_rollout >= threshold: + rollouts_with_scores_to_save.append( + { + "conversation": conversation_messages, + "score": score_for_rollout, + } + ) + + if rollouts_with_scores_to_save: + item_data_to_save = { + "item_id": f"item_{self.iter-1}", + "format": dataset_item["selected_format"].value, + "rollouts": rollouts_with_scores_to_save, + "group_metadata": { + "total_rollouts": num_scored_rollouts, + "saved_rollouts": len(rollouts_with_scores_to_save), + "dataset_type": dataset_item.get("dataset_type", "unknown"), + "group_statistics": { + "min_score": min(group_scores) if group_scores else 0.0, + "max_score": max(group_scores) if group_scores else 0.0, + "avg_score": ( + sum(group_scores) / len(group_scores) + if group_scores + else 0.0 + ), + }, + }, + } + self.rollouts_to_save_buffer.append(item_data_to_save) + self.processed_item_count += 1 + + # Calculate progress toward next save (like reasoning gym does) + current_batch_progress = self.processed_item_count % 100 + if current_batch_progress == 0: + current_batch_progress = 100 + + # Log progress every 10 items or when we hit the save threshold + if ( + current_batch_progress % 10 == 0 + or current_batch_progress == 100 + ): + if self.debug_logging: + self.logger.info( + f"Data dump progress: {current_batch_progress}/100 items buffered " + f"(Total processed: {self.processed_item_count}, Buffer size: {len(self.rollouts_to_save_buffer)})" # noqa + ) + + # Save batch every 100 processed items + if ( + self.config.dump_rollouts + and self.processed_item_count > 0 + and self.processed_item_count % 100 == 0 + ): + if self.debug_logging: + log_msg = ( + f"Reached {self.processed_item_count} processed items. " + f"Triggering save for {len(self.rollouts_to_save_buffer)} rollouts" # noqa + ) + self.logger.info(log_msg) + await self._save_rollouts_to_jsonl() + else: + max_score = max(group_scores) if group_scores else 0.0 + if self.debug_logging: + self.logger.debug( + f"Skipping group save - no high-quality rollouts (max score: {max_score:.3f}, threshold: {threshold})" # noqa + ) + + return scored_data, [] + + async def _save_rollouts_to_jsonl(self): + """Save rollouts to JSONL file.""" + if not self.rollouts_to_save_buffer: + return + + try: + if not os.path.exists(self.datadumps_dir): + os.makedirs(self.datadumps_dir) + except OSError as e: + if self.debug_logging: + self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}") + return + + file_path = os.path.join( + self.datadumps_dir, + f"answer_format_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl", + ) + + try: + with open(file_path, "w") as f: + for rollout_dict in self.rollouts_to_save_buffer: + json.dump(rollout_dict, f) + f.write("\n") + + if self.debug_logging: + self.logger.info( + f"Saved {len(self.rollouts_to_save_buffer)} rollouts to {file_path}" + ) + + self.rollouts_to_save_buffer.clear() + self.save_file_batch_num += 1 + + except Exception as e: + if self.debug_logging: + self.logger.error(f"Error saving rollouts: {e}") + + async def _save_failed_rollouts_to_jsonl(self): + """Save failed rollouts to JSONL file for debugging.""" + if not self.failed_rollouts_to_save_buffer: + if self.debug_logging: + self.logger.info("No failed rollouts in buffer to save.") + return + + try: + if not os.path.exists(self.datadumps_dir): + os.makedirs(self.datadumps_dir) + except OSError as e: + if self.debug_logging: + self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}") + return + + file_path = os.path.join( + self.datadumps_dir, + f"answer_format_FAILED_rollouts_{self.run_uuid}_{self.failed_save_file_batch_num:04d}.jsonl", + ) + + try: + with open(file_path, "w") as f: + for rollout_dict in self.failed_rollouts_to_save_buffer: + json.dump(rollout_dict, f) + f.write("\n") + + if self.debug_logging: + self.logger.info( + f"Successfully saved {len(self.failed_rollouts_to_save_buffer)} FAILED rollouts to {file_path}" + ) + + self.failed_rollouts_to_save_buffer.clear() + self.failed_save_file_batch_num += 1 + + except Exception as e: + if self.debug_logging: + self.logger.error(f"Error writing failed rollouts to {file_path}: {e}") + else: + print( + f"An unexpected error occurred while saving failed rollouts to {file_path}: {e}" + ) + + async def rollout_and_score_eval(self, test_item: Dict[str, Any]) -> float: + """Evaluate a single test item.""" + # Get appropriate formats for this dataset type + dataset_type = test_item.get("dataset_type", "generic") + available_formats = self._get_formats_for_dataset_type(dataset_type) + + if not available_formats: + # Fallback to generic formats if no formats available + available_formats = list(self.generic_formats) + + # Randomly select format for evaluation + selected_format = random.choice(available_formats) + test_item["selected_format"] = selected_format + + system_prompt = self._generate_system_prompt(selected_format) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": test_item["prompt"]}, + ] + + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + completion = await self.server.completion( + prompt=prompt, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.1, + split="eval", + ) + + model_response = completion.choices[0].text + extracted_content = self._extract_answer_content( + model_response, selected_format + ) + + return 1.0 if extracted_content is not None else 0.0 + + async def evaluate(self, *args, **kwargs): + """Run evaluation on the test set.""" + if self.debug_logging: + self.logger.info("Starting evaluation") + + if not self.test_items: + self.eval_metrics.append(("eval/percent_correct", 0.0)) + return + + # Use subset for faster evaluation + items_to_eval = self.test_items[: min(len(self.test_items), 50)] + + eval_results = await tqdm_asyncio.gather( + *[self.rollout_and_score_eval(item) for item in items_to_eval] + ) + + if eval_results: + avg_score = sum(eval_results) / len(eval_results) + else: + avg_score = 0.0 + + self.eval_metrics.append(("eval/percent_correct", avg_score)) + + if self.debug_logging: + self.logger.info(f"Evaluation complete: avg_score={avg_score:.3f}") + + async def add_rollouts_for_wandb( + self, + scored_data: Optional[ScoredDataGroup], + item: Item = None, + ): + """Add rollouts to wandb logging.""" + if scored_data is None or not scored_data["tokens"]: + return + + dataset_item = item[1] if item and len(item) > 1 else {} + selected_format = dataset_item.get("selected_format", AnswerFormat.JSON) + + num_keep = self.config.num_rollouts_per_group_for_logging + if num_keep == -1: + num_keep = len(scored_data["tokens"]) + else: + num_keep = min(num_keep, len(scored_data["tokens"])) + + rollout_batch = [] + for i in range(num_keep): + # Get the conversation text by decoding tokens (like reasoning gym does) + display_convo_text = self.tokenizer.decode( + scored_data["tokens"][i], skip_special_tokens=True + ) + + rollout_batch.append( + ( + display_convo_text, + scored_data["scores"][i], + selected_format.value, + dataset_item.get("metadata", {}).get("dataset_name", "unknown"), + ) + ) + + if rollout_batch: + self.rollouts_for_wandb.append(rollout_batch) + + if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + + async def create_rollout_table(self, wandb_metrics: Dict) -> Dict: + """Create wandb table for rollout visualization.""" + if self.rollouts_for_wandb: + table = wandb.Table( + columns=[ + "full_conversation", + "score", + "selected_format", + "dataset_name", + ] + ) + for group in self.rollouts_for_wandb: + for entry in group: + table.add_data(*entry) + wandb_metrics["train/rollouts"] = table + + self.rollouts_for_wandb = [] + return wandb_metrics + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log metrics to wandb with clear, intuitive categories.""" + if wandb_metrics is None: + wandb_metrics = {} + + # === CORE PERFORMANCE METRICS === + if self.percent_correct_buffer: + percent_correct = sum(self.percent_correct_buffer) / len( + self.percent_correct_buffer + ) + wandb_metrics["performance/accuracy"] = percent_correct + else: + wandb_metrics["performance/accuracy"] = 0.0 + + self.percent_correct_buffer = list() + + # Add eval metrics + for key, value in self.eval_metrics: + # Rename eval metrics to be clearer + clean_key = key.replace("eval/percent_correct", "performance/eval_accuracy") + wandb_metrics[clean_key] = value + self.eval_metrics = list() + + # === GROUP STATISTICS === + if self.group_statistics["total_groups"] > 0: + total = self.group_statistics["total_groups"] + successful = self.group_statistics["successful_groups"] + failed = self.group_statistics["failed_groups"] + + wandb_metrics["groups/total_processed"] = total + wandb_metrics["groups/successful"] = successful + wandb_metrics["groups/failed"] = failed + wandb_metrics["groups/success_rate"] = successful / total + + if self.group_statistics["average_scores"]: + all_scores = self.group_statistics["average_scores"] + recent_scores = all_scores[ + -50: + ] # Last 50 groups for recent performance + + wandb_metrics["groups/avg_score_overall"] = sum(all_scores) / len( + all_scores + ) + wandb_metrics["groups/avg_score_recent"] = sum(recent_scores) / len( + recent_scores + ) + wandb_metrics["groups/best_score_recent"] = max(recent_scores) + wandb_metrics["groups/worst_score_recent"] = min(recent_scores) + + # === FORMAT PERFORMANCE SUMMARY === + if self.format_total_counts: + # Calculate overall format statistics + total_attempts = sum(self.format_total_counts.values()) + total_successes = sum(self.format_success_counts.values()) + + wandb_metrics["formats/total_attempts"] = total_attempts + wandb_metrics["formats/total_successes"] = total_successes + wandb_metrics["formats/overall_success_rate"] = ( + total_successes / total_attempts if total_attempts > 0 else 0.0 + ) + + # Count formats by performance tier + format_success_rates = {} + for format_name, total_count in self.format_total_counts.items(): + if total_count >= 5: # Only consider formats with enough data + success_count = self.format_success_counts.get(format_name, 0) + success_rate = success_count / total_count + format_success_rates[format_name] = success_rate + + if format_success_rates: + rates = list(format_success_rates.values()) + wandb_metrics["formats/num_formats_tested"] = len(format_success_rates) + wandb_metrics["formats/avg_success_rate"] = sum(rates) / len(rates) + wandb_metrics["formats/best_success_rate"] = max(rates) + wandb_metrics["formats/worst_success_rate"] = min(rates) + + # Performance tiers + high_performing = sum(1 for rate in rates if rate >= 0.8) + medium_performing = sum(1 for rate in rates if 0.5 <= rate < 0.8) + low_performing = sum(1 for rate in rates if rate < 0.5) + + wandb_metrics["formats/high_performing_count"] = high_performing + wandb_metrics["formats/medium_performing_count"] = medium_performing + wandb_metrics["formats/low_performing_count"] = low_performing + + # === BALANCED TRAINING PROGRESS (only if enabled) === + if self.ensure_equivalent_ratios and self.format_successful_groups: + successful_counts = list(self.format_successful_groups.values()) + threshold = self.format_group_threshold + + # Overall progress + formats_at_threshold = sum( + 1 for count in successful_counts if count >= threshold + ) + total_formats = len(self.base_supported_formats) + completion_pct = (formats_at_threshold / total_formats) * 100 + + wandb_metrics["balance/completion_percentage"] = completion_pct + wandb_metrics["balance/formats_completed"] = formats_at_threshold + wandb_metrics["balance/formats_total"] = total_formats + + # Progress distribution + wandb_metrics["balance/min_successful_groups"] = min(successful_counts) + wandb_metrics["balance/max_successful_groups"] = max(successful_counts) + wandb_metrics["balance/avg_successful_groups"] = sum( + successful_counts + ) / len(successful_counts) + + # How many formats are at different progress levels + quarter_threshold = threshold * 0.25 + half_threshold = threshold * 0.5 + three_quarter_threshold = threshold * 0.75 + + at_quarter = sum( + 1 for count in successful_counts if count >= quarter_threshold + ) + at_half = sum(1 for count in successful_counts if count >= half_threshold) + at_three_quarter = sum( + 1 for count in successful_counts if count >= three_quarter_threshold + ) + + wandb_metrics["balance/formats_25pct_progress"] = at_quarter + wandb_metrics["balance/formats_50pct_progress"] = at_half + wandb_metrics["balance/formats_75pct_progress"] = at_three_quarter + + # === DATASET DISTRIBUTION === + if self.group_statistics["format_distribution"]: + total_uses = sum(self.group_statistics["format_distribution"].values()) + + # Just track the most and least used formats for high-level monitoring + sorted_formats = sorted( + self.group_statistics["format_distribution"].items(), + key=lambda x: x[1], + reverse=True, + ) + + if sorted_formats: + most_used_format, most_used_count = sorted_formats[0] + least_used_format, least_used_count = sorted_formats[-1] + + wandb_metrics["distribution/most_used_format_pct"] = ( + most_used_count / total_uses + ) * 100 + wandb_metrics["distribution/least_used_format_pct"] = ( + least_used_count / total_uses + ) * 100 + wandb_metrics["distribution/usage_ratio"] = most_used_count / max( + least_used_count, 1 + ) + + # Create rollout table for detailed inspection + wandb_metrics = await self.create_rollout_table(wandb_metrics) + + await super().wandb_log(wandb_metrics) + + async def close(self): + """Clean up and save any remaining rollouts.""" + if self.debug_logging: + self.logger.info("Closing AnswerFormatEnv") + + # Log final statistics + if self.group_statistics["total_groups"] > 0: + success_rate = ( + self.group_statistics["successful_groups"] + / self.group_statistics["total_groups"] + ) + avg_score = sum(self.group_statistics["average_scores"]) / len( + self.group_statistics["average_scores"] + ) + self.logger.info("Final Statistics:") + self.logger.info( + f" Total groups processed: {self.group_statistics['total_groups']}" + ) + self.logger.info( + f" Successful groups: {self.group_statistics['successful_groups']}" + ) + self.logger.info( + f" Failed groups: {self.group_statistics['failed_groups']}" + ) + self.logger.info(f" Overall success rate: {success_rate:.4f}") + self.logger.info(f" Overall average score: {avg_score:.4f}") + + # Save any remaining rollouts + if self.config.dump_rollouts and self.rollouts_to_save_buffer: + await self._save_rollouts_to_jsonl() + + # Save any remaining failed rollouts + if self.config.dump_failed_rollouts and self.failed_rollouts_to_save_buffer: + if self.debug_logging: + self.logger.info( + f"Found {len(self.failed_rollouts_to_save_buffer)} failed rollouts in buffer. Saving now." + ) + await self._save_failed_rollouts_to_jsonl() + elif self.debug_logging: + self.logger.info("No failed rollouts in buffer to save upon closing.") + + # Call superclass close if it exists + if hasattr(super(), "close"): + await super().close() + + +if __name__ == "__main__": + AnswerFormatEnv.cli()