diff --git a/environments/hack0/accessibility_env/accessibility_env.py b/environments/hack0/accessibility_env/accessibility_env.py index d46aebbb..79dbd756 100644 --- a/environments/hack0/accessibility_env/accessibility_env.py +++ b/environments/hack0/accessibility_env/accessibility_env.py @@ -1,32 +1,37 @@ -# environments/hack0/accessibility_env/accessibility_env.py -import os # For API keys, etc. -from typing import Dict, List, Optional, Tuple # Common type hints, added Dict +import json +import os +from typing import Dict, List, Optional, Tuple -import tenacity - -# from bs4 import BeautifulSoup +from bs4 import BeautifulSoup +from pydantic import Field from transformers.models.auto.tokenization_auto import AutoTokenizer -# Corrected imports for Atropos types from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup, ) -from atroposlib.type_definitions import ( # GameHistory might not be needed yet, Item is common - Item, -) +from atroposlib.type_definitions import Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +from .accessibility_rules import ( + AccessibilityRule, + LabelAssociationRule, + MissingAltTextRule, +) + class AccessibilityEnvConfig(BaseEnvConfig): - # Add any custom config fields specific to your env later - pass + dataset_path: str = Field( + default="data/accessibility_dataset.jsonl", # Default relative path + description="Path to the JSONL file containing the accessibility dataset.", + ) class AccessibilityEnv(BaseEnv): - name = "accessibility_env" # A unique name for your environment + config: AccessibilityEnvConfig + name = "accessibility_env" def __init__( self, @@ -36,33 +41,44 @@ class AccessibilityEnv(BaseEnv): testing=False, ): super().__init__(config, server_configs, slurm, testing) - # Initialize any env-specific attributes here + self.tokenizer = None + + # Initialize your list of rule instances + self.accessibility_rules: List[AccessibilityRule] = [ + MissingAltTextRule(), + LabelAssociationRule(), + ] + + # For quick lookup if needed, though iterating self.accessibility_rules is fine + self.rules_by_key = {rule.issue_key: rule for rule in self.accessibility_rules} @classmethod def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]: + current_dataset_size = 10 + env_config = AccessibilityEnvConfig( - tokenizer_name="meta-llama/Llama-2-7b-chat-hf", - group_size=1, # Smaller for faster testing initially + tokenizer_name="gpt2", + group_size=8, use_wandb=True, rollout_server_url="http://localhost:8000", - total_steps=3, # For process mode, number of items to generate - batch_size=1, # Max items in a single call to score (related to group_size) - steps_per_eval=5, + total_steps=current_dataset_size, + batch_size=1, + steps_per_eval=current_dataset_size, max_token_length=2048, - wandb_name="accessibility_llama_dev", # Dev run name + wandb_name="accessibility_openai_default_dev", ) - llama_api_key = os.environ.get("LLAMA_API_KEY") - if not llama_api_key: - print("WARNING: LLAMA_API_KEY environment variable not set!") + openai_api_key_from_env = os.environ.get("OPENAI_API_KEY") + if not openai_api_key_from_env: + print( + "WARNING (from config_init): OPENAI_API_KEY environment variable not set for default config!" + ) server_configs = [ APIServerConfig( - model_name="Llama-4-Maverick-17B-128E-Instruct-FP8", - base_url="https://api.llama.com/v1", # <<<---- Llama API base URL - api_key=llama_api_key, - num_requests_for_eval=16, - ), + model_name="gpt-3.5-turbo", + api_key=openai_api_key_from_env, + ) ] return env_config, server_configs @@ -71,27 +87,17 @@ class AccessibilityEnv(BaseEnv): try: self.tokenizer = AutoTokenizer.from_pretrained( self.config.tokenizer_name, trust_remote_code=True - ) # tokenizer_name is 'gpt2' + ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token - # Set a default chat template if it's not already set - # This is crucial for tokenizers like 'gpt2' that don't have one by default. if self.tokenizer.chat_template is None: - # A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat. - # For gpt-3.5-turbo, the actual formatting is handled by the API, - # but for local tokenization for the trainer, we need *a* template. - # A basic template for generic tokenization: self.tokenizer.chat_template = ( "{% for message in messages %}" "{{ message['role'] + ': ' + message['content'] + '\\n' }}" "{% endfor %}" ) - # Alternatively, for many models, a more structured Jinja template like - # the Llama or ChatML one might be used if you were training with such a format. - # For just getting token IDs for a generic model for RL, the simple one above might suffice. - # Or, if tokenize_for_trainer is smart, it might just concatenate. - # Let's check if a simpler approach is needed for tokenize_for_trainer. + print( f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'." ) @@ -105,83 +111,57 @@ class AccessibilityEnv(BaseEnv): ) raise RuntimeError(f"Failed to load tokenizer: {e}") from e - self.dataset = [ - { - "id": "ex001", - "html": "
",
- "issues_to_fix": ["missing_alt_text"],
- },
- {
- "id": "ex002",
- "html": "",
- "issues_to_fix": ["missing_for_attribute_on_label"],
- },
- ]
+ # Load dataset from file
+ self.dataset: List[Dict] = []
+ env_script_dir = os.path.dirname(os.path.abspath(__file__))
+ full_dataset_path = os.path.join(env_script_dir, self.config.dataset_path)
+
+ print(f"[{self.name}] Attempting to load dataset from: {full_dataset_path}")
+ try:
+ with open(full_dataset_path, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip(): # Ensure line is not empty
+ self.dataset.append(json.loads(line))
+ if not self.dataset:
+ raise FileNotFoundError(
+ "Dataset file was empty or contained no valid JSON lines."
+ )
+ except FileNotFoundError:
+ print(f"[{self.name}] ERROR: Dataset file not found at {full_dataset_path}")
+ raise
+ except json.JSONDecodeError as e:
+ print(
+ f"[{self.name}] ERROR: Failed to decode JSON from {full_dataset_path}. Error: {e}"
+ )
+ raise
+ except Exception as e:
+ print(
+ f"[{self.name}] ERROR: An unexpected error occurred while loading dataset: {e}"
+ )
+ raise
+
self.iter = 0
- print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
+ print(
+ f"""[{self.name}] Setup complete. Loaded {len(self.dataset)}
+ items. Initialized {len(self.accessibility_rules)} accessibility rules."""
+ )
async def get_next_item(self) -> Optional[Item]:
if self.iter >= len(self.dataset):
- if (
- self.iter >= self.config.total_steps
- ): # Stop after total_steps for 'process'
+ if self.iter >= self.config.total_steps:
return None
- # Potentially loop dataset or handle running out of unique items
- # For hackathon, just stopping might be fine if dataset is small
- # and total_steps is matched to dataset size.
- # self.iter = 0 # To loop
+
print(f"[{self.name}] Reached end of dataset or total_steps.")
return None
item_data = self.dataset[self.iter]
self.iter += 1
- # Format item_data into the 'Item' structure Atropos expects
- # Typically (prompt_messages_tuple, gold_answer_or_metadata_tuple)
- # Example:
- # user_prompt = {"role": "user", "content": f"Make this HTML accessible: {item_data['html_snippet']}"}
- # system_prompt_content = "You are an AI assistant specializing in web accessibility. Modify the given
- # HTML to meet WCAG AA standards. Output only the modified HTML."
- # system_prompt = {"role": "system", "content": system_prompt_content}
- # prompt_messages = (system_prompt, user_prompt) # This needs to be a tuple of dicts
- # messages_for_item = tuple(frozenset(p.items()) for p in prompt_messages) # Atropos often expects this format
- # return (messages_for_item, item_data.get('expected_outcome_or_id')) # Second part is for scoring reference
- # Simpler start for prompt:
- # prompt = (
- # (
- # {
- # "role": "system",
- # "content": "You are an AI assistant. Given HTML, make it more accessible.",
- # },
- # ),
- # ({"role": "user", "content": f"Original HTML: {item_data['html']}"},),
- # )
- # This prompt structure might need adjustment based on how Atropos and the LLM API expect it.
- # The gsm8k example has:
- # user_message = {"role": "user", "content": item["question"]}
- # chat_completions = await self.server.chat_completion(
- # messages=[{"role": "system", "content": system_prompt}, user_message], ...
- # So a list of dicts is passed to chat_completion.
- # The 'Item' type for get_next_item is often a tuple: ( (message_part_1, message_part_2, ...),
- # metadata_for_scoring )
- # where each message_part is often a frozenset of items from a dict. This is a bit complex.
- # Let's start with a simple string prompt and adapt.
- # For now, let's assume item is (prompt_string, metadata_for_scoring)
- # The `collect_trajectories` in coding_server.py takes `item: Item`
- # and then accesses `item[0][0]` which implies item is nested.
- # `prompt = tuple([frozenset({"role": "user", "content": next_item["description"]}.items())])`
- # `return (prompt, answer)`
- # So, first element of item is a tuple of frozensets.
-
- # Let's simplify for now and refine based on Atropos internals if needed.
- # We'll construct the messages list directly in collect_trajectories.
- # So get_next_item can return the raw data needed.
- return item_data # This will be like {"html": "...", "id": "..."}
+ return item_data
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
- # 'item_data' here is what get_next_item returned.
original_html = item["html"]
system_message_content = (
"You are an expert web developer specializing in accessibility. "
@@ -198,262 +178,180 @@ class AccessibilityEnv(BaseEnv):
{"role": "user", "content": user_message_content},
]
- try:
- chat_completions = await self.server.chat_completion(
- messages=messages,
- n=self.config.group_size, # Number of completions
- # `max_tokens` here is for the *completion* part, not the whole context.
- # Your Llama API example used 256. Adjust as needed for HTML output.
- max_tokens=1024, # Max tokens for the LLM's response
- # temperature=0.7, # Optional: adjust for creativity vs. determinism
- # model=self.server_configs[0].model_name # This should be picked up automatically from server_configs
- # by the self.server object.
- )
- except tenacity.RetryError as retry_err: # Specifically catch RetryError
- print(
- "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
- )
- print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:")
- print(f"[{self.name}] RetryError Details: {retry_err}")
- # ... and the response details if available on 'e' ...
- original_exception = None
- if retry_err.last_attempt:
- if retry_err.last_attempt.failed:
- original_exception = retry_err.last_attempt.exception()
- print(
- f"[{self.name}] Last attempt failed. Original exception that caused retries:"
- )
- print(f"[{self.name}] Type: {type(original_exception)}")
- print(
- f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}"
- )
- print(
- f"[{self.name}] Full Str: {str(original_exception)}"
- ) # More direct string representation
- else:
- # This case is unusual for a RetryError due to failure
- print(
- f"""[{self.name}] Last attempt recorded but did not 'fail'.
- Result: {retry_err.last_attempt.result()}"""
- )
- else:
- print(
- f"""[{self.name}] Could not get 'last_attempt' details from
- RetryError object. Raw RetryError: {retry_err}"""
- )
-
- # Now, if we have the original_exception, try to get more details (like HTTP response)
- if original_exception:
- # Check if the original exception is an OpenAI/HTTPX style error
- # by looking for a 'response' attribute.
- if (
- hasattr(original_exception, "response")
- and original_exception.response is not None
- ):
- response_obj = original_exception.response
- status_code_text = "Status code N/A"
- response_content_text = "Response content N/A"
-
- if hasattr(response_obj, "status_code"):
- status_code_text = str(response_obj.status_code)
-
- print(
- f"[{self.name}] Underlying API Response Status Code: {status_code_text}"
- )
-
- # Try to get JSON content first (common for API errors)
- if hasattr(response_obj, "json") and callable(response_obj.json):
- try:
- response_json_parsed = (
- response_obj.json()
- ) # Note: this might need to be awaited if response_obj.json is async
- # but typically in an exception, it's already processed.
- print(
- f"[{self.name}] Underlying API Response JSON: {response_json_parsed}"
- )
- except Exception as json_e_inner:
- print(
- f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}"
- )
- # Fallback to text if JSON parsing fails
- if hasattr(response_obj, "text"):
- response_content_text = response_obj.text
- print(
- f"[{self.name}] Underlying API Response Text: {response_content_text}"
- )
- elif hasattr(response_obj, "content"): # often bytes
- try:
- response_content_text = (
- response_obj.content.decode()
- )
- print(
- f"""[{self.name}] Underlying API Response
- Content (decoded): {response_content_text}"""
- )
- except Exception:
- response_content_text = str(response_obj.content)
- print(
- f"""[{self.name}] Underlying API Response Content
- (raw bytes as str): {response_content_text}"""
- )
- # If no json() method, try .text or .content directly
- elif hasattr(response_obj, "text"):
- response_content_text = response_obj.text
- print(
- f"[{self.name}] Underlying API Response Text: {response_content_text}"
- )
- elif hasattr(response_obj, "content"):
- try:
- response_content_text = response_obj.content.decode()
- print(
- f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}"
- )
- except Exception:
- response_content_text = str(response_obj.content)
- print(
- f"""[{self.name}] Underlying API Response Content
- (raw bytes as str): {response_content_text}"""
- )
-
- print(
- "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
- )
- print(
- f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}"
- )
- return None, []
+ chat_completions = await self.server.chat_completion(
+ messages=messages,
+ n=self.config.group_size,
+ max_tokens=1024,
+ )
to_score_inputs = []
- for choice in chat_completions.choices:
- llm_response_content = choice.message.content
- # The 'messages' to store for scoring/tokenization should represent the full exchange
- # that led to this specific llm_response_content.
- # This includes the original system and user messages, and the assistant's response.
- full_exchange_messages = messages + [
- {"role": "assistant", "content": llm_response_content}
- ]
- to_score_inputs.append(
- {
- "full_exchange_messages": full_exchange_messages, # For tokenization
- "llm_modified_html": llm_response_content, # For direct scoring
- "original_html_info": item, # To know what to check against
- }
- )
-
- # The `score` method in Atropos expects a list where each element typically is
- # (messages_tuple_for_tokenization, original_item_metadata_for_scoring_logic)
- # We need to adapt `to_score_inputs` to what `self.score` will expect.
- # Let's define that `self.score` will take this list of dicts directly.
- # The `collect_trajectories` from the blog post returns `to_postprocess, to_backlog`
- # where `to_postprocess` is the output of `self.score`.
+ if chat_completions is not None:
+ for choice in chat_completions.choices:
+ llm_response_content = choice.message.content
+ full_exchange_messages = messages + [
+ {"role": "assistant", "content": llm_response_content}
+ ]
+ to_score_inputs.append(
+ {
+ "full_exchange_messages": full_exchange_messages,
+ "llm_modified_html": llm_response_content,
+ "original_html_info": item,
+ }
+ )
scored_data_group = await self.score(to_score_inputs)
return scored_data_group, [] # Assuming no backlog for now
- async def score(
- self, rollout_group_data: List[dict]
- ) -> Optional[ScoredDataGroup]: # Return type is still ScoredDataGroup
+ async def score(self, rollout_group_data: List[dict]) -> Optional[ScoredDataGroup]:
print(f"[{self.name}] Scoring {len(rollout_group_data)} rollouts...")
- all_tokens: List[List[int]] = []
- all_masks: List[List[int]] = []
- all_scores: List[float] = []
- # For TypedDict, optional fields that are not provided will simply not be keys in the dictionary.
- # However, if we want to include them as None, we can. Let's prepare for that.
- all_advantages: Optional[List[List[float]]] = (
- None # Or initialize as [] if you might populate it
- )
- all_ref_logprobs: Optional[List[List[float]]] = None # Or initialize as []
- all_messages_for_trainer: Optional[List[List[Dict]]] = (
- None # Assuming Message is also a dict-like structure or TypedDict
- )
+ # Initialize lists to store data for all successfully processed items in the batch
+ final_tokens_batch: List[List[int]] = []
+ final_masks_batch: List[List[int]] = []
+ final_scores_batch: List[float] = []
+ final_concatenated_dialogues_batch: List[str] = []
+
+ # Optional fields for ScoredDataGroup, will remain None for this basic setup
+ all_advantages: Optional[List[List[float]]] = None
+ all_ref_logprobs: Optional[List[List[float]]] = None
for data_item in rollout_group_data:
- llm_html = data_item["llm_modified_html"]
+ llm_html_str = data_item["llm_modified_html"]
original_info = data_item["original_html_info"]
+ full_exchange_messages_list_of_dicts = data_item[
+ "full_exchange_messages"
+ ] # This is List[Dict[str, str]]
- current_score = -1.0
- if "