diff --git a/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py b/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py index 3eb323a1..e9874ca4 100644 --- a/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py +++ b/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py @@ -44,7 +44,10 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union from uuid import UUID +import toml # Added import import wandb +import xmltodict # Added +import yaml # Added import from datasets import load_dataset from pydantic import ( BaseModel, @@ -69,20 +72,33 @@ from atroposlib.envs.base import ( from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # System prompt for the LLM -system_prompt = ( - "You are an AI assistant that generates JSON objects according to Pydantic schemas.\n" - "You may use extremely long chains of thought to deeply consider the problem and deliberate " - "with yourself via systematic reasoning processes to help come to a correct solution prior to answering. " - "You should enclose your thoughts and internal monologue inside tags.\n\n" - "CRITICAL: Your final JSON output MUST be enclosed within tags.\n" - "The JSON must be valid and complete. Do not include any text after the closing tag.\n" - "Example format:\n" - "\nMy reasoning here...\n\n\n" - '\n{"field1": "value1", "field2": "value2"}\n\n\n' - "Ensure the generated JSON strictly adheres to the Pydantic model schema and any specific field " - "requirements provided in the user prompt. Generate all required fields for the model, and " - "include optional fields if they make sense in the context or are specified." -) +# system_prompt = ( +# "You are an AI assistant that generates JSON objects according to Pydantic schemas.\\n" +# "You may use extremely long chains of thought to deeply consider the problem and deliberate " +# "with yourself via systematic reasoning processes to help come to a correct solution prior to answering. " +# "You should enclose your thoughts and internal monologue inside tags.\\n\\n" +# "CRITICAL: Your final JSON output MUST be enclosed within tags.\\n" +# "The JSON must be valid and complete. Do not include any text after the closing tag.\\n" +# "Example format:\\n" +# "\\nMy reasoning here...\\n\\n\\n" +# '\\n{"field1": "value1", "field2": "value2"}\\n\\n\\n' +# "Ensure the generated JSON strictly adheres to the Pydantic model schema and any specific field " +# "requirements provided in the user prompt. Generate all required fields for the model, and " +# "include optional fields if they make sense in the context or are specified." +# ) + + +class StructuredOutputFormat(Enum): + JSON = "json" + YAML = "yaml" + TOML = "toml" + XML = "xml" + + +class OutputContainerFormat(Enum): + TAGGED = "tagged" # e.g., ... + NONE = "none" # Raw output + MARKDOWN = "markdown" # e.g., ```json ... ``` class PydanticEnvConfig(BaseEnvConfig): @@ -106,6 +122,20 @@ class PydanticEnvConfig(BaseEnvConfig): default=True, description="Whether to include messages in the dataset for SFT data generation", ) + allowed_structured_formats: Optional[List[StructuredOutputFormat]] = Field( + default=None, + description="Optional list of StructuredOutputFormat enums to use for randomization. If None or empty, all supported formats are used.", # noqa: E501 + ) + allowed_container_formats: Optional[List[OutputContainerFormat]] = Field( + default=None, + description="Optional list of OutputContainerFormat enums to use for randomization. If None or empty, all supported formats are used.", # noqa: E501 + ) + eval_set_percentage: float = Field( + default=0.1, + description="Percentage of the dataset to use for the evaluation set (e.g., 0.1 for 10%).", + ge=0.0, + le=1.0, + ) class PydanticSchemaFollowingEnv(BaseEnv): @@ -119,15 +149,8 @@ class PydanticSchemaFollowingEnv(BaseEnv): testing=False, ): super().__init__(config, server_configs, slurm, testing) - self.percent_correct_buffer = list() # Tracks 1.0 scores - self.eval_metrics = list() - self.rollouts_for_wandb = [] - self.dataset_items: List[Dict[str, Any]] = [] - self.model_cache: Dict[str, Type[BaseModel]] = ( - {} - ) # Cache for dynamically created models - # Set up debug logging + # Set up debug logging FIRST, as it's used by subsequent setup logic self.debug_logging = getattr(config, "debug_logging", True) if self.debug_logging: self.logger = logging.getLogger(f"{self.__class__.__name__}") @@ -144,6 +167,56 @@ class PydanticSchemaFollowingEnv(BaseEnv): self.logger = logging.getLogger(f"{self.__class__.__name__}") self.logger.addHandler(logging.NullHandler()) + self.percent_correct_buffer = list() # Tracks 1.0 scores + self.eval_metrics = list() + self.rollouts_for_wandb = [] + self.dataset_items: List[Dict[str, Any]] = [] + self.model_cache: Dict[str, Type[BaseModel]] = ( + {} + ) # Cache for dynamically created models + + # Determine supported formats based on config or defaults + if ( + config.allowed_structured_formats + and len(config.allowed_structured_formats) > 0 + ): + self.supported_structured_formats = config.allowed_structured_formats + if self.debug_logging: + self.logger.info( + f"Using configured structured formats: {[f.value for f in self.supported_structured_formats]}" + ) + else: + self.supported_structured_formats: List[StructuredOutputFormat] = [ + StructuredOutputFormat.JSON, + StructuredOutputFormat.YAML, + StructuredOutputFormat.TOML, + StructuredOutputFormat.XML, + ] + if self.debug_logging: + self.logger.info( + f"Using default structured formats: {[f.value for f in self.supported_structured_formats]}" + ) + + if ( + config.allowed_container_formats + and len(config.allowed_container_formats) > 0 + ): + self.supported_container_formats = config.allowed_container_formats + if self.debug_logging: + self.logger.info( + f"Using configured container formats: {[f.value for f in self.supported_container_formats]}" + ) + else: + self.supported_container_formats: List[OutputContainerFormat] = [ + OutputContainerFormat.TAGGED, + OutputContainerFormat.NONE, + OutputContainerFormat.MARKDOWN, + ] + if self.debug_logging: + self.logger.info( + f"Using default container formats: {[f.value for f in self.supported_container_formats]}" + ) + # Data dumping setup self.run_uuid = str(uuid.uuid4()) @@ -167,6 +240,70 @@ class PydanticSchemaFollowingEnv(BaseEnv): if config.dump_rollouts: self.logger.info(f"Rollouts will be saved to: {self.datadumps_dir}") + def _generate_system_prompt( + self, + structured_format: StructuredOutputFormat, + container_format: OutputContainerFormat, + ) -> str: + """Generates a system prompt tailored to the selected output and container formats.""" + prompt_lines = [ + f"You are an AI assistant that generates structured data in {structured_format.value.upper()} format according to Pydantic schemas.", # noqa: E501 + "You may use extremely long chains of thought to deeply consider the problem and deliberate " + "with yourself via systematic reasoning processes to help come to a correct solution prior to answering.", + "You should enclose your thoughts and internal monologue inside tags.", + ] + + example_output = "" + if structured_format == StructuredOutputFormat.JSON: + example_output = '{\\n "field1": "value1",\\n "field2": "value2"\\n}' + elif structured_format == StructuredOutputFormat.YAML: + example_output = "field1: value1\\nfield2: value2" + elif structured_format == StructuredOutputFormat.TOML: + example_output = 'field1 = "value1"\\nfield2 = "value2"' + elif structured_format == StructuredOutputFormat.XML: + example_output = "\\n value1\\n value2\\n" # Assuming root tag is model name # noqa: E501 + + if container_format == OutputContainerFormat.TAGGED: + tag_name = f"{structured_format.value}_output" + prompt_lines.extend( + [ + f"CRITICAL: Your final {structured_format.value.upper()} output MUST be enclosed within <{tag_name}> tags.", # noqa: E501 + f"The {structured_format.value.upper()} must be valid and complete. Do not include any text after the closing tag.", # noqa: E501 + "Example format:", + "\\nMy reasoning here...\\n\\n", + f"<{tag_name}>\\n{example_output}\\n", + ] + ) + elif container_format == OutputContainerFormat.MARKDOWN: + prompt_lines.extend( + [ + f"CRITICAL: Your final {structured_format.value.upper()} output MUST be enclosed within a markdown code block (```).", # noqa: E501 + f"The {structured_format.value.upper()} must be valid and complete.", + "Example format:", + "\\nMy reasoning here...\\n\\n", + f"```{structured_format.value}\\n{example_output}\\n```", + ] + ) + elif container_format == OutputContainerFormat.NONE: + prompt_lines.extend( + [ + f"CRITICAL: Your final {structured_format.value.upper()} output should be provided directly after the closing tag, with no surrounding tags or markdown.", # noqa: E501 + f"The {structured_format.value.upper()} must be valid and complete.", + "Example format:", + "\\nMy reasoning here...\\n\\n", + example_output, + ] + ) + + prompt_lines.extend( + [ + f"Ensure the generated {structured_format.value.upper()} strictly adheres to the Pydantic model schema and any specific field " # noqa: E501 + "requirements provided in the user prompt. Generate all required fields for the model, and " # noqa: E501 + "include optional fields if they make sense in the context or are specified." # noqa: E501 + ] + ) + return "\\n".join(prompt_lines) + @classmethod def config_init(cls) -> Tuple[PydanticEnvConfig, List[APIServerConfig]]: """Initialize configuration for the environment.""" @@ -175,9 +312,10 @@ class PydanticSchemaFollowingEnv(BaseEnv): group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", - total_steps=2000, + total_steps=250, batch_size=1024, steps_per_eval=20, + max_num_workers=16, max_token_length=1024 * 12, inference_weight=1.0, wandb_name="pydantic_schema_following", @@ -185,9 +323,19 @@ class PydanticSchemaFollowingEnv(BaseEnv): eval_limit_ratio=0.1, dataset_name="justus27/pydantic-adherance-test", dataset_split="train", - debug_logging=True, # Enable debug logging by default - dump_rollouts=True, # Enable data dumping by default - include_messages=True, # Ensure messages are included for SFT data generation + debug_logging=False, + dump_rollouts=False, + allowed_structured_formats=[ + StructuredOutputFormat.JSON, + StructuredOutputFormat.YAML, + StructuredOutputFormat.TOML, + ], + allowed_container_formats=[ + OutputContainerFormat.TAGGED, + OutputContainerFormat.NONE, + OutputContainerFormat.MARKDOWN, + ], + eval_set_percentage=0.005, ) server_configs = [ APIServerConfig( @@ -339,7 +487,9 @@ class PydanticSchemaFollowingEnv(BaseEnv): self.logger.debug("Dataset shuffled") # Split into train and test - split_idx = int(len(self.dataset_items) * 0.90) # 90% train, 10% test + split_idx = int( + len(self.dataset_items) * (1.0 - self.config.eval_set_percentage) + ) self.train_items = self.dataset_items[:split_idx] self.test_items = self.dataset_items[split_idx:] @@ -401,19 +551,47 @@ class PydanticSchemaFollowingEnv(BaseEnv): if self.debug_logging: self.logger.debug(f"Prompt length: {len(user_content)} characters") + # Randomly select structured and container formats + selected_structured_format = random.choice(self.supported_structured_formats) + selected_container_format = random.choice(self.supported_container_formats) + + # Store the selections in the dataset_item + dataset_item["selected_structured_format"] = selected_structured_format + dataset_item["selected_container_format"] = selected_container_format + + if self.debug_logging: + self.logger.debug( + f"Selected structured format: {selected_structured_format.value}" + ) + self.logger.debug( + f"Selected container format: {selected_container_format.value}" + ) + + # Generate the system prompt + current_system_prompt = self._generate_system_prompt( + selected_structured_format, selected_container_format + ) + # Create the message structure prompt_messages = [ - frozenset({"role": "system", "content": system_prompt}.items()), + frozenset({"role": "system", "content": current_system_prompt}.items()), frozenset({"role": "user", "content": user_content}.items()), ] # Return the prompt and the full dataset item for scoring return tuple(prompt_messages), dataset_item - def _extract_json_response(self, text: str) -> Optional[str]: - """Extracts JSON content from tags, with strict thinking tag validation.""" + def _extract_structured_data_response( + self, + text: str, + container_format: OutputContainerFormat, + structured_format: StructuredOutputFormat, + ) -> Optional[str]: + """Extracts structured data content based on container and structured format, with strict thinking tag validation.""" # noqa: E501 if self.debug_logging: - self.logger.debug(f"Extracting JSON from response (length: {len(text)})") + self.logger.debug( + f"Extracting {structured_format.value} from response (length: {len(text)}), container: {container_format.value}" # noqa: E501 + ) # Ensure text is a string if not isinstance(text, str): @@ -423,103 +601,132 @@ class PydanticSchemaFollowingEnv(BaseEnv): ) text = str(text) - # First, validate thinking tags (similar to MCQA environment) + # 1. Validate thinking tags think_tags = re.findall(r"", text, re.IGNORECASE) think_close_tags = re.findall(r"", text, re.IGNORECASE) - # Check for proper thinking tag structure if len(think_tags) != 1 or len(think_close_tags) != 1: if self.debug_logging: self.logger.warning( - f"Invalid thinking tag structure: {len(think_tags)} open tags, {len(think_close_tags)} close tags" + f"Invalid thinking tag structure: {len(think_tags)} open tags, {len(think_close_tags)} close tags. Full text: {text[:500]}" # noqa: E501 ) return None # Split the text into thinking and response sections - parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1) - if len(parts) != 2: + # Use a regex that captures the content before, between, and after think tags robustly + match_think_block = re.match( + r"(.*?)(.*?)(.*)", text, re.DOTALL | re.IGNORECASE + ) + if not match_think_block: if self.debug_logging: self.logger.warning( - "Could not split text into thinking and response sections" + f"Could not find a complete ... block. Full text: {text[:500]}" ) return None - thinking_section, response_section = parts + # thinking_section_plus_prefix = match_think_block.group(1) # Content before + # thinking_block_content = match_think_block.group(2) # ... + response_section = match_think_block.group(3) # Content after - # Validate thinking section contains opening tag - if "" not in thinking_section.lower(): + # Validate that the first tag is indeed the one we captured (no before it) + if "" in match_think_block.group(1).lower(): if self.debug_logging: - self.logger.warning("Thinking section missing opening tag") + self.logger.warning( + f"Nested or malformed tags detected before main block. Full text: {text[:500]}" + ) return None # Check if there are any thinking tags in the response section (after ) if "" in response_section.lower(): if self.debug_logging: self.logger.warning( - "Found tags in response section after " + f"Found tags in response section after . Full text: {text[:500]}" ) return None - # Now extract JSON from the response section only - match = re.search( - r"\s*(.*?)\s*", - response_section, - re.DOTALL | re.IGNORECASE, - ) - if match: - json_str = match.group(1).strip() - if self.debug_logging: - self.logger.debug( - f"Found JSON output tags in response section, extracted {len(json_str)} characters" - ) + extracted_content: Optional[str] = None - # Handle empty extraction - if not json_str: - if self.debug_logging: - self.logger.warning("JSON output tags found but content is empty") - return None - - # Validate JSON - try: - json.loads(json_str) - if self.debug_logging: - self.logger.debug("Extracted JSON is valid") - return json_str - except json.JSONDecodeError as e: - if self.debug_logging: - self.logger.warning(f"Extracted text is not valid JSON: {e}") - return None - - # Fallback: Look for JSON in response section only (no thinking tag validation) - if self.debug_logging: - self.logger.debug( - "No tags found in response section, trying fallback extraction" - ) - - # Look for JSON objects that start with { and end with } in response section only - json_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}" - matches = re.findall(json_pattern, response_section, re.DOTALL) - - for potential_json in matches: - try: - json.loads(potential_json.strip()) + # 2. Extract based on container_format from response_section + if container_format == OutputContainerFormat.TAGGED: + tag_name = f"{structured_format.value}_output" + # Loosen regex to allow for attributes in the opening tag if any, and handle varying whitespace + pattern = rf"<{tag_name}[^>]*>\s*(.*?)\s*" + match = re.search(pattern, response_section, re.DOTALL | re.IGNORECASE) + if match: + extracted_content = match.group(1) if self.debug_logging: self.logger.debug( - f"Fallback extraction successful from response section: {len(potential_json)} characters" + f"Extracted using TAGGED ({tag_name}): {len(extracted_content)} chars" + ) + else: + if self.debug_logging: + self.logger.warning( + f"TAGGED format: Could not find <{tag_name}>... tags in response section. Response section: {response_section[:300]}" # noqa: E501 ) - return potential_json.strip() - except json.JSONDecodeError: - continue - if self.debug_logging: - self.logger.warning("No valid JSON found with any extraction method") - return None + elif container_format == OutputContainerFormat.MARKDOWN: + # Pattern to match ```language ... ``` or just ``` ... ``` + # It captures the content within the backticks. + # It optionally matches the language specifier. + pattern = rf"^\s*```(?:{re.escape(structured_format.value)})?\s*\n(.*?)\n\s*```\s*$" + match = re.search( + pattern, response_section.strip(), re.DOTALL | re.IGNORECASE + ) + if match: + extracted_content = match.group(1) + if self.debug_logging: + self.logger.debug( + f"Extracted using MARKDOWN: {len(extracted_content)} chars" + ) + else: + # Fallback for markdown: if ``` content ``` is not found, try ``` content ``` + pattern_no_lang = r"^\s*```\s*\n(.*?)\n\s*```\s*$" + match_no_lang = re.search( + pattern_no_lang, response_section.strip(), re.DOTALL | re.IGNORECASE + ) + if match_no_lang: + extracted_content = match_no_lang.group(1) + if self.debug_logging: + self.logger.debug( + f"Extracted using MARKDOWN (no lang specified): {len(extracted_content)} chars" + ) + else: + if self.debug_logging: + self.logger.warning( + f"MARKDOWN format: Could not find ```...``` code block in response section. Response section: {response_section[:300]}" # noqa: E501 + ) + + elif container_format == OutputContainerFormat.NONE: + extracted_content = response_section + if self.debug_logging: + self.logger.debug( + f"Extracted using NONE: {len(extracted_content)} chars" + ) + + # 3. Post-processing + if extracted_content is not None: + extracted_content = extracted_content.strip() + if not extracted_content: # Empty after stripping + if self.debug_logging: + self.logger.warning("Extracted content is empty after stripping.") + return None + if self.debug_logging: + self.logger.debug( + f"Successfully extracted content ({len(extracted_content)} chars). First 100: {extracted_content[:100]}" # noqa: E501 + ) + return extracted_content + else: + if self.debug_logging: + self.logger.warning( + f"Extraction failed for container type {container_format.value}. No content extracted." + ) + return None async def score( self, rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]], ) -> Optional[ScoredDataGroup]: - """Score the rollouts based on Pydantic validation.""" + """Score the rollouts based on Pydantic validation or other structural checks.""" if self.debug_logging: self.logger.debug(f"Scoring {len(rollout_group_data)} rollouts") @@ -529,124 +736,202 @@ class PydanticSchemaFollowingEnv(BaseEnv): scores_obj["scores"] = list() scores_obj["messages"] = list() # Add messages for data dumping - # All items in rollout_group_data share the same dataset_item (item[1]) if not rollout_group_data: if self.debug_logging: self.logger.warning("No rollout data to score") return None dataset_item = rollout_group_data[0][1] + problem_id = dataset_item.get("problem_id", "N/A") + selected_structured_format = dataset_item["selected_structured_format"] + selected_container_format = dataset_item["selected_container_format"] if self.debug_logging: self.logger.debug( - f"Scoring for problem_id: {dataset_item.get('problem_id', 'N/A')}" + f"Scoring for problem_id: {problem_id}, structured_format: {selected_structured_format.value}, container_format: {selected_container_format.value}" # noqa: E501 ) - # Extract verification info (pydantic config) and model name verification_info = dataset_item["verification_info"] - - # Parse the verification info to get pydantic config and model name try: verification_data = json.loads(verification_info) pydantic_config = verification_data["pydantic_config"] model_name = verification_data["model_name"] - if self.debug_logging: - self.logger.debug(f"Target model for scoring: {model_name}") - self.logger.debug( - f"Pydantic config length: {len(pydantic_config)} characters" - ) + self.logger.debug(f"Target Pydantic model for validation: {model_name}") except (json.JSONDecodeError, KeyError) as e: - error_msg = f"Error parsing verification_info: {e}" + error_msg = f"Error parsing verification_info for {problem_id}: {e}" if self.debug_logging: self.logger.error(error_msg) print(error_msg) return None - # Create the Pydantic model dynamically try: target_model_cls = self._create_pydantic_model_from_code( pydantic_config, model_name ) except Exception as e: - error_msg = f"Error creating Pydantic model: {e}" + error_msg = ( + f"Error creating Pydantic model {model_name} for {problem_id}: {e}" + ) if self.debug_logging: self.logger.error(error_msg) print(error_msg) return None - # Score each rollout valid_count = 0 invalid_count = 0 extraction_failures = 0 + parsing_failures = 0 - # Shuffle to avoid bias in selection random.shuffle(rollout_group_data) for i, (item_messages, _) in enumerate(rollout_group_data): - if self.debug_logging and i == 0: - self.logger.debug(f"Scoring rollout {i+1}/{len(rollout_group_data)}") - # Convert frozensets to dictionaries for easier access messages_as_dicts = [dict(fs_message) for fs_message in item_messages] - model_response_text = messages_as_dicts[-1]["content"] # LLM full response + model_response_text = messages_as_dicts[-1]["content"] - if self.debug_logging and i == 0: - self.logger.debug( - f"Response length: {len(model_response_text)} characters" - ) + extracted_str = self._extract_structured_data_response( + model_response_text, + selected_container_format, + selected_structured_format, + ) - json_str = self._extract_json_response(model_response_text) + reward = 0.0 + parsed_data = None + validation_error_msg = "Extraction failed" - reward = 0.0 # Default score + if extracted_str: + try: + if selected_structured_format == StructuredOutputFormat.JSON: + parsed_data = json.loads(extracted_str) + elif selected_structured_format == StructuredOutputFormat.YAML: + parsed_data = yaml.safe_load(extracted_str) + elif selected_structured_format == StructuredOutputFormat.TOML: + parsed_data = toml.loads(extracted_str) + elif selected_structured_format == StructuredOutputFormat.XML: + try: + data_dict_outer = xmltodict.parse(extracted_str) + # Assumption: Pydantic model corresponds to the content *within* the single root XML tag. + # xmltodict.parse returns a dict like {'RootTag': actual_data_dict}. + # We extract actual_data_dict for validation. More complex XML might require + # specific xmltodict process_instructions or more sophisticated unwrapping. + if ( + isinstance(data_dict_outer, dict) + and len(data_dict_outer) == 1 + ): + root_key = list(data_dict_outer.keys())[0] + parsed_data = data_dict_outer[root_key] + if self.debug_logging: + self.logger.debug( + f"Eval item (problem {problem_id}): XML parsed, root key '{root_key}', data: {str(parsed_data)[:100]}..." # noqa: E501 + ) + else: + raise ValueError( + f"XML from xmltodict.parse was not a dict with a single root key as expected. Got: {type(data_dict_outer)}, Keys: {list(data_dict_outer.keys()) if isinstance(data_dict_outer, dict) else 'N/A'}" # noqa: E501 + ) + except xmltodict.expat.ExpatError as e_xml_parse: + # No reward variable here, score is returned directly + if self.debug_logging: + self.logger.debug( + f"Eval item (problem {problem_id}): XML parsing failed (ExpatError): {e_xml_parse}. Extracted XML: {extracted_str[:100]}..." # noqa: E501 + ) + return 0.0 # Return score directly + except Exception as e_xml_generic: + if self.debug_logging: + self.logger.debug( + f"Eval item (problem {problem_id}): XML processing error: {e_xml_generic}. Extracted XML: {extracted_str[:100]}..." # noqa: E501 + ) + return 0.0 # Return score directly - if json_str: - # Validate JSON against the Pydantic model - is_valid, error_msg = self._validate_json_against_model( - json_str, target_model_cls, dataset_item.get("problem_id", "N/A") - ) + if parsed_data is not None and selected_structured_format in [ + StructuredOutputFormat.JSON, + StructuredOutputFormat.YAML, + StructuredOutputFormat.TOML, + StructuredOutputFormat.XML, + ]: + is_valid, pydantic_error_msg = ( + self._validate_parsed_data_against_model( + parsed_data, target_model_cls, problem_id + ) + ) + if is_valid: + reward = 1.0 + valid_count += 1 + validation_error_msg = None + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: Pydantic validation successful for {selected_structured_format.value}" # noqa: E501 + ) + else: + reward = 0.0 + invalid_count += 1 + validation_error_msg = pydantic_error_msg + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: Pydantic validation failed for {selected_structured_format.value}. Error: {pydantic_error_msg}" # noqa: E501 + ) - if is_valid: - reward = 1.0 # Valid schema - full score - valid_count += 1 - if self.debug_logging and i < 3: # Log first few successes - self.logger.debug(f"Rollout {i}: Validation successful") - else: - reward = 0.0 # Validation failed - invalid_count += 1 - if self.debug_logging and i < 3: # Log first few validation errors - self.logger.debug(f"Rollout {i}: {error_msg}") + except json.JSONDecodeError as e_json: + reward = 0.0 + parsing_failures += 1 + validation_error_msg = f"JSON parsing failed: {e_json}" + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..." + ) + except yaml.YAMLError as e_yaml: + reward = 0.0 + parsing_failures += 1 + validation_error_msg = f"YAML parsing failed: {e_yaml}" + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..." + ) + except toml.TomlDecodeError as e_toml: + reward = 0.0 + parsing_failures += 1 + validation_error_msg = f"TOML parsing failed: {e_toml}" + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..." + ) + except Exception as e_parse: # Catch any other parsing related errors + reward = 0.0 + parsing_failures += 1 + validation_error_msg = f"Generic parsing failed for {selected_structured_format.value}: {e_parse}" + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: {validation_error_msg}. Extracted: {extracted_str[:100]}..." + ) else: - reward = 0.0 # No JSON output found or extraction failed + reward = 0.0 # No structured data output found or extraction failed extraction_failures += 1 - if self.debug_logging and i < 3: # Log first few failures - self.logger.debug(f"Rollout {i}: JSON extraction failed") + # validation_error_msg is already "Extraction failed" + if self.debug_logging and i < 3: + self.logger.debug( + f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501 + ) - # Tokenize for training - convert frozensets to dicts for tokenizer try: - # Validate that messages_as_dicts is properly formatted if not isinstance(messages_as_dicts, list): if self.debug_logging: self.logger.error( f"Expected list for tokenization, got {type(messages_as_dicts)}" ) continue - - # Validate each message has required keys for msg_idx, msg in enumerate(messages_as_dicts): if not isinstance(msg, dict): if self.debug_logging: self.logger.error( f"Message {msg_idx} is not a dict: {type(msg)}" ) - continue + continue # Skip this rollout if message format is incorrect if "role" not in msg or "content" not in msg: if self.debug_logging: self.logger.error( f"Message {msg_idx} missing required keys: {msg.keys()}" ) - continue - # Ensure content is a string + continue # Skip this rollout if not isinstance(msg["content"], str): if self.debug_logging: self.logger.warning( @@ -655,34 +940,36 @@ class PydanticSchemaFollowingEnv(BaseEnv): msg["content"] = str(msg["content"]) out_dict = tokenize_for_trainer( - self.tokenizer, messages_as_dicts, include_messages=True + self.tokenizer, + messages_as_dicts, + include_messages=self.config.include_messages, # Using config value ) tokens = out_dict["tokens"] masks = out_dict["masks"] except Exception as e: if self.debug_logging: - self.logger.error(f"Tokenization failed for rollout {i}: {e}") + self.logger.error( + f"Tokenization failed for rollout {i} (problem: {problem_id}): {e}" + ) self.logger.debug( f"Messages format: {[type(m) for m in messages_as_dicts]}" ) continue - if len([1 for i in masks if i != -100]) < 10: # Min context length + if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length if self.debug_logging: self.logger.debug( - "Skipping rollout due to insufficient context length" + f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length after tokenization." # noqa: E501 ) continue scores_obj["tokens"].append(tokens) scores_obj["masks"].append(masks) scores_obj["scores"].append(reward) - scores_obj["messages"].append( - out_dict.get("messages", messages_as_dicts) - ) # Store converted messages for dumping + # Store original messages (converted to dicts) if available in out_dict, else the modified ones + scores_obj["messages"].append(out_dict.get("messages", messages_as_dicts)) - # Track perfect scores for wandb self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0) if len(scores_obj["tokens"]) >= self.config.group_size: @@ -690,39 +977,43 @@ class PydanticSchemaFollowingEnv(BaseEnv): if self.debug_logging: self.logger.info( - f"Scoring complete: {valid_count} valid, {invalid_count} invalid, {extraction_failures} extraction failures" # noqa: E501 + f"Scoring complete for {problem_id} (Format: {selected_structured_format.value}, Container: {selected_container_format.value}): " # noqa: E501 + f"{valid_count} valid (Pydantic), {invalid_count} invalid (Pydantic), " + f"{parsing_failures} parsing failures, {extraction_failures} extraction failures." ) if scores_obj["scores"]: avg_score = sum(scores_obj["scores"]) / len(scores_obj["scores"]) - self.logger.info(f"Average score for this batch: {avg_score:.3f}") + self.logger.info( + f"Average score for this batch ({problem_id}, Format: {selected_structured_format.value}, Container: {selected_container_format.value}): {avg_score:.3f}" # noqa: E501 + ) if not scores_obj["tokens"]: # No valid examples processed if self.debug_logging: - self.logger.warning("No valid examples processed in this batch") + self.logger.warning( + f"No valid examples processed in this batch for {problem_id}" + ) return None + # This condition might need adjustment if 0.0 is a valid signal for non-Pydantic formats if ( all(scores_obj["scores"][0] == score for score in scores_obj["scores"]) and scores_obj["scores"][0] != 1.0 ): if self.debug_logging: self.logger.debug( - "All scores are identical and not perfect, returning None for learning signal" + f"All scores are identical ({scores_obj['scores'][0]}) and not perfect for {problem_id}, returning None for learning signal." # noqa: E501 ) return None - # Apply length penalty if average response length is too high and all scores are 1.0 if all(s == 1.0 for s in scores_obj["scores"]): avg_len = sum(len(t) for t in scores_obj["tokens"]) / len( scores_obj["tokens"] ) - if ( - avg_len > self.config.max_token_length * 0.75 - ): # Penalize if too verbose even when correct + if avg_len > self.config.max_token_length * 0.75: scores_obj["scores"] = [s * 0.9 for s in scores_obj["scores"]] if self.debug_logging: self.logger.debug( - f"Applied length penalty: avg_len={avg_len}, penalty_threshold={self.config.max_token_length * 0.75}" # noqa: E501 + f"Applied length penalty for {problem_id}: avg_len={avg_len}, penalty_threshold={self.config.max_token_length * 0.75}" # noqa: E501 ) return scores_obj @@ -826,11 +1117,14 @@ class PydanticSchemaFollowingEnv(BaseEnv): else [] ) score_for_rollout = scored_data["scores"][i] - # Extract the generated JSON from the assistant's response if conversation_messages: assistant_response = conversation_messages[-1].get("content", "") - generated_json = self._extract_json_response(assistant_response) + generated_json = self._extract_structured_data_response( + assistant_response, + dataset_item["selected_container_format"], + dataset_item["selected_structured_format"], + ) else: generated_json = None @@ -928,16 +1222,29 @@ class PydanticSchemaFollowingEnv(BaseEnv): print(error_msg) async def rollout_and_score_eval(self, dataset_item: Dict[str, Any]) -> float: - """Evaluate a single item from the test set.""" + """Evaluate a single item from the test set with randomized formats.""" + problem_id = dataset_item.get("problem_id", "N/A") + + # Randomly select formats for evaluation consistency with training + selected_structured_format = random.choice(self.supported_structured_formats) + selected_container_format = random.choice(self.supported_container_formats) + # Store them in dataset_item if needed for logging or other parts, though not strictly for this function's direct logic # noqa: E501 + dataset_item["selected_structured_format"] = selected_structured_format + dataset_item["selected_container_format"] = selected_container_format + if self.debug_logging: self.logger.debug( - f"Evaluating item: {dataset_item.get('problem_id', 'N/A')}" + f"Evaluating item: {problem_id}, structured: {selected_structured_format.value}, container: {selected_container_format.value}" # noqa: E501 ) user_content = dataset_item["prompt"] + current_system_prompt = self._generate_system_prompt( + selected_structured_format, selected_container_format + ) + messages = [ - {"role": "system", "content": system_prompt}, + {"role": "system", "content": current_system_prompt}, {"role": "user", "content": user_content}, ] @@ -946,28 +1253,39 @@ class PydanticSchemaFollowingEnv(BaseEnv): ) if self.debug_logging: - self.logger.debug(f"Eval prompt length: {len(prompt)} characters") + self.logger.debug( + f"Eval prompt length for {problem_id}: {len(prompt)} characters" + ) completion = await self.server.completion( prompt=prompt, n=1, max_tokens=self.config.max_token_length, temperature=0.1, # Lower temperature for eval - split="eval", + split="eval", # Ensure correct server endpoint is used ) model_response_text = completion.choices[0].text if self.debug_logging: self.logger.debug( - f"Eval response length: {len(model_response_text)} characters" + f"Eval response length for {problem_id}: {len(model_response_text)} characters" ) - json_str = self._extract_json_response(model_response_text) + extracted_str = self._extract_structured_data_response( + model_response_text, selected_container_format, selected_structured_format + ) score = 0.0 - # Extract verification info and create model + if not extracted_str: + if self.debug_logging: + self.logger.debug( + f"Eval extraction failed for {problem_id} (Format: {selected_structured_format.value}, Container: {selected_container_format.value})" # noqa: E501 + ) + return 0.0 + + # Attempt to parse and validate try: verification_info = dataset_item["verification_info"] verification_data = json.loads(verification_info) @@ -975,44 +1293,103 @@ class PydanticSchemaFollowingEnv(BaseEnv): model_name = verification_data["model_name"] if self.debug_logging: - self.logger.debug(f"Eval target model: {model_name}") + self.logger.debug( + f"Eval target Pydantic model for {problem_id}: {model_name}" + ) target_model_cls = self._create_pydantic_model_from_code( pydantic_config, model_name ) - if json_str: - # Validate JSON against the Pydantic model - is_valid, error_msg = self._validate_json_against_model( - json_str, target_model_cls, dataset_item.get("problem_id", "N/A") - ) - - if is_valid: - score = 1.0 # Valid schema + parsed_data = None + if selected_structured_format == StructuredOutputFormat.JSON: + parsed_data = json.loads(extracted_str) + elif selected_structured_format == StructuredOutputFormat.YAML: + parsed_data = yaml.safe_load(extracted_str) + elif selected_structured_format == StructuredOutputFormat.TOML: + parsed_data = toml.loads(extracted_str) + elif selected_structured_format == StructuredOutputFormat.XML: + try: + data_dict_outer = xmltodict.parse(extracted_str) + # Assumption: Pydantic model corresponds to the content *within* the single root XML tag. + # xmltodict.parse returns a dict like {'RootTag': actual_data_dict}. + # We extract actual_data_dict for validation. More complex XML might require + # specific xmltodict process_instructions or more sophisticated unwrapping. + if isinstance(data_dict_outer, dict) and len(data_dict_outer) == 1: + root_key = list(data_dict_outer.keys())[0] + parsed_data = data_dict_outer[root_key] + if self.debug_logging: + self.logger.debug( + f"Eval item (problem {problem_id}): XML parsed, root key '{root_key}', data: {str(parsed_data)[:100]}..." # noqa: E501 + ) + else: + raise ValueError( + f"XML from xmltodict.parse was not a dict with a single root key as expected. Got: {type(data_dict_outer)}, Keys: {list(data_dict_outer.keys()) if isinstance(data_dict_outer, dict) else 'N/A'}" # noqa: E501 + ) + except xmltodict.expat.ExpatError as e_xml_parse: + # No reward variable here, score is returned directly if self.debug_logging: self.logger.debug( - f"Eval validation successful for {dataset_item.get('problem_id', 'N/A')}" + f"Eval item (problem {problem_id}): XML parsing failed (ExpatError): {e_xml_parse}. Extracted XML: {extracted_str[:100]}..." # noqa: E501 + ) + return 0.0 # Return score directly + except Exception as e_xml_generic: + if self.debug_logging: + self.logger.debug( + f"Eval item (problem {problem_id}): XML processing error: {e_xml_generic}. Extracted XML: {extracted_str[:100]}..." # noqa: E501 + ) + return 0.0 # Return score directly + + if parsed_data is not None and selected_structured_format in [ + StructuredOutputFormat.JSON, + StructuredOutputFormat.YAML, + StructuredOutputFormat.TOML, + StructuredOutputFormat.XML, + ]: + is_valid, pydantic_error_msg = self._validate_parsed_data_against_model( + parsed_data, target_model_cls, problem_id + ) + if is_valid: + score = 1.0 + if self.debug_logging: + self.logger.debug( + f"Eval Pydantic validation successful for {problem_id} (Format: {selected_structured_format.value})" # noqa: E501 ) else: - score = 0.0 # Validation failed + score = 0.0 # Already 0.0 by default unless validation passes if self.debug_logging: self.logger.debug( - f"Eval validation failed for {dataset_item.get('problem_id', 'N/A')}: {error_msg}" + f"Eval Pydantic validation failed for {problem_id} (Format: {selected_structured_format.value}): {pydantic_error_msg}" # noqa: E501 ) - else: - score = 0.0 # No valid JSON extracted - if self.debug_logging: - self.logger.debug( - f"Eval JSON extraction failed for {dataset_item.get('problem_id', 'N/A')}" - ) - except Exception as e: - score = 0.0 # Any error in model creation or setup + except json.JSONDecodeError as je: + if self.debug_logging: + self.logger.debug( + f"Eval JSON parsing failed for {problem_id} (Format: {selected_structured_format.value}): {je}. Extracted: {extracted_str[:100]}..." # noqa: E501 + ) + score = 0.0 + except yaml.YAMLError as ye: + if self.debug_logging: + self.logger.debug( + f"Eval YAML parsing failed for {problem_id} (Format: {selected_structured_format.value}): {ye}. Extracted: {extracted_str[:100]}..." # noqa: E501 + ) + score = 0.0 + except toml.TomlDecodeError as te: + if self.debug_logging: + self.logger.debug( + f"Eval TOML parsing failed for {problem_id} (Format: {selected_structured_format.value}): {te}. Extracted: {extracted_str[:100]}..." # noqa: E501 + ) + score = 0.0 + except ( + Exception + ) as e: # Catch errors in model creation, other parsing, or validation setup + score = ( + 0.0 # Ensure score is 0.0 for any other unexpected error in this block + ) if self.debug_logging: self.logger.error( - f"Error in eval setup for {dataset_item.get('problem_id', 'N/A')}: {e}" + f"Error during eval scoring pipeline for {problem_id} (Format: {selected_structured_format.value}): {type(e).__name__}: {e}" # noqa: E501 ) - return score async def evaluate(self, *args, **kwargs): @@ -1074,11 +1451,19 @@ class PydanticSchemaFollowingEnv(BaseEnv): self.logger.debug("No scored data to log to wandb") return - dataset_item = item[1] if item else {} + # dataset_item contains selected_structured_format and selected_container_format + dataset_item = item[1] if item and len(item) > 1 else {} + problem_id = dataset_item.get("problem_id", "N/A") + selected_structured_format = dataset_item.get( + "selected_structured_format", StructuredOutputFormat.JSON + ) # Default if not found + selected_container_format = dataset_item.get( + "selected_container_format", OutputContainerFormat.TAGGED + ) # Default if not found if self.debug_logging: self.logger.debug( - f"Logging rollouts for problem_id: {dataset_item.get('problem_id', 'N/A')}" + f"Logging rollouts for problem_id: {problem_id}, structured: {selected_structured_format.value}, container: {selected_container_format.value}" # noqa: E501 ) num_keep = self.config.num_rollouts_per_group_for_logging @@ -1088,59 +1473,128 @@ class PydanticSchemaFollowingEnv(BaseEnv): num_keep = min(num_keep, len(scored_data["tokens"])) if self.debug_logging: - self.logger.debug(f"Keeping {num_keep} rollouts for logging") + self.logger.debug( + f"Keeping {num_keep} rollouts for logging for {problem_id}" + ) rollout_batch = [] for i in range(num_keep): - # Decode tokens to text for logging - full_convo_text = self.tokenizer.decode( - scored_data["tokens"][i], skip_special_tokens=True + # The full_convo_text comes from scored_data["messages"], which should be the tokenized version. + # We need the raw model output for extraction. + # scored_data["messages"][i] should be the list of message dicts for the i-th rollout + + # Reconstruct full_convo_text from messages if possible, or use decoded tokens as fallback + raw_conversation_messages = scored_data["messages"][i] + assistant_response_text = "" + if ( + isinstance(raw_conversation_messages, list) + and len(raw_conversation_messages) > 0 + ): + # The last message should be the assistant's response + if ( + isinstance(raw_conversation_messages[-1], dict) + and raw_conversation_messages[-1].get("role") == "assistant" + ): + assistant_response_text = raw_conversation_messages[-1].get( + "content", "" + ) + else: # Try to find assistant message if not last, or if format is unexpected + for msg in reversed(raw_conversation_messages): + if isinstance(msg, dict) and msg.get("role") == "assistant": + assistant_response_text = msg.get("content", "") + break + + if not assistant_response_text: # Fallback if proper message not found + # This decodes the entire conversation, including system/user prompts + assistant_response_text = self.tokenizer.decode( + scored_data["tokens"][i], skip_special_tokens=True + ) + if self.debug_logging: + self.logger.warning( + f"WandB: Could not get raw assistant message for {problem_id}, using full decoded tokens for extraction." # noqa: E501 + ) + + extracted_output = self._extract_structured_data_response( + assistant_response_text, # Use assistant_response_text (ideally raw, or decoded full convo) + selected_container_format, + selected_structured_format, ) - extracted_json = self._extract_json_response(full_convo_text) - - # Extract model info from dataset item try: verification_info = dataset_item.get("verification_info", "{}") verification_data = json.loads(verification_info) - model_name = verification_data.get("model_name", "N/A") - expected_json = verification_data.get("pydantic_config", "N/A") + pydantic_model_name = verification_data.get( + "model_name", "N/A" + ) # Renamed to avoid conflict + expected_schema_info = verification_data.get("pydantic_config", "N/A") except (json.JSONDecodeError, KeyError): - model_name = "N/A" - expected_json = "N/A" + pydantic_model_name = "N/A" + expected_schema_info = "N/A" if self.debug_logging: - self.logger.debug("Could not extract model name for wandb logging") + self.logger.debug( + f"Could not extract model name/schema for wandb logging for {problem_id}" + ) + + # Construct the full conversation text for display from the messages list + # This ensures the system prompt (which might be dynamic) is correctly shown. + display_convo_text = "" + if isinstance(raw_conversation_messages, list): + try: + display_convo_text = self.tokenizer.apply_chat_template( + raw_conversation_messages, + tokenize=False, + add_generation_prompt=False, + ) + except Exception as e_tmpl: + if self.debug_logging: + self.logger.warning( + f"WandB: Error applying chat template for {problem_id}: {e_tmpl}. Falling back to joining content." # noqa: E501 + ) + display_convo_text = "\n---\n".join( + [ + str(msg.get("content", "")) + for msg in raw_conversation_messages + ] + ) + else: + display_convo_text = ( + assistant_response_text # Fallback to decoded full string + ) rollout_batch.append( ( - full_convo_text, # Full conversation + display_convo_text, # Full conversation for display scored_data["scores"][i], - model_name, - dataset_item.get("problem_id", "N/A"), + pydantic_model_name, # The Pydantic model name it was validated against + problem_id, dataset_item.get("task_type", "N/A"), + selected_structured_format.value, # Log selected structured format + selected_container_format.value, # Log selected container format ( - extracted_json - if extracted_json - else "Extraction failed or no JSON" + extracted_output + if extracted_output + else "Extraction failed or no output" ), - ( - expected_json[:200] + "..." - if len(expected_json) > 200 - else expected_json - ), # Truncate for display + expected_schema_info[:200] + + ("..." if len(expected_schema_info) > 200 else ""), ) ) if rollout_batch: self.rollouts_for_wandb.append(rollout_batch) if self.debug_logging: - self.logger.debug(f"Added {len(rollout_batch)} rollouts to wandb queue") + self.logger.debug( + f"Added {len(rollout_batch)} rollouts to wandb queue for {problem_id}" + ) if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: - removed = self.rollouts_for_wandb.pop(0) - if self.debug_logging: + removed_count = 0 + while len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + removed_count += 1 + if self.debug_logging and removed_count > 0: self.logger.debug( - f"Removed oldest rollout batch ({len(removed)} items) from wandb queue" + f"Removed {removed_count} oldest rollout batch(es) from wandb queue" ) async def create_rollout_table(self, wandb_metrics: Dict) -> Dict: @@ -1155,11 +1609,13 @@ class PydanticSchemaFollowingEnv(BaseEnv): columns=[ "full_conversation", "score", - "model_name", + "pydantic_model_name", "problem_id", "task_type", - "extracted_json", - "expected_schema", + "selected_structured_format", + "selected_container_format", + "extracted_output", + "expected_schema_preview", ] ) total_entries = 0 @@ -1258,17 +1714,17 @@ class PydanticSchemaFollowingEnv(BaseEnv): if self.debug_logging: self.logger.info("PydanticSchemaFollowingEnv closed.") - def _validate_json_against_model( - self, json_str: str, model_cls: Type[BaseModel], problem_id: str = "N/A" + def _validate_parsed_data_against_model( + self, parsed_data: Any, model_cls: Type[BaseModel], problem_id: str = "N/A" ) -> Tuple[bool, Optional[str]]: """ - Validate JSON string against a Pydantic model and return detailed error info. + Validate parsed data (e.g., a dictionary) against a Pydantic model and return detailed error info. Returns: Tuple of (is_valid, error_message) """ try: - model_cls.model_validate_json(json_str) + model_cls.model_validate(parsed_data) return True, None except ValidationError as ve: try: @@ -1285,11 +1741,6 @@ class PydanticSchemaFollowingEnv(BaseEnv): # Fallback if formatting ve.errors() fails for some reason error_msg = f"Pydantic validation failed for {problem_id}. Error: {str(ve)[:250]}. (Additionally, formatting error details failed: {str(format_exc)[:50]})" # noqa: E501 - if self.debug_logging: - self.logger.debug(error_msg) - return False, error_msg - except json.JSONDecodeError as je: - error_msg = f"JSON decode failed for {problem_id}: {str(je)[:100]}" if self.debug_logging: self.logger.debug(error_msg) return False, error_msg @@ -1304,8 +1755,6 @@ class PydanticSchemaFollowingEnv(BaseEnv): self.logger.error( error_msg ) # Log as error due to its specific nature - # Optionally log the problematic JSON for further inspection if not too large - # self.logger.debug(f"Problematic JSON string for {problem_id} (first 500 chars): {json_str[:500]}") else: # Handle other TypeErrors normally error_msg = f"Unexpected TypeError during validation for {problem_id}: {type(te).__name__}: {str(te)[:100]}" # noqa: E501