Add tokenizer. Fix typing

This commit is contained in:
Josh 2025-05-18 13:32:59 -07:00
parent 659247fc00
commit 94038876f4

View file

@ -1,6 +1,8 @@
# environments/hack0/accessibility_env/accessibility_env.py
import os # For API keys, etc.
from typing import List, Optional, Tuple # Common type hints, added Dict
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
from transformers.models.auto.tokenization_auto import AutoTokenizer
# Corrected imports for Atropos types
from atroposlib.envs.base import (
@ -61,9 +63,39 @@ class AccessibilityEnv(BaseEnv):
async def setup(self):
print(f"[{self.name}] Setting up environment...")
# Load dataset, initialize tools (e.g., HTML parser) here
self.dataset = [] # Placeholder for your HTML snippets
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_name, trust_remote_code=True
)
# It's good practice to set pad_token if it's not already set, common for GPT-like models
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(
f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully."
)
except Exception as e:
print(
f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}"
)
# Decide how to handle this - raise error, or try to proceed without tokenization (not ideal for Atropos)
# For now, let's allow it to proceed, but tokenization will fail later if tokenizer is None
# A better approach might be to raise an exception here if tokenizer is critical.
# raise RuntimeError(f"Failed to load tokenizer: {e}") from e
self.dataset = [
{
"id": "ex001",
"html": "<h1>Welcome</h1><img src='image.jpg'>",
"issues_to_fix": ["missing_alt_text"],
},
{
"id": "ex002",
"html": "<label>Name</label><input type='text' name='username'>",
"issues_to_fix": ["missing_for_attribute_on_label"],
},
] # Placeholder for your HTML snippets
self.iter = 0
print(f"[{self.name}] Setup complete.")
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
async def get_next_item(self) -> Optional[Item]:
if self.iter >= len(self.dataset):
@ -124,10 +156,10 @@ class AccessibilityEnv(BaseEnv):
return item_data # This will be like {"html": "...", "id": "..."}
async def collect_trajectories(
self, item_data: Item
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
# 'item_data' here is what get_next_item returned.
original_html = item_data["html"]
original_html = item["html"]
system_message_content = (
"You are an expert web developer specializing in accessibility. "
"Given the following HTML snippet, please make the minimal necessary modifications "
@ -163,7 +195,7 @@ class AccessibilityEnv(BaseEnv):
{
"full_exchange_messages": full_exchange_messages, # For tokenization
"llm_modified_html": llm_response_content, # For direct scoring
"original_html_info": item_data, # To know what to check against
"original_html_info": item, # To know what to check against
}
)
@ -177,72 +209,103 @@ class AccessibilityEnv(BaseEnv):
scored_data_group = await self.score(to_score_inputs)
return scored_data_group, [] # Assuming no backlog for now
async def score(self, rollout_group_data: List[dict]) -> Optional[ScoredDataGroup]:
# rollout_group_data is a list of dicts, each like:
# {
# "full_exchange_messages": [...],
# "llm_modified_html": "...",
# "original_html_info": {"html": "...", "id": "...", "issues": [...]}
# }
async def score(
self, rollout_group_data: List[dict]
) -> Optional[ScoredDataGroup]: # Return type is still ScoredDataGroup
print(f"[{self.name}] Scoring {len(rollout_group_data)} rollouts...")
scores_obj = ScoredDataGroup() # Use the Atropos defined type
# Initialize lists within scores_obj as per ScoredDataGroup structure
# (typically 'tokens', 'masks', 'scores', maybe 'logprobs')
scores_obj["tokens"] = []
scores_obj["masks"] = []
scores_obj["scores"] = []
# scores_obj["infos"] = [] # Optional for extra debug info
all_tokens: List[List[int]] = []
all_masks: List[List[int]] = []
all_scores: List[float] = []
# For TypedDict, optional fields that are not provided will simply not be keys in the dictionary.
# However, if we want to include them as None, we can. Let's prepare for that.
all_advantages: Optional[List[List[float]]] = (
None # Or initialize as [] if you might populate it
)
all_ref_logprobs: Optional[List[List[float]]] = None # Or initialize as []
all_messages_for_trainer: Optional[List[List[Dict]]] = (
None # Assuming Message is also a dict-like structure or TypedDict
)
for data_item in rollout_group_data:
llm_html = data_item["llm_modified_html"]
original_info = data_item["original_html_info"]
# Basic reward: 1.0 if fixed, -1.0 if not.
# This will be replaced with actual WCAG checks.
current_score = -1.0 # Default to failure
# ---- YOUR SCORING LOGIC HERE ----
# Example: (pseudo-code, requires BeautifulSoup and specific checks)
# violations_fixed = self.check_wcag_fixes(llm_html, original_info)
# if violations_fixed:
# current_score = 1.0
# For now, a placeholder:
current_score = -1.0
if "<img" in original_info["html"] and "alt=" in llm_html:
current_score = 1.0
elif "<label>" in original_info["html"] and "for=" in llm_html:
current_score = 1.0
# Tokenize the full exchange for the trainer
# The 'tokenize_for_trainer' util expects a tuple/list of message dicts
tokenized_output = tokenize_for_trainer(
self.tokenizer,
data_item["full_exchange_messages"], # Pass the list of message dicts
)
try:
# Ensure self.tokenizer is initialized in __init__ or setup
if not hasattr(self, "tokenizer") or self.tokenizer is None:
print(f"[{self.name}] Error: Tokenizer not initialized.")
# Attempt to initialize it here if it makes sense, or ensure it's done in setup()
# from transformers import AutoTokenizer
# self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name, trust_remote_code=True)
# This is a fallback, better to ensure it's in setup()
# For now, let's assume it's there. If not, this will fail earlier or be caught by linter.
pass # Assuming tokenizer is initialized
tokenized_output = tokenize_for_trainer(
self.tokenizer, # Make sure self.tokenizer is loaded, e.g., in setup()
data_item["full_exchange_messages"],
)
except Exception as e:
print(f"[{self.name}] Error during tokenization: {e}. Skipping item.")
continue
# Ensure tokenized_output contains 'tokens' and 'masks'
if "tokens" not in tokenized_output or "masks" not in tokenized_output:
print(
f"[{self.name}] Warning: Tokenization did not return tokens/masks for an item. Skipping."
)
continue
scores_obj["tokens"].append(tokenized_output["tokens"])
scores_obj["masks"].append(tokenized_output["masks"])
scores_obj["scores"].append(current_score)
# scores_obj["infos"].append({"original_id": original_info["id"], "llm_output_preview": llm_html[:100]})
all_tokens.append(tokenized_output["tokens"])
all_masks.append(tokenized_output["masks"])
all_scores.append(current_score)
# Handle case where no valid items were scored
if not scores_obj["scores"]:
# If you were to populate optional fields, you'd do it here. For example:
# if "advantages" in tokenized_output: # Fictional example
# if all_advantages is None: all_advantages = []
# all_advantages.append(tokenized_output["advantages"])
if not all_scores:
print(f"[{self.name}] No valid items to score, returning None.")
return None
# Atropos convention: if all scores are identical, return None (no learning signal)
# This might be too strict for early testing. Consider enabling later.
# if len(set(scores_obj["scores"])) == 1 and len(scores_obj["scores"]) > 1 :
# print(f"[{self.name}] All scores are identical ({scores_obj['scores'][0]}), returning None.")
# return None
# print(f"[{self.name}] Scoring complete. Scores: {all_scores}") # Already printed if successful below
print(f"[{self.name}] Scoring complete. Scores: {scores_obj['scores']}")
return scores_obj
# Construct the dictionary that conforms to ScoredDataGroup TypedDict
# Mandatory fields:
data_to_return: ScoredDataGroup = {
"tokens": all_tokens,
"masks": all_masks,
"scores": all_scores,
"advantages": None,
"ref_logprobs": None,
"group_overrides": None,
"messages": None,
"overrides": None,
}
# Add optional fields if they have values (or if you explicitly want them as None)
# The TypedDict definition uses Optional[], so if a key is missing, it's fine.
# If you want to explicitly include them as None if not populated:
if all_advantages is not None:
data_to_return["advantages"] = all_advantages
if all_ref_logprobs is not None:
data_to_return["ref_logprobs"] = all_ref_logprobs
if all_messages_for_trainer is not None:
data_to_return["messages"] = all_messages_for_trainer
# group_overrides and overrides are also optional
print(
f"""[{self.name}] Scoring complete. Data to return (first score):
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""
)
return data_to_return
async def evaluate(
self,