mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
Enhanced Pydantic Schema Following Environment with Dynamic Error Introduction and Editing Task Support (#185)
* New JSON env and documentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed seperate JSON generation environment * Updated pydantic environment with edit functionality * Error helper function * Updated README * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed pre-commit issues --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6386a5e185
commit
17faebae03
3 changed files with 1631 additions and 46 deletions
|
|
@ -71,6 +71,16 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Import editing functionality
|
||||
try:
|
||||
from .error_introduction import (
|
||||
ErrorIntroductionConfig,
|
||||
introduce_error_for_pydantic,
|
||||
)
|
||||
except ImportError:
|
||||
# Handle case when running as script or tests
|
||||
from error_introduction import ErrorIntroductionConfig, introduce_error_for_pydantic
|
||||
|
||||
|
||||
class StructuredOutputFormat(Enum):
|
||||
JSON = "json"
|
||||
|
|
@ -121,6 +131,36 @@ class PydanticEnvConfig(BaseEnvConfig):
|
|||
le=1.0,
|
||||
)
|
||||
|
||||
# Field for task type
|
||||
task_type: str = Field(
|
||||
default="generation", description="Task type: 'generation' or 'editing'"
|
||||
)
|
||||
|
||||
# Error introduction configuration
|
||||
error_introduction_seed: Optional[int] = Field(
|
||||
default=None, description="Seed for deterministic error introduction"
|
||||
)
|
||||
|
||||
error_types_enabled: List[str] = Field(
|
||||
default=[
|
||||
"type_error",
|
||||
"format_error",
|
||||
"constraint_error",
|
||||
"enum_error",
|
||||
"required_field_missing",
|
||||
],
|
||||
description="Types of errors to introduce in editing tasks",
|
||||
)
|
||||
|
||||
error_introduction_probability: float = Field(
|
||||
default=1.0,
|
||||
description="Probability of introducing an error (for editing tasks)",
|
||||
)
|
||||
|
||||
max_errors_per_item: int = Field(
|
||||
default=1, description="Maximum number of errors to introduce per editing task"
|
||||
)
|
||||
|
||||
|
||||
class PydanticSchemaFollowingEnv(BaseEnv):
|
||||
env_config_cls = PydanticEnvConfig
|
||||
|
|
@ -515,6 +555,7 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
self.logger.debug(
|
||||
f"Item problem_id: {dataset_item.get('problem_id', 'N/A')}"
|
||||
)
|
||||
self.logger.debug(f"Item task_type: {dataset_item.get('task_type', 'N/A')}")
|
||||
|
||||
# Parse verification info to get model name
|
||||
try:
|
||||
|
|
@ -527,41 +568,12 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
|
||||
self.iter += 1
|
||||
|
||||
# Extract the prompt from the dataset item
|
||||
user_content = dataset_item["prompt"]
|
||||
|
||||
if self.debug_logging:
|
||||
self.logger.debug(f"Prompt length: {len(user_content)} characters")
|
||||
|
||||
# Randomly select structured and container formats
|
||||
selected_structured_format = random.choice(self.supported_structured_formats)
|
||||
selected_container_format = random.choice(self.supported_container_formats)
|
||||
|
||||
# Store the selections in the dataset_item
|
||||
dataset_item["selected_structured_format"] = selected_structured_format
|
||||
dataset_item["selected_container_format"] = selected_container_format
|
||||
|
||||
if self.debug_logging:
|
||||
self.logger.debug(
|
||||
f"Selected structured format: {selected_structured_format.value}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Selected container format: {selected_container_format.value}"
|
||||
)
|
||||
|
||||
# Generate the system prompt
|
||||
current_system_prompt = self._generate_system_prompt(
|
||||
selected_structured_format, selected_container_format
|
||||
)
|
||||
|
||||
# Create the message structure
|
||||
prompt_messages = [
|
||||
frozenset({"role": "system", "content": current_system_prompt}.items()),
|
||||
frozenset({"role": "user", "content": user_content}.items()),
|
||||
]
|
||||
|
||||
# Return the prompt and the full dataset item for scoring
|
||||
return tuple(prompt_messages), dataset_item
|
||||
# Check if this is an editing or generation task
|
||||
task_type = dataset_item.get("task_type", "generation")
|
||||
if task_type == "editing":
|
||||
return self._create_editing_item(dataset_item)
|
||||
else:
|
||||
return self._create_generation_item(dataset_item)
|
||||
|
||||
def _extract_structured_data_response(
|
||||
self,
|
||||
|
|
@ -1746,6 +1758,391 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
self.logger.debug(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def _generate_valid_data_for_model(
|
||||
self, model_cls: Type[BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate realistic valid data for a Pydantic model using field constraints and semantic hints.
|
||||
|
||||
This method analyzes the model's fields, their types, constraints, and names to generate
|
||||
appropriate data that will pass model validation.
|
||||
"""
|
||||
valid_data = {}
|
||||
model_fields = model_cls.model_fields
|
||||
|
||||
for field_name, field_info in model_fields.items():
|
||||
# Skip optional fields sometimes for variability
|
||||
if not field_info.is_required() and random.random() < 0.3:
|
||||
continue
|
||||
|
||||
field_value = self._generate_field_value(field_name, field_info)
|
||||
if field_value is not None:
|
||||
valid_data[field_name] = field_value
|
||||
|
||||
# Validate the generated data to ensure it's correct
|
||||
try:
|
||||
model_cls.model_validate(valid_data)
|
||||
return valid_data
|
||||
except ValidationError as e:
|
||||
# If validation fails, provide minimal valid data
|
||||
if self.debug_logging:
|
||||
self.logger.warning(f"Generated data failed validation: {e}")
|
||||
return self._generate_minimal_valid_data(model_cls)
|
||||
|
||||
def _generate_field_value(self, field_name: str, field_info) -> Any:
|
||||
"""Generate a realistic value for a specific field based on its type and constraints."""
|
||||
annotation = field_info.annotation
|
||||
|
||||
# Handle Union types (like Optional[T])
|
||||
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
|
||||
# Get the first non-None type
|
||||
for arg in annotation.__args__:
|
||||
if arg is not type(None):
|
||||
annotation = arg
|
||||
break
|
||||
|
||||
# Handle List types
|
||||
if hasattr(annotation, "__origin__") and annotation.__origin__ is list:
|
||||
item_type = annotation.__args__[0] if annotation.__args__ else str
|
||||
return [
|
||||
self._generate_value_for_type(item_type, f"{field_name}_item")
|
||||
for _ in range(random.randint(1, 3))
|
||||
]
|
||||
|
||||
# Handle Dict types
|
||||
if hasattr(annotation, "__origin__") and annotation.__origin__ is dict:
|
||||
key_type = annotation.__args__[0] if len(annotation.__args__) > 0 else str
|
||||
value_type = annotation.__args__[1] if len(annotation.__args__) > 1 else str
|
||||
return {
|
||||
self._generate_value_for_type(
|
||||
key_type, "key"
|
||||
): self._generate_value_for_type(value_type, f"{field_name}_value")
|
||||
for _ in range(random.randint(1, 2))
|
||||
}
|
||||
|
||||
# Generate value based on type
|
||||
return self._generate_value_for_type(annotation, field_name)
|
||||
|
||||
def _generate_value_for_type(
|
||||
self, type_annotation: Type, field_name: str = ""
|
||||
) -> Any:
|
||||
"""Generate a value for a specific type, using field name as hint for semantic meaning."""
|
||||
field_name_lower = field_name.lower()
|
||||
|
||||
# Special Pydantic types
|
||||
if type_annotation is EmailStr or "email" in field_name_lower:
|
||||
domains = ["example.com", "test.org", "sample.net", "demo.io"]
|
||||
names = ["user", "admin", "test", "demo", "sample"]
|
||||
return f"{random.choice(names)}{random.randint(1, 99)}@{random.choice(domains)}"
|
||||
|
||||
if (
|
||||
type_annotation is HttpUrl
|
||||
or "url" in field_name_lower
|
||||
or "link" in field_name_lower
|
||||
):
|
||||
domains = ["example.com", "test.org", "sample.net"]
|
||||
paths = ["api/v1", "users/profile", "data/export", "files/download"]
|
||||
return f"https://{random.choice(domains)}/{random.choice(paths)}"
|
||||
|
||||
if (
|
||||
type_annotation is UUID
|
||||
or "uuid" in field_name_lower
|
||||
or "id" in field_name_lower
|
||||
):
|
||||
return str(uuid.uuid4())
|
||||
|
||||
# Date/Time types
|
||||
if type_annotation is datetime:
|
||||
return datetime.now() - timedelta(days=random.randint(0, 365))
|
||||
|
||||
if type_annotation is date:
|
||||
return (datetime.now() - timedelta(days=random.randint(0, 365))).date()
|
||||
|
||||
# Numeric types with semantic hints
|
||||
if type_annotation is int or type_annotation is float:
|
||||
if "age" in field_name_lower:
|
||||
return random.randint(18, 80)
|
||||
elif "year" in field_name_lower:
|
||||
return random.randint(2000, 2024)
|
||||
elif "score" in field_name_lower or "rating" in field_name_lower:
|
||||
return random.randint(1, 100)
|
||||
elif "price" in field_name_lower or "cost" in field_name_lower:
|
||||
return round(random.uniform(10.0, 1000.0), 2)
|
||||
elif "count" in field_name_lower or "quantity" in field_name_lower:
|
||||
return random.randint(1, 100)
|
||||
else:
|
||||
return (
|
||||
random.randint(1, 100)
|
||||
if type_annotation is int
|
||||
else round(random.uniform(1.0, 100.0), 2)
|
||||
)
|
||||
|
||||
if type_annotation is Decimal:
|
||||
return Decimal(str(round(random.uniform(1.0, 100.0), 2)))
|
||||
|
||||
if type_annotation is bool:
|
||||
return random.choice([True, False])
|
||||
|
||||
# String types with semantic hints
|
||||
if type_annotation is str:
|
||||
if "name" in field_name_lower:
|
||||
first_names = ["John", "Jane", "Alice", "Bob", "Charlie", "Diana"]
|
||||
last_names = [
|
||||
"Smith",
|
||||
"Johnson",
|
||||
"Williams",
|
||||
"Brown",
|
||||
"Jones",
|
||||
"Garcia",
|
||||
]
|
||||
if "first" in field_name_lower:
|
||||
return random.choice(first_names)
|
||||
elif "last" in field_name_lower:
|
||||
return random.choice(last_names)
|
||||
else:
|
||||
return f"{random.choice(first_names)} {random.choice(last_names)}"
|
||||
elif "title" in field_name_lower:
|
||||
titles = [
|
||||
"Senior Developer",
|
||||
"Product Manager",
|
||||
"Data Scientist",
|
||||
"UX Designer",
|
||||
]
|
||||
return random.choice(titles)
|
||||
elif "description" in field_name_lower:
|
||||
descriptions = [
|
||||
"A comprehensive solution",
|
||||
"High-quality product",
|
||||
"Innovative approach",
|
||||
"User-friendly design",
|
||||
]
|
||||
return random.choice(descriptions)
|
||||
elif "address" in field_name_lower:
|
||||
return f"{random.randint(100, 9999)} Main St, City, State {random.randint(10000, 99999)}"
|
||||
elif "phone" in field_name_lower:
|
||||
return f"+1-{random.randint(100, 999)}-{random.randint(100, 999)}-{random.randint(1000, 9999)}"
|
||||
else:
|
||||
return f"sample_{field_name_lower}_{random.randint(1, 999)}"
|
||||
|
||||
# Enum types
|
||||
if hasattr(type_annotation, "__bases__") and Enum in type_annotation.__bases__:
|
||||
return random.choice(list(type_annotation)).value
|
||||
|
||||
# Nested BaseModel
|
||||
if (
|
||||
hasattr(type_annotation, "__bases__")
|
||||
and BaseModel in type_annotation.__bases__
|
||||
):
|
||||
return self._generate_valid_data_for_model(type_annotation)
|
||||
|
||||
# Fallback for unknown types
|
||||
return f"generated_value_{random.randint(1, 999)}"
|
||||
|
||||
def _generate_minimal_valid_data(
|
||||
self, model_cls: Type[BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate minimal valid data when constraint-based generation fails."""
|
||||
valid_data = {}
|
||||
model_fields = model_cls.model_fields
|
||||
|
||||
for field_name, field_info in model_fields.items():
|
||||
if field_info.is_required():
|
||||
annotation = field_info.annotation
|
||||
|
||||
# Handle Union types (like Optional[T])
|
||||
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
|
||||
for arg in annotation.__args__:
|
||||
if arg is not type(None):
|
||||
annotation = arg
|
||||
break
|
||||
|
||||
# Generate simple fallback values
|
||||
if annotation is int:
|
||||
valid_data[field_name] = 1
|
||||
elif annotation is float:
|
||||
valid_data[field_name] = 1.0
|
||||
elif annotation is bool:
|
||||
valid_data[field_name] = True
|
||||
elif annotation is str or annotation is EmailStr:
|
||||
valid_data[field_name] = "valid_example"
|
||||
elif annotation is HttpUrl:
|
||||
valid_data[field_name] = "https://example.com"
|
||||
elif annotation is UUID:
|
||||
valid_data[field_name] = str(uuid.uuid4())
|
||||
elif annotation is datetime:
|
||||
valid_data[field_name] = datetime.now()
|
||||
elif annotation is date:
|
||||
valid_data[field_name] = date.today()
|
||||
else:
|
||||
valid_data[field_name] = "fallback_value"
|
||||
|
||||
return valid_data
|
||||
|
||||
def _generate_editing_task_prompt(
|
||||
self,
|
||||
dataset_item: Dict[str, Any],
|
||||
structured_format: StructuredOutputFormat,
|
||||
container_format: OutputContainerFormat,
|
||||
) -> str:
|
||||
"""Generate prompt for editing tasks."""
|
||||
# Get Pydantic schema and erroneous data
|
||||
verification_info = dataset_item.get("verification_info", "{}")
|
||||
erroneous_data = dataset_item.get("erroneous_data")
|
||||
|
||||
if not verification_info or erroneous_data is None:
|
||||
raise ValueError(
|
||||
"Editing task requires both 'verification_info' and 'erroneous_data' fields"
|
||||
)
|
||||
|
||||
try:
|
||||
verification_data = json.loads(verification_info)
|
||||
pydantic_config = verification_data["pydantic_config"]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
raise ValueError("Invalid verification_info format for editing task")
|
||||
|
||||
erroneous_data_str = json.dumps(erroneous_data, indent=2)
|
||||
|
||||
base_prompt = self._generate_system_prompt(structured_format, container_format)
|
||||
|
||||
editing_prompt = f"""
|
||||
{base_prompt}
|
||||
|
||||
The following data contains errors and should be corrected to conform to the provided Pydantic model.
|
||||
|
||||
Pydantic Model:
|
||||
```python
|
||||
{pydantic_config}
|
||||
```
|
||||
|
||||
Erroneous Data:
|
||||
```json
|
||||
{erroneous_data_str}
|
||||
```
|
||||
|
||||
Please identify and fix the errors, then provide the corrected data in the requested format.
|
||||
"""
|
||||
|
||||
return editing_prompt
|
||||
|
||||
def _create_editing_item(
|
||||
self, dataset_item: Dict[str, Any]
|
||||
) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
|
||||
"""Create an editing task item with appropriate prompts."""
|
||||
# Randomly select structured and container formats
|
||||
selected_structured_format = random.choice(self.supported_structured_formats)
|
||||
selected_container_format = random.choice(self.supported_container_formats)
|
||||
|
||||
# Store the selections in the dataset_item
|
||||
dataset_item["selected_structured_format"] = selected_structured_format
|
||||
dataset_item["selected_container_format"] = selected_container_format
|
||||
|
||||
if self.debug_logging:
|
||||
self.logger.debug(
|
||||
f"Creating editing task for problem_id: {dataset_item.get('problem_id', 'N/A')}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Selected structured format: {selected_structured_format.value}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Selected container format: {selected_container_format.value}"
|
||||
)
|
||||
|
||||
# If erroneous_data is not already provided, generate it from a valid example
|
||||
if "erroneous_data" not in dataset_item:
|
||||
# Create valid data first, then introduce errors
|
||||
verification_info = dataset_item.get("verification_info", "{}")
|
||||
if verification_info:
|
||||
try:
|
||||
verification_data = json.loads(verification_info)
|
||||
pydantic_config = verification_data["pydantic_config"]
|
||||
model_name = verification_data["model_name"]
|
||||
|
||||
# Create the Pydantic model
|
||||
target_model_cls = self._create_pydantic_model_from_code(
|
||||
pydantic_config, model_name
|
||||
)
|
||||
|
||||
# Generate valid data using Pydantic constraints and field information
|
||||
valid_data = self._generate_valid_data_for_model(target_model_cls)
|
||||
|
||||
# Introduce errors into the valid data using configuration
|
||||
error_config = ErrorIntroductionConfig.from_env_config(
|
||||
error_types_enabled=self.config.error_types_enabled,
|
||||
max_errors_per_item=self.config.max_errors_per_item,
|
||||
error_introduction_probability=self.config.error_introduction_probability,
|
||||
error_introduction_seed=self.config.error_introduction_seed,
|
||||
)
|
||||
|
||||
erroneous_data = introduce_error_for_pydantic(
|
||||
valid_data, target_model_cls, config=error_config
|
||||
)
|
||||
|
||||
if erroneous_data:
|
||||
dataset_item["erroneous_data"] = erroneous_data
|
||||
else:
|
||||
# Fallback: use valid data with a simple error
|
||||
dataset_item["erroneous_data"] = {
|
||||
**valid_data,
|
||||
"invalid_field": "should_not_exist",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_logging:
|
||||
self.logger.warning(f"Could not generate erroneous data: {e}")
|
||||
# Use a simple fallback
|
||||
dataset_item["erroneous_data"] = {"error": "could_not_generate"}
|
||||
|
||||
# Generate the editing task prompt
|
||||
prompt_content = self._generate_editing_task_prompt(
|
||||
dataset_item, selected_structured_format, selected_container_format
|
||||
)
|
||||
|
||||
# Create the message structure
|
||||
prompt_messages = [
|
||||
frozenset({"role": "user", "content": prompt_content}.items()),
|
||||
]
|
||||
|
||||
return tuple(prompt_messages), dataset_item
|
||||
|
||||
def _create_generation_item(
|
||||
self, dataset_item: Dict[str, Any]
|
||||
) -> Tuple[Tuple[frozenset, ...], Dict[str, Any]]:
|
||||
"""Create a generation task item (existing logic extracted for clarity)."""
|
||||
# Extract the prompt from the dataset item
|
||||
user_content = dataset_item["prompt"]
|
||||
|
||||
if self.debug_logging:
|
||||
self.logger.debug(f"Prompt length: {len(user_content)} characters")
|
||||
|
||||
# Randomly select structured and container formats
|
||||
selected_structured_format = random.choice(self.supported_structured_formats)
|
||||
selected_container_format = random.choice(self.supported_container_formats)
|
||||
|
||||
# Store the selections in the dataset_item
|
||||
dataset_item["selected_structured_format"] = selected_structured_format
|
||||
dataset_item["selected_container_format"] = selected_container_format
|
||||
|
||||
if self.debug_logging:
|
||||
self.logger.debug(
|
||||
f"Selected structured format: {selected_structured_format.value}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Selected container format: {selected_container_format.value}"
|
||||
)
|
||||
|
||||
# Generate the system prompt
|
||||
current_system_prompt = self._generate_system_prompt(
|
||||
selected_structured_format, selected_container_format
|
||||
)
|
||||
|
||||
# Create the message structure
|
||||
prompt_messages = [
|
||||
frozenset({"role": "system", "content": current_system_prompt}.items()),
|
||||
frozenset({"role": "user", "content": user_content}.items()),
|
||||
]
|
||||
|
||||
return tuple(prompt_messages), dataset_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PydanticSchemaFollowingEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue