mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add tokenizer. Fix typing
This commit is contained in:
parent
659247fc00
commit
94038876f4
1 changed files with 112 additions and 49 deletions
|
|
@ -1,6 +1,8 @@
|
|||
# environments/hack0/accessibility_env/accessibility_env.py
|
||||
import os # For API keys, etc.
|
||||
from typing import List, Optional, Tuple # Common type hints, added Dict
|
||||
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
|
||||
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
# Corrected imports for Atropos types
|
||||
from atroposlib.envs.base import (
|
||||
|
|
@ -61,9 +63,39 @@ class AccessibilityEnv(BaseEnv):
|
|||
async def setup(self):
|
||||
print(f"[{self.name}] Setting up environment...")
|
||||
# Load dataset, initialize tools (e.g., HTML parser) here
|
||||
self.dataset = [] # Placeholder for your HTML snippets
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config.tokenizer_name, trust_remote_code=True
|
||||
)
|
||||
# It's good practice to set pad_token if it's not already set, common for GPT-like models
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
print(
|
||||
f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully."
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}"
|
||||
)
|
||||
# Decide how to handle this - raise error, or try to proceed without tokenization (not ideal for Atropos)
|
||||
# For now, let's allow it to proceed, but tokenization will fail later if tokenizer is None
|
||||
# A better approach might be to raise an exception here if tokenizer is critical.
|
||||
# raise RuntimeError(f"Failed to load tokenizer: {e}") from e
|
||||
|
||||
self.dataset = [
|
||||
{
|
||||
"id": "ex001",
|
||||
"html": "<h1>Welcome</h1><img src='image.jpg'>",
|
||||
"issues_to_fix": ["missing_alt_text"],
|
||||
},
|
||||
{
|
||||
"id": "ex002",
|
||||
"html": "<label>Name</label><input type='text' name='username'>",
|
||||
"issues_to_fix": ["missing_for_attribute_on_label"],
|
||||
},
|
||||
] # Placeholder for your HTML snippets
|
||||
self.iter = 0
|
||||
print(f"[{self.name}] Setup complete.")
|
||||
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
|
||||
|
||||
async def get_next_item(self) -> Optional[Item]:
|
||||
if self.iter >= len(self.dataset):
|
||||
|
|
@ -124,10 +156,10 @@ class AccessibilityEnv(BaseEnv):
|
|||
return item_data # This will be like {"html": "...", "id": "..."}
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item_data: Item
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
# 'item_data' here is what get_next_item returned.
|
||||
original_html = item_data["html"]
|
||||
original_html = item["html"]
|
||||
system_message_content = (
|
||||
"You are an expert web developer specializing in accessibility. "
|
||||
"Given the following HTML snippet, please make the minimal necessary modifications "
|
||||
|
|
@ -163,7 +195,7 @@ class AccessibilityEnv(BaseEnv):
|
|||
{
|
||||
"full_exchange_messages": full_exchange_messages, # For tokenization
|
||||
"llm_modified_html": llm_response_content, # For direct scoring
|
||||
"original_html_info": item_data, # To know what to check against
|
||||
"original_html_info": item, # To know what to check against
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -177,72 +209,103 @@ class AccessibilityEnv(BaseEnv):
|
|||
scored_data_group = await self.score(to_score_inputs)
|
||||
return scored_data_group, [] # Assuming no backlog for now
|
||||
|
||||
async def score(self, rollout_group_data: List[dict]) -> Optional[ScoredDataGroup]:
|
||||
# rollout_group_data is a list of dicts, each like:
|
||||
# {
|
||||
# "full_exchange_messages": [...],
|
||||
# "llm_modified_html": "...",
|
||||
# "original_html_info": {"html": "...", "id": "...", "issues": [...]}
|
||||
# }
|
||||
async def score(
|
||||
self, rollout_group_data: List[dict]
|
||||
) -> Optional[ScoredDataGroup]: # Return type is still ScoredDataGroup
|
||||
print(f"[{self.name}] Scoring {len(rollout_group_data)} rollouts...")
|
||||
scores_obj = ScoredDataGroup() # Use the Atropos defined type
|
||||
# Initialize lists within scores_obj as per ScoredDataGroup structure
|
||||
# (typically 'tokens', 'masks', 'scores', maybe 'logprobs')
|
||||
scores_obj["tokens"] = []
|
||||
scores_obj["masks"] = []
|
||||
scores_obj["scores"] = []
|
||||
# scores_obj["infos"] = [] # Optional for extra debug info
|
||||
|
||||
all_tokens: List[List[int]] = []
|
||||
all_masks: List[List[int]] = []
|
||||
all_scores: List[float] = []
|
||||
# For TypedDict, optional fields that are not provided will simply not be keys in the dictionary.
|
||||
# However, if we want to include them as None, we can. Let's prepare for that.
|
||||
all_advantages: Optional[List[List[float]]] = (
|
||||
None # Or initialize as [] if you might populate it
|
||||
)
|
||||
all_ref_logprobs: Optional[List[List[float]]] = None # Or initialize as []
|
||||
all_messages_for_trainer: Optional[List[List[Dict]]] = (
|
||||
None # Assuming Message is also a dict-like structure or TypedDict
|
||||
)
|
||||
|
||||
for data_item in rollout_group_data:
|
||||
llm_html = data_item["llm_modified_html"]
|
||||
original_info = data_item["original_html_info"]
|
||||
|
||||
# Basic reward: 1.0 if fixed, -1.0 if not.
|
||||
# This will be replaced with actual WCAG checks.
|
||||
current_score = -1.0 # Default to failure
|
||||
# ---- YOUR SCORING LOGIC HERE ----
|
||||
# Example: (pseudo-code, requires BeautifulSoup and specific checks)
|
||||
# violations_fixed = self.check_wcag_fixes(llm_html, original_info)
|
||||
# if violations_fixed:
|
||||
# current_score = 1.0
|
||||
# For now, a placeholder:
|
||||
current_score = -1.0
|
||||
if "<img" in original_info["html"] and "alt=" in llm_html:
|
||||
current_score = 1.0
|
||||
elif "<label>" in original_info["html"] and "for=" in llm_html:
|
||||
current_score = 1.0
|
||||
|
||||
# Tokenize the full exchange for the trainer
|
||||
# The 'tokenize_for_trainer' util expects a tuple/list of message dicts
|
||||
tokenized_output = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
data_item["full_exchange_messages"], # Pass the list of message dicts
|
||||
)
|
||||
try:
|
||||
# Ensure self.tokenizer is initialized in __init__ or setup
|
||||
if not hasattr(self, "tokenizer") or self.tokenizer is None:
|
||||
print(f"[{self.name}] Error: Tokenizer not initialized.")
|
||||
# Attempt to initialize it here if it makes sense, or ensure it's done in setup()
|
||||
# from transformers import AutoTokenizer
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name, trust_remote_code=True)
|
||||
# This is a fallback, better to ensure it's in setup()
|
||||
# For now, let's assume it's there. If not, this will fail earlier or be caught by linter.
|
||||
pass # Assuming tokenizer is initialized
|
||||
|
||||
tokenized_output = tokenize_for_trainer(
|
||||
self.tokenizer, # Make sure self.tokenizer is loaded, e.g., in setup()
|
||||
data_item["full_exchange_messages"],
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during tokenization: {e}. Skipping item.")
|
||||
continue
|
||||
|
||||
# Ensure tokenized_output contains 'tokens' and 'masks'
|
||||
if "tokens" not in tokenized_output or "masks" not in tokenized_output:
|
||||
print(
|
||||
f"[{self.name}] Warning: Tokenization did not return tokens/masks for an item. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
scores_obj["tokens"].append(tokenized_output["tokens"])
|
||||
scores_obj["masks"].append(tokenized_output["masks"])
|
||||
scores_obj["scores"].append(current_score)
|
||||
# scores_obj["infos"].append({"original_id": original_info["id"], "llm_output_preview": llm_html[:100]})
|
||||
all_tokens.append(tokenized_output["tokens"])
|
||||
all_masks.append(tokenized_output["masks"])
|
||||
all_scores.append(current_score)
|
||||
|
||||
# Handle case where no valid items were scored
|
||||
if not scores_obj["scores"]:
|
||||
# If you were to populate optional fields, you'd do it here. For example:
|
||||
# if "advantages" in tokenized_output: # Fictional example
|
||||
# if all_advantages is None: all_advantages = []
|
||||
# all_advantages.append(tokenized_output["advantages"])
|
||||
|
||||
if not all_scores:
|
||||
print(f"[{self.name}] No valid items to score, returning None.")
|
||||
return None
|
||||
|
||||
# Atropos convention: if all scores are identical, return None (no learning signal)
|
||||
# This might be too strict for early testing. Consider enabling later.
|
||||
# if len(set(scores_obj["scores"])) == 1 and len(scores_obj["scores"]) > 1 :
|
||||
# print(f"[{self.name}] All scores are identical ({scores_obj['scores'][0]}), returning None.")
|
||||
# return None
|
||||
# print(f"[{self.name}] Scoring complete. Scores: {all_scores}") # Already printed if successful below
|
||||
|
||||
print(f"[{self.name}] Scoring complete. Scores: {scores_obj['scores']}")
|
||||
return scores_obj
|
||||
# Construct the dictionary that conforms to ScoredDataGroup TypedDict
|
||||
# Mandatory fields:
|
||||
data_to_return: ScoredDataGroup = {
|
||||
"tokens": all_tokens,
|
||||
"masks": all_masks,
|
||||
"scores": all_scores,
|
||||
"advantages": None,
|
||||
"ref_logprobs": None,
|
||||
"group_overrides": None,
|
||||
"messages": None,
|
||||
"overrides": None,
|
||||
}
|
||||
|
||||
# Add optional fields if they have values (or if you explicitly want them as None)
|
||||
# The TypedDict definition uses Optional[], so if a key is missing, it's fine.
|
||||
# If you want to explicitly include them as None if not populated:
|
||||
if all_advantages is not None:
|
||||
data_to_return["advantages"] = all_advantages
|
||||
if all_ref_logprobs is not None:
|
||||
data_to_return["ref_logprobs"] = all_ref_logprobs
|
||||
if all_messages_for_trainer is not None:
|
||||
data_to_return["messages"] = all_messages_for_trainer
|
||||
# group_overrides and overrides are also optional
|
||||
|
||||
print(
|
||||
f"""[{self.name}] Scoring complete. Data to return (first score):
|
||||
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""
|
||||
)
|
||||
return data_to_return
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue