Integrate joshgarza's accessibility environment - Merged accessibility environment from joshgarza:main - Moved from environments/hack0/ to environments/community/ - Updated community README with detailed description of accessibility auto-fixer - Added note about missing dataset file - Credited author @joshgarza with GitHub link

This commit is contained in:
Shannon Sands 2025-05-24 13:31:50 +10:00
parent 30ddc8a36d
commit 32cf5e3d42
7 changed files with 28 additions and 0 deletions

View file

@ -0,0 +1,85 @@
# Accessibility Auto-Fixer Environment for Atropos
**Team/Author:** Accessibility Bot / Josh Garza
**Track:** Objective (WCAG rules are specific and rule-based)
**wandb run:** https://wandb.ai/joshgarza-n-a/atropos/runs/tqpiiofa?nw=nwuserjoshgarza
## Environment Design and Motivation
### Problem Addressed
Web accessibility is crucial for ensuring that websites and web applications are usable by everyone, including people with disabilities. Manually auditing and fixing HTML to meet Web Content Accessibility Guidelines (WCAG) is time-consuming and requires specialized knowledge. This Atropos environment fine-tunes an LLM to automatically identify and apply minimal, correct fixes to HTML snippets to improve their WCAG compliance.
### Why This Is Important
Automating accessibility improvements reduces effort and cost, leading to more inclusive web experiences. A fine-tuned model can serve as a developer assistant, batch-processor for large codebases, or educational tool.
### How the Environment Works
1. **Input:**
- HTML snippets from `data/accessibility_dataset.jsonl`
- Each snippet is tagged with WCAG issues to fix (e.g. `missing_alt_text`, `missing_label_for`)
2. **LLM Interaction:**
- Prompt the model (e.g. GPT-3.5-turbo) to output only the corrected HTML
3. **Scoring (Rule-Based):**
- Define `AccessibilityRule` classes (e.g. `MissingAltTextRule`, `LabelAssociationRule`, `LinkHasTextRule`)
- Parse the LLMs output with BeautifulSoup
- Check each issue in `issues_to_fix` against the corresponding rule
- Assign a score:
- **+1.0** All targeted issues fixed correctly
- **0.00.8** Some but not all issues fixed
- **0.5** Parseable HTML, but none of the targeted issues fixed
- **1.0** Unparseable HTML or regressions on targeted issues
4. **Output:**
- Rollouts compatible with Atropos (tokenized prompts/responses, masks, scores) for RL training
### MVP: Targeted WCAG Criteria
1. **Images (`<img>`):** missing or empty `alt` attributes (WCAG 1.1.1)
2. **Form labels:** improper `<label for="…">` associations (WCAG 1.3.1, 3.3.2, 4.1.2)
3. **Links (`<a>`):** lacking discernible text or accessible name (`aria-label`/`aria-labelledby`) (WCAG 2.4.4, 4.1.2)
### Potential Impact
A well-trained model could catch and fix common accessibility errors early, streamlining development and improving inclusivity.
---
## Quickstart Documentation
### 1. Prerequisites
- **Python 3.10+**
- **OpenAI API Key** (export as `OPENAI_API_KEY`):
```bash
export OPENAI_API_KEY="sk-YourActualOpenAIKeyHere"
```
### 2. Setup
1. **Clone & enter environment directory**
```bash
cd environments/hack0/your_env_folder_name/
```
2. **Install dependencies**
```bash
pip install -r requirements.txt
pip install lxml
```
3. **Ensure Atropos core is installed**
```bash
# From the Atropos root:
pip install -e .[dev]
```
### 3. Running the Environment (process mode)
```bash
python -m environments.hack0.your_env_folder_name.accessibility_env process \
--env.data_path_to_save_groups environments/hack0/your_env_folder_name/rollouts.jsonl \
--env.dataset_path data/accessibility_dataset.jsonl \
--env.total_steps 6 \
--env.group_size 1 \
--openai.model_name "gpt-3.5-turbo" \
--openai.api_key "$OPENAI_API_KEY"
```

View file

@ -0,0 +1,358 @@
import json
import os
from typing import Dict, List, Optional, Tuple
from bs4 import BeautifulSoup
from pydantic import Field
from transformers.models.auto.tokenization_auto import AutoTokenizer
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .accessibility_rules import (
AccessibilityRule,
LabelAssociationRule,
MissingAltTextRule,
)
class AccessibilityEnvConfig(BaseEnvConfig):
dataset_path: str = Field(
default="data/accessibility_dataset.jsonl", # Default relative path
description="Path to the JSONL file containing the accessibility dataset.",
)
class AccessibilityEnv(BaseEnv):
config: AccessibilityEnvConfig
name = "accessibility_env"
def __init__(
self,
config: AccessibilityEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.tokenizer = None
# Initialize your list of rule instances
self.accessibility_rules: List[AccessibilityRule] = [
MissingAltTextRule(),
LabelAssociationRule(),
]
# For quick lookup if needed, though iterating self.accessibility_rules is fine
self.rules_by_key = {rule.issue_key: rule for rule in self.accessibility_rules}
@classmethod
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
current_dataset_size = 10
env_config = AccessibilityEnvConfig(
tokenizer_name="gpt2",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=current_dataset_size,
batch_size=1,
steps_per_eval=current_dataset_size,
max_token_length=2048,
wandb_name="accessibility_openai_default_dev",
)
openai_api_key_from_env = os.environ.get("OPENAI_API_KEY")
if not openai_api_key_from_env:
print(
"WARNING (from config_init): OPENAI_API_KEY environment variable not set for default config!"
)
server_configs = [
APIServerConfig(
model_name="gpt-3.5-turbo",
api_key=openai_api_key_from_env,
)
]
return env_config, server_configs
async def setup(self):
print(f"[{self.name}] Setting up environment...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_name, trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = (
"{% for message in messages %}"
"{{ message['role'] + ': ' + message['content'] + '\\n' }}"
"{% endfor %}"
)
print(
f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'."
)
print(
f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully."
)
except Exception as e:
print(
f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}"
)
raise RuntimeError(f"Failed to load tokenizer: {e}") from e
# Load dataset from file
self.dataset: List[Dict] = []
env_script_dir = os.path.dirname(os.path.abspath(__file__))
full_dataset_path = os.path.join(env_script_dir, self.config.dataset_path)
print(f"[{self.name}] Attempting to load dataset from: {full_dataset_path}")
try:
with open(full_dataset_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip(): # Ensure line is not empty
self.dataset.append(json.loads(line))
if not self.dataset:
raise FileNotFoundError(
"Dataset file was empty or contained no valid JSON lines."
)
except FileNotFoundError:
print(f"[{self.name}] ERROR: Dataset file not found at {full_dataset_path}")
raise
except json.JSONDecodeError as e:
print(
f"[{self.name}] ERROR: Failed to decode JSON from {full_dataset_path}. Error: {e}"
)
raise
except Exception as e:
print(
f"[{self.name}] ERROR: An unexpected error occurred while loading dataset: {e}"
)
raise
self.iter = 0
print(
f"""[{self.name}] Setup complete. Loaded {len(self.dataset)}
items. Initialized {len(self.accessibility_rules)} accessibility rules."""
)
async def get_next_item(self) -> Optional[Item]:
if self.iter >= len(self.dataset):
if self.iter >= self.config.total_steps:
return None
print(f"[{self.name}] Reached end of dataset or total_steps.")
return None
item_data = self.dataset[self.iter]
self.iter += 1
return item_data
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
original_html = item["html"]
system_message_content = (
"You are an expert web developer specializing in accessibility. "
"Given the following HTML snippet, please make the minimal necessary modifications "
"to ensure it meets WCAG 2.1 AA standards for the issues present. "
"Output only the complete, modified HTML snippet. Do not include explanations unless explicitly asked."
)
user_message_content = (
f"Original HTML:\n```html\n{original_html}\n```\nModified HTML:"
)
messages = [
{"role": "system", "content": system_message_content},
{"role": "user", "content": user_message_content},
]
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024,
)
to_score_inputs = []
if chat_completions is not None:
for choice in chat_completions.choices:
llm_response_content = choice.message.content
full_exchange_messages = messages + [
{"role": "assistant", "content": llm_response_content}
]
to_score_inputs.append(
{
"full_exchange_messages": full_exchange_messages,
"llm_modified_html": llm_response_content,
"original_html_info": item,
}
)
scored_data_group = await self.score(to_score_inputs)
return scored_data_group, [] # Assuming no backlog for now
async def score(self, rollout_group_data: List[dict]) -> Optional[ScoredDataGroup]:
print(f"[{self.name}] Scoring {len(rollout_group_data)} rollouts...")
# Initialize lists to store data for all successfully processed items in the batch
final_tokens_batch: List[List[int]] = []
final_masks_batch: List[List[int]] = []
final_scores_batch: List[float] = []
final_concatenated_dialogues_batch: List[str] = []
# Optional fields for ScoredDataGroup, will remain None for this basic setup
all_advantages: Optional[List[List[float]]] = None
all_ref_logprobs: Optional[List[List[float]]] = None
for data_item in rollout_group_data:
llm_html_str = data_item["llm_modified_html"]
original_info = data_item["original_html_info"]
full_exchange_messages_list_of_dicts = data_item[
"full_exchange_messages"
] # This is List[Dict[str, str]]
current_item_score = 0.0
num_issues_actually_fixed = 0
issues_expected_to_fix = original_info.get("issues_to_fix", [])
num_issues_targeted = len(issues_expected_to_fix)
soup: Optional[BeautifulSoup] = None
can_proceed_with_rule_checks = False
try:
soup = BeautifulSoup(llm_html_str, "lxml")
can_proceed_with_rule_checks = True
except Exception as e:
print(
f"[{self.name}] Item {original_info.get('id', 'N/A')}: Could not parse LLM output as HTML: {e}"
)
if can_proceed_with_rule_checks and soup is not None:
for rule_instance in self.accessibility_rules:
if rule_instance.issue_key in issues_expected_to_fix:
try:
if rule_instance.check(soup, original_info):
num_issues_actually_fixed += 1
print(
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
Rule '{rule_instance.issue_key}' PASSED."""
)
else:
print(
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
Rule '{rule_instance.issue_key}' FAILED."""
)
except Exception as rule_e:
print(
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
Error executing rule '{rule_instance.issue_key}': {rule_e}"""
)
# Determine score based on fixes and parseability
if num_issues_targeted > 0:
if not can_proceed_with_rule_checks: # Parsing failed
current_item_score = (
-1.0 * num_issues_targeted
) # Penalize per targeted issue if unparseable
elif num_issues_actually_fixed == num_issues_targeted:
current_item_score = 1.0 # All targeted issues fixed
elif (
num_issues_actually_fixed > 0
): # Some, but not all, targeted issues fixed
current_item_score = 0.8 * (
num_issues_actually_fixed / num_issues_targeted
)
else: # Parseable, but no targeted issues fixed
current_item_score = -0.5
else: # No issues were targeted for this item (e.g., input was considered good by dataset design)
if (
not can_proceed_with_rule_checks
): # LLM made a good input unparseable
current_item_score = -1.0
else: # Parseable, and no issues were targeted (good input remained good)
current_item_score = 0.0 # Neutral score
# Tokenization
try:
if not self.tokenizer:
raise ValueError("Tokenizer not initialized.")
tokenized_output = tokenize_for_trainer(
self.tokenizer, full_exchange_messages_list_of_dicts
)
except Exception as e:
print(
f"""[{self.name}] Error during tokenization for item
{original_info.get('id', 'N/A')}: {e}. Skipping this item."""
)
continue # Skip to the next data_item in rollout_group_data
if "tokens" not in tokenized_output or "masks" not in tokenized_output:
print(
f"""[{self.name}] Tokenization did not produce 'tokens' or
'masks' for item {original_info.get('id', 'N/A')}. Skipping this item."""
)
continue # Skip to the next data_item
# If we reach here, scoring and tokenization for the current item were successful
final_tokens_batch.append(tokenized_output["tokens"])
final_masks_batch.append(tokenized_output["masks"])
final_scores_batch.append(current_item_score)
if self.config.include_messages:
formatted_message_log = "".join(
f"{msg_dict['role']}: {msg_dict['content']}\n"
for msg_dict in full_exchange_messages_list_of_dicts
)
final_concatenated_dialogues_batch.append(formatted_message_log.strip())
# After processing all items in rollout_group_data
if (
not final_scores_batch
): # If all items were skipped (e.g., due to tokenization errors)
print(
f"""[{self.name}] No valid items to include in ScoredDataGroup
after processing all rollouts, returning None."""
)
return None
data_to_return: ScoredDataGroup = {
"tokens": final_tokens_batch,
"masks": final_masks_batch,
"scores": final_scores_batch,
"advantages": all_advantages,
"ref_logprobs": all_ref_logprobs,
"group_overrides": {},
"messages": (
final_concatenated_dialogues_batch
if self.config.include_messages and final_concatenated_dialogues_batch
else None
), # type: ignore[assignment]
"overrides": None,
}
print(
f"[{self.name}] Scoring batch complete. Final scores for batch: {data_to_return['scores']}"
)
return data_to_return
async def evaluate(
self,
):
print(f"[{self.name}] Evaluate method called (placeholder).")
# Implement evaluation logic if you have a separate test set and metrics
pass
if __name__ == "__main__":
# This makes your environment runnable with `python accessibility_env.py process`
AccessibilityEnv.cli()

View file

@ -0,0 +1,137 @@
from abc import ABC, abstractmethod
from bs4 import BeautifulSoup, Tag # Ensure Tag is imported
class AccessibilityRule(ABC):
"""Abstract base class for an accessibility rule checker."""
@property
@abstractmethod
def issue_key(self) -> str:
"""A unique string key identifying the type of issue this rule checks for.
This should match the keys used in the 'issues_to_fix' list in your dataset.
"""
pass
@abstractmethod
def check(self, soup: BeautifulSoup, original_html_info: dict) -> bool:
"""
Checks the provided HTML (parsed as a BeautifulSoup soup) for compliance
with this specific accessibility rule.
Args:
soup: The BeautifulSoup object representing the LLM's modified HTML.
original_html_info: A dictionary containing information about the original
HTML snippet, potentially including the original HTML string
if needed for context by some rules.
Returns:
True if the HTML passes this rule (i.e., the issue is fixed or wasn't present).
False if the HTML fails this rule (i.e., the issue persists or was introduced).
"""
pass
class MissingAltTextRule(AccessibilityRule):
@property
def issue_key(self) -> str:
return "missing_alt_text"
def check(self, soup: BeautifulSoup, original_html_info: dict) -> bool:
# Check if images were relevant in the first place for this item based on original_html_info
# This helps decide if an absence of images now is a failure or if the check is moot.
original_had_images = "<img" in original_html_info.get("html", "").lower()
img_tags = soup.find_all("img")
if original_had_images and not img_tags:
# Relevant images were present in original, but LLM removed all img tags.
return False
if not original_had_images and not img_tags:
# No images in original, no images in LLM output. Rule passes by default for this item.
return True
if (
not img_tags and original_had_images
): # Should be caught by the first check, but for clarity
return False
# If there are image tags, all must have valid alt text
for img_element in img_tags:
if isinstance(img_element, Tag):
alt_value = img_element.get("alt")
if not (isinstance(alt_value, str) and alt_value.strip() != ""):
return False # Found an image without a valid alt attribute
else:
# This case should ideally not happen if soup.find_all('img') works as expected
return False
return True # All images found have valid alt text
class LabelAssociationRule(AccessibilityRule):
@property
def issue_key(self) -> str:
return "missing_label_for" # Or "label_association" if you prefer
def check(self, soup: BeautifulSoup, original_html_info: dict) -> bool:
original_had_labels = "<label" in original_html_info.get("html", "").lower()
label_tags = soup.find_all("label")
if original_had_labels and not label_tags:
return False
if not original_had_labels and not label_tags:
return True # Rule passes by default
if not label_tags and original_had_labels:
return False
for label_element in label_tags:
if isinstance(label_element, Tag):
for_value = label_element.get("for")
has_explicit_for_association = False
if isinstance(for_value, str):
for_attr_str = for_value.strip()
if for_attr_str != "":
# Check if an element with this ID exists and is an appropriate input type
target_element = soup.find(id=for_attr_str)
if target_element and target_element.name in [
"input",
"textarea",
"select",
"button",
"meter",
"output",
"progress",
"selectlist",
]:
has_explicit_for_association = True
# Check for implicit association (label contains the input element)
# Note: This is a simpler check. True implicit association has more nuance.
contains_input_directly = False
# Only direct children that are form controls for simplicity here
for child in label_element.children:
if isinstance(child, Tag) and child.name in [
"input",
"textarea",
"select",
"button",
]: # etc.
contains_input_directly = True
break
if not (has_explicit_for_association or contains_input_directly):
return False # Found a label not correctly associated
else:
return False
return True # All labels are correctly associated
# --- Add more rule classes here as needed ---
# e.g., class AriaAttributeRule(AccessibilityRule): ...
# e.g., class ContrastRule(AccessibilityRule): ...

View file

@ -0,0 +1,2 @@
beautifulsoup4>=4.0.0
lxml>=4.0.0