atropos/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py

2123 lines
91 KiB
Python

"""
Pydantic Schema Following Environment
This environment trains models to generate JSON that adheres to Pydantic schemas.
It loads schemas dynamically from a HuggingFace dataset and validates model outputs.
Recent improvements (2025-06-01):
1. Fixed Pydantic ValidationError compatibility issues with proper exception handling
2. Enhanced JSON extraction with fallback methods for responses missing proper tags
3. Added comprehensive input validation for tokenization to prevent 'list' object errors
4. Improved system prompt with explicit formatting requirements and examples
5. Reduced max_token_length to prevent overly verbose responses
6. Added robust error handling throughout the scoring pipeline
7. Enhanced debug logging for better troubleshooting
8. MAJOR: Added strict thinking tag validation (similar to MCQA environment)
9. Enforces exactly one <think></think> section followed by <json_output></json_output>
10. Added detailed validation method with comprehensive error reporting
Key Features:
- Dynamic Pydantic model creation from dataset configurations
- Comprehensive data dumping for analysis
- Strict thinking tag validation for consistent response format
- Fallback JSON extraction for improved success rates
- Detailed debug logging and error handling
- Robust validation with detailed error messages
Response Format Requirements:
- Must use exactly ONE <think> opening tag and ONE </think> closing tag
- All reasoning must be inside the thinking tags
- JSON output must be in <json_output></json_output> tags after </think>
- No additional <think> tags allowed after the first </think> closing tag
"""
import asyncio
import json
import logging
import os
import random
import re
import uuid
from datetime import date, datetime, timedelta
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
import toml # Added import
import wandb
import xmltodict # Added
import yaml # Added import
from datasets import load_dataset
from pydantic import (
BaseModel,
ConfigDict,
EmailStr,
Field,
HttpUrl,
ValidationError,
field_validator,
model_validator,
)
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
ScoredDataGroup,
)
# Import editing functionality
try:
from .error_introduction import (
ErrorIntroductionConfig,
introduce_error_for_pydantic,
)
except ImportError:
# Handle case when running as script or tests
from error_introduction import ErrorIntroductionConfig, introduce_error_for_pydantic
class StructuredOutputFormat(Enum):
JSON = "json"
YAML = "yaml"
TOML = "toml"
XML = "xml"
class OutputContainerFormat(Enum):
TAGGED = "tagged"
NONE = "none"
MARKDOWN = "markdown"
class PydanticEnvConfig(BaseEnvConfig):
"""Custom config class for PydanticSchemaFollowingEnv with additional parameters."""
dataset_name: str = Field(
default="justus27/pydantic-adherance-test",
description="Name of the HuggingFace dataset to load",
)
dataset_split: str = Field(
default="train", description="Dataset split to use (train, test, validation)"
)
debug_logging: bool = Field(
default=True, description="Enable detailed debug logging"
)
dump_rollouts: bool = Field(
default=False,
description="Whether to dump rollouts to JSONL files for analysis",
)
include_messages: bool = Field(
default=True,
description="Whether to include messages in the dataset for SFT data generation",
)
allowed_structured_formats: Optional[List[StructuredOutputFormat]] = Field(
default=None,
description="Optional list of StructuredOutputFormat enums to use for randomization. If None or empty, all supported formats are used.", # noqa: E501
)
allowed_container_formats: Optional[List[OutputContainerFormat]] = Field(
default=None,
description="Optional list of OutputContainerFormat enums to use for randomization. If None or empty, all supported formats are used.", # noqa: E501
)
eval_set_percentage: float = Field(
default=0.1,
description="Percentage of the dataset to use for the evaluation set (e.g., 0.1 for 10%).",
ge=0.0,
le=1.0,
)
# Field for task type
task_type: str = Field(
default="generation", description="Task type: 'generation' or 'editing'"
)
# Error introduction configuration
error_introduction_seed: Optional[int] = Field(
default=None, description="Seed for deterministic error introduction"
)
error_types_enabled: List[str] = Field(
default=[
"type_error",
"format_error",
"constraint_error",
"enum_error",
"required_field_missing",
],
description="Types of errors to introduce in editing tasks",
)
error_introduction_probability: float = Field(
default=1.0,
description="Probability of introducing an error (for editing tasks)",
)
max_errors_per_item: int = Field(
default=1, description="Maximum number of errors to introduce per editing task"
)
class PydanticSchemaFollowingEnv(BaseEnv):
env_config_cls = PydanticEnvConfig
def __init__(
self,
config: PydanticEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
# Set up debug logging FIRST, as it's used by subsequent setup logic
self.debug_logging = getattr(config, "debug_logging", True)
if self.debug_logging:
self.logger = logging.getLogger(f"{self.__class__.__name__}")
self.logger.setLevel(logging.DEBUG)
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.info("Debug logging enabled for PydanticSchemaFollowingEnv")
else:
self.logger = logging.getLogger(f"{self.__class__.__name__}")
self.logger.addHandler(logging.NullHandler())
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.rollouts_for_wandb = []
self.dataset_items: List[Dict[str, Any]] = []
self.model_cache: Dict[str, Type[BaseModel]] = {}
# Determine supported formats based on config or defaults
if (
config.allowed_structured_formats
and len(config.allowed_structured_formats) > 0
):
self.supported_structured_formats = config.allowed_structured_formats
if self.debug_logging:
self.logger.info(
f"Using configured structured formats: {[f.value for f in self.supported_structured_formats]}"
)
else:
self.supported_structured_formats: List[StructuredOutputFormat] = [
StructuredOutputFormat.JSON,
StructuredOutputFormat.YAML,
StructuredOutputFormat.TOML,
StructuredOutputFormat.XML,
]
if self.debug_logging:
self.logger.info(
f"Using default structured formats: {[f.value for f in self.supported_structured_formats]}"
)
if (
config.allowed_container_formats
and len(config.allowed_container_formats) > 0
):
self.supported_container_formats = config.allowed_container_formats
if self.debug_logging:
self.logger.info(
f"Using configured container formats: {[f.value for f in self.supported_container_formats]}"
)
else:
self.supported_container_formats: List[OutputContainerFormat] = [
OutputContainerFormat.TAGGED,
OutputContainerFormat.NONE,
OutputContainerFormat.MARKDOWN,
]
if self.debug_logging:
self.logger.info(
f"Using default container formats: {[f.value for f in self.supported_container_formats]}"
)
# Data dumping setup
self.run_uuid = str(uuid.uuid4())
# Buffer for saving rollouts - each item group contains rollouts for one dataset item
# RolloutDetail: conversation, score, expected_json, model_name, problem_id
RolloutDetail = Dict[str, Union[List[Dict[str, str]], float, str]]
ItemGroup = Dict[str, Union[str, List[RolloutDetail]]]
self.rollouts_to_save_buffer: List[ItemGroup] = []
self.processed_item_count = 0
# Create datadumps directory relative to this file
self.datadumps_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "datadumps"
)
self.save_file_batch_num = 0
if self.debug_logging:
self.logger.info(
f"Data dumping {'enabled' if config.dump_rollouts else 'disabled'}"
)
if config.dump_rollouts:
self.logger.info(f"Rollouts will be saved to: {self.datadumps_dir}")
def _generate_system_prompt(
self,
structured_format: StructuredOutputFormat,
container_format: OutputContainerFormat,
) -> str:
"""Generates a system prompt tailored to the selected output and container formats."""
prompt_lines = [
f"You are an AI assistant that generates structured data in {structured_format.value.upper()} format according to Pydantic schemas.", # noqa: E501
"You may use extremely long chains of thought to deeply consider the problem and deliberate "
"with yourself via systematic reasoning processes to help come to a correct solution prior to answering.",
"You should enclose your thoughts and internal monologue inside <think> </think> tags.",
]
example_output = ""
if structured_format == StructuredOutputFormat.JSON:
example_output = '{\\n "field1": "value1",\\n "field2": "value2"\\n}'
elif structured_format == StructuredOutputFormat.YAML:
example_output = "field1: value1\\nfield2: value2"
elif structured_format == StructuredOutputFormat.TOML:
example_output = 'field1 = "value1"\\nfield2 = "value2"'
elif structured_format == StructuredOutputFormat.XML:
example_output = "<YourModelName>\\n <field1>value1</field1>\\n <field2>value2</field2>\\n</YourModelName>" # Assuming root tag is model name # noqa: E501
if container_format == OutputContainerFormat.TAGGED:
tag_name = f"{structured_format.value}_output"
prompt_lines.extend(
[
f"CRITICAL: Your final {structured_format.value.upper()} output MUST be enclosed within <{tag_name}> </{tag_name}> tags.", # noqa: E501
f"The {structured_format.value.upper()} must be valid and complete. Do not include any text after the closing </{tag_name}> tag.", # noqa: E501
"Example format:",
"<think>\\nMy reasoning here...\\n</think>\\n",
f"<{tag_name}>\\n{example_output}\\n</{tag_name}>",
]
)
elif container_format == OutputContainerFormat.MARKDOWN:
prompt_lines.extend(
[
f"CRITICAL: Your final {structured_format.value.upper()} output MUST be enclosed within a markdown code block (```).", # noqa: E501
f"The {structured_format.value.upper()} must be valid and complete.",
"Example format:",
"<think>\\nMy reasoning here...\\n</think>\\n",
f"```{structured_format.value}\\n{example_output}\\n```",
]
)
elif container_format == OutputContainerFormat.NONE:
prompt_lines.extend(
[
f"CRITICAL: Your final {structured_format.value.upper()} output should be provided directly after the closing </think> tag, with no surrounding tags or markdown.", # noqa: E501
f"The {structured_format.value.upper()} must be valid and complete.",
"Example format:",
"<think>\\nMy reasoning here...\\n</think>\\n",
example_output,
]
)
prompt_lines.extend(
[
f"Ensure the generated {structured_format.value.upper()} strictly adheres to the Pydantic model schema and any specific field " # noqa: E501
"requirements provided in the user prompt. Generate all required fields for the model, and " # noqa: E501
"include optional fields if they make sense in the context or are specified." # noqa: E501
]
)
return "\\n".join(prompt_lines)
@classmethod
def config_init(cls) -> Tuple[PydanticEnvConfig, List[APIServerConfig]]:
"""Initialize configuration for the environment."""
env_config = PydanticEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=250,
batch_size=1024,
steps_per_eval=20,
max_num_workers=16,
max_token_length=1024 * 12,
inference_weight=1.0,
wandb_name="pydantic_schema_following",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
dataset_name="justus27/pydantic-adherance-test",
dataset_split="train",
debug_logging=False,
dump_rollouts=False,
allowed_structured_formats=[
StructuredOutputFormat.JSON,
StructuredOutputFormat.YAML,
StructuredOutputFormat.TOML,
],
allowed_container_formats=[
OutputContainerFormat.TAGGED,
OutputContainerFormat.NONE,
OutputContainerFormat.MARKDOWN,
],
eval_set_percentage=0.005,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
]
return env_config, server_configs
def _create_pydantic_model_from_code(
self, pydantic_config: str, model_name: str
) -> Type[BaseModel]:
"""
Dynamically create a Pydantic model from the provided configuration code.
This executes the pydantic_config string and extracts the target model.
"""
if self.debug_logging:
self.logger.debug(f"Creating Pydantic model '{model_name}' from config")
self.logger.debug(f"Config length: {len(pydantic_config)} characters")
# Check cache first
cache_key = f"{model_name}_{hash(pydantic_config)}"
if cache_key in self.model_cache:
if self.debug_logging:
self.logger.debug(f"Model '{model_name}' found in cache")
return self.model_cache[cache_key]
if self.debug_logging:
self.logger.debug(
f"Model '{model_name}' not in cache, creating new instance"
)
# Create a namespace for executing the pydantic config
namespace = {
"BaseModel": BaseModel,
"model_validator": model_validator,
"ConfigDict": ConfigDict,
"ValidationError": ValidationError,
"HttpUrl": HttpUrl,
"EmailStr": EmailStr,
"Field": Field,
"field_validator": field_validator,
"List": List,
"Dict": Dict,
"Optional": Optional,
"Union": Union,
"Any": Any,
"Literal": getattr(__import__("typing"), "Literal", None),
"datetime": datetime,
"date": date,
"time": getattr(__import__("datetime"), "time"),
"timedelta": timedelta,
"Enum": Enum,
"Decimal": Decimal,
"UUID": UUID,
# Add common imports that might be needed
"typing": __import__("typing"),
"json": json,
"re": re,
}
try:
# Execute the pydantic configuration code
if self.debug_logging:
self.logger.debug(f"Executing pydantic config for model '{model_name}'")
exec(pydantic_config, namespace)
# Extract the target model class
if model_name in namespace:
model_class = namespace[model_name]
self.model_cache[cache_key] = model_class
if self.debug_logging:
self.logger.debug(
f"Successfully created and cached model '{model_name}'"
)
self.logger.debug(
f"Model fields: {list(model_class.model_fields.keys())}"
)
return model_class
else:
error_msg = (
f"Model '{model_name}' not found in the executed pydantic config"
)
if self.debug_logging:
self.logger.error(error_msg)
self.logger.debug(
f"Available classes in namespace: {[k for k in namespace.keys() if isinstance(namespace[k], type)]}" # noqa: E501
)
raise ValueError(error_msg)
except Exception as e:
error_msg = f"Error creating Pydantic model '{model_name}': {e}"
if self.debug_logging:
self.logger.error(error_msg)
self.logger.debug(f"Pydantic config that failed:\n{pydantic_config}")
print(error_msg)
raise
async def setup(self):
"""Load the dataset and prepare tasks."""
if self.debug_logging:
self.logger.info("Starting environment setup")
try:
# Load the dataset - you'll need to specify the correct dataset name/path
dataset_name = getattr(
self.config, "dataset_name", "justus27/pydantic-adherance-test"
)
dataset_split = getattr(self.config, "dataset_split", "train")
if self.debug_logging:
self.logger.info(
f"Loading dataset: {dataset_name}, split: {dataset_split}"
)
# Load your dataset
dataset = load_dataset(dataset_name, split=dataset_split)
if self.debug_logging:
self.logger.info(
f"Dataset loaded successfully. Total items: {len(dataset)}"
)
# Convert to list for easier handling
self.dataset_items = list(dataset)
if self.debug_logging:
self.logger.debug(
f"Sample dataset item keys: {list(self.dataset_items[0].keys()) if self.dataset_items else 'No items'}" # noqa: E501
)
if self.dataset_items:
sample_item = self.dataset_items[0]
self.logger.debug(
f"Sample problem_id: {sample_item.get('problem_id', 'N/A')}"
)
self.logger.debug(
f"Sample task_type: {sample_item.get('task_type', 'N/A')}"
)
self.logger.debug(
f"Sample prompt length: {len(sample_item.get('prompt', ''))}"
)
# Shuffle the dataset
random.shuffle(self.dataset_items)
if self.debug_logging:
self.logger.debug("Dataset shuffled")
# Split into train and test
split_idx = int(
len(self.dataset_items) * (1.0 - self.config.eval_set_percentage)
)
self.train_items = self.dataset_items[:split_idx]
self.test_items = self.dataset_items[split_idx:]
self.iter = 0
if self.debug_logging:
self.logger.info(
f"Dataset split complete: {len(self.train_items)} training items, {len(self.test_items)} test items"
)
self.logger.info("Environment setup complete")
print(
f"PydanticSchemaFollowingEnv setup complete. {len(self.train_items)} training items, {len(self.test_items)} test items." # noqa: E501
)
except Exception as e:
error_msg = f"Error during setup: {e}"
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
# Fallback to empty lists if dataset loading fails
self.train_items = []
self.test_items = []
self.iter = 0
async def get_next_item(self) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
"""Get the next training item from the dataset."""
if not self.train_items:
error_msg = "No training items available. Setup might have failed or dataset is empty."
if self.debug_logging:
self.logger.error(error_msg)
raise ValueError(error_msg)
# Get the next item cyclically
dataset_item = self.train_items[self.iter % len(self.train_items)]
if self.debug_logging:
self.logger.debug(
f"Getting item {self.iter % len(self.train_items)} (iteration {self.iter})"
)
self.logger.debug(
f"Item problem_id: {dataset_item.get('problem_id', 'N/A')}"
)
self.logger.debug(f"Item task_type: {dataset_item.get('task_type', 'N/A')}")
# Parse verification info to get model name
try:
verification_info = dataset_item.get("verification_info", "{}")
verification_data = json.loads(verification_info)
model_name = verification_data.get("model_name", "Unknown")
self.logger.debug(f"Target model: {model_name}")
except (json.JSONDecodeError, KeyError):
self.logger.debug("Could not parse model name from verification_info")
self.iter += 1
# Check if this is an editing or generation task
task_type = dataset_item.get("task_type", "generation")
if task_type == "editing":
return self._create_editing_item(dataset_item)
else:
return self._create_generation_item(dataset_item)
def _extract_structured_data_response(
self,
text: str,
container_format: OutputContainerFormat,
structured_format: StructuredOutputFormat,
) -> Optional[str]:
"""Extracts structured data content based on container and structured format, with strict thinking tag validation.""" # noqa: E501
if self.debug_logging:
self.logger.debug(
f"Extracting {structured_format.value} from response (length: {len(text)}), container: {container_format.value}" # noqa: E501
)
# Ensure text is a string
if not isinstance(text, str):
if self.debug_logging:
self.logger.warning(
f"Expected string but got {type(text)}, converting to string"
)
text = str(text)
# 1. Validate thinking tags
think_tags = re.findall(r"<think>", text, re.IGNORECASE)
think_close_tags = re.findall(r"</think>", text, re.IGNORECASE)
if len(think_tags) != 1 or len(think_close_tags) != 1:
if self.debug_logging:
self.logger.warning(
f"Invalid thinking tag structure: {len(think_tags)} open tags, {len(think_close_tags)} close tags. Full text: {text[:500]}" # noqa: E501
)
return None
# Split the text into thinking and response sections
# Use a regex that captures the content before, between, and after think tags robustly
match_think_block = re.match(
r"(.*?)(<think>.*?</think>)(.*)", text, re.DOTALL | re.IGNORECASE
)
if not match_think_block:
if self.debug_logging:
self.logger.warning(
f"Could not find a complete <think>...</think> block. Full text: {text[:500]}"
)
return None
# thinking_section_plus_prefix = match_think_block.group(1) # Content before <think>
# thinking_block_content = match_think_block.group(2) # <think>...</think>
response_section = match_think_block.group(3) # Content after </think>
# Validate that the first <think> tag is indeed the one we captured (no <think> before it)
if "<think>" in match_think_block.group(1).lower():
if self.debug_logging:
self.logger.warning(
f"Nested or malformed <think> tags detected before main block. Full text: {text[:500]}"
)
return None
# Check if there are any thinking tags in the response section (after </think>)
if "<think>" in response_section.lower():
if self.debug_logging:
self.logger.warning(
f"Found <think> tags in response section after </think>. Full text: {text[:500]}"
)
return None
extracted_content: Optional[str] = None
# 2. Extract based on container_format from response_section
if container_format == OutputContainerFormat.TAGGED:
tag_name = f"{structured_format.value}_output"
# Loosen regex to allow for attributes in the opening tag if any, and handle varying whitespace
pattern = rf"<{tag_name}[^>]*>\s*(.*?)\s*</{tag_name}>"
match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE)
if match:
extracted_content = match.group(1)
if self.debug_logging:
self.logger.debug(
f"Extracted using TAGGED ({tag_name}): {len(extracted_content)} chars"
)
else:
if self.debug_logging:
self.logger.warning(
f"TAGGED format: Could not find <{tag_name}>...</{tag_name}> tags in response section. Response section: {response_section[:300]}" # noqa: E501
)
elif container_format == OutputContainerFormat.MARKDOWN:
# Pattern to match ```language ... ``` or just ``` ... ```
# It captures the content within the backticks.
# It optionally matches the language specifier.
pattern = rf"^\s*```(?:{re.escape(structured_format.value)})?\s*\n(.*?)\n\s*```\s*$"
match = re.search(
pattern, response_section.strip(), re.DOTALL | re.IGNORECASE
)
if match:
extracted_content = match.group(1)
if self.debug_logging:
self.logger.debug(
f"Extracted using MARKDOWN: {len(extracted_content)} chars"
)
else:
# Fallback for markdown: if ```<format> content ``` is not found, try ``` content ```
pattern_no_lang = r"^\s*```\s*\n(.*?)\n\s*```\s*$"
match_no_lang = re.search(
pattern_no_lang, response_section.strip(), re.DOTALL | re.IGNORECASE
)
if match_no_lang:
extracted_content = match_no_lang.group(1)
if self.debug_logging:
self.logger.debug(
f"Extracted using MARKDOWN (no lang specified): {len(extracted_content)} chars"
)
else:
if self.debug_logging:
self.logger.warning(
f"MARKDOWN format: Could not find ```...``` code block in response section. Response section: {response_section[:300]}" # noqa: E501
)
elif container_format == OutputContainerFormat.NONE:
extracted_content = response_section
if self.debug_logging:
self.logger.debug(
f"Extracted using NONE: {len(extracted_content)} chars"
)
# 3. Post-processing
if extracted_content is not None:
extracted_content = extracted_content.strip()
if not extracted_content: # Empty after stripping
if self.debug_logging:
self.logger.warning("Extracted content is empty after stripping.")
return None
if self.debug_logging:
self.logger.debug(
f"Successfully extracted content ({len(extracted_content)} chars). First 100: {extracted_content[:100]}" # noqa: E501
)
return extracted_content
else:
if self.debug_logging:
self.logger.warning(
f"Extraction failed for container type {container_format.value}. No content extracted."
)
return None
async def score(
self,
rollout_group_data: List[Dict[str, Any]],
) -> Optional[ScoredDataGroup]:
"""Score the rollouts based on Pydantic validation or other structural checks."""
if self.debug_logging:
self.logger.debug(f"Scoring {len(rollout_group_data)} rollouts")
scores_obj = ScoredDataGroup()
scores_obj["tokens"] = list()
scores_obj["masks"] = list()
scores_obj["scores"] = list()
scores_obj["messages"] = list()
scores_obj["inference_logprobs"] = list()
if not rollout_group_data:
if self.debug_logging:
self.logger.warning("No rollout data to score")
return None
dataset_item = rollout_group_data[0]["dataset_item"]
problem_id = dataset_item.get("problem_id", "N/A")
selected_structured_format = dataset_item["selected_structured_format"]
selected_container_format = dataset_item["selected_container_format"]
if self.debug_logging:
self.logger.debug(
f"Scoring for problem_id: {problem_id}, structured_format: {selected_structured_format.value}, container_format: {selected_container_format.value}" # noqa: E501
)
verification_info = dataset_item["verification_info"]
try:
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
model_name = verification_data["model_name"]
if self.debug_logging:
self.logger.debug(f"Target Pydantic model for validation: {model_name}")
except (json.JSONDecodeError, KeyError) as e:
error_msg = f"Error parsing verification_info for {problem_id}: {e}"
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
return None
try:
target_model_cls = self._create_pydantic_model_from_code(
pydantic_config, model_name
)
except Exception as e:
error_msg = (
f"Error creating Pydantic model {model_name} for {problem_id}: {e}"
)
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
return None
valid_count = 0
invalid_count = 0
extraction_failures = 0
parsing_failures = 0
random.shuffle(rollout_group_data)
for i, rollout_item in enumerate(rollout_group_data):
item_messages = rollout_item["messages"]
tokens = rollout_item["tokens"]
masks = rollout_item["masks"]
logprobs = rollout_item["logprobs"]
messages_as_dicts = [dict(fs_message) for fs_message in item_messages]
model_response_text = messages_as_dicts[-1]["content"]
extracted_str = self._extract_structured_data_response(
model_response_text,
selected_container_format,
selected_structured_format,
)
reward = 0.0
parsed_data = None
validation_error_msg = "Extraction failed"
if extracted_str:
try:
if selected_structured_format == StructuredOutputFormat.JSON:
parsed_data = json.loads(extracted_str)
elif selected_structured_format == StructuredOutputFormat.YAML:
parsed_data = yaml.safe_load(extracted_str)
elif selected_structured_format == StructuredOutputFormat.TOML:
parsed_data = toml.loads(extracted_str)
elif selected_structured_format == StructuredOutputFormat.XML:
try:
data_dict_outer = xmltodict.parse(extracted_str)
# Assumption: Pydantic model corresponds to the content *within* the single root XML tag.
# xmltodict.parse returns a dict like {'RootTag': actual_data_dict}.
# We extract actual_data_dict for validation. More complex XML might require
# specific xmltodict process_instructions or more sophisticated unwrapping.
if (
isinstance(data_dict_outer, dict)
and len(data_dict_outer) == 1
):
root_key = list(data_dict_outer.keys())[0]
parsed_data = data_dict_outer[root_key]
if self.debug_logging:
self.logger.debug(
f"Eval item (problem {problem_id}): XML parsed, root key '{root_key}', data: {str(parsed_data)[:100]}..." # noqa: E501
)
else:
raise ValueError(
f"XML from xmltodict.parse was not a dict with a single root key as expected. Got: {type(data_dict_outer)}, Keys: {list(data_dict_outer.keys()) if isinstance(data_dict_outer, dict) else 'N/A'}" # noqa: E501
)
except xmltodict.expat.ExpatError as e_xml_parse:
# No reward variable here, score is returned directly
if self.debug_logging:
self.logger.debug(
f"Eval item (problem {problem_id}): XML parsing failed (ExpatError): {e_xml_parse}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
)
return 0.0 # Return score directly
except Exception as e_xml_generic:
if self.debug_logging:
self.logger.debug(
f"Eval item (problem {problem_id}): XML processing error: {e_xml_generic}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
)
return 0.0 # Return score directly
if parsed_data is not None and selected_structured_format in [
StructuredOutputFormat.JSON,
StructuredOutputFormat.YAML,
StructuredOutputFormat.TOML,
StructuredOutputFormat.XML,
]:
is_valid, pydantic_error_msg = (
self._validate_parsed_data_against_model(
parsed_data, target_model_cls, problem_id
)
)
if is_valid:
reward = 1.0
valid_count += 1
validation_error_msg = None
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: Pydantic validation successful for {selected_structured_format.value}" # noqa: E501
)
else:
reward = 0.0
invalid_count += 1
validation_error_msg = pydantic_error_msg
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: Pydantic validation failed for {selected_structured_format.value}. Error: {pydantic_error_msg}" # noqa: E501
)
except json.JSONDecodeError as e_json:
reward = 0.0
parsing_failures += 1
validation_error_msg = f"JSON parsing failed: {e_json}"
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
)
except yaml.YAMLError as e_yaml:
reward = 0.0
parsing_failures += 1
validation_error_msg = f"YAML parsing failed: {e_yaml}"
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
)
except toml.TomlDecodeError as e_toml:
reward = 0.0
parsing_failures += 1
validation_error_msg = f"TOML parsing failed: {e_toml}"
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
)
except Exception as e_parse: # Catch any other parsing related errors
reward = 0.0
parsing_failures += 1
validation_error_msg = f"Generic parsing failed for {selected_structured_format.value}: {e_parse}"
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
)
else:
reward = 0.0 # No structured data output found or extraction failed
extraction_failures += 1
# validation_error_msg is already "Extraction failed"
if self.debug_logging and i < 3:
self.logger.debug(
f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501
)
# Remove examples with insufficient context
if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length
if self.debug_logging:
self.logger.debug(
f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length." # noqa: E501
)
continue
scores_obj["tokens"].append(tokens)
scores_obj["masks"].append(masks)
scores_obj["inference_logprobs"].append(logprobs)
scores_obj["scores"].append(reward)
# Store original messages (converted to dicts)
scores_obj["messages"].append(messages_as_dicts)
self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0)
if len(scores_obj["tokens"]) >= self.config.group_size:
break
if self.debug_logging:
self.logger.info(
f"Scoring complete for {problem_id} (Format: {selected_structured_format.value}, Container: {selected_container_format.value}): " # noqa: E501
f"{valid_count} valid (Pydantic), {invalid_count} invalid (Pydantic), "
f"{parsing_failures} parsing failures, {extraction_failures} extraction failures."
)
if scores_obj["scores"]:
avg_score = sum(scores_obj["scores"]) / len(scores_obj["scores"])
self.logger.info(
f"Average score for this batch ({problem_id}, Format: {selected_structured_format.value}, Container: {selected_container_format.value}): {avg_score:.3f}" # noqa: E501
)
if not scores_obj["tokens"]: # No valid examples processed
if self.debug_logging:
self.logger.warning(
f"No valid examples processed in this batch for {problem_id}"
)
return None
# This condition might need adjustment if 0.0 is a valid signal for non-Pydantic formats
if (
all(scores_obj["scores"][0] == score for score in scores_obj["scores"])
and scores_obj["scores"][0] != 1.0
):
if self.debug_logging:
self.logger.debug(
f"All scores are identical ({scores_obj['scores'][0]}) and not perfect for {problem_id}, returning None for learning signal." # noqa: E501
)
return None
if all(s == 1.0 for s in scores_obj["scores"]):
avg_len = sum(len(t) for t in scores_obj["tokens"]) / len(
scores_obj["tokens"]
)
if avg_len > self.config.max_token_length * 0.75:
scores_obj["scores"] = [s * 0.9 for s in scores_obj["scores"]]
if self.debug_logging:
self.logger.debug(
f"Applied length penalty for {problem_id}: avg_len={avg_len}, penalty_threshold={self.config.max_token_length * 0.75}" # noqa: E501
)
return scores_obj
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List]:
"""Collect trajectories for a given item."""
prompt_messages_tuple, dataset_item = item
if self.debug_logging:
self.logger.debug(
f"Collecting trajectories for problem_id: {dataset_item.get('problem_id', 'N/A')}"
)
# Convert frozensets to dicts for the API call
messages_for_api = [dict(fs_message) for fs_message in prompt_messages_tuple]
prompt_str = self.tokenizer.apply_chat_template(
messages_for_api, add_generation_prompt=True, tokenize=False
)
if self.debug_logging:
self.logger.debug(f"Generated prompt length: {len(prompt_str)} characters")
self.logger.debug(
f"Requesting {self.config.group_size} completions with max_tokens={self.config.max_token_length}, temperature=0.9" # noqa: E501
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.9,
)
state = managed.get_state()
nodes = state["nodes"]
if self.debug_logging:
self.logger.debug(
f"Received {len(completions.choices)} completions from server"
)
to_score_list = []
for i, choice in enumerate(completions.choices):
if self.debug_logging and i < 3: # Log first few completions
self.logger.debug(
f"Completion {i} length: {len(choice.text)} characters"
)
# Create a full message list for this choice
current_trajectory_messages = list(prompt_messages_tuple)
current_trajectory_messages.append(
frozenset({"role": "assistant", "content": choice.text}.items())
)
to_score_list.append(
{
"messages": tuple(current_trajectory_messages),
"dataset_item": dataset_item,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
scored_data = await self.score(to_score_list)
if self.debug_logging:
if scored_data:
self.logger.debug(
f"Scoring successful: {len(scored_data['scores'])} scored items"
)
else:
self.logger.warning("Scoring returned None")
# Log batch progress for data dumping
current_batch_progress = self.processed_item_count % 100
log_message_group_processed = (
f"GROUP_PROC - Item Iter: {self.iter-1}, Scored Data Present: {bool(scored_data)}, "
f"Dump Rollouts Cfg: {self.config.dump_rollouts}, "
f"Total Items Processed (for save): {self.processed_item_count}, Batch Counter: {current_batch_progress}/99"
)
if self.debug_logging:
self.logger.info(log_message_group_processed)
# Data dumping logic
if self.debug_logging:
self.logger.info(
f"COLLECT_TRAJ - dump_rollouts: {self.config.dump_rollouts}, "
f"processed_item_count: {self.processed_item_count}, "
f"current_buffer_size: {len(self.rollouts_to_save_buffer)}"
)
if scored_data and self.config.dump_rollouts:
rollouts_for_current_item = []
num_scored_rollouts = len(scored_data.get("scores", []))
conversation_messages_batch = scored_data.get("messages", [])
# Extract model info from dataset item
try:
verification_info = dataset_item.get("verification_info", "{}")
verification_data = json.loads(verification_info)
model_name = verification_data.get("model_name", "Unknown")
except (json.JSONDecodeError, KeyError):
model_name = "Unknown"
for i in range(num_scored_rollouts):
conversation_messages = (
conversation_messages_batch[i]
if i < len(conversation_messages_batch)
else []
)
score_for_rollout = scored_data["scores"][i]
# Extract the generated JSON from the assistant's response
if conversation_messages:
assistant_response = conversation_messages[-1].get("content", "")
generated_json = self._extract_structured_data_response(
assistant_response,
dataset_item["selected_container_format"],
dataset_item["selected_structured_format"],
)
else:
generated_json = None
rollouts_for_current_item.append(
{
"conversation": conversation_messages,
"score": score_for_rollout,
"expected_json": dataset_item.get("verification_info", ""),
"generated_json": generated_json,
"model_name": model_name,
"problem_id": dataset_item.get("problem_id", "N/A"),
"task_type": dataset_item.get("task_type", "N/A"),
}
)
if rollouts_for_current_item:
# Use problem_id as the source item ID
source_item_id = dataset_item.get("problem_id", f"item_{self.iter-1}")
item_data_to_save = {
"item_id": source_item_id,
"rollouts": rollouts_for_current_item,
}
self.rollouts_to_save_buffer.append(item_data_to_save)
self.processed_item_count += 1
if self.debug_logging:
self.logger.debug(
f"Added {len(rollouts_for_current_item)} rollouts for item {source_item_id}"
)
# Save batch every 100 processed items
if (
self.config.dump_rollouts
and self.processed_item_count > 0
and self.processed_item_count % 100 == 0
):
log_msg = (
f"Reached {self.processed_item_count} processed items. "
f"Triggering save for {len(self.rollouts_to_save_buffer)} item groups."
)
if self.debug_logging:
self.logger.info(log_msg)
await self._save_rollouts_to_jsonl()
return scored_data, []
async def _save_rollouts_to_jsonl(self):
"""Saves the buffered rollouts to a JSONL file in the datadumps directory."""
if not self.rollouts_to_save_buffer:
if self.debug_logging:
self.logger.info("No rollouts in buffer to save.")
return
try:
if not os.path.exists(self.datadumps_dir):
os.makedirs(self.datadumps_dir)
if self.debug_logging:
self.logger.info(f"Created directory: {self.datadumps_dir}")
except OSError as e:
error_msg = f"Error creating directory {self.datadumps_dir}: {e}"
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
return
file_path = os.path.join(
self.datadumps_dir,
f"pydantic_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
)
try:
with open(file_path, "w") as f:
for rollout_dict in self.rollouts_to_save_buffer:
json.dump(rollout_dict, f)
f.write("\n")
success_msg = f"Successfully saved {len(self.rollouts_to_save_buffer)} rollouts to {file_path}"
if self.debug_logging:
self.logger.info(success_msg)
print(success_msg)
self.rollouts_to_save_buffer.clear()
self.save_file_batch_num += 1
except IOError as e:
error_msg = f"Error writing rollouts to {file_path}: {e}"
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
except Exception as e:
error_msg = f"Unexpected error saving rollouts to {file_path}: {e}"
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
async def rollout_and_score_eval(self, dataset_item: Dict[str, Any]) -> float:
"""Evaluate a single item from the test set with randomized formats."""
problem_id = dataset_item.get("problem_id", "N/A")
# Randomly select formats for evaluation consistency with training
selected_structured_format = random.choice(self.supported_structured_formats)
selected_container_format = random.choice(self.supported_container_formats)
# Store them in dataset_item if needed for logging or other parts, though not strictly for this function's direct logic # noqa: E501
dataset_item["selected_structured_format"] = selected_structured_format
dataset_item["selected_container_format"] = selected_container_format
if self.debug_logging:
self.logger.debug(
f"Evaluating item: {problem_id}, structured: {selected_structured_format.value}, container: {selected_container_format.value}" # noqa: E501
)
user_content = dataset_item["prompt"]
current_system_prompt = self._generate_system_prompt(
selected_structured_format, selected_container_format
)
messages = [
{"role": "system", "content": current_system_prompt},
{"role": "user", "content": user_content},
]
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
if self.debug_logging:
self.logger.debug(
f"Eval prompt length for {problem_id}: {len(prompt)} characters"
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.1,
split="eval",
)
model_response_text = completion.choices[0].text
if self.debug_logging:
self.logger.debug(
f"Eval response length for {problem_id}: {len(model_response_text)} characters"
)
extracted_str = self._extract_structured_data_response(
model_response_text, selected_container_format, selected_structured_format
)
score = 0.0
if not extracted_str:
if self.debug_logging:
self.logger.debug(
f"Eval extraction failed for {problem_id} (Format: {selected_structured_format.value}, Container: {selected_container_format.value})" # noqa: E501
)
return 0.0
# Attempt to parse and validate
try:
verification_info = dataset_item["verification_info"]
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
model_name = verification_data["model_name"]
if self.debug_logging:
self.logger.debug(
f"Eval target Pydantic model for {problem_id}: {model_name}"
)
target_model_cls = self._create_pydantic_model_from_code(
pydantic_config, model_name
)
parsed_data = None
if selected_structured_format == StructuredOutputFormat.JSON:
parsed_data = json.loads(extracted_str)
elif selected_structured_format == StructuredOutputFormat.YAML:
parsed_data = yaml.safe_load(extracted_str)
elif selected_structured_format == StructuredOutputFormat.TOML:
parsed_data = toml.loads(extracted_str)
elif selected_structured_format == StructuredOutputFormat.XML:
try:
data_dict_outer = xmltodict.parse(extracted_str)
# Assumption: Pydantic model corresponds to the content *within* the single root XML tag.
# xmltodict.parse returns a dict like {'RootTag': actual_data_dict}.
# We extract actual_data_dict for validation. More complex XML might require
# specific xmltodict process_instructions or more sophisticated unwrapping.
if isinstance(data_dict_outer, dict) and len(data_dict_outer) == 1:
root_key = list(data_dict_outer.keys())[0]
parsed_data = data_dict_outer[root_key]
if self.debug_logging:
self.logger.debug(
f"Eval item (problem {problem_id}): XML parsed, root key '{root_key}', data: {str(parsed_data)[:100]}..." # noqa: E501
)
else:
raise ValueError(
f"XML from xmltodict.parse was not a dict with a single root key as expected. Got: {type(data_dict_outer)}, Keys: {list(data_dict_outer.keys()) if isinstance(data_dict_outer, dict) else 'N/A'}" # noqa: E501
)
except xmltodict.expat.ExpatError as e_xml_parse:
# No reward variable here, score is returned directly
if self.debug_logging:
self.logger.debug(
f"Eval item (problem {problem_id}): XML parsing failed (ExpatError): {e_xml_parse}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
)
return 0.0 # Return score directly
except Exception as e_xml_generic:
if self.debug_logging:
self.logger.debug(
f"Eval item (problem {problem_id}): XML processing error: {e_xml_generic}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
)
return 0.0 # Return score directly
if parsed_data is not None and selected_structured_format in [
StructuredOutputFormat.JSON,
StructuredOutputFormat.YAML,
StructuredOutputFormat.TOML,
StructuredOutputFormat.XML,
]:
is_valid, pydantic_error_msg = self._validate_parsed_data_against_model(
parsed_data, target_model_cls, problem_id
)
if is_valid:
score = 1.0
if self.debug_logging:
self.logger.debug(
f"Eval Pydantic validation successful for {problem_id} (Format: {selected_structured_format.value})" # noqa: E501
)
else:
score = 0.0 # Already 0.0 by default unless validation passes
if self.debug_logging:
self.logger.debug(
f"Eval Pydantic validation failed for {problem_id} (Format: {selected_structured_format.value}): {pydantic_error_msg}" # noqa: E501
)
except json.JSONDecodeError as je:
if self.debug_logging:
self.logger.debug(
f"Eval JSON parsing failed for {problem_id} (Format: {selected_structured_format.value}): {je}. Extracted: {extracted_str[:100]}..." # noqa: E501
)
score = 0.0
except yaml.YAMLError as ye:
if self.debug_logging:
self.logger.debug(
f"Eval YAML parsing failed for {problem_id} (Format: {selected_structured_format.value}): {ye}. Extracted: {extracted_str[:100]}..." # noqa: E501
)
score = 0.0
except toml.TomlDecodeError as te:
if self.debug_logging:
self.logger.debug(
f"Eval TOML parsing failed for {problem_id} (Format: {selected_structured_format.value}): {te}. Extracted: {extracted_str[:100]}..." # noqa: E501
)
score = 0.0
except (
Exception
) as e: # Catch errors in model creation, other parsing, or validation setup
score = (
0.0 # Ensure score is 0.0 for any other unexpected error in this block
)
if self.debug_logging:
self.logger.error(
f"Error during eval scoring pipeline for {problem_id} (Format: {selected_structured_format.value}): {type(e).__name__}: {e}" # noqa: E501
)
return score
async def evaluate(self, *args, **kwargs):
"""Run evaluation on the test set."""
if self.debug_logging:
self.logger.info("Starting evaluation")
if not self.test_items:
warning_msg = "No test items available for evaluation."
if self.debug_logging:
self.logger.warning(warning_msg)
print(warning_msg)
self.eval_metrics.append(("eval/percent_correct", 0.0))
return
# Use a subset for faster evaluation
items_to_eval = self.test_items[: min(len(self.test_items), 50)]
if self.debug_logging:
self.logger.info(f"Evaluating {len(items_to_eval)} items from test set")
eval_results = await tqdm_asyncio.gather(
*[self.rollout_and_score_eval(item) for item in items_to_eval]
)
# Calculate metrics
perfect_scores = sum(1 for score in eval_results if score == 1.0)
if eval_results:
avg_score = sum(eval_results) / len(eval_results)
percent_perfect = perfect_scores / len(eval_results)
else:
avg_score = 0.0
percent_perfect = 0.0
self.eval_metrics.append(("eval/avg_score", avg_score))
self.eval_metrics.append(("eval/percent_perfect", percent_perfect))
if self.debug_logging:
self.logger.info(
f"Evaluation complete: avg_score={avg_score:.3f}, percent_perfect={percent_perfect:.3f}"
)
self.logger.info(f"Perfect scores: {perfect_scores}/{len(eval_results)}")
print(
f"Evaluation complete. Avg Score: {avg_score:.3f}, Percent Perfect (1.0): {percent_perfect:.3f}"
)
async def add_rollouts_for_wandb(
self,
scored_data: Optional[ScoredDataGroup],
item: Item = None,
):
"""Add rollouts to wandb logging."""
if self.debug_logging:
self.logger.debug("Adding rollouts for wandb logging")
if scored_data is None or not scored_data["tokens"]:
if self.debug_logging:
self.logger.debug("No scored data to log to wandb")
return
# dataset_item contains selected_structured_format and selected_container_format
dataset_item = item[1] if item and len(item) > 1 else {}
problem_id = dataset_item.get("problem_id", "N/A")
selected_structured_format = dataset_item.get(
"selected_structured_format", StructuredOutputFormat.JSON
)
selected_container_format = dataset_item.get(
"selected_container_format", OutputContainerFormat.TAGGED
)
if self.debug_logging:
self.logger.debug(
f"Logging rollouts for problem_id: {problem_id}, structured: {selected_structured_format.value}, container: {selected_container_format.value}" # noqa: E501
)
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = len(scored_data["tokens"])
else:
num_keep = min(num_keep, len(scored_data["tokens"]))
if self.debug_logging:
self.logger.debug(
f"Keeping {num_keep} rollouts for logging for {problem_id}"
)
rollout_batch = []
for i in range(num_keep):
# The full_convo_text comes from scored_data["messages"], which should be the tokenized version.
# We need the raw model output for extraction.
# scored_data["messages"][i] should be the list of message dicts for the i-th rollout
# Reconstruct full_convo_text from messages if possible, or use decoded tokens as fallback
raw_conversation_messages = scored_data["messages"][i]
assistant_response_text = ""
if (
isinstance(raw_conversation_messages, list)
and len(raw_conversation_messages) > 0
):
# The last message should be the assistant's response
if (
isinstance(raw_conversation_messages[-1], dict)
and raw_conversation_messages[-1].get("role") == "assistant"
):
assistant_response_text = raw_conversation_messages[-1].get(
"content", ""
)
else: # Try to find assistant message if not last, or if format is unexpected
for msg in reversed(raw_conversation_messages):
if isinstance(msg, dict) and msg.get("role") == "assistant":
assistant_response_text = msg.get("content", "")
break
if not assistant_response_text: # Fallback if proper message not found
# This decodes the entire conversation, including system/user prompts
assistant_response_text = self.tokenizer.decode(
scored_data["tokens"][i], skip_special_tokens=True
)
if self.debug_logging:
self.logger.warning(
f"WandB: Could not get raw assistant message for {problem_id}, using full decoded tokens for extraction." # noqa: E501
)
extracted_output = self._extract_structured_data_response(
assistant_response_text,
selected_container_format,
selected_structured_format,
)
try:
verification_info = dataset_item.get("verification_info", "{}")
verification_data = json.loads(verification_info)
pydantic_model_name = verification_data.get("model_name", "N/A")
expected_schema_info = verification_data.get("pydantic_config", "N/A")
except (json.JSONDecodeError, KeyError):
pydantic_model_name = "N/A"
expected_schema_info = "N/A"
if self.debug_logging:
self.logger.debug(
f"Could not extract model name/schema for wandb logging for {problem_id}"
)
# Construct the full conversation text for display from the messages list
# This ensures the system prompt (which might be dynamic) is correctly shown.
display_convo_text = ""
if isinstance(raw_conversation_messages, list):
try:
display_convo_text = self.tokenizer.apply_chat_template(
raw_conversation_messages,
tokenize=False,
add_generation_prompt=False,
)
except Exception as e_tmpl:
if self.debug_logging:
self.logger.warning(
f"WandB: Error applying chat template for {problem_id}: {e_tmpl}. Falling back to joining content." # noqa: E501
)
display_convo_text = "\n---\n".join(
[
str(msg.get("content", ""))
for msg in raw_conversation_messages
]
)
else:
display_convo_text = assistant_response_text
rollout_batch.append(
(
display_convo_text,
scored_data["scores"][i],
pydantic_model_name,
problem_id,
dataset_item.get("task_type", "N/A"),
selected_structured_format.value,
selected_container_format.value,
(
extracted_output
if extracted_output
else "Extraction failed or no output"
),
expected_schema_info[:200]
+ ("..." if len(expected_schema_info) > 200 else ""),
)
)
if rollout_batch:
self.rollouts_for_wandb.append(rollout_batch)
if self.debug_logging:
self.logger.debug(
f"Added {len(rollout_batch)} rollouts to wandb queue for {problem_id}"
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
removed_count = 0
while len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
removed_count += 1
if self.debug_logging and removed_count > 0:
self.logger.debug(
f"Removed {removed_count} oldest rollout batch(es) from wandb queue"
)
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
"""Create wandb table for rollout visualization."""
if self.debug_logging:
self.logger.debug(
f"Creating wandb rollout table with {len(self.rollouts_for_wandb)} batches"
)
if self.rollouts_for_wandb:
table = wandb.Table(
columns=[
"full_conversation",
"score",
"pydantic_model_name",
"problem_id",
"task_type",
"selected_structured_format",
"selected_container_format",
"extracted_output",
"expected_schema_preview",
]
)
total_entries = 0
for group in self.rollouts_for_wandb:
for entry in group:
table.add_data(*entry)
total_entries += 1
wandb_metrics["train/rollouts"] = table
if self.debug_logging:
self.logger.debug(
f"Created wandb table with {total_entries} total entries"
)
self.rollouts_for_wandb = []
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to wandb."""
if self.debug_logging:
self.logger.debug("Logging metrics to wandb")
if wandb_metrics is None:
wandb_metrics = {}
if self.percent_correct_buffer:
percent_perfect = sum(self.percent_correct_buffer) / len(
self.percent_correct_buffer
)
wandb_metrics["train/percent_perfect"] = percent_perfect
if self.debug_logging:
self.logger.debug(
f"Train percent perfect: {percent_perfect:.3f} (from {len(self.percent_correct_buffer)} samples)"
)
else:
wandb_metrics["train/percent_perfect"] = 0.0
if self.debug_logging:
self.logger.debug("No percent_correct_buffer data available")
self.percent_correct_buffer = list()
# Add eval metrics
if self.eval_metrics:
if self.debug_logging:
self.logger.debug(
f"Adding {len(self.eval_metrics)} eval metrics to wandb"
)
for key, value in self.eval_metrics:
wandb_metrics[key] = value
if self.debug_logging:
self.logger.debug(f"Eval metric: {key} = {value}")
self.eval_metrics = list()
# Create rollout table (if any rollouts were collected)
wandb_metrics = await self.create_rollout_table(wandb_metrics)
if self.debug_logging:
self.logger.debug(f"Final wandb metrics: {list(wandb_metrics.keys())}")
await super().wandb_log(wandb_metrics)
async def close(self):
"""Clean up and save any remaining rollouts before exiting."""
if self.debug_logging:
self.logger.info(
"Closing PydanticSchemaFollowingEnv. Attempting to save any remaining rollouts..."
)
if self.config.dump_rollouts and self.rollouts_to_save_buffer:
if self.debug_logging:
self.logger.info(
f"Found {len(self.rollouts_to_save_buffer)} rollouts in buffer. Saving now."
)
await self._save_rollouts_to_jsonl()
else:
if self.debug_logging:
self.logger.info("No rollouts in buffer to save upon closing.")
# Call the superclass's close method more robustly
base_close_method = getattr(super(), "close", None)
if base_close_method and callable(base_close_method):
try:
if asyncio.iscoroutinefunction(base_close_method):
await base_close_method()
else:
base_close_method()
except Exception as e_super_close:
if self.debug_logging:
self.logger.error(f"Error during super().close(): {e_super_close}")
elif self.debug_logging:
self.logger.debug(
"No callable super().close() method found or it does not exist."
)
if self.debug_logging:
self.logger.info("PydanticSchemaFollowingEnv closed.")
def _validate_parsed_data_against_model(
self, parsed_data: Any, model_cls: Type[BaseModel], problem_id: str = "N/A"
) -> Tuple[bool, Optional[str]]:
"""
Validate parsed data (e.g., a dictionary) against a Pydantic model and return detailed error info.
Returns:
Tuple of (is_valid, error_message)
"""
try:
model_cls.model_validate(parsed_data)
return True, None
except ValidationError as ve:
try:
# Attempt to get structured error data, which is generally more robust
error_details = json.dumps(ve.errors(), indent=2)
# Truncate potentially long error details
max_detail_len = 250
if len(error_details) > max_detail_len:
error_details_str = f"{error_details[:max_detail_len]}..."
else:
error_details_str = error_details
error_msg = f"Pydantic validation failed for {problem_id} with {len(ve.errors())} error(s):\n{error_details_str}" # noqa: E501
except Exception as format_exc:
# Fallback if formatting ve.errors() fails for some reason
error_msg = f"Pydantic validation failed for {problem_id}. Error: {str(ve)[:250]}. (Additionally, formatting error details failed: {str(format_exc)[:50]})" # noqa: E501
if self.debug_logging:
self.logger.debug(error_msg)
return False, error_msg
except TypeError as te:
# Specifically check for the "model" keyword argument issue in ValidationError.__new__
if (
"ValidationError.__new__() got an unexpected keyword argument 'model'"
in str(te)
):
error_msg = f"Pydantic internal TypeError for {problem_id}: {str(te)[:200]}. This may indicate a Pydantic V1/V2 compatibility issue within the dynamic schema definition from the dataset." # noqa: E501
if self.debug_logging:
self.logger.error(
error_msg
) # Log as error due to its specific nature
else:
# Handle other TypeErrors normally
error_msg = f"Unexpected TypeError during validation for {problem_id}: {type(te).__name__}: {str(te)[:100]}" # noqa: E501
if self.debug_logging:
self.logger.debug(error_msg)
return False, error_msg
except Exception as e:
# Catch any other unexpected exceptions during validation
error_msg = f"Generic unexpected validation error for {problem_id}: {type(e).__name__}: {str(e)[:100]}"
if self.debug_logging:
self.logger.debug(error_msg)
return False, error_msg
def _generate_valid_data_for_model(
self, model_cls: Type[BaseModel]
) -> Dict[str, Any]:
"""
Generate realistic valid data for a Pydantic model using field constraints and semantic hints.
This method analyzes the model's fields, their types, constraints, and names to generate
appropriate data that will pass model validation.
"""
valid_data = {}
model_fields = model_cls.model_fields
for field_name, field_info in model_fields.items():
# Skip optional fields sometimes for variability
if not field_info.is_required() and random.random() < 0.3:
continue
field_value = self._generate_field_value(field_name, field_info)
if field_value is not None:
valid_data[field_name] = field_value
# Validate the generated data to ensure it's correct
try:
model_cls.model_validate(valid_data)
return valid_data
except ValidationError as e:
# If validation fails, provide minimal valid data
if self.debug_logging:
self.logger.warning(f"Generated data failed validation: {e}")
return self._generate_minimal_valid_data(model_cls)
def _generate_field_value(self, field_name: str, field_info) -> Any:
"""Generate a realistic value for a specific field based on its type and constraints."""
annotation = field_info.annotation
# Handle Union types (like Optional[T])
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
# Get the first non-None type
for arg in annotation.__args__:
if arg is not type(None):
annotation = arg
break
# Handle List types
if hasattr(annotation, "__origin__") and annotation.__origin__ is list:
item_type = annotation.__args__[0] if annotation.__args__ else str
return [
self._generate_value_for_type(item_type, f"{field_name}_item")
for _ in range(random.randint(1, 3))
]
# Handle Dict types
if hasattr(annotation, "__origin__") and annotation.__origin__ is dict:
key_type = annotation.__args__[0] if len(annotation.__args__) > 0 else str
value_type = annotation.__args__[1] if len(annotation.__args__) > 1 else str
return {
self._generate_value_for_type(
key_type, "key"
): self._generate_value_for_type(value_type, f"{field_name}_value")
for _ in range(random.randint(1, 2))
}
# Generate value based on type
return self._generate_value_for_type(annotation, field_name)
def _generate_value_for_type(
self, type_annotation: Type, field_name: str = ""
) -> Any:
"""Generate a value for a specific type, using field name as hint for semantic meaning."""
field_name_lower = field_name.lower()
# Special Pydantic types
if type_annotation is EmailStr or "email" in field_name_lower:
domains = ["example.com", "test.org", "sample.net", "demo.io"]
names = ["user", "admin", "test", "demo", "sample"]
return f"{random.choice(names)}{random.randint(1, 99)}@{random.choice(domains)}"
if (
type_annotation is HttpUrl
or "url" in field_name_lower
or "link" in field_name_lower
):
domains = ["example.com", "test.org", "sample.net"]
paths = ["api/v1", "users/profile", "data/export", "files/download"]
return f"https://{random.choice(domains)}/{random.choice(paths)}"
if (
type_annotation is UUID
or "uuid" in field_name_lower
or "id" in field_name_lower
):
return str(uuid.uuid4())
# Date/Time types
if type_annotation is datetime:
return datetime.now() - timedelta(days=random.randint(0, 365))
if type_annotation is date:
return (datetime.now() - timedelta(days=random.randint(0, 365))).date()
# Numeric types with semantic hints
if type_annotation is int or type_annotation is float:
if "age" in field_name_lower:
return random.randint(18, 80)
elif "year" in field_name_lower:
return random.randint(2000, 2024)
elif "score" in field_name_lower or "rating" in field_name_lower:
return random.randint(1, 100)
elif "price" in field_name_lower or "cost" in field_name_lower:
return round(random.uniform(10.0, 1000.0), 2)
elif "count" in field_name_lower or "quantity" in field_name_lower:
return random.randint(1, 100)
else:
return (
random.randint(1, 100)
if type_annotation is int
else round(random.uniform(1.0, 100.0), 2)
)
if type_annotation is Decimal:
return Decimal(str(round(random.uniform(1.0, 100.0), 2)))
if type_annotation is bool:
return random.choice([True, False])
# String types with semantic hints
if type_annotation is str:
if "name" in field_name_lower:
first_names = ["John", "Jane", "Alice", "Bob", "Charlie", "Diana"]
last_names = [
"Smith",
"Johnson",
"Williams",
"Brown",
"Jones",
"Garcia",
]
if "first" in field_name_lower:
return random.choice(first_names)
elif "last" in field_name_lower:
return random.choice(last_names)
else:
return f"{random.choice(first_names)} {random.choice(last_names)}"
elif "title" in field_name_lower:
titles = [
"Senior Developer",
"Product Manager",
"Data Scientist",
"UX Designer",
]
return random.choice(titles)
elif "description" in field_name_lower:
descriptions = [
"A comprehensive solution",
"High-quality product",
"Innovative approach",
"User-friendly design",
]
return random.choice(descriptions)
elif "address" in field_name_lower:
return f"{random.randint(100, 9999)} Main St, City, State {random.randint(10000, 99999)}"
elif "phone" in field_name_lower:
return f"+1-{random.randint(100, 999)}-{random.randint(100, 999)}-{random.randint(1000, 9999)}"
else:
return f"sample_{field_name_lower}_{random.randint(1, 999)}"
# Enum types
if hasattr(type_annotation, "__bases__") and Enum in type_annotation.__bases__:
return random.choice(list(type_annotation)).value
# Nested BaseModel
if (
hasattr(type_annotation, "__bases__")
and BaseModel in type_annotation.__bases__
):
return self._generate_valid_data_for_model(type_annotation)
# Fallback for unknown types
return f"generated_value_{random.randint(1, 999)}"
def _generate_minimal_valid_data(
self, model_cls: Type[BaseModel]
) -> Dict[str, Any]:
"""Generate minimal valid data when constraint-based generation fails."""
valid_data = {}
model_fields = model_cls.model_fields
for field_name, field_info in model_fields.items():
if field_info.is_required():
annotation = field_info.annotation
# Handle Union types (like Optional[T])
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
for arg in annotation.__args__:
if arg is not type(None):
annotation = arg
break
# Generate simple fallback values
if annotation is int:
valid_data[field_name] = 1
elif annotation is float:
valid_data[field_name] = 1.0
elif annotation is bool:
valid_data[field_name] = True
elif annotation is str or annotation is EmailStr:
valid_data[field_name] = "valid_example"
elif annotation is HttpUrl:
valid_data[field_name] = "https://example.com"
elif annotation is UUID:
valid_data[field_name] = str(uuid.uuid4())
elif annotation is datetime:
valid_data[field_name] = datetime.now()
elif annotation is date:
valid_data[field_name] = date.today()
else:
valid_data[field_name] = "fallback_value"
return valid_data
def _generate_editing_task_prompt(
self,
dataset_item: Dict[str, Any],
structured_format: StructuredOutputFormat,
container_format: OutputContainerFormat,
) -> str:
"""Generate prompt for editing tasks."""
# Get Pydantic schema and erroneous data
verification_info = dataset_item.get("verification_info", "{}")
erroneous_data = dataset_item.get("erroneous_data")
if not verification_info or erroneous_data is None:
raise ValueError(
"Editing task requires both 'verification_info' and 'erroneous_data' fields"
)
try:
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
except (json.JSONDecodeError, KeyError):
raise ValueError("Invalid verification_info format for editing task")
erroneous_data_str = json.dumps(erroneous_data, indent=2)
base_prompt = self._generate_system_prompt(structured_format, container_format)
editing_prompt = f"""
{base_prompt}
The following data contains errors and should be corrected to conform to the provided Pydantic model.
Pydantic Model:
```python
{pydantic_config}
```
Erroneous Data:
```json
{erroneous_data_str}
```
Please identify and fix the errors, then provide the corrected data in the requested format.
"""
return editing_prompt
def _create_editing_item(
self, dataset_item: Dict[str, Any]
) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
"""Create an editing task item with appropriate prompts."""
# Randomly select structured and container formats
selected_structured_format = random.choice(self.supported_structured_formats)
selected_container_format = random.choice(self.supported_container_formats)
# Store the selections in the dataset_item
dataset_item["selected_structured_format"] = selected_structured_format
dataset_item["selected_container_format"] = selected_container_format
if self.debug_logging:
self.logger.debug(
f"Creating editing task for problem_id: {dataset_item.get('problem_id', 'N/A')}"
)
self.logger.debug(
f"Selected structured format: {selected_structured_format.value}"
)
self.logger.debug(
f"Selected container format: {selected_container_format.value}"
)
# If erroneous_data is not already provided, generate it from a valid example
if "erroneous_data" not in dataset_item:
# Create valid data first, then introduce errors
verification_info = dataset_item.get("verification_info", "{}")
if verification_info:
try:
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
model_name = verification_data["model_name"]
# Create the Pydantic model
target_model_cls = self._create_pydantic_model_from_code(
pydantic_config, model_name
)
# Generate valid data using Pydantic constraints and field information
valid_data = self._generate_valid_data_for_model(target_model_cls)
# Introduce errors into the valid data using configuration
error_config = ErrorIntroductionConfig.from_env_config(
error_types_enabled=self.config.error_types_enabled,
max_errors_per_item=self.config.max_errors_per_item,
error_introduction_probability=self.config.error_introduction_probability,
error_introduction_seed=self.config.error_introduction_seed,
)
erroneous_data = introduce_error_for_pydantic(
valid_data, target_model_cls, config=error_config
)
if erroneous_data:
dataset_item["erroneous_data"] = erroneous_data
else:
# Fallback: use valid data with a simple error
dataset_item["erroneous_data"] = {
**valid_data,
"invalid_field": "should_not_exist",
}
except Exception as e:
if self.debug_logging:
self.logger.warning(f"Could not generate erroneous data: {e}")
# Use a simple fallback
dataset_item["erroneous_data"] = {"error": "could_not_generate"}
# Generate the editing task prompt
prompt_content = self._generate_editing_task_prompt(
dataset_item, selected_structured_format, selected_container_format
)
# Create the message structure
prompt_messages = [
frozenset({"role": "user", "content": prompt_content}.items()),
]
return tuple(prompt_messages), dataset_item
def _create_generation_item(
self, dataset_item: Dict[str, Any]
) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
"""Create a generation task item (existing logic extracted for clarity)."""
# Extract the prompt from the dataset item
user_content = dataset_item["prompt"]
if self.debug_logging:
self.logger.debug(f"Prompt length: {len(user_content)} characters")
# Randomly select structured and container formats
selected_structured_format = random.choice(self.supported_structured_formats)
selected_container_format = random.choice(self.supported_container_formats)
# Store the selections in the dataset_item
dataset_item["selected_structured_format"] = selected_structured_format
dataset_item["selected_container_format"] = selected_container_format
if self.debug_logging:
self.logger.debug(
f"Selected structured format: {selected_structured_format.value}"
)
self.logger.debug(
f"Selected container format: {selected_container_format.value}"
)
# Generate the system prompt
current_system_prompt = self._generate_system_prompt(
selected_structured_format, selected_container_format
)
# Create the message structure
prompt_messages = [
frozenset({"role": "system", "content": current_system_prompt}.items()),
frozenset({"role": "user", "content": user_content}.items()),
]
return tuple(prompt_messages), dataset_item
if __name__ == "__main__":
PydanticSchemaFollowingEnv.cli()