Cleanup. End-to-end functionality in place

This commit is contained in:
Josh 2025-05-18 17:38:29 -07:00
parent 8ff2b02ce0
commit 904360a02e
4 changed files with 351 additions and 315 deletions

View file

@ -1,32 +1,37 @@
# environments/hack0/accessibility_env/accessibility_env.py
import os # For API keys, etc.
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
import json
import os
from typing import Dict, List, Optional, Tuple
import tenacity
# from bs4 import BeautifulSoup
from bs4 import BeautifulSoup
from pydantic import Field
from transformers.models.auto.tokenization_auto import AutoTokenizer
# Corrected imports for Atropos types
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import ( # GameHistory might not be needed yet, Item is common
Item,
)
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .accessibility_rules import (
AccessibilityRule,
LabelAssociationRule,
MissingAltTextRule,
)
class AccessibilityEnvConfig(BaseEnvConfig):
# Add any custom config fields specific to your env later
pass
dataset_path: str = Field(
default="data/accessibility_dataset.jsonl", # Default relative path
description="Path to the JSONL file containing the accessibility dataset.",
)
class AccessibilityEnv(BaseEnv):
name = "accessibility_env" # A unique name for your environment
config: AccessibilityEnvConfig
name = "accessibility_env"
def __init__(
self,
@ -36,33 +41,44 @@ class AccessibilityEnv(BaseEnv):
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
# Initialize any env-specific attributes here
self.tokenizer = None
# Initialize your list of rule instances
self.accessibility_rules: List[AccessibilityRule] = [
MissingAltTextRule(),
LabelAssociationRule(),
]
# For quick lookup if needed, though iterating self.accessibility_rules is fine
self.rules_by_key = {rule.issue_key: rule for rule in self.accessibility_rules}
@classmethod
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
current_dataset_size = 10
env_config = AccessibilityEnvConfig(
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
group_size=1, # Smaller for faster testing initially
tokenizer_name="gpt2",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=3, # For process mode, number of items to generate
batch_size=1, # Max items in a single call to score (related to group_size)
steps_per_eval=5,
total_steps=current_dataset_size,
batch_size=1,
steps_per_eval=current_dataset_size,
max_token_length=2048,
wandb_name="accessibility_llama_dev", # Dev run name
wandb_name="accessibility_openai_default_dev",
)
llama_api_key = os.environ.get("LLAMA_API_KEY")
if not llama_api_key:
print("WARNING: LLAMA_API_KEY environment variable not set!")
openai_api_key_from_env = os.environ.get("OPENAI_API_KEY")
if not openai_api_key_from_env:
print(
"WARNING (from config_init): OPENAI_API_KEY environment variable not set for default config!"
)
server_configs = [
APIServerConfig(
model_name="Llama-4-Maverick-17B-128E-Instruct-FP8",
base_url="https://api.llama.com/v1", # <<<---- Llama API base URL
api_key=llama_api_key,
num_requests_for_eval=16,
),
model_name="gpt-3.5-turbo",
api_key=openai_api_key_from_env,
)
]
return env_config, server_configs
@ -71,27 +87,17 @@ class AccessibilityEnv(BaseEnv):
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_name, trust_remote_code=True
) # tokenizer_name is 'gpt2'
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Set a default chat template if it's not already set
# This is crucial for tokenizers like 'gpt2' that don't have one by default.
if self.tokenizer.chat_template is None:
# A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat.
# For gpt-3.5-turbo, the actual formatting is handled by the API,
# but for local tokenization for the trainer, we need *a* template.
# A basic template for generic tokenization:
self.tokenizer.chat_template = (
"{% for message in messages %}"
"{{ message['role'] + ': ' + message['content'] + '\\n' }}"
"{% endfor %}"
)
# Alternatively, for many models, a more structured Jinja template like
# the Llama or ChatML one might be used if you were training with such a format.
# For just getting token IDs for a generic model for RL, the simple one above might suffice.
# Or, if tokenize_for_trainer is smart, it might just concatenate.
# Let's check if a simpler approach is needed for tokenize_for_trainer.
print(
f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'."
)
@ -105,83 +111,57 @@ class AccessibilityEnv(BaseEnv):
)
raise RuntimeError(f"Failed to load tokenizer: {e}") from e
self.dataset = [
{
"id": "ex001",
"html": "<h1>Welcome</h1><img src='image.jpg'>",
"issues_to_fix": ["missing_alt_text"],
},
{
"id": "ex002",
"html": "<label>Name</label><input type='text' name='username'>",
"issues_to_fix": ["missing_for_attribute_on_label"],
},
]
# Load dataset from file
self.dataset: List[Dict] = []
env_script_dir = os.path.dirname(os.path.abspath(__file__))
full_dataset_path = os.path.join(env_script_dir, self.config.dataset_path)
print(f"[{self.name}] Attempting to load dataset from: {full_dataset_path}")
try:
with open(full_dataset_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip(): # Ensure line is not empty
self.dataset.append(json.loads(line))
if not self.dataset:
raise FileNotFoundError(
"Dataset file was empty or contained no valid JSON lines."
)
except FileNotFoundError:
print(f"[{self.name}] ERROR: Dataset file not found at {full_dataset_path}")
raise
except json.JSONDecodeError as e:
print(
f"[{self.name}] ERROR: Failed to decode JSON from {full_dataset_path}. Error: {e}"
)
raise
except Exception as e:
print(
f"[{self.name}] ERROR: An unexpected error occurred while loading dataset: {e}"
)
raise
self.iter = 0
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
print(
f"""[{self.name}] Setup complete. Loaded {len(self.dataset)}
items. Initialized {len(self.accessibility_rules)} accessibility rules."""
)
async def get_next_item(self) -> Optional[Item]:
if self.iter >= len(self.dataset):
if (
self.iter >= self.config.total_steps
): # Stop after total_steps for 'process'
if self.iter >= self.config.total_steps:
return None
# Potentially loop dataset or handle running out of unique items
# For hackathon, just stopping might be fine if dataset is small
# and total_steps is matched to dataset size.
# self.iter = 0 # To loop
print(f"[{self.name}] Reached end of dataset or total_steps.")
return None
item_data = self.dataset[self.iter]
self.iter += 1
# Format item_data into the 'Item' structure Atropos expects
# Typically (prompt_messages_tuple, gold_answer_or_metadata_tuple)
# Example:
# user_prompt = {"role": "user", "content": f"Make this HTML accessible: {item_data['html_snippet']}"}
# system_prompt_content = "You are an AI assistant specializing in web accessibility. Modify the given
# HTML to meet WCAG AA standards. Output only the modified HTML."
# system_prompt = {"role": "system", "content": system_prompt_content}
# prompt_messages = (system_prompt, user_prompt) # This needs to be a tuple of dicts
# messages_for_item = tuple(frozenset(p.items()) for p in prompt_messages) # Atropos often expects this format
# return (messages_for_item, item_data.get('expected_outcome_or_id')) # Second part is for scoring reference
# Simpler start for prompt:
# prompt = (
# (
# {
# "role": "system",
# "content": "You are an AI assistant. Given HTML, make it more accessible.",
# },
# ),
# ({"role": "user", "content": f"Original HTML: {item_data['html']}"},),
# )
# This prompt structure might need adjustment based on how Atropos and the LLM API expect it.
# The gsm8k example has:
# user_message = {"role": "user", "content": item["question"]}
# chat_completions = await self.server.chat_completion(
# messages=[{"role": "system", "content": system_prompt}, user_message], ...
# So a list of dicts is passed to chat_completion.
# The 'Item' type for get_next_item is often a tuple: ( (message_part_1, message_part_2, ...),
# metadata_for_scoring )
# where each message_part is often a frozenset of items from a dict. This is a bit complex.
# Let's start with a simple string prompt and adapt.
# For now, let's assume item is (prompt_string, metadata_for_scoring)
# The `collect_trajectories` in coding_server.py takes `item: Item`
# and then accesses `item[0][0]` which implies item is nested.
# `prompt = tuple([frozenset({"role": "user", "content": next_item["description"]}.items())])`
# `return (prompt, answer)`
# So, first element of item is a tuple of frozensets.
# Let's simplify for now and refine based on Atropos internals if needed.
# We'll construct the messages list directly in collect_trajectories.
# So get_next_item can return the raw data needed.
return item_data # This will be like {"html": "...", "id": "..."}
return item_data
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
# 'item_data' here is what get_next_item returned.
original_html = item["html"]
system_message_content = (
"You are an expert web developer specializing in accessibility. "
@ -198,262 +178,180 @@ class AccessibilityEnv(BaseEnv):
{"role": "user", "content": user_message_content},
]
try:
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size, # Number of completions
# `max_tokens` here is for the *completion* part, not the whole context.
# Your Llama API example used 256. Adjust as needed for HTML output.
max_tokens=1024, # Max tokens for the LLM's response
# temperature=0.7, # Optional: adjust for creativity vs. determinism
# model=self.server_configs[0].model_name # This should be picked up automatically from server_configs
# by the self.server object.
)
except tenacity.RetryError as retry_err: # Specifically catch RetryError
print(
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
)
print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:")
print(f"[{self.name}] RetryError Details: {retry_err}")
# ... and the response details if available on 'e' ...
original_exception = None
if retry_err.last_attempt:
if retry_err.last_attempt.failed:
original_exception = retry_err.last_attempt.exception()
print(
f"[{self.name}] Last attempt failed. Original exception that caused retries:"
)
print(f"[{self.name}] Type: {type(original_exception)}")
print(
f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}"
)
print(
f"[{self.name}] Full Str: {str(original_exception)}"
) # More direct string representation
else:
# This case is unusual for a RetryError due to failure
print(
f"""[{self.name}] Last attempt recorded but did not 'fail'.
Result: {retry_err.last_attempt.result()}"""
)
else:
print(
f"""[{self.name}] Could not get 'last_attempt' details from
RetryError object. Raw RetryError: {retry_err}"""
)
# Now, if we have the original_exception, try to get more details (like HTTP response)
if original_exception:
# Check if the original exception is an OpenAI/HTTPX style error
# by looking for a 'response' attribute.
if (
hasattr(original_exception, "response")
and original_exception.response is not None
):
response_obj = original_exception.response
status_code_text = "Status code N/A"
response_content_text = "Response content N/A"
if hasattr(response_obj, "status_code"):
status_code_text = str(response_obj.status_code)
print(
f"[{self.name}] Underlying API Response Status Code: {status_code_text}"
)
# Try to get JSON content first (common for API errors)
if hasattr(response_obj, "json") and callable(response_obj.json):
try:
response_json_parsed = (
response_obj.json()
) # Note: this might need to be awaited if response_obj.json is async
# but typically in an exception, it's already processed.
print(
f"[{self.name}] Underlying API Response JSON: {response_json_parsed}"
)
except Exception as json_e_inner:
print(
f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}"
)
# Fallback to text if JSON parsing fails
if hasattr(response_obj, "text"):
response_content_text = response_obj.text
print(
f"[{self.name}] Underlying API Response Text: {response_content_text}"
)
elif hasattr(response_obj, "content"): # often bytes
try:
response_content_text = (
response_obj.content.decode()
)
print(
f"""[{self.name}] Underlying API Response
Content (decoded): {response_content_text}"""
)
except Exception:
response_content_text = str(response_obj.content)
print(
f"""[{self.name}] Underlying API Response Content
(raw bytes as str): {response_content_text}"""
)
# If no json() method, try .text or .content directly
elif hasattr(response_obj, "text"):
response_content_text = response_obj.text
print(
f"[{self.name}] Underlying API Response Text: {response_content_text}"
)
elif hasattr(response_obj, "content"):
try:
response_content_text = response_obj.content.decode()
print(
f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}"
)
except Exception:
response_content_text = str(response_obj.content)
print(
f"""[{self.name}] Underlying API Response Content
(raw bytes as str): {response_content_text}"""
)
print(
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
)
print(
f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}"
)
return None, []
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024,
)
to_score_inputs = []
for choice in chat_completions.choices:
llm_response_content = choice.message.content
# The 'messages' to store for scoring/tokenization should represent the full exchange
# that led to this specific llm_response_content.
# This includes the original system and user messages, and the assistant's response.
full_exchange_messages = messages + [
{"role": "assistant", "content": llm_response_content}
]
to_score_inputs.append(
{
"full_exchange_messages": full_exchange_messages, # For tokenization
"llm_modified_html": llm_response_content, # For direct scoring
"original_html_info": item, # To know what to check against
}
)
# The `score` method in Atropos expects a list where each element typically is
# (messages_tuple_for_tokenization, original_item_metadata_for_scoring_logic)
# We need to adapt `to_score_inputs` to what `self.score` will expect.
# Let's define that `self.score` will take this list of dicts directly.
# The `collect_trajectories` from the blog post returns `to_postprocess, to_backlog`
# where `to_postprocess` is the output of `self.score`.
if chat_completions is not None:
for choice in chat_completions.choices:
llm_response_content = choice.message.content
full_exchange_messages = messages + [
{"role": "assistant", "content": llm_response_content}
]
to_score_inputs.append(
{
"full_exchange_messages": full_exchange_messages,
"llm_modified_html": llm_response_content,
"original_html_info": item,
}
)
scored_data_group = await self.score(to_score_inputs)
return scored_data_group, [] # Assuming no backlog for now
async def score(
self, rollout_group_data: List[dict]
) -> Optional[ScoredDataGroup]: # Return type is still ScoredDataGroup
async def score(self, rollout_group_data: List[dict]) -> Optional[ScoredDataGroup]:
print(f"[{self.name}] Scoring {len(rollout_group_data)} rollouts...")
all_tokens: List[List[int]] = []
all_masks: List[List[int]] = []
all_scores: List[float] = []
# For TypedDict, optional fields that are not provided will simply not be keys in the dictionary.
# However, if we want to include them as None, we can. Let's prepare for that.
all_advantages: Optional[List[List[float]]] = (
None # Or initialize as [] if you might populate it
)
all_ref_logprobs: Optional[List[List[float]]] = None # Or initialize as []
all_messages_for_trainer: Optional[List[List[Dict]]] = (
None # Assuming Message is also a dict-like structure or TypedDict
)
# Initialize lists to store data for all successfully processed items in the batch
final_tokens_batch: List[List[int]] = []
final_masks_batch: List[List[int]] = []
final_scores_batch: List[float] = []
final_concatenated_dialogues_batch: List[str] = []
# Optional fields for ScoredDataGroup, will remain None for this basic setup
all_advantages: Optional[List[List[float]]] = None
all_ref_logprobs: Optional[List[List[float]]] = None
for data_item in rollout_group_data:
llm_html = data_item["llm_modified_html"]
llm_html_str = data_item["llm_modified_html"]
original_info = data_item["original_html_info"]
full_exchange_messages_list_of_dicts = data_item[
"full_exchange_messages"
] # This is List[Dict[str, str]]
current_score = -1.0
if "<img" in original_info["html"] and "alt=" in llm_html:
current_score = 1.0
elif "<label>" in original_info["html"] and "for=" in llm_html:
current_score = 1.0
current_item_score = 0.0
num_issues_actually_fixed = 0
issues_expected_to_fix = original_info.get("issues_to_fix", [])
num_issues_targeted = len(issues_expected_to_fix)
soup: Optional[BeautifulSoup] = None
can_proceed_with_rule_checks = False
try:
# Ensure self.tokenizer is initialized in __init__ or setup
if not hasattr(self, "tokenizer") or self.tokenizer is None:
print(f"[{self.name}] Error: Tokenizer not initialized.")
# Attempt to initialize it here if it makes sense, or ensure it's done in setup()
# from transformers import AutoTokenizer
# self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name, trust_remote_code=True)
# This is a fallback, better to ensure it's in setup()
# For now, let's assume it's there. If not, this will fail earlier or be caught by linter.
pass # Assuming tokenizer is initialized
soup = BeautifulSoup(llm_html_str, "lxml")
can_proceed_with_rule_checks = True
except Exception as e:
print(
f"[{self.name}] Item {original_info.get('id', 'N/A')}: Could not parse LLM output as HTML: {e}"
)
if can_proceed_with_rule_checks and soup is not None:
for rule_instance in self.accessibility_rules:
if rule_instance.issue_key in issues_expected_to_fix:
try:
if rule_instance.check(soup, original_info):
num_issues_actually_fixed += 1
print(
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
Rule '{rule_instance.issue_key}' PASSED."""
)
else:
print(
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
Rule '{rule_instance.issue_key}' FAILED."""
)
except Exception as rule_e:
print(
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
Error executing rule '{rule_instance.issue_key}': {rule_e}"""
)
# Determine score based on fixes and parseability
if num_issues_targeted > 0:
if not can_proceed_with_rule_checks: # Parsing failed
current_item_score = (
-1.0 * num_issues_targeted
) # Penalize per targeted issue if unparseable
elif num_issues_actually_fixed == num_issues_targeted:
current_item_score = 1.0 # All targeted issues fixed
elif (
num_issues_actually_fixed > 0
): # Some, but not all, targeted issues fixed
current_item_score = 0.8 * (
num_issues_actually_fixed / num_issues_targeted
)
else: # Parseable, but no targeted issues fixed
current_item_score = -0.5
else: # No issues were targeted for this item (e.g., input was considered good by dataset design)
if (
not can_proceed_with_rule_checks
): # LLM made a good input unparseable
current_item_score = -1.0
else: # Parseable, and no issues were targeted (good input remained good)
current_item_score = 0.0 # Neutral score
# Tokenization
try:
if not self.tokenizer:
raise ValueError("Tokenizer not initialized.")
tokenized_output = tokenize_for_trainer(
self.tokenizer, # Make sure self.tokenizer is loaded, e.g., in setup()
data_item["full_exchange_messages"],
self.tokenizer, full_exchange_messages_list_of_dicts
)
except Exception as e:
print(f"[{self.name}] Error during tokenization: {e}. Skipping item.")
continue
print(
f"""[{self.name}] Error during tokenization for item
{original_info.get('id', 'N/A')}: {e}. Skipping this item."""
)
continue # Skip to the next data_item in rollout_group_data
if "tokens" not in tokenized_output or "masks" not in tokenized_output:
print(
f"[{self.name}] Warning: Tokenization did not return tokens/masks for an item. Skipping."
f"""[{self.name}] Tokenization did not produce 'tokens' or
'masks' for item {original_info.get('id', 'N/A')}. Skipping this item."""
)
continue
continue # Skip to the next data_item
all_tokens.append(tokenized_output["tokens"])
all_masks.append(tokenized_output["masks"])
all_scores.append(current_score)
# If we reach here, scoring and tokenization for the current item were successful
final_tokens_batch.append(tokenized_output["tokens"])
final_masks_batch.append(tokenized_output["masks"])
final_scores_batch.append(current_item_score)
# If you were to populate optional fields, you'd do it here. For example:
# if "advantages" in tokenized_output: # Fictional example
# if all_advantages is None: all_advantages = []
# all_advantages.append(tokenized_output["advantages"])
if self.config.include_messages:
formatted_message_log = "".join(
f"{msg_dict['role']}: {msg_dict['content']}\n"
for msg_dict in full_exchange_messages_list_of_dicts
)
final_concatenated_dialogues_batch.append(formatted_message_log.strip())
if not all_scores:
print(f"[{self.name}] No valid items to score, returning None.")
# After processing all items in rollout_group_data
if (
not final_scores_batch
): # If all items were skipped (e.g., due to tokenization errors)
print(
f"""[{self.name}] No valid items to include in ScoredDataGroup
after processing all rollouts, returning None."""
)
return None
# print(f"[{self.name}] Scoring complete. Scores: {all_scores}") # Already printed if successful below
# Construct the dictionary that conforms to ScoredDataGroup TypedDict
# Mandatory fields:
data_to_return: ScoredDataGroup = {
"tokens": all_tokens,
"masks": all_masks,
"scores": all_scores,
"tokens": final_tokens_batch,
"masks": final_masks_batch,
"scores": final_scores_batch,
"advantages": all_advantages,
"ref_logprobs": all_ref_logprobs,
"group_overrides": {},
"messages": all_messages_for_trainer,
"messages": (
final_concatenated_dialogues_batch
if self.config.include_messages and final_concatenated_dialogues_batch
else None
), # type: ignore[assignment]
"overrides": None,
}
print(
f"""[{self.name}] Scoring complete. Data to return (first score):
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""
f"[{self.name}] Scoring batch complete. Final scores for batch: {data_to_return['scores']}"
)
return data_to_return
async def evaluate(
self,
): # Optional, might not be needed for hackathon 'process' focus
):
print(f"[{self.name}] Evaluate method called (placeholder).")
# Implement evaluation logic if you have a separate test set and metrics
pass
# --- Helper methods for scoring ---
# def check_wcag_fixes(self, modified_html: str, original_item_info: dict) -> bool:
# # Placeholder for your actual WCAG checking logic
# # e.g., using BeautifulSoup to parse modified_html
# # and checking against `original_item_info['issues_to_fix']`
# # from bs4 import BeautifulSoup
# # soup = BeautifulSoup(modified_html, 'html.parser')
# # ... logic ...
# return False
if __name__ == "__main__":
# This makes your environment runnable with `python accessibility_env.py process`

View file

@ -0,0 +1,137 @@
from abc import ABC, abstractmethod
from bs4 import BeautifulSoup, Tag # Ensure Tag is imported
class AccessibilityRule(ABC):
"""Abstract base class for an accessibility rule checker."""
@property
@abstractmethod
def issue_key(self) -> str:
"""A unique string key identifying the type of issue this rule checks for.
This should match the keys used in the 'issues_to_fix' list in your dataset.
"""
pass
@abstractmethod
def check(self, soup: BeautifulSoup, original_html_info: dict) -> bool:
"""
Checks the provided HTML (parsed as a BeautifulSoup soup) for compliance
with this specific accessibility rule.
Args:
soup: The BeautifulSoup object representing the LLM's modified HTML.
original_html_info: A dictionary containing information about the original
HTML snippet, potentially including the original HTML string
if needed for context by some rules.
Returns:
True if the HTML passes this rule (i.e., the issue is fixed or wasn't present).
False if the HTML fails this rule (i.e., the issue persists or was introduced).
"""
pass
class MissingAltTextRule(AccessibilityRule):
@property
def issue_key(self) -> str:
return "missing_alt_text"
def check(self, soup: BeautifulSoup, original_html_info: dict) -> bool:
# Check if images were relevant in the first place for this item based on original_html_info
# This helps decide if an absence of images now is a failure or if the check is moot.
original_had_images = "<img" in original_html_info.get("html", "").lower()
img_tags = soup.find_all("img")
if original_had_images and not img_tags:
# Relevant images were present in original, but LLM removed all img tags.
return False
if not original_had_images and not img_tags:
# No images in original, no images in LLM output. Rule passes by default for this item.
return True
if (
not img_tags and original_had_images
): # Should be caught by the first check, but for clarity
return False
# If there are image tags, all must have valid alt text
for img_element in img_tags:
if isinstance(img_element, Tag):
alt_value = img_element.get("alt")
if not (isinstance(alt_value, str) and alt_value.strip() != ""):
return False # Found an image without a valid alt attribute
else:
# This case should ideally not happen if soup.find_all('img') works as expected
return False
return True # All images found have valid alt text
class LabelAssociationRule(AccessibilityRule):
@property
def issue_key(self) -> str:
return "missing_label_for" # Or "label_association" if you prefer
def check(self, soup: BeautifulSoup, original_html_info: dict) -> bool:
original_had_labels = "<label" in original_html_info.get("html", "").lower()
label_tags = soup.find_all("label")
if original_had_labels and not label_tags:
return False
if not original_had_labels and not label_tags:
return True # Rule passes by default
if not label_tags and original_had_labels:
return False
for label_element in label_tags:
if isinstance(label_element, Tag):
for_value = label_element.get("for")
has_explicit_for_association = False
if isinstance(for_value, str):
for_attr_str = for_value.strip()
if for_attr_str != "":
# Check if an element with this ID exists and is an appropriate input type
target_element = soup.find(id=for_attr_str)
if target_element and target_element.name in [
"input",
"textarea",
"select",
"button",
"meter",
"output",
"progress",
"selectlist",
]:
has_explicit_for_association = True
# Check for implicit association (label contains the input element)
# Note: This is a simpler check. True implicit association has more nuance.
contains_input_directly = False
# Only direct children that are form controls for simplicity here
for child in label_element.children:
if isinstance(child, Tag) and child.name in [
"input",
"textarea",
"select",
"button",
]: # etc.
contains_input_directly = True
break
if not (has_explicit_for_association or contains_input_directly):
return False # Found a label not correctly associated
else:
return False
return True # All labels are correctly associated
# --- Add more rule classes here as needed ---
# e.g., class AriaAttributeRule(AccessibilityRule): ...
# e.g., class ContrastRule(AccessibilityRule): ...

View file

@ -1 +1,2 @@
beautifulsoup4
beautifulsoup4>=4.0.0
lxml>=4.0.0