Enhanced Pydantic Schema Following Environment with Dynamic Error Introduction and Editing Task Support (#185)

* New JSON env and documentation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removed seperate JSON generation  environment

* Updated pydantic environment with edit functionality

* Error helper function

* Updated README

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed pre-commit issues

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Philip Lippmann 2025-07-11 02:44:16 +02:00 committed by GitHub
parent 6386a5e185
commit 17faebae03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 1631 additions and 46 deletions

View file

@ -71,6 +71,16 @@ from atroposlib.envs.base import (
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Import editing functionality
try:
from .error_introduction import (
ErrorIntroductionConfig,
introduce_error_for_pydantic,
)
except ImportError:
# Handle case when running as script or tests
from error_introduction import ErrorIntroductionConfig, introduce_error_for_pydantic
class StructuredOutputFormat(Enum):
JSON = "json"
@ -121,6 +131,36 @@ class PydanticEnvConfig(BaseEnvConfig):
le=1.0,
)
# Field for task type
task_type: str = Field(
default="generation", description="Task type: 'generation' or 'editing'"
)
# Error introduction configuration
error_introduction_seed: Optional[int] = Field(
default=None, description="Seed for deterministic error introduction"
)
error_types_enabled: List[str] = Field(
default=[
"type_error",
"format_error",
"constraint_error",
"enum_error",
"required_field_missing",
],
description="Types of errors to introduce in editing tasks",
)
error_introduction_probability: float = Field(
default=1.0,
description="Probability of introducing an error (for editing tasks)",
)
max_errors_per_item: int = Field(
default=1, description="Maximum number of errors to introduce per editing task"
)
class PydanticSchemaFollowingEnv(BaseEnv):
env_config_cls = PydanticEnvConfig
@ -515,6 +555,7 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.logger.debug(
f"Item problem_id: {dataset_item.get('problem_id', 'N/A')}"
)
self.logger.debug(f"Item task_type: {dataset_item.get('task_type', 'N/A')}")
# Parse verification info to get model name
try:
@ -527,41 +568,12 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.iter += 1
# Extract the prompt from the dataset item
user_content = dataset_item["prompt"]
if self.debug_logging:
self.logger.debug(f"Prompt length: {len(user_content)} characters")
# Randomly select structured and container formats
selected_structured_format = random.choice(self.supported_structured_formats)
selected_container_format = random.choice(self.supported_container_formats)
# Store the selections in the dataset_item
dataset_item["selected_structured_format"] = selected_structured_format
dataset_item["selected_container_format"] = selected_container_format
if self.debug_logging:
self.logger.debug(
f"Selected structured format: {selected_structured_format.value}"
)
self.logger.debug(
f"Selected container format: {selected_container_format.value}"
)
# Generate the system prompt
current_system_prompt = self._generate_system_prompt(
selected_structured_format, selected_container_format
)
# Create the message structure
prompt_messages = [
frozenset({"role": "system", "content": current_system_prompt}.items()),
frozenset({"role": "user", "content": user_content}.items()),
]
# Return the prompt and the full dataset item for scoring
return tuple(prompt_messages), dataset_item
# Check if this is an editing or generation task
task_type = dataset_item.get("task_type", "generation")
if task_type == "editing":
return self._create_editing_item(dataset_item)
else:
return self._create_generation_item(dataset_item)
def _extract_structured_data_response(
self,
@ -1746,6 +1758,391 @@ class PydanticSchemaFollowingEnv(BaseEnv):
self.logger.debug(error_msg)
return False, error_msg
def _generate_valid_data_for_model(
self, model_cls: Type[BaseModel]
) -> Dict[str, Any]:
"""
Generate realistic valid data for a Pydantic model using field constraints and semantic hints.
This method analyzes the model's fields, their types, constraints, and names to generate
appropriate data that will pass model validation.
"""
valid_data = {}
model_fields = model_cls.model_fields
for field_name, field_info in model_fields.items():
# Skip optional fields sometimes for variability
if not field_info.is_required() and random.random() < 0.3:
continue
field_value = self._generate_field_value(field_name, field_info)
if field_value is not None:
valid_data[field_name] = field_value
# Validate the generated data to ensure it's correct
try:
model_cls.model_validate(valid_data)
return valid_data
except ValidationError as e:
# If validation fails, provide minimal valid data
if self.debug_logging:
self.logger.warning(f"Generated data failed validation: {e}")
return self._generate_minimal_valid_data(model_cls)
def _generate_field_value(self, field_name: str, field_info) -> Any:
"""Generate a realistic value for a specific field based on its type and constraints."""
annotation = field_info.annotation
# Handle Union types (like Optional[T])
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
# Get the first non-None type
for arg in annotation.__args__:
if arg is not type(None):
annotation = arg
break
# Handle List types
if hasattr(annotation, "__origin__") and annotation.__origin__ is list:
item_type = annotation.__args__[0] if annotation.__args__ else str
return [
self._generate_value_for_type(item_type, f"{field_name}_item")
for _ in range(random.randint(1, 3))
]
# Handle Dict types
if hasattr(annotation, "__origin__") and annotation.__origin__ is dict:
key_type = annotation.__args__[0] if len(annotation.__args__) > 0 else str
value_type = annotation.__args__[1] if len(annotation.__args__) > 1 else str
return {
self._generate_value_for_type(
key_type, "key"
): self._generate_value_for_type(value_type, f"{field_name}_value")
for _ in range(random.randint(1, 2))
}
# Generate value based on type
return self._generate_value_for_type(annotation, field_name)
def _generate_value_for_type(
self, type_annotation: Type, field_name: str = ""
) -> Any:
"""Generate a value for a specific type, using field name as hint for semantic meaning."""
field_name_lower = field_name.lower()
# Special Pydantic types
if type_annotation is EmailStr or "email" in field_name_lower:
domains = ["example.com", "test.org", "sample.net", "demo.io"]
names = ["user", "admin", "test", "demo", "sample"]
return f"{random.choice(names)}{random.randint(1, 99)}@{random.choice(domains)}"
if (
type_annotation is HttpUrl
or "url" in field_name_lower
or "link" in field_name_lower
):
domains = ["example.com", "test.org", "sample.net"]
paths = ["api/v1", "users/profile", "data/export", "files/download"]
return f"https://{random.choice(domains)}/{random.choice(paths)}"
if (
type_annotation is UUID
or "uuid" in field_name_lower
or "id" in field_name_lower
):
return str(uuid.uuid4())
# Date/Time types
if type_annotation is datetime:
return datetime.now() - timedelta(days=random.randint(0, 365))
if type_annotation is date:
return (datetime.now() - timedelta(days=random.randint(0, 365))).date()
# Numeric types with semantic hints
if type_annotation is int or type_annotation is float:
if "age" in field_name_lower:
return random.randint(18, 80)
elif "year" in field_name_lower:
return random.randint(2000, 2024)
elif "score" in field_name_lower or "rating" in field_name_lower:
return random.randint(1, 100)
elif "price" in field_name_lower or "cost" in field_name_lower:
return round(random.uniform(10.0, 1000.0), 2)
elif "count" in field_name_lower or "quantity" in field_name_lower:
return random.randint(1, 100)
else:
return (
random.randint(1, 100)
if type_annotation is int
else round(random.uniform(1.0, 100.0), 2)
)
if type_annotation is Decimal:
return Decimal(str(round(random.uniform(1.0, 100.0), 2)))
if type_annotation is bool:
return random.choice([True, False])
# String types with semantic hints
if type_annotation is str:
if "name" in field_name_lower:
first_names = ["John", "Jane", "Alice", "Bob", "Charlie", "Diana"]
last_names = [
"Smith",
"Johnson",
"Williams",
"Brown",
"Jones",
"Garcia",
]
if "first" in field_name_lower:
return random.choice(first_names)
elif "last" in field_name_lower:
return random.choice(last_names)
else:
return f"{random.choice(first_names)} {random.choice(last_names)}"
elif "title" in field_name_lower:
titles = [
"Senior Developer",
"Product Manager",
"Data Scientist",
"UX Designer",
]
return random.choice(titles)
elif "description" in field_name_lower:
descriptions = [
"A comprehensive solution",
"High-quality product",
"Innovative approach",
"User-friendly design",
]
return random.choice(descriptions)
elif "address" in field_name_lower:
return f"{random.randint(100, 9999)} Main St, City, State {random.randint(10000, 99999)}"
elif "phone" in field_name_lower:
return f"+1-{random.randint(100, 999)}-{random.randint(100, 999)}-{random.randint(1000, 9999)}"
else:
return f"sample_{field_name_lower}_{random.randint(1, 999)}"
# Enum types
if hasattr(type_annotation, "__bases__") and Enum in type_annotation.__bases__:
return random.choice(list(type_annotation)).value
# Nested BaseModel
if (
hasattr(type_annotation, "__bases__")
and BaseModel in type_annotation.__bases__
):
return self._generate_valid_data_for_model(type_annotation)
# Fallback for unknown types
return f"generated_value_{random.randint(1, 999)}"
def _generate_minimal_valid_data(
self, model_cls: Type[BaseModel]
) -> Dict[str, Any]:
"""Generate minimal valid data when constraint-based generation fails."""
valid_data = {}
model_fields = model_cls.model_fields
for field_name, field_info in model_fields.items():
if field_info.is_required():
annotation = field_info.annotation
# Handle Union types (like Optional[T])
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
for arg in annotation.__args__:
if arg is not type(None):
annotation = arg
break
# Generate simple fallback values
if annotation is int:
valid_data[field_name] = 1
elif annotation is float:
valid_data[field_name] = 1.0
elif annotation is bool:
valid_data[field_name] = True
elif annotation is str or annotation is EmailStr:
valid_data[field_name] = "valid_example"
elif annotation is HttpUrl:
valid_data[field_name] = "https://example.com"
elif annotation is UUID:
valid_data[field_name] = str(uuid.uuid4())
elif annotation is datetime:
valid_data[field_name] = datetime.now()
elif annotation is date:
valid_data[field_name] = date.today()
else:
valid_data[field_name] = "fallback_value"
return valid_data
def _generate_editing_task_prompt(
self,
dataset_item: Dict[str, Any],
structured_format: StructuredOutputFormat,
container_format: OutputContainerFormat,
) -> str:
"""Generate prompt for editing tasks."""
# Get Pydantic schema and erroneous data
verification_info = dataset_item.get("verification_info", "{}")
erroneous_data = dataset_item.get("erroneous_data")
if not verification_info or erroneous_data is None:
raise ValueError(
"Editing task requires both 'verification_info' and 'erroneous_data' fields"
)
try:
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
except (json.JSONDecodeError, KeyError):
raise ValueError("Invalid verification_info format for editing task")
erroneous_data_str = json.dumps(erroneous_data, indent=2)
base_prompt = self._generate_system_prompt(structured_format, container_format)
editing_prompt = f"""
{base_prompt}
The following data contains errors and should be corrected to conform to the provided Pydantic model.
Pydantic Model:
```python
{pydantic_config}
```
Erroneous Data:
```json
{erroneous_data_str}
```
Please identify and fix the errors, then provide the corrected data in the requested format.
"""
return editing_prompt
def _create_editing_item(
self, dataset_item: Dict[str, Any]
) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
"""Create an editing task item with appropriate prompts."""
# Randomly select structured and container formats
selected_structured_format = random.choice(self.supported_structured_formats)
selected_container_format = random.choice(self.supported_container_formats)
# Store the selections in the dataset_item
dataset_item["selected_structured_format"] = selected_structured_format
dataset_item["selected_container_format"] = selected_container_format
if self.debug_logging:
self.logger.debug(
f"Creating editing task for problem_id: {dataset_item.get('problem_id', 'N/A')}"
)
self.logger.debug(
f"Selected structured format: {selected_structured_format.value}"
)
self.logger.debug(
f"Selected container format: {selected_container_format.value}"
)
# If erroneous_data is not already provided, generate it from a valid example
if "erroneous_data" not in dataset_item:
# Create valid data first, then introduce errors
verification_info = dataset_item.get("verification_info", "{}")
if verification_info:
try:
verification_data = json.loads(verification_info)
pydantic_config = verification_data["pydantic_config"]
model_name = verification_data["model_name"]
# Create the Pydantic model
target_model_cls = self._create_pydantic_model_from_code(
pydantic_config, model_name
)
# Generate valid data using Pydantic constraints and field information
valid_data = self._generate_valid_data_for_model(target_model_cls)
# Introduce errors into the valid data using configuration
error_config = ErrorIntroductionConfig.from_env_config(
error_types_enabled=self.config.error_types_enabled,
max_errors_per_item=self.config.max_errors_per_item,
error_introduction_probability=self.config.error_introduction_probability,
error_introduction_seed=self.config.error_introduction_seed,
)
erroneous_data = introduce_error_for_pydantic(
valid_data, target_model_cls, config=error_config
)
if erroneous_data:
dataset_item["erroneous_data"] = erroneous_data
else:
# Fallback: use valid data with a simple error
dataset_item["erroneous_data"] = {
**valid_data,
"invalid_field": "should_not_exist",
}
except Exception as e:
if self.debug_logging:
self.logger.warning(f"Could not generate erroneous data: {e}")
# Use a simple fallback
dataset_item["erroneous_data"] = {"error": "could_not_generate"}
# Generate the editing task prompt
prompt_content = self._generate_editing_task_prompt(
dataset_item, selected_structured_format, selected_container_format
)
# Create the message structure
prompt_messages = [
frozenset({"role": "user", "content": prompt_content}.items()),
]
return tuple(prompt_messages), dataset_item
def _create_generation_item(
self, dataset_item: Dict[str, Any]
) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
"""Create a generation task item (existing logic extracted for clarity)."""
# Extract the prompt from the dataset item
user_content = dataset_item["prompt"]
if self.debug_logging:
self.logger.debug(f"Prompt length: {len(user_content)} characters")
# Randomly select structured and container formats
selected_structured_format = random.choice(self.supported_structured_formats)
selected_container_format = random.choice(self.supported_container_formats)
# Store the selections in the dataset_item
dataset_item["selected_structured_format"] = selected_structured_format
dataset_item["selected_container_format"] = selected_container_format
if self.debug_logging:
self.logger.debug(
f"Selected structured format: {selected_structured_format.value}"
)
self.logger.debug(
f"Selected container format: {selected_container_format.value}"
)
# Generate the system prompt
current_system_prompt = self._generate_system_prompt(
selected_structured_format, selected_container_format
)
# Create the message structure
prompt_messages = [
frozenset({"role": "system", "content": current_system_prompt}.items()),
frozenset({"role": "user", "content": user_content}.items()),
]
return tuple(prompt_messages), dataset_item
if __name__ == "__main__":
PydanticSchemaFollowingEnv.cli()