mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Cleanup. End-to-end functionality in place
This commit is contained in:
parent
8ff2b02ce0
commit
904360a02e
4 changed files with 351 additions and 315 deletions
|
|
@ -1,32 +1,37 @@
|
|||
# environments/hack0/accessibility_env/accessibility_env.py
|
||||
import os # For API keys, etc.
|
||||
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tenacity
|
||||
|
||||
# from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import Field
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
# Corrected imports for Atropos types
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import ( # GameHistory might not be needed yet, Item is common
|
||||
Item,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
from .accessibility_rules import (
|
||||
AccessibilityRule,
|
||||
LabelAssociationRule,
|
||||
MissingAltTextRule,
|
||||
)
|
||||
|
||||
|
||||
class AccessibilityEnvConfig(BaseEnvConfig):
|
||||
# Add any custom config fields specific to your env later
|
||||
pass
|
||||
dataset_path: str = Field(
|
||||
default="data/accessibility_dataset.jsonl", # Default relative path
|
||||
description="Path to the JSONL file containing the accessibility dataset.",
|
||||
)
|
||||
|
||||
|
||||
class AccessibilityEnv(BaseEnv):
|
||||
name = "accessibility_env" # A unique name for your environment
|
||||
config: AccessibilityEnvConfig
|
||||
name = "accessibility_env"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -36,33 +41,44 @@ class AccessibilityEnv(BaseEnv):
|
|||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
# Initialize any env-specific attributes here
|
||||
self.tokenizer = None
|
||||
|
||||
# Initialize your list of rule instances
|
||||
self.accessibility_rules: List[AccessibilityRule] = [
|
||||
MissingAltTextRule(),
|
||||
LabelAssociationRule(),
|
||||
]
|
||||
|
||||
# For quick lookup if needed, though iterating self.accessibility_rules is fine
|
||||
self.rules_by_key = {rule.issue_key: rule for rule in self.accessibility_rules}
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
|
||||
current_dataset_size = 10
|
||||
|
||||
env_config = AccessibilityEnvConfig(
|
||||
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
|
||||
group_size=1, # Smaller for faster testing initially
|
||||
tokenizer_name="gpt2",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=3, # For process mode, number of items to generate
|
||||
batch_size=1, # Max items in a single call to score (related to group_size)
|
||||
steps_per_eval=5,
|
||||
total_steps=current_dataset_size,
|
||||
batch_size=1,
|
||||
steps_per_eval=current_dataset_size,
|
||||
max_token_length=2048,
|
||||
wandb_name="accessibility_llama_dev", # Dev run name
|
||||
wandb_name="accessibility_openai_default_dev",
|
||||
)
|
||||
|
||||
llama_api_key = os.environ.get("LLAMA_API_KEY")
|
||||
if not llama_api_key:
|
||||
print("WARNING: LLAMA_API_KEY environment variable not set!")
|
||||
openai_api_key_from_env = os.environ.get("OPENAI_API_KEY")
|
||||
if not openai_api_key_from_env:
|
||||
print(
|
||||
"WARNING (from config_init): OPENAI_API_KEY environment variable not set for default config!"
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
base_url="https://api.llama.com/v1", # <<<---- Llama API base URL
|
||||
api_key=llama_api_key,
|
||||
num_requests_for_eval=16,
|
||||
),
|
||||
model_name="gpt-3.5-turbo",
|
||||
api_key=openai_api_key_from_env,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
|
|
@ -71,27 +87,17 @@ class AccessibilityEnv(BaseEnv):
|
|||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config.tokenizer_name, trust_remote_code=True
|
||||
) # tokenizer_name is 'gpt2'
|
||||
)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# Set a default chat template if it's not already set
|
||||
# This is crucial for tokenizers like 'gpt2' that don't have one by default.
|
||||
if self.tokenizer.chat_template is None:
|
||||
# A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat.
|
||||
# For gpt-3.5-turbo, the actual formatting is handled by the API,
|
||||
# but for local tokenization for the trainer, we need *a* template.
|
||||
# A basic template for generic tokenization:
|
||||
self.tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{ message['role'] + ': ' + message['content'] + '\\n' }}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
# Alternatively, for many models, a more structured Jinja template like
|
||||
# the Llama or ChatML one might be used if you were training with such a format.
|
||||
# For just getting token IDs for a generic model for RL, the simple one above might suffice.
|
||||
# Or, if tokenize_for_trainer is smart, it might just concatenate.
|
||||
# Let's check if a simpler approach is needed for tokenize_for_trainer.
|
||||
|
||||
print(
|
||||
f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'."
|
||||
)
|
||||
|
|
@ -105,83 +111,57 @@ class AccessibilityEnv(BaseEnv):
|
|||
)
|
||||
raise RuntimeError(f"Failed to load tokenizer: {e}") from e
|
||||
|
||||
self.dataset = [
|
||||
{
|
||||
"id": "ex001",
|
||||
"html": "<h1>Welcome</h1><img src='image.jpg'>",
|
||||
"issues_to_fix": ["missing_alt_text"],
|
||||
},
|
||||
{
|
||||
"id": "ex002",
|
||||
"html": "<label>Name</label><input type='text' name='username'>",
|
||||
"issues_to_fix": ["missing_for_attribute_on_label"],
|
||||
},
|
||||
]
|
||||
# Load dataset from file
|
||||
self.dataset: List[Dict] = []
|
||||
env_script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
full_dataset_path = os.path.join(env_script_dir, self.config.dataset_path)
|
||||
|
||||
print(f"[{self.name}] Attempting to load dataset from: {full_dataset_path}")
|
||||
try:
|
||||
with open(full_dataset_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip(): # Ensure line is not empty
|
||||
self.dataset.append(json.loads(line))
|
||||
if not self.dataset:
|
||||
raise FileNotFoundError(
|
||||
"Dataset file was empty or contained no valid JSON lines."
|
||||
)
|
||||
except FileNotFoundError:
|
||||
print(f"[{self.name}] ERROR: Dataset file not found at {full_dataset_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"[{self.name}] ERROR: Failed to decode JSON from {full_dataset_path}. Error: {e}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[{self.name}] ERROR: An unexpected error occurred while loading dataset: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
self.iter = 0
|
||||
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
|
||||
print(
|
||||
f"""[{self.name}] Setup complete. Loaded {len(self.dataset)}
|
||||
items. Initialized {len(self.accessibility_rules)} accessibility rules."""
|
||||
)
|
||||
|
||||
async def get_next_item(self) -> Optional[Item]:
|
||||
if self.iter >= len(self.dataset):
|
||||
if (
|
||||
self.iter >= self.config.total_steps
|
||||
): # Stop after total_steps for 'process'
|
||||
if self.iter >= self.config.total_steps:
|
||||
return None
|
||||
# Potentially loop dataset or handle running out of unique items
|
||||
# For hackathon, just stopping might be fine if dataset is small
|
||||
# and total_steps is matched to dataset size.
|
||||
# self.iter = 0 # To loop
|
||||
|
||||
print(f"[{self.name}] Reached end of dataset or total_steps.")
|
||||
return None
|
||||
|
||||
item_data = self.dataset[self.iter]
|
||||
self.iter += 1
|
||||
# Format item_data into the 'Item' structure Atropos expects
|
||||
# Typically (prompt_messages_tuple, gold_answer_or_metadata_tuple)
|
||||
# Example:
|
||||
# user_prompt = {"role": "user", "content": f"Make this HTML accessible: {item_data['html_snippet']}"}
|
||||
# system_prompt_content = "You are an AI assistant specializing in web accessibility. Modify the given
|
||||
# HTML to meet WCAG AA standards. Output only the modified HTML."
|
||||
# system_prompt = {"role": "system", "content": system_prompt_content}
|
||||
# prompt_messages = (system_prompt, user_prompt) # This needs to be a tuple of dicts
|
||||
# messages_for_item = tuple(frozenset(p.items()) for p in prompt_messages) # Atropos often expects this format
|
||||
# return (messages_for_item, item_data.get('expected_outcome_or_id')) # Second part is for scoring reference
|
||||
|
||||
# Simpler start for prompt:
|
||||
# prompt = (
|
||||
# (
|
||||
# {
|
||||
# "role": "system",
|
||||
# "content": "You are an AI assistant. Given HTML, make it more accessible.",
|
||||
# },
|
||||
# ),
|
||||
# ({"role": "user", "content": f"Original HTML: {item_data['html']}"},),
|
||||
# )
|
||||
# This prompt structure might need adjustment based on how Atropos and the LLM API expect it.
|
||||
# The gsm8k example has:
|
||||
# user_message = {"role": "user", "content": item["question"]}
|
||||
# chat_completions = await self.server.chat_completion(
|
||||
# messages=[{"role": "system", "content": system_prompt}, user_message], ...
|
||||
# So a list of dicts is passed to chat_completion.
|
||||
# The 'Item' type for get_next_item is often a tuple: ( (message_part_1, message_part_2, ...),
|
||||
# metadata_for_scoring )
|
||||
# where each message_part is often a frozenset of items from a dict. This is a bit complex.
|
||||
# Let's start with a simple string prompt and adapt.
|
||||
# For now, let's assume item is (prompt_string, metadata_for_scoring)
|
||||
# The `collect_trajectories` in coding_server.py takes `item: Item`
|
||||
# and then accesses `item[0][0]` which implies item is nested.
|
||||
# `prompt = tuple([frozenset({"role": "user", "content": next_item["description"]}.items())])`
|
||||
# `return (prompt, answer)`
|
||||
# So, first element of item is a tuple of frozensets.
|
||||
|
||||
# Let's simplify for now and refine based on Atropos internals if needed.
|
||||
# We'll construct the messages list directly in collect_trajectories.
|
||||
# So get_next_item can return the raw data needed.
|
||||
return item_data # This will be like {"html": "...", "id": "..."}
|
||||
return item_data
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
# 'item_data' here is what get_next_item returned.
|
||||
original_html = item["html"]
|
||||
system_message_content = (
|
||||
"You are an expert web developer specializing in accessibility. "
|
||||
|
|
@ -198,262 +178,180 @@ class AccessibilityEnv(BaseEnv):
|
|||
{"role": "user", "content": user_message_content},
|
||||
]
|
||||
|
||||
try:
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size, # Number of completions
|
||||
# `max_tokens` here is for the *completion* part, not the whole context.
|
||||
# Your Llama API example used 256. Adjust as needed for HTML output.
|
||||
max_tokens=1024, # Max tokens for the LLM's response
|
||||
# temperature=0.7, # Optional: adjust for creativity vs. determinism
|
||||
# model=self.server_configs[0].model_name # This should be picked up automatically from server_configs
|
||||
# by the self.server object.
|
||||
)
|
||||
except tenacity.RetryError as retry_err: # Specifically catch RetryError
|
||||
print(
|
||||
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:")
|
||||
print(f"[{self.name}] RetryError Details: {retry_err}")
|
||||
# ... and the response details if available on 'e' ...
|
||||
original_exception = None
|
||||
if retry_err.last_attempt:
|
||||
if retry_err.last_attempt.failed:
|
||||
original_exception = retry_err.last_attempt.exception()
|
||||
print(
|
||||
f"[{self.name}] Last attempt failed. Original exception that caused retries:"
|
||||
)
|
||||
print(f"[{self.name}] Type: {type(original_exception)}")
|
||||
print(
|
||||
f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}"
|
||||
)
|
||||
print(
|
||||
f"[{self.name}] Full Str: {str(original_exception)}"
|
||||
) # More direct string representation
|
||||
else:
|
||||
# This case is unusual for a RetryError due to failure
|
||||
print(
|
||||
f"""[{self.name}] Last attempt recorded but did not 'fail'.
|
||||
Result: {retry_err.last_attempt.result()}"""
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"""[{self.name}] Could not get 'last_attempt' details from
|
||||
RetryError object. Raw RetryError: {retry_err}"""
|
||||
)
|
||||
|
||||
# Now, if we have the original_exception, try to get more details (like HTTP response)
|
||||
if original_exception:
|
||||
# Check if the original exception is an OpenAI/HTTPX style error
|
||||
# by looking for a 'response' attribute.
|
||||
if (
|
||||
hasattr(original_exception, "response")
|
||||
and original_exception.response is not None
|
||||
):
|
||||
response_obj = original_exception.response
|
||||
status_code_text = "Status code N/A"
|
||||
response_content_text = "Response content N/A"
|
||||
|
||||
if hasattr(response_obj, "status_code"):
|
||||
status_code_text = str(response_obj.status_code)
|
||||
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Status Code: {status_code_text}"
|
||||
)
|
||||
|
||||
# Try to get JSON content first (common for API errors)
|
||||
if hasattr(response_obj, "json") and callable(response_obj.json):
|
||||
try:
|
||||
response_json_parsed = (
|
||||
response_obj.json()
|
||||
) # Note: this might need to be awaited if response_obj.json is async
|
||||
# but typically in an exception, it's already processed.
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response JSON: {response_json_parsed}"
|
||||
)
|
||||
except Exception as json_e_inner:
|
||||
print(
|
||||
f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}"
|
||||
)
|
||||
# Fallback to text if JSON parsing fails
|
||||
if hasattr(response_obj, "text"):
|
||||
response_content_text = response_obj.text
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Text: {response_content_text}"
|
||||
)
|
||||
elif hasattr(response_obj, "content"): # often bytes
|
||||
try:
|
||||
response_content_text = (
|
||||
response_obj.content.decode()
|
||||
)
|
||||
print(
|
||||
f"""[{self.name}] Underlying API Response
|
||||
Content (decoded): {response_content_text}"""
|
||||
)
|
||||
except Exception:
|
||||
response_content_text = str(response_obj.content)
|
||||
print(
|
||||
f"""[{self.name}] Underlying API Response Content
|
||||
(raw bytes as str): {response_content_text}"""
|
||||
)
|
||||
# If no json() method, try .text or .content directly
|
||||
elif hasattr(response_obj, "text"):
|
||||
response_content_text = response_obj.text
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Text: {response_content_text}"
|
||||
)
|
||||
elif hasattr(response_obj, "content"):
|
||||
try:
|
||||
response_content_text = response_obj.content.decode()
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}"
|
||||
)
|
||||
except Exception:
|
||||
response_content_text = str(response_obj.content)
|
||||
print(
|
||||
f"""[{self.name}] Underlying API Response Content
|
||||
(raw bytes as str): {response_content_text}"""
|
||||
)
|
||||
|
||||
print(
|
||||
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
print(
|
||||
f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}"
|
||||
)
|
||||
return None, []
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
to_score_inputs = []
|
||||
for choice in chat_completions.choices:
|
||||
llm_response_content = choice.message.content
|
||||
# The 'messages' to store for scoring/tokenization should represent the full exchange
|
||||
# that led to this specific llm_response_content.
|
||||
# This includes the original system and user messages, and the assistant's response.
|
||||
full_exchange_messages = messages + [
|
||||
{"role": "assistant", "content": llm_response_content}
|
||||
]
|
||||
to_score_inputs.append(
|
||||
{
|
||||
"full_exchange_messages": full_exchange_messages, # For tokenization
|
||||
"llm_modified_html": llm_response_content, # For direct scoring
|
||||
"original_html_info": item, # To know what to check against
|
||||
}
|
||||
)
|
||||
|
||||
# The `score` method in Atropos expects a list where each element typically is
|
||||
# (messages_tuple_for_tokenization, original_item_metadata_for_scoring_logic)
|
||||
# We need to adapt `to_score_inputs` to what `self.score` will expect.
|
||||
# Let's define that `self.score` will take this list of dicts directly.
|
||||
# The `collect_trajectories` from the blog post returns `to_postprocess, to_backlog`
|
||||
# where `to_postprocess` is the output of `self.score`.
|
||||
if chat_completions is not None:
|
||||
for choice in chat_completions.choices:
|
||||
llm_response_content = choice.message.content
|
||||
full_exchange_messages = messages + [
|
||||
{"role": "assistant", "content": llm_response_content}
|
||||
]
|
||||
to_score_inputs.append(
|
||||
{
|
||||
"full_exchange_messages": full_exchange_messages,
|
||||
"llm_modified_html": llm_response_content,
|
||||
"original_html_info": item,
|
||||
}
|
||||
)
|
||||
|
||||
scored_data_group = await self.score(to_score_inputs)
|
||||
return scored_data_group, [] # Assuming no backlog for now
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data: List[dict]
|
||||
) -> Optional[ScoredDataGroup]: # Return type is still ScoredDataGroup
|
||||
async def score(self, rollout_group_data: List[dict]) -> Optional[ScoredDataGroup]:
|
||||
print(f"[{self.name}] Scoring {len(rollout_group_data)} rollouts...")
|
||||
|
||||
all_tokens: List[List[int]] = []
|
||||
all_masks: List[List[int]] = []
|
||||
all_scores: List[float] = []
|
||||
# For TypedDict, optional fields that are not provided will simply not be keys in the dictionary.
|
||||
# However, if we want to include them as None, we can. Let's prepare for that.
|
||||
all_advantages: Optional[List[List[float]]] = (
|
||||
None # Or initialize as [] if you might populate it
|
||||
)
|
||||
all_ref_logprobs: Optional[List[List[float]]] = None # Or initialize as []
|
||||
all_messages_for_trainer: Optional[List[List[Dict]]] = (
|
||||
None # Assuming Message is also a dict-like structure or TypedDict
|
||||
)
|
||||
# Initialize lists to store data for all successfully processed items in the batch
|
||||
final_tokens_batch: List[List[int]] = []
|
||||
final_masks_batch: List[List[int]] = []
|
||||
final_scores_batch: List[float] = []
|
||||
final_concatenated_dialogues_batch: List[str] = []
|
||||
|
||||
# Optional fields for ScoredDataGroup, will remain None for this basic setup
|
||||
all_advantages: Optional[List[List[float]]] = None
|
||||
all_ref_logprobs: Optional[List[List[float]]] = None
|
||||
|
||||
for data_item in rollout_group_data:
|
||||
llm_html = data_item["llm_modified_html"]
|
||||
llm_html_str = data_item["llm_modified_html"]
|
||||
original_info = data_item["original_html_info"]
|
||||
full_exchange_messages_list_of_dicts = data_item[
|
||||
"full_exchange_messages"
|
||||
] # This is List[Dict[str, str]]
|
||||
|
||||
current_score = -1.0
|
||||
if "<img" in original_info["html"] and "alt=" in llm_html:
|
||||
current_score = 1.0
|
||||
elif "<label>" in original_info["html"] and "for=" in llm_html:
|
||||
current_score = 1.0
|
||||
current_item_score = 0.0
|
||||
num_issues_actually_fixed = 0
|
||||
issues_expected_to_fix = original_info.get("issues_to_fix", [])
|
||||
num_issues_targeted = len(issues_expected_to_fix)
|
||||
|
||||
soup: Optional[BeautifulSoup] = None
|
||||
can_proceed_with_rule_checks = False
|
||||
try:
|
||||
# Ensure self.tokenizer is initialized in __init__ or setup
|
||||
if not hasattr(self, "tokenizer") or self.tokenizer is None:
|
||||
print(f"[{self.name}] Error: Tokenizer not initialized.")
|
||||
# Attempt to initialize it here if it makes sense, or ensure it's done in setup()
|
||||
# from transformers import AutoTokenizer
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name, trust_remote_code=True)
|
||||
# This is a fallback, better to ensure it's in setup()
|
||||
# For now, let's assume it's there. If not, this will fail earlier or be caught by linter.
|
||||
pass # Assuming tokenizer is initialized
|
||||
soup = BeautifulSoup(llm_html_str, "lxml")
|
||||
can_proceed_with_rule_checks = True
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[{self.name}] Item {original_info.get('id', 'N/A')}: Could not parse LLM output as HTML: {e}"
|
||||
)
|
||||
|
||||
if can_proceed_with_rule_checks and soup is not None:
|
||||
for rule_instance in self.accessibility_rules:
|
||||
if rule_instance.issue_key in issues_expected_to_fix:
|
||||
try:
|
||||
if rule_instance.check(soup, original_info):
|
||||
num_issues_actually_fixed += 1
|
||||
print(
|
||||
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
|
||||
Rule '{rule_instance.issue_key}' PASSED."""
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
|
||||
Rule '{rule_instance.issue_key}' FAILED."""
|
||||
)
|
||||
except Exception as rule_e:
|
||||
print(
|
||||
f"""[{self.name}] Item {original_info.get('id', 'N/A')}:
|
||||
Error executing rule '{rule_instance.issue_key}': {rule_e}"""
|
||||
)
|
||||
|
||||
# Determine score based on fixes and parseability
|
||||
if num_issues_targeted > 0:
|
||||
if not can_proceed_with_rule_checks: # Parsing failed
|
||||
current_item_score = (
|
||||
-1.0 * num_issues_targeted
|
||||
) # Penalize per targeted issue if unparseable
|
||||
elif num_issues_actually_fixed == num_issues_targeted:
|
||||
current_item_score = 1.0 # All targeted issues fixed
|
||||
elif (
|
||||
num_issues_actually_fixed > 0
|
||||
): # Some, but not all, targeted issues fixed
|
||||
current_item_score = 0.8 * (
|
||||
num_issues_actually_fixed / num_issues_targeted
|
||||
)
|
||||
else: # Parseable, but no targeted issues fixed
|
||||
current_item_score = -0.5
|
||||
else: # No issues were targeted for this item (e.g., input was considered good by dataset design)
|
||||
if (
|
||||
not can_proceed_with_rule_checks
|
||||
): # LLM made a good input unparseable
|
||||
current_item_score = -1.0
|
||||
else: # Parseable, and no issues were targeted (good input remained good)
|
||||
current_item_score = 0.0 # Neutral score
|
||||
|
||||
# Tokenization
|
||||
try:
|
||||
if not self.tokenizer:
|
||||
raise ValueError("Tokenizer not initialized.")
|
||||
tokenized_output = tokenize_for_trainer(
|
||||
self.tokenizer, # Make sure self.tokenizer is loaded, e.g., in setup()
|
||||
data_item["full_exchange_messages"],
|
||||
self.tokenizer, full_exchange_messages_list_of_dicts
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during tokenization: {e}. Skipping item.")
|
||||
continue
|
||||
print(
|
||||
f"""[{self.name}] Error during tokenization for item
|
||||
{original_info.get('id', 'N/A')}: {e}. Skipping this item."""
|
||||
)
|
||||
continue # Skip to the next data_item in rollout_group_data
|
||||
|
||||
if "tokens" not in tokenized_output or "masks" not in tokenized_output:
|
||||
print(
|
||||
f"[{self.name}] Warning: Tokenization did not return tokens/masks for an item. Skipping."
|
||||
f"""[{self.name}] Tokenization did not produce 'tokens' or
|
||||
'masks' for item {original_info.get('id', 'N/A')}. Skipping this item."""
|
||||
)
|
||||
continue
|
||||
continue # Skip to the next data_item
|
||||
|
||||
all_tokens.append(tokenized_output["tokens"])
|
||||
all_masks.append(tokenized_output["masks"])
|
||||
all_scores.append(current_score)
|
||||
# If we reach here, scoring and tokenization for the current item were successful
|
||||
final_tokens_batch.append(tokenized_output["tokens"])
|
||||
final_masks_batch.append(tokenized_output["masks"])
|
||||
final_scores_batch.append(current_item_score)
|
||||
|
||||
# If you were to populate optional fields, you'd do it here. For example:
|
||||
# if "advantages" in tokenized_output: # Fictional example
|
||||
# if all_advantages is None: all_advantages = []
|
||||
# all_advantages.append(tokenized_output["advantages"])
|
||||
if self.config.include_messages:
|
||||
formatted_message_log = "".join(
|
||||
f"{msg_dict['role']}: {msg_dict['content']}\n"
|
||||
for msg_dict in full_exchange_messages_list_of_dicts
|
||||
)
|
||||
final_concatenated_dialogues_batch.append(formatted_message_log.strip())
|
||||
|
||||
if not all_scores:
|
||||
print(f"[{self.name}] No valid items to score, returning None.")
|
||||
# After processing all items in rollout_group_data
|
||||
if (
|
||||
not final_scores_batch
|
||||
): # If all items were skipped (e.g., due to tokenization errors)
|
||||
print(
|
||||
f"""[{self.name}] No valid items to include in ScoredDataGroup
|
||||
after processing all rollouts, returning None."""
|
||||
)
|
||||
return None
|
||||
|
||||
# print(f"[{self.name}] Scoring complete. Scores: {all_scores}") # Already printed if successful below
|
||||
|
||||
# Construct the dictionary that conforms to ScoredDataGroup TypedDict
|
||||
# Mandatory fields:
|
||||
data_to_return: ScoredDataGroup = {
|
||||
"tokens": all_tokens,
|
||||
"masks": all_masks,
|
||||
"scores": all_scores,
|
||||
"tokens": final_tokens_batch,
|
||||
"masks": final_masks_batch,
|
||||
"scores": final_scores_batch,
|
||||
"advantages": all_advantages,
|
||||
"ref_logprobs": all_ref_logprobs,
|
||||
"group_overrides": {},
|
||||
"messages": all_messages_for_trainer,
|
||||
"messages": (
|
||||
final_concatenated_dialogues_batch
|
||||
if self.config.include_messages and final_concatenated_dialogues_batch
|
||||
else None
|
||||
), # type: ignore[assignment]
|
||||
"overrides": None,
|
||||
}
|
||||
|
||||
print(
|
||||
f"""[{self.name}] Scoring complete. Data to return (first score):
|
||||
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""
|
||||
f"[{self.name}] Scoring batch complete. Final scores for batch: {data_to_return['scores']}"
|
||||
)
|
||||
return data_to_return
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
): # Optional, might not be needed for hackathon 'process' focus
|
||||
):
|
||||
print(f"[{self.name}] Evaluate method called (placeholder).")
|
||||
# Implement evaluation logic if you have a separate test set and metrics
|
||||
pass
|
||||
|
||||
# --- Helper methods for scoring ---
|
||||
# def check_wcag_fixes(self, modified_html: str, original_item_info: dict) -> bool:
|
||||
# # Placeholder for your actual WCAG checking logic
|
||||
# # e.g., using BeautifulSoup to parse modified_html
|
||||
# # and checking against `original_item_info['issues_to_fix']`
|
||||
# # from bs4 import BeautifulSoup
|
||||
# # soup = BeautifulSoup(modified_html, 'html.parser')
|
||||
# # ... logic ...
|
||||
# return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This makes your environment runnable with `python accessibility_env.py process`
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue