diff --git a/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py b/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py
index 3eb323a1..e9874ca4 100644
--- a/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py
+++ b/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py
@@ -44,7 +44,10 @@ from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
+import toml # Added import
import wandb
+import xmltodict # Added
+import yaml # Added import
from datasets import load_dataset
from pydantic import (
BaseModel,
@@ -69,20 +72,33 @@ from atroposlib.envs.base import (
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# System prompt for the LLM
-system_prompt = (
- "You are an AI assistant that generates JSON objects according to Pydantic schemas.\n"
- "You may use extremely long chains of thought to deeply consider the problem and deliberate "
- "with yourself via systematic reasoning processes to help come to a correct solution prior to answering. "
- "You should enclose your thoughts and internal monologue inside tags.\n\n"
- "CRITICAL: Your final JSON output MUST be enclosed within tags.\n"
- "The JSON must be valid and complete. Do not include any text after the closing tag.\n"
- "Example format:\n"
- "\nMy reasoning here...\n\n\n"
- '\n{"field1": "value1", "field2": "value2"}\n\n\n'
- "Ensure the generated JSON strictly adheres to the Pydantic model schema and any specific field "
- "requirements provided in the user prompt. Generate all required fields for the model, and "
- "include optional fields if they make sense in the context or are specified."
-)
+# system_prompt = (
+# "You are an AI assistant that generates JSON objects according to Pydantic schemas.\\n"
+# "You may use extremely long chains of thought to deeply consider the problem and deliberate "
+# "with yourself via systematic reasoning processes to help come to a correct solution prior to answering. "
+# "You should enclose your thoughts and internal monologue inside tags.\\n\\n"
+# "CRITICAL: Your final JSON output MUST be enclosed within tags.\\n"
+# "The JSON must be valid and complete. Do not include any text after the closing tag.\\n"
+# "Example format:\\n"
+# "\\nMy reasoning here...\\n\\n\\n"
+# '\\n{"field1": "value1", "field2": "value2"}\\n\\n\\n'
+# "Ensure the generated JSON strictly adheres to the Pydantic model schema and any specific field "
+# "requirements provided in the user prompt. Generate all required fields for the model, and "
+# "include optional fields if they make sense in the context or are specified."
+# )
+
+
+class StructuredOutputFormat(Enum):
+ JSON = "json"
+ YAML = "yaml"
+ TOML = "toml"
+ XML = "xml"
+
+
+class OutputContainerFormat(Enum):
+ TAGGED = "tagged" # e.g., ...
+ NONE = "none" # Raw output
+ MARKDOWN = "markdown" # e.g., ```json ... ```
class PydanticEnvConfig(BaseEnvConfig):
@@ -106,6 +122,20 @@ class PydanticEnvConfig(BaseEnvConfig):
default=True,
description="Whether to include messages in the dataset for SFT data generation",
)
+ allowed_structured_formats: Optional[List[StructuredOutputFormat]] = Field(
+ default=None,
+ description="Optional list of StructuredOutputFormat enums to use for randomization. If None or empty, all supported formats are used.", # noqa: E501
+ )
+ allowed_container_formats: Optional[List[OutputContainerFormat]] = Field(
+ default=None,
+ description="Optional list of OutputContainerFormat enums to use for randomization. If None or empty, all supported formats are used.", # noqa: E501
+ )
+ eval_set_percentage: float = Field(
+ default=0.1,
+ description="Percentage of the dataset to use for the evaluation set (e.g., 0.1 for 10%).",
+ ge=0.0,
+ le=1.0,
+ )
class PydanticSchemaFollowingEnv(BaseEnv):
@@ -119,15 +149,8 @@ class PydanticSchemaFollowingEnv(BaseEnv):
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
- self.percent_correct_buffer = list() # Tracks 1.0 scores
- self.eval_metrics = list()
- self.rollouts_for_wandb = []
- self.dataset_items: List[Dict[str, Any]] = []
- self.model_cache: Dict[str, Type[BaseModel]] = (
- {}
- ) # Cache for dynamically created models
- # Set up debug logging
+ # Set up debug logging FIRST, as it's used by subsequent setup logic
self.debug_logging = getattr(config, "debug_logging", True)
if self.debug_logging:
self.logger = logging.getLogger(f"{self.__class__.__name__}")
@@ -144,6 +167,56 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.logger = logging.getLogger(f"{self.__class__.__name__}")
self.logger.addHandler(logging.NullHandler())
+ self.percent_correct_buffer = list() # Tracks 1.0 scores
+ self.eval_metrics = list()
+ self.rollouts_for_wandb = []
+ self.dataset_items: List[Dict[str, Any]] = []
+ self.model_cache: Dict[str, Type[BaseModel]] = (
+ {}
+ ) # Cache for dynamically created models
+
+ # Determine supported formats based on config or defaults
+ if (
+ config.allowed_structured_formats
+ and len(config.allowed_structured_formats) > 0
+ ):
+ self.supported_structured_formats = config.allowed_structured_formats
+ if self.debug_logging:
+ self.logger.info(
+ f"Using configured structured formats: {[f.value for f in self.supported_structured_formats]}"
+ )
+ else:
+ self.supported_structured_formats: List[StructuredOutputFormat] = [
+ StructuredOutputFormat.JSON,
+ StructuredOutputFormat.YAML,
+ StructuredOutputFormat.TOML,
+ StructuredOutputFormat.XML,
+ ]
+ if self.debug_logging:
+ self.logger.info(
+ f"Using default structured formats: {[f.value for f in self.supported_structured_formats]}"
+ )
+
+ if (
+ config.allowed_container_formats
+ and len(config.allowed_container_formats) > 0
+ ):
+ self.supported_container_formats = config.allowed_container_formats
+ if self.debug_logging:
+ self.logger.info(
+ f"Using configured container formats: {[f.value for f in self.supported_container_formats]}"
+ )
+ else:
+ self.supported_container_formats: List[OutputContainerFormat] = [
+ OutputContainerFormat.TAGGED,
+ OutputContainerFormat.NONE,
+ OutputContainerFormat.MARKDOWN,
+ ]
+ if self.debug_logging:
+ self.logger.info(
+ f"Using default container formats: {[f.value for f in self.supported_container_formats]}"
+ )
+
# Data dumping setup
self.run_uuid = str(uuid.uuid4())
@@ -167,6 +240,70 @@ class PydanticSchemaFollowingEnv(BaseEnv):
if config.dump_rollouts:
self.logger.info(f"Rollouts will be saved to: {self.datadumps_dir}")
+ def _generate_system_prompt(
+ self,
+ structured_format: StructuredOutputFormat,
+ container_format: OutputContainerFormat,
+ ) -> str:
+ """Generates a system prompt tailored to the selected output and container formats."""
+ prompt_lines = [
+ f"You are an AI assistant that generates structured data in {structured_format.value.upper()} format according to Pydantic schemas.", # noqa: E501
+ "You may use extremely long chains of thought to deeply consider the problem and deliberate "
+ "with yourself via systematic reasoning processes to help come to a correct solution prior to answering.",
+ "You should enclose your thoughts and internal monologue inside tags.",
+ ]
+
+ example_output = ""
+ if structured_format == StructuredOutputFormat.JSON:
+ example_output = '{\\n "field1": "value1",\\n "field2": "value2"\\n}'
+ elif structured_format == StructuredOutputFormat.YAML:
+ example_output = "field1: value1\\nfield2: value2"
+ elif structured_format == StructuredOutputFormat.TOML:
+ example_output = 'field1 = "value1"\\nfield2 = "value2"'
+ elif structured_format == StructuredOutputFormat.XML:
+ example_output = "\\n value1\\n value2\\n" # Assuming root tag is model name # noqa: E501
+
+ if container_format == OutputContainerFormat.TAGGED:
+ tag_name = f"{structured_format.value}_output"
+ prompt_lines.extend(
+ [
+ f"CRITICAL: Your final {structured_format.value.upper()} output MUST be enclosed within <{tag_name}> {tag_name}> tags.", # noqa: E501
+ f"The {structured_format.value.upper()} must be valid and complete. Do not include any text after the closing {tag_name}> tag.", # noqa: E501
+ "Example format:",
+ "\\nMy reasoning here...\\n\\n",
+ f"<{tag_name}>\\n{example_output}\\n{tag_name}>",
+ ]
+ )
+ elif container_format == OutputContainerFormat.MARKDOWN:
+ prompt_lines.extend(
+ [
+ f"CRITICAL: Your final {structured_format.value.upper()} output MUST be enclosed within a markdown code block (```).", # noqa: E501
+ f"The {structured_format.value.upper()} must be valid and complete.",
+ "Example format:",
+ "\\nMy reasoning here...\\n\\n",
+ f"```{structured_format.value}\\n{example_output}\\n```",
+ ]
+ )
+ elif container_format == OutputContainerFormat.NONE:
+ prompt_lines.extend(
+ [
+ f"CRITICAL: Your final {structured_format.value.upper()} output should be provided directly after the closing tag, with no surrounding tags or markdown.", # noqa: E501
+ f"The {structured_format.value.upper()} must be valid and complete.",
+ "Example format:",
+ "\\nMy reasoning here...\\n\\n",
+ example_output,
+ ]
+ )
+
+ prompt_lines.extend(
+ [
+ f"Ensure the generated {structured_format.value.upper()} strictly adheres to the Pydantic model schema and any specific field " # noqa: E501
+ "requirements provided in the user prompt. Generate all required fields for the model, and " # noqa: E501
+ "include optional fields if they make sense in the context or are specified." # noqa: E501
+ ]
+ )
+ return "\\n".join(prompt_lines)
+
@classmethod
def config_init(cls) -> Tuple[PydanticEnvConfig, List[APIServerConfig]]:
"""Initialize configuration for the environment."""
@@ -175,9 +312,10 @@ class PydanticSchemaFollowingEnv(BaseEnv):
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
- total_steps=2000,
+ total_steps=250,
batch_size=1024,
steps_per_eval=20,
+ max_num_workers=16,
max_token_length=1024 * 12,
inference_weight=1.0,
wandb_name="pydantic_schema_following",
@@ -185,9 +323,19 @@ class PydanticSchemaFollowingEnv(BaseEnv):
eval_limit_ratio=0.1,
dataset_name="justus27/pydantic-adherance-test",
dataset_split="train",
- debug_logging=True, # Enable debug logging by default
- dump_rollouts=True, # Enable data dumping by default
- include_messages=True, # Ensure messages are included for SFT data generation
+ debug_logging=False,
+ dump_rollouts=False,
+ allowed_structured_formats=[
+ StructuredOutputFormat.JSON,
+ StructuredOutputFormat.YAML,
+ StructuredOutputFormat.TOML,
+ ],
+ allowed_container_formats=[
+ OutputContainerFormat.TAGGED,
+ OutputContainerFormat.NONE,
+ OutputContainerFormat.MARKDOWN,
+ ],
+ eval_set_percentage=0.005,
)
server_configs = [
APIServerConfig(
@@ -339,7 +487,9 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.logger.debug("Dataset shuffled")
# Split into train and test
- split_idx = int(len(self.dataset_items) * 0.90) # 90% train, 10% test
+ split_idx = int(
+ len(self.dataset_items) * (1.0 - self.config.eval_set_percentage)
+ )
self.train_items = self.dataset_items[:split_idx]
self.test_items = self.dataset_items[split_idx:]
@@ -401,19 +551,47 @@ class PydanticSchemaFollowingEnv(BaseEnv):
if self.debug_logging:
self.logger.debug(f"Prompt length: {len(user_content)} characters")
+ # Randomly select structured and container formats
+ selected_structured_format = random.choice(self.supported_structured_formats)
+ selected_container_format = random.choice(self.supported_container_formats)
+
+ # Store the selections in the dataset_item
+ dataset_item["selected_structured_format"] = selected_structured_format
+ dataset_item["selected_container_format"] = selected_container_format
+
+ if self.debug_logging:
+ self.logger.debug(
+ f"Selected structured format: {selected_structured_format.value}"
+ )
+ self.logger.debug(
+ f"Selected container format: {selected_container_format.value}"
+ )
+
+ # Generate the system prompt
+ current_system_prompt = self._generate_system_prompt(
+ selected_structured_format, selected_container_format
+ )
+
# Create the message structure
prompt_messages = [
- frozenset({"role": "system", "content": system_prompt}.items()),
+ frozenset({"role": "system", "content": current_system_prompt}.items()),
frozenset({"role": "user", "content": user_content}.items()),
]
# Return the prompt and the full dataset item for scoring
return tuple(prompt_messages), dataset_item
- def _extract_json_response(self, text: str) -> Optional[str]:
- """Extracts JSON content from tags, with strict thinking tag validation."""
+ def _extract_structured_data_response(
+ self,
+ text: str,
+ container_format: OutputContainerFormat,
+ structured_format: StructuredOutputFormat,
+ ) -> Optional[str]:
+ """Extracts structured data content based on container and structured format, with strict thinking tag validation.""" # noqa: E501
if self.debug_logging:
- self.logger.debug(f"Extracting JSON from response (length: {len(text)})")
+ self.logger.debug(
+ f"Extracting {structured_format.value} from response (length: {len(text)}), container: {container_format.value}" # noqa: E501
+ )
# Ensure text is a string
if not isinstance(text, str):
@@ -423,103 +601,132 @@ class PydanticSchemaFollowingEnv(BaseEnv):
)
text = str(text)
- # First, validate thinking tags (similar to MCQA environment)
+ # 1. Validate thinking tags
think_tags = re.findall(r"", text, re.IGNORECASE)
think_close_tags = re.findall(r"", text, re.IGNORECASE)
- # Check for proper thinking tag structure
if len(think_tags) != 1 or len(think_close_tags) != 1:
if self.debug_logging:
self.logger.warning(
- f"Invalid thinking tag structure: {len(think_tags)} open tags, {len(think_close_tags)} close tags"
+ f"Invalid thinking tag structure: {len(think_tags)} open tags, {len(think_close_tags)} close tags. Full text: {text[:500]}" # noqa: E501
)
return None
# Split the text into thinking and response sections
- parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1)
- if len(parts) != 2:
+ # Use a regex that captures the content before, between, and after think tags robustly
+ match_think_block = re.match(
+ r"(.*?)(.*?)(.*)", text, re.DOTALL | re.IGNORECASE
+ )
+ if not match_think_block:
if self.debug_logging:
self.logger.warning(
- "Could not split text into thinking and response sections"
+ f"Could not find a complete ... block. Full text: {text[:500]}"
)
return None
- thinking_section, response_section = parts
+ # thinking_section_plus_prefix = match_think_block.group(1) # Content before
+ # thinking_block_content = match_think_block.group(2) # ...
+ response_section = match_think_block.group(3) # Content after
- # Validate thinking section contains opening tag
- if "" not in thinking_section.lower():
+ # Validate that the first tag is indeed the one we captured (no before it)
+ if "" in match_think_block.group(1).lower():
if self.debug_logging:
- self.logger.warning("Thinking section missing opening tag")
+ self.logger.warning(
+ f"Nested or malformed tags detected before main block. Full text: {text[:500]}"
+ )
return None
# Check if there are any thinking tags in the response section (after )
if "" in response_section.lower():
if self.debug_logging:
self.logger.warning(
- "Found tags in response section after "
+ f"Found tags in response section after . Full text: {text[:500]}"
)
return None
- # Now extract JSON from the response section only
- match = re.search(
- r"\s*(.*?)\s*",
- response_section,
- re.DOTALL | re.IGNORECASE,
- )
- if match:
- json_str = match.group(1).strip()
- if self.debug_logging:
- self.logger.debug(
- f"Found JSON output tags in response section, extracted {len(json_str)} characters"
- )
+ extracted_content: Optional[str] = None
- # Handle empty extraction
- if not json_str:
- if self.debug_logging:
- self.logger.warning("JSON output tags found but content is empty")
- return None
-
- # Validate JSON
- try:
- json.loads(json_str)
- if self.debug_logging:
- self.logger.debug("Extracted JSON is valid")
- return json_str
- except json.JSONDecodeError as e:
- if self.debug_logging:
- self.logger.warning(f"Extracted text is not valid JSON: {e}")
- return None
-
- # Fallback: Look for JSON in response section only (no thinking tag validation)
- if self.debug_logging:
- self.logger.debug(
- "No tags found in response section, trying fallback extraction"
- )
-
- # Look for JSON objects that start with { and end with } in response section only
- json_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
- matches = re.findall(json_pattern, response_section, re.DOTALL)
-
- for potential_json in matches:
- try:
- json.loads(potential_json.strip())
+ # 2. Extract based on container_format from response_section
+ if container_format == OutputContainerFormat.TAGGED:
+ tag_name = f"{structured_format.value}_output"
+ # Loosen regex to allow for attributes in the opening tag if any, and handle varying whitespace
+ pattern = rf"<{tag_name}[^>]*>\s*(.*?)\s*{tag_name}>"
+ match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE)
+ if match:
+ extracted_content = match.group(1)
if self.debug_logging:
self.logger.debug(
- f"Fallback extraction successful from response section: {len(potential_json)} characters"
+ f"Extracted using TAGGED ({tag_name}): {len(extracted_content)} chars"
+ )
+ else:
+ if self.debug_logging:
+ self.logger.warning(
+ f"TAGGED format: Could not find <{tag_name}>...{tag_name}> tags in response section. Response section: {response_section[:300]}" # noqa: E501
)
- return potential_json.strip()
- except json.JSONDecodeError:
- continue
- if self.debug_logging:
- self.logger.warning("No valid JSON found with any extraction method")
- return None
+ elif container_format == OutputContainerFormat.MARKDOWN:
+ # Pattern to match ```language ... ``` or just ``` ... ```
+ # It captures the content within the backticks.
+ # It optionally matches the language specifier.
+ pattern = rf"^\s*```(?:{re.escape(structured_format.value)})?\s*\n(.*?)\n\s*```\s*$"
+ match = re.search(
+ pattern, response_section.strip(), re.DOTALL | re.IGNORECASE
+ )
+ if match:
+ extracted_content = match.group(1)
+ if self.debug_logging:
+ self.logger.debug(
+ f"Extracted using MARKDOWN: {len(extracted_content)} chars"
+ )
+ else:
+ # Fallback for markdown: if ``` content ``` is not found, try ``` content ```
+ pattern_no_lang = r"^\s*```\s*\n(.*?)\n\s*```\s*$"
+ match_no_lang = re.search(
+ pattern_no_lang, response_section.strip(), re.DOTALL | re.IGNORECASE
+ )
+ if match_no_lang:
+ extracted_content = match_no_lang.group(1)
+ if self.debug_logging:
+ self.logger.debug(
+ f"Extracted using MARKDOWN (no lang specified): {len(extracted_content)} chars"
+ )
+ else:
+ if self.debug_logging:
+ self.logger.warning(
+ f"MARKDOWN format: Could not find ```...``` code block in response section. Response section: {response_section[:300]}" # noqa: E501
+ )
+
+ elif container_format == OutputContainerFormat.NONE:
+ extracted_content = response_section
+ if self.debug_logging:
+ self.logger.debug(
+ f"Extracted using NONE: {len(extracted_content)} chars"
+ )
+
+ # 3. Post-processing
+ if extracted_content is not None:
+ extracted_content = extracted_content.strip()
+ if not extracted_content: # Empty after stripping
+ if self.debug_logging:
+ self.logger.warning("Extracted content is empty after stripping.")
+ return None
+ if self.debug_logging:
+ self.logger.debug(
+ f"Successfully extracted content ({len(extracted_content)} chars). First 100: {extracted_content[:100]}" # noqa: E501
+ )
+ return extracted_content
+ else:
+ if self.debug_logging:
+ self.logger.warning(
+ f"Extraction failed for container type {container_format.value}. No content extracted."
+ )
+ return None
async def score(
self,
rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]],
) -> Optional[ScoredDataGroup]:
- """Score the rollouts based on Pydantic validation."""
+ """Score the rollouts based on Pydantic validation or other structural checks."""
if self.debug_logging:
self.logger.debug(f"Scoring {len(rollout_group_data)} rollouts")
@@ -529,124 +736,202 @@ class PydanticSchemaFollowingEnv(BaseEnv):
scores_obj["scores"] = list()
scores_obj["messages"] = list() # Add messages for data dumping
- # All items in rollout_group_data share the same dataset_item (item[1])
if not rollout_group_data:
if self.debug_logging:
self.logger.warning("No rollout data to score")
return None
dataset_item = rollout_group_data[0][1]
+ problem_id = dataset_item.get("problem_id", "N/A")
+ selected_structured_format = dataset_item["selected_structured_format"]
+ selected_container_format = dataset_item["selected_container_format"]
if self.debug_logging:
self.logger.debug(
- f"Scoring for problem_id: {dataset_item.get('problem_id', 'N/A')}"
+ f"Scoring for problem_id: {problem_id}, structured_format: {selected_structured_format.value}, container_format: {selected_container_format.value}" # noqa: E501
)
- # Extract verification info (pydantic config) and model name
verification_info = dataset_item["verification_info"]
-
- # Parse the verification info to get pydantic config and model name
try:
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
model_name = verification_data["model_name"]
-
if self.debug_logging:
- self.logger.debug(f"Target model for scoring: {model_name}")
- self.logger.debug(
- f"Pydantic config length: {len(pydantic_config)} characters"
- )
+ self.logger.debug(f"Target Pydantic model for validation: {model_name}")
except (json.JSONDecodeError, KeyError) as e:
- error_msg = f"Error parsing verification_info: {e}"
+ error_msg = f"Error parsing verification_info for {problem_id}: {e}"
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
return None
- # Create the Pydantic model dynamically
try:
target_model_cls = self._create_pydantic_model_from_code(
pydantic_config, model_name
)
except Exception as e:
- error_msg = f"Error creating Pydantic model: {e}"
+ error_msg = (
+ f"Error creating Pydantic model {model_name} for {problem_id}: {e}"
+ )
if self.debug_logging:
self.logger.error(error_msg)
print(error_msg)
return None
- # Score each rollout
valid_count = 0
invalid_count = 0
extraction_failures = 0
+ parsing_failures = 0
- # Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for i, (item_messages, _) in enumerate(rollout_group_data):
- if self.debug_logging and i == 0:
- self.logger.debug(f"Scoring rollout {i+1}/{len(rollout_group_data)}")
- # Convert frozensets to dictionaries for easier access
messages_as_dicts = [dict(fs_message) for fs_message in item_messages]
- model_response_text = messages_as_dicts[-1]["content"] # LLM full response
+ model_response_text = messages_as_dicts[-1]["content"]
- if self.debug_logging and i == 0:
- self.logger.debug(
- f"Response length: {len(model_response_text)} characters"
- )
+ extracted_str = self._extract_structured_data_response(
+ model_response_text,
+ selected_container_format,
+ selected_structured_format,
+ )
- json_str = self._extract_json_response(model_response_text)
+ reward = 0.0
+ parsed_data = None
+ validation_error_msg = "Extraction failed"
- reward = 0.0 # Default score
+ if extracted_str:
+ try:
+ if selected_structured_format == StructuredOutputFormat.JSON:
+ parsed_data = json.loads(extracted_str)
+ elif selected_structured_format == StructuredOutputFormat.YAML:
+ parsed_data = yaml.safe_load(extracted_str)
+ elif selected_structured_format == StructuredOutputFormat.TOML:
+ parsed_data = toml.loads(extracted_str)
+ elif selected_structured_format == StructuredOutputFormat.XML:
+ try:
+ data_dict_outer = xmltodict.parse(extracted_str)
+ # Assumption: Pydantic model corresponds to the content *within* the single root XML tag.
+ # xmltodict.parse returns a dict like {'RootTag': actual_data_dict}.
+ # We extract actual_data_dict for validation. More complex XML might require
+ # specific xmltodict process_instructions or more sophisticated unwrapping.
+ if (
+ isinstance(data_dict_outer, dict)
+ and len(data_dict_outer) == 1
+ ):
+ root_key = list(data_dict_outer.keys())[0]
+ parsed_data = data_dict_outer[root_key]
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval item (problem {problem_id}): XML parsed, root key '{root_key}', data: {str(parsed_data)[:100]}..." # noqa: E501
+ )
+ else:
+ raise ValueError(
+ f"XML from xmltodict.parse was not a dict with a single root key as expected. Got: {type(data_dict_outer)}, Keys: {list(data_dict_outer.keys()) if isinstance(data_dict_outer, dict) else 'N/A'}" # noqa: E501
+ )
+ except xmltodict.expat.ExpatError as e_xml_parse:
+ # No reward variable here, score is returned directly
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval item (problem {problem_id}): XML parsing failed (ExpatError): {e_xml_parse}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
+ )
+ return 0.0 # Return score directly
+ except Exception as e_xml_generic:
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval item (problem {problem_id}): XML processing error: {e_xml_generic}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
+ )
+ return 0.0 # Return score directly
- if json_str:
- # Validate JSON against the Pydantic model
- is_valid, error_msg = self._validate_json_against_model(
- json_str, target_model_cls, dataset_item.get("problem_id", "N/A")
- )
+ if parsed_data is not None and selected_structured_format in [
+ StructuredOutputFormat.JSON,
+ StructuredOutputFormat.YAML,
+ StructuredOutputFormat.TOML,
+ StructuredOutputFormat.XML,
+ ]:
+ is_valid, pydantic_error_msg = (
+ self._validate_parsed_data_against_model(
+ parsed_data, target_model_cls, problem_id
+ )
+ )
+ if is_valid:
+ reward = 1.0
+ valid_count += 1
+ validation_error_msg = None
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: Pydantic validation successful for {selected_structured_format.value}" # noqa: E501
+ )
+ else:
+ reward = 0.0
+ invalid_count += 1
+ validation_error_msg = pydantic_error_msg
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: Pydantic validation failed for {selected_structured_format.value}. Error: {pydantic_error_msg}" # noqa: E501
+ )
- if is_valid:
- reward = 1.0 # Valid schema - full score
- valid_count += 1
- if self.debug_logging and i < 3: # Log first few successes
- self.logger.debug(f"Rollout {i}: Validation successful")
- else:
- reward = 0.0 # Validation failed
- invalid_count += 1
- if self.debug_logging and i < 3: # Log first few validation errors
- self.logger.debug(f"Rollout {i}: {error_msg}")
+ except json.JSONDecodeError as e_json:
+ reward = 0.0
+ parsing_failures += 1
+ validation_error_msg = f"JSON parsing failed: {e_json}"
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
+ )
+ except yaml.YAMLError as e_yaml:
+ reward = 0.0
+ parsing_failures += 1
+ validation_error_msg = f"YAML parsing failed: {e_yaml}"
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
+ )
+ except toml.TomlDecodeError as e_toml:
+ reward = 0.0
+ parsing_failures += 1
+ validation_error_msg = f"TOML parsing failed: {e_toml}"
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
+ )
+ except Exception as e_parse: # Catch any other parsing related errors
+ reward = 0.0
+ parsing_failures += 1
+ validation_error_msg = f"Generic parsing failed for {selected_structured_format.value}: {e_parse}"
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..."
+ )
else:
- reward = 0.0 # No JSON output found or extraction failed
+ reward = 0.0 # No structured data output found or extraction failed
extraction_failures += 1
- if self.debug_logging and i < 3: # Log first few failures
- self.logger.debug(f"Rollout {i}: JSON extraction failed")
+ # validation_error_msg is already "Extraction failed"
+ if self.debug_logging and i < 3:
+ self.logger.debug(
+ f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501
+ )
- # Tokenize for training - convert frozensets to dicts for tokenizer
try:
- # Validate that messages_as_dicts is properly formatted
if not isinstance(messages_as_dicts, list):
if self.debug_logging:
self.logger.error(
f"Expected list for tokenization, got {type(messages_as_dicts)}"
)
continue
-
- # Validate each message has required keys
for msg_idx, msg in enumerate(messages_as_dicts):
if not isinstance(msg, dict):
if self.debug_logging:
self.logger.error(
f"Message {msg_idx} is not a dict: {type(msg)}"
)
- continue
+ continue # Skip this rollout if message format is incorrect
if "role" not in msg or "content" not in msg:
if self.debug_logging:
self.logger.error(
f"Message {msg_idx} missing required keys: {msg.keys()}"
)
- continue
- # Ensure content is a string
+ continue # Skip this rollout
if not isinstance(msg["content"], str):
if self.debug_logging:
self.logger.warning(
@@ -655,34 +940,36 @@ class PydanticSchemaFollowingEnv(BaseEnv):
msg["content"] = str(msg["content"])
out_dict = tokenize_for_trainer(
- self.tokenizer, messages_as_dicts, include_messages=True
+ self.tokenizer,
+ messages_as_dicts,
+ include_messages=self.config.include_messages, # Using config value
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
except Exception as e:
if self.debug_logging:
- self.logger.error(f"Tokenization failed for rollout {i}: {e}")
+ self.logger.error(
+ f"Tokenization failed for rollout {i} (problem: {problem_id}): {e}"
+ )
self.logger.debug(
f"Messages format: {[type(m) for m in messages_as_dicts]}"
)
continue
- if len([1 for i in masks if i != -100]) < 10: # Min context length
+ if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length
if self.debug_logging:
self.logger.debug(
- "Skipping rollout due to insufficient context length"
+ f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length after tokenization." # noqa: E501
)
continue
scores_obj["tokens"].append(tokens)
scores_obj["masks"].append(masks)
scores_obj["scores"].append(reward)
- scores_obj["messages"].append(
- out_dict.get("messages", messages_as_dicts)
- ) # Store converted messages for dumping
+ # Store original messages (converted to dicts) if available in out_dict, else the modified ones
+ scores_obj["messages"].append(out_dict.get("messages", messages_as_dicts))
- # Track perfect scores for wandb
self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0)
if len(scores_obj["tokens"]) >= self.config.group_size:
@@ -690,39 +977,43 @@ class PydanticSchemaFollowingEnv(BaseEnv):
if self.debug_logging:
self.logger.info(
- f"Scoring complete: {valid_count} valid, {invalid_count} invalid, {extraction_failures} extraction failures" # noqa: E501
+ f"Scoring complete for {problem_id} (Format: {selected_structured_format.value}, Container: {selected_container_format.value}): " # noqa: E501
+ f"{valid_count} valid (Pydantic), {invalid_count} invalid (Pydantic), "
+ f"{parsing_failures} parsing failures, {extraction_failures} extraction failures."
)
if scores_obj["scores"]:
avg_score = sum(scores_obj["scores"]) / len(scores_obj["scores"])
- self.logger.info(f"Average score for this batch: {avg_score:.3f}")
+ self.logger.info(
+ f"Average score for this batch ({problem_id}, Format: {selected_structured_format.value}, Container: {selected_container_format.value}): {avg_score:.3f}" # noqa: E501
+ )
if not scores_obj["tokens"]: # No valid examples processed
if self.debug_logging:
- self.logger.warning("No valid examples processed in this batch")
+ self.logger.warning(
+ f"No valid examples processed in this batch for {problem_id}"
+ )
return None
+ # This condition might need adjustment if 0.0 is a valid signal for non-Pydantic formats
if (
all(scores_obj["scores"][0] == score for score in scores_obj["scores"])
and scores_obj["scores"][0] != 1.0
):
if self.debug_logging:
self.logger.debug(
- "All scores are identical and not perfect, returning None for learning signal"
+ f"All scores are identical ({scores_obj['scores'][0]}) and not perfect for {problem_id}, returning None for learning signal." # noqa: E501
)
return None
- # Apply length penalty if average response length is too high and all scores are 1.0
if all(s == 1.0 for s in scores_obj["scores"]):
avg_len = sum(len(t) for t in scores_obj["tokens"]) / len(
scores_obj["tokens"]
)
- if (
- avg_len > self.config.max_token_length * 0.75
- ): # Penalize if too verbose even when correct
+ if avg_len > self.config.max_token_length * 0.75:
scores_obj["scores"] = [s * 0.9 for s in scores_obj["scores"]]
if self.debug_logging:
self.logger.debug(
- f"Applied length penalty: avg_len={avg_len}, penalty_threshold={self.config.max_token_length * 0.75}" # noqa: E501
+ f"Applied length penalty for {problem_id}: avg_len={avg_len}, penalty_threshold={self.config.max_token_length * 0.75}" # noqa: E501
)
return scores_obj
@@ -826,11 +1117,14 @@ class PydanticSchemaFollowingEnv(BaseEnv):
else []
)
score_for_rollout = scored_data["scores"][i]
-
# Extract the generated JSON from the assistant's response
if conversation_messages:
assistant_response = conversation_messages[-1].get("content", "")
- generated_json = self._extract_json_response(assistant_response)
+ generated_json = self._extract_structured_data_response(
+ assistant_response,
+ dataset_item["selected_container_format"],
+ dataset_item["selected_structured_format"],
+ )
else:
generated_json = None
@@ -928,16 +1222,29 @@ class PydanticSchemaFollowingEnv(BaseEnv):
print(error_msg)
async def rollout_and_score_eval(self, dataset_item: Dict[str, Any]) -> float:
- """Evaluate a single item from the test set."""
+ """Evaluate a single item from the test set with randomized formats."""
+ problem_id = dataset_item.get("problem_id", "N/A")
+
+ # Randomly select formats for evaluation consistency with training
+ selected_structured_format = random.choice(self.supported_structured_formats)
+ selected_container_format = random.choice(self.supported_container_formats)
+ # Store them in dataset_item if needed for logging or other parts, though not strictly for this function's direct logic # noqa: E501
+ dataset_item["selected_structured_format"] = selected_structured_format
+ dataset_item["selected_container_format"] = selected_container_format
+
if self.debug_logging:
self.logger.debug(
- f"Evaluating item: {dataset_item.get('problem_id', 'N/A')}"
+ f"Evaluating item: {problem_id}, structured: {selected_structured_format.value}, container: {selected_container_format.value}" # noqa: E501
)
user_content = dataset_item["prompt"]
+ current_system_prompt = self._generate_system_prompt(
+ selected_structured_format, selected_container_format
+ )
+
messages = [
- {"role": "system", "content": system_prompt},
+ {"role": "system", "content": current_system_prompt},
{"role": "user", "content": user_content},
]
@@ -946,28 +1253,39 @@ class PydanticSchemaFollowingEnv(BaseEnv):
)
if self.debug_logging:
- self.logger.debug(f"Eval prompt length: {len(prompt)} characters")
+ self.logger.debug(
+ f"Eval prompt length for {problem_id}: {len(prompt)} characters"
+ )
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.1, # Lower temperature for eval
- split="eval",
+ split="eval", # Ensure correct server endpoint is used
)
model_response_text = completion.choices[0].text
if self.debug_logging:
self.logger.debug(
- f"Eval response length: {len(model_response_text)} characters"
+ f"Eval response length for {problem_id}: {len(model_response_text)} characters"
)
- json_str = self._extract_json_response(model_response_text)
+ extracted_str = self._extract_structured_data_response(
+ model_response_text, selected_container_format, selected_structured_format
+ )
score = 0.0
- # Extract verification info and create model
+ if not extracted_str:
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval extraction failed for {problem_id} (Format: {selected_structured_format.value}, Container: {selected_container_format.value})" # noqa: E501
+ )
+ return 0.0
+
+ # Attempt to parse and validate
try:
verification_info = dataset_item["verification_info"]
verification_data = json.loads(verification_info)
@@ -975,44 +1293,103 @@ class PydanticSchemaFollowingEnv(BaseEnv):
model_name = verification_data["model_name"]
if self.debug_logging:
- self.logger.debug(f"Eval target model: {model_name}")
+ self.logger.debug(
+ f"Eval target Pydantic model for {problem_id}: {model_name}"
+ )
target_model_cls = self._create_pydantic_model_from_code(
pydantic_config, model_name
)
- if json_str:
- # Validate JSON against the Pydantic model
- is_valid, error_msg = self._validate_json_against_model(
- json_str, target_model_cls, dataset_item.get("problem_id", "N/A")
- )
-
- if is_valid:
- score = 1.0 # Valid schema
+ parsed_data = None
+ if selected_structured_format == StructuredOutputFormat.JSON:
+ parsed_data = json.loads(extracted_str)
+ elif selected_structured_format == StructuredOutputFormat.YAML:
+ parsed_data = yaml.safe_load(extracted_str)
+ elif selected_structured_format == StructuredOutputFormat.TOML:
+ parsed_data = toml.loads(extracted_str)
+ elif selected_structured_format == StructuredOutputFormat.XML:
+ try:
+ data_dict_outer = xmltodict.parse(extracted_str)
+ # Assumption: Pydantic model corresponds to the content *within* the single root XML tag.
+ # xmltodict.parse returns a dict like {'RootTag': actual_data_dict}.
+ # We extract actual_data_dict for validation. More complex XML might require
+ # specific xmltodict process_instructions or more sophisticated unwrapping.
+ if isinstance(data_dict_outer, dict) and len(data_dict_outer) == 1:
+ root_key = list(data_dict_outer.keys())[0]
+ parsed_data = data_dict_outer[root_key]
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval item (problem {problem_id}): XML parsed, root key '{root_key}', data: {str(parsed_data)[:100]}..." # noqa: E501
+ )
+ else:
+ raise ValueError(
+ f"XML from xmltodict.parse was not a dict with a single root key as expected. Got: {type(data_dict_outer)}, Keys: {list(data_dict_outer.keys()) if isinstance(data_dict_outer, dict) else 'N/A'}" # noqa: E501
+ )
+ except xmltodict.expat.ExpatError as e_xml_parse:
+ # No reward variable here, score is returned directly
if self.debug_logging:
self.logger.debug(
- f"Eval validation successful for {dataset_item.get('problem_id', 'N/A')}"
+ f"Eval item (problem {problem_id}): XML parsing failed (ExpatError): {e_xml_parse}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
+ )
+ return 0.0 # Return score directly
+ except Exception as e_xml_generic:
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval item (problem {problem_id}): XML processing error: {e_xml_generic}. Extracted XML: {extracted_str[:100]}..." # noqa: E501
+ )
+ return 0.0 # Return score directly
+
+ if parsed_data is not None and selected_structured_format in [
+ StructuredOutputFormat.JSON,
+ StructuredOutputFormat.YAML,
+ StructuredOutputFormat.TOML,
+ StructuredOutputFormat.XML,
+ ]:
+ is_valid, pydantic_error_msg = self._validate_parsed_data_against_model(
+ parsed_data, target_model_cls, problem_id
+ )
+ if is_valid:
+ score = 1.0
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval Pydantic validation successful for {problem_id} (Format: {selected_structured_format.value})" # noqa: E501
)
else:
- score = 0.0 # Validation failed
+ score = 0.0 # Already 0.0 by default unless validation passes
if self.debug_logging:
self.logger.debug(
- f"Eval validation failed for {dataset_item.get('problem_id', 'N/A')}: {error_msg}"
+ f"Eval Pydantic validation failed for {problem_id} (Format: {selected_structured_format.value}): {pydantic_error_msg}" # noqa: E501
)
- else:
- score = 0.0 # No valid JSON extracted
- if self.debug_logging:
- self.logger.debug(
- f"Eval JSON extraction failed for {dataset_item.get('problem_id', 'N/A')}"
- )
- except Exception as e:
- score = 0.0 # Any error in model creation or setup
+ except json.JSONDecodeError as je:
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval JSON parsing failed for {problem_id} (Format: {selected_structured_format.value}): {je}. Extracted: {extracted_str[:100]}..." # noqa: E501
+ )
+ score = 0.0
+ except yaml.YAMLError as ye:
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval YAML parsing failed for {problem_id} (Format: {selected_structured_format.value}): {ye}. Extracted: {extracted_str[:100]}..." # noqa: E501
+ )
+ score = 0.0
+ except toml.TomlDecodeError as te:
+ if self.debug_logging:
+ self.logger.debug(
+ f"Eval TOML parsing failed for {problem_id} (Format: {selected_structured_format.value}): {te}. Extracted: {extracted_str[:100]}..." # noqa: E501
+ )
+ score = 0.0
+ except (
+ Exception
+ ) as e: # Catch errors in model creation, other parsing, or validation setup
+ score = (
+ 0.0 # Ensure score is 0.0 for any other unexpected error in this block
+ )
if self.debug_logging:
self.logger.error(
- f"Error in eval setup for {dataset_item.get('problem_id', 'N/A')}: {e}"
+ f"Error during eval scoring pipeline for {problem_id} (Format: {selected_structured_format.value}): {type(e).__name__}: {e}" # noqa: E501
)
-
return score
async def evaluate(self, *args, **kwargs):
@@ -1074,11 +1451,19 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.logger.debug("No scored data to log to wandb")
return
- dataset_item = item[1] if item else {}
+ # dataset_item contains selected_structured_format and selected_container_format
+ dataset_item = item[1] if item and len(item) > 1 else {}
+ problem_id = dataset_item.get("problem_id", "N/A")
+ selected_structured_format = dataset_item.get(
+ "selected_structured_format", StructuredOutputFormat.JSON
+ ) # Default if not found
+ selected_container_format = dataset_item.get(
+ "selected_container_format", OutputContainerFormat.TAGGED
+ ) # Default if not found
if self.debug_logging:
self.logger.debug(
- f"Logging rollouts for problem_id: {dataset_item.get('problem_id', 'N/A')}"
+ f"Logging rollouts for problem_id: {problem_id}, structured: {selected_structured_format.value}, container: {selected_container_format.value}" # noqa: E501
)
num_keep = self.config.num_rollouts_per_group_for_logging
@@ -1088,59 +1473,128 @@ class PydanticSchemaFollowingEnv(BaseEnv):
num_keep = min(num_keep, len(scored_data["tokens"]))
if self.debug_logging:
- self.logger.debug(f"Keeping {num_keep} rollouts for logging")
+ self.logger.debug(
+ f"Keeping {num_keep} rollouts for logging for {problem_id}"
+ )
rollout_batch = []
for i in range(num_keep):
- # Decode tokens to text for logging
- full_convo_text = self.tokenizer.decode(
- scored_data["tokens"][i], skip_special_tokens=True
+ # The full_convo_text comes from scored_data["messages"], which should be the tokenized version.
+ # We need the raw model output for extraction.
+ # scored_data["messages"][i] should be the list of message dicts for the i-th rollout
+
+ # Reconstruct full_convo_text from messages if possible, or use decoded tokens as fallback
+ raw_conversation_messages = scored_data["messages"][i]
+ assistant_response_text = ""
+ if (
+ isinstance(raw_conversation_messages, list)
+ and len(raw_conversation_messages) > 0
+ ):
+ # The last message should be the assistant's response
+ if (
+ isinstance(raw_conversation_messages[-1], dict)
+ and raw_conversation_messages[-1].get("role") == "assistant"
+ ):
+ assistant_response_text = raw_conversation_messages[-1].get(
+ "content", ""
+ )
+ else: # Try to find assistant message if not last, or if format is unexpected
+ for msg in reversed(raw_conversation_messages):
+ if isinstance(msg, dict) and msg.get("role") == "assistant":
+ assistant_response_text = msg.get("content", "")
+ break
+
+ if not assistant_response_text: # Fallback if proper message not found
+ # This decodes the entire conversation, including system/user prompts
+ assistant_response_text = self.tokenizer.decode(
+ scored_data["tokens"][i], skip_special_tokens=True
+ )
+ if self.debug_logging:
+ self.logger.warning(
+ f"WandB: Could not get raw assistant message for {problem_id}, using full decoded tokens for extraction." # noqa: E501
+ )
+
+ extracted_output = self._extract_structured_data_response(
+ assistant_response_text, # Use assistant_response_text (ideally raw, or decoded full convo)
+ selected_container_format,
+ selected_structured_format,
)
- extracted_json = self._extract_json_response(full_convo_text)
-
- # Extract model info from dataset item
try:
verification_info = dataset_item.get("verification_info", "{}")
verification_data = json.loads(verification_info)
- model_name = verification_data.get("model_name", "N/A")
- expected_json = verification_data.get("pydantic_config", "N/A")
+ pydantic_model_name = verification_data.get(
+ "model_name", "N/A"
+ ) # Renamed to avoid conflict
+ expected_schema_info = verification_data.get("pydantic_config", "N/A")
except (json.JSONDecodeError, KeyError):
- model_name = "N/A"
- expected_json = "N/A"
+ pydantic_model_name = "N/A"
+ expected_schema_info = "N/A"
if self.debug_logging:
- self.logger.debug("Could not extract model name for wandb logging")
+ self.logger.debug(
+ f"Could not extract model name/schema for wandb logging for {problem_id}"
+ )
+
+ # Construct the full conversation text for display from the messages list
+ # This ensures the system prompt (which might be dynamic) is correctly shown.
+ display_convo_text = ""
+ if isinstance(raw_conversation_messages, list):
+ try:
+ display_convo_text = self.tokenizer.apply_chat_template(
+ raw_conversation_messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ except Exception as e_tmpl:
+ if self.debug_logging:
+ self.logger.warning(
+ f"WandB: Error applying chat template for {problem_id}: {e_tmpl}. Falling back to joining content." # noqa: E501
+ )
+ display_convo_text = "\n---\n".join(
+ [
+ str(msg.get("content", ""))
+ for msg in raw_conversation_messages
+ ]
+ )
+ else:
+ display_convo_text = (
+ assistant_response_text # Fallback to decoded full string
+ )
rollout_batch.append(
(
- full_convo_text, # Full conversation
+ display_convo_text, # Full conversation for display
scored_data["scores"][i],
- model_name,
- dataset_item.get("problem_id", "N/A"),
+ pydantic_model_name, # The Pydantic model name it was validated against
+ problem_id,
dataset_item.get("task_type", "N/A"),
+ selected_structured_format.value, # Log selected structured format
+ selected_container_format.value, # Log selected container format
(
- extracted_json
- if extracted_json
- else "Extraction failed or no JSON"
+ extracted_output
+ if extracted_output
+ else "Extraction failed or no output"
),
- (
- expected_json[:200] + "..."
- if len(expected_json) > 200
- else expected_json
- ), # Truncate for display
+ expected_schema_info[:200]
+ + ("..." if len(expected_schema_info) > 200 else ""),
)
)
if rollout_batch:
self.rollouts_for_wandb.append(rollout_batch)
if self.debug_logging:
- self.logger.debug(f"Added {len(rollout_batch)} rollouts to wandb queue")
+ self.logger.debug(
+ f"Added {len(rollout_batch)} rollouts to wandb queue for {problem_id}"
+ )
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
- removed = self.rollouts_for_wandb.pop(0)
- if self.debug_logging:
+ removed_count = 0
+ while len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
+ self.rollouts_for_wandb.pop(0)
+ removed_count += 1
+ if self.debug_logging and removed_count > 0:
self.logger.debug(
- f"Removed oldest rollout batch ({len(removed)} items) from wandb queue"
+ f"Removed {removed_count} oldest rollout batch(es) from wandb queue"
)
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
@@ -1155,11 +1609,13 @@ class PydanticSchemaFollowingEnv(BaseEnv):
columns=[
"full_conversation",
"score",
- "model_name",
+ "pydantic_model_name",
"problem_id",
"task_type",
- "extracted_json",
- "expected_schema",
+ "selected_structured_format",
+ "selected_container_format",
+ "extracted_output",
+ "expected_schema_preview",
]
)
total_entries = 0
@@ -1258,17 +1714,17 @@ class PydanticSchemaFollowingEnv(BaseEnv):
if self.debug_logging:
self.logger.info("PydanticSchemaFollowingEnv closed.")
- def _validate_json_against_model(
- self, json_str: str, model_cls: Type[BaseModel], problem_id: str = "N/A"
+ def _validate_parsed_data_against_model(
+ self, parsed_data: Any, model_cls: Type[BaseModel], problem_id: str = "N/A"
) -> Tuple[bool, Optional[str]]:
"""
- Validate JSON string against a Pydantic model and return detailed error info.
+ Validate parsed data (e.g., a dictionary) against a Pydantic model and return detailed error info.
Returns:
Tuple of (is_valid, error_message)
"""
try:
- model_cls.model_validate_json(json_str)
+ model_cls.model_validate(parsed_data)
return True, None
except ValidationError as ve:
try:
@@ -1285,11 +1741,6 @@ class PydanticSchemaFollowingEnv(BaseEnv):
# Fallback if formatting ve.errors() fails for some reason
error_msg = f"Pydantic validation failed for {problem_id}. Error: {str(ve)[:250]}. (Additionally, formatting error details failed: {str(format_exc)[:50]})" # noqa: E501
- if self.debug_logging:
- self.logger.debug(error_msg)
- return False, error_msg
- except json.JSONDecodeError as je:
- error_msg = f"JSON decode failed for {problem_id}: {str(je)[:100]}"
if self.debug_logging:
self.logger.debug(error_msg)
return False, error_msg
@@ -1304,8 +1755,6 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.logger.error(
error_msg
) # Log as error due to its specific nature
- # Optionally log the problematic JSON for further inspection if not too large
- # self.logger.debug(f"Problematic JSON string for {problem_id} (first 500 chars): {json_str[:500]}")
else:
# Handle other TypeErrors normally
error_msg = f"Unexpected TypeError during validation for {problem_id}: {type(te).__name__}: {str(te)[:100]}" # noqa: E501