Working POC

This commit is contained in:
Josh 2025-05-18 14:44:16 -07:00
parent 94038876f4
commit 8ff2b02ce0
2 changed files with 170 additions and 39 deletions

View file

@ -2,6 +2,9 @@
import os # For API keys, etc.
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
import tenacity
# from bs4 import BeautifulSoup
from transformers.models.auto.tokenization_auto import AutoTokenizer
# Corrected imports for Atropos types
@ -38,23 +41,26 @@ class AccessibilityEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
env_config = AccessibilityEnvConfig(
tokenizer_name="NousResearch/Llama-3-8B-Instruct- যেভাবে-তুমি-বাংলা-বলো", # Placeholder
group_size=2, # Smaller for faster testing initially
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
group_size=1, # Smaller for faster testing initially
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=10, # For process mode, number of items to generate
batch_size=4, # Max items in a single call to score (related to group_size)
total_steps=3, # For process mode, number of items to generate
batch_size=1, # Max items in a single call to score (related to group_size)
steps_per_eval=5,
max_token_length=2048,
wandb_name="accessibility_env_hackathon_dev", # Dev run name
wandb_name="accessibility_llama_dev", # Dev run name
)
llama_api_key = os.environ.get("LLAMA_API_KEY")
if not llama_api_key:
print("WARNING: LLAMA_API_KEY environment variable not set!")
server_configs = [
APIServerConfig(
model_name="gpt-3.5-turbo", # Or your preferred model
# base_url=None, # Defaults to OpenAI if None
api_key=os.environ.get(
"OPENAI_API_KEY", "YOUR_API_KEY_PLACEHOLDER_IF_NOT_SET"
), # Important!
model_name="Llama-4-Maverick-17B-128E-Instruct-FP8",
base_url="https://api.llama.com/v1", # <<<---- Llama API base URL
api_key=llama_api_key,
num_requests_for_eval=16,
),
]
@ -62,14 +68,34 @@ class AccessibilityEnv(BaseEnv):
async def setup(self):
print(f"[{self.name}] Setting up environment...")
# Load dataset, initialize tools (e.g., HTML parser) here
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_name, trust_remote_code=True
)
# It's good practice to set pad_token if it's not already set, common for GPT-like models
) # tokenizer_name is 'gpt2'
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Set a default chat template if it's not already set
# This is crucial for tokenizers like 'gpt2' that don't have one by default.
if self.tokenizer.chat_template is None:
# A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat.
# For gpt-3.5-turbo, the actual formatting is handled by the API,
# but for local tokenization for the trainer, we need *a* template.
# A basic template for generic tokenization:
self.tokenizer.chat_template = (
"{% for message in messages %}"
"{{ message['role'] + ': ' + message['content'] + '\\n' }}"
"{% endfor %}"
)
# Alternatively, for many models, a more structured Jinja template like
# the Llama or ChatML one might be used if you were training with such a format.
# For just getting token IDs for a generic model for RL, the simple one above might suffice.
# Or, if tokenize_for_trainer is smart, it might just concatenate.
# Let's check if a simpler approach is needed for tokenize_for_trainer.
print(
f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'."
)
print(
f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully."
)
@ -77,10 +103,7 @@ class AccessibilityEnv(BaseEnv):
print(
f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}"
)
# Decide how to handle this - raise error, or try to proceed without tokenization (not ideal for Atropos)
# For now, let's allow it to proceed, but tokenization will fail later if tokenizer is None
# A better approach might be to raise an exception here if tokenizer is critical.
# raise RuntimeError(f"Failed to load tokenizer: {e}") from e
raise RuntimeError(f"Failed to load tokenizer: {e}") from e
self.dataset = [
{
@ -93,7 +116,7 @@ class AccessibilityEnv(BaseEnv):
"html": "<label>Name</label><input type='text' name='username'>",
"issues_to_fix": ["missing_for_attribute_on_label"],
},
] # Placeholder for your HTML snippets
]
self.iter = 0
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
@ -175,12 +198,130 @@ class AccessibilityEnv(BaseEnv):
{"role": "user", "content": user_message_content},
]
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
# temperature=0.7, # Optional: adjust for creativity vs. determinism
)
try:
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size, # Number of completions
# `max_tokens` here is for the *completion* part, not the whole context.
# Your Llama API example used 256. Adjust as needed for HTML output.
max_tokens=1024, # Max tokens for the LLM's response
# temperature=0.7, # Optional: adjust for creativity vs. determinism
# model=self.server_configs[0].model_name # This should be picked up automatically from server_configs
# by the self.server object.
)
except tenacity.RetryError as retry_err: # Specifically catch RetryError
print(
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
)
print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:")
print(f"[{self.name}] RetryError Details: {retry_err}")
# ... and the response details if available on 'e' ...
original_exception = None
if retry_err.last_attempt:
if retry_err.last_attempt.failed:
original_exception = retry_err.last_attempt.exception()
print(
f"[{self.name}] Last attempt failed. Original exception that caused retries:"
)
print(f"[{self.name}] Type: {type(original_exception)}")
print(
f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}"
)
print(
f"[{self.name}] Full Str: {str(original_exception)}"
) # More direct string representation
else:
# This case is unusual for a RetryError due to failure
print(
f"""[{self.name}] Last attempt recorded but did not 'fail'.
Result: {retry_err.last_attempt.result()}"""
)
else:
print(
f"""[{self.name}] Could not get 'last_attempt' details from
RetryError object. Raw RetryError: {retry_err}"""
)
# Now, if we have the original_exception, try to get more details (like HTTP response)
if original_exception:
# Check if the original exception is an OpenAI/HTTPX style error
# by looking for a 'response' attribute.
if (
hasattr(original_exception, "response")
and original_exception.response is not None
):
response_obj = original_exception.response
status_code_text = "Status code N/A"
response_content_text = "Response content N/A"
if hasattr(response_obj, "status_code"):
status_code_text = str(response_obj.status_code)
print(
f"[{self.name}] Underlying API Response Status Code: {status_code_text}"
)
# Try to get JSON content first (common for API errors)
if hasattr(response_obj, "json") and callable(response_obj.json):
try:
response_json_parsed = (
response_obj.json()
) # Note: this might need to be awaited if response_obj.json is async
# but typically in an exception, it's already processed.
print(
f"[{self.name}] Underlying API Response JSON: {response_json_parsed}"
)
except Exception as json_e_inner:
print(
f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}"
)
# Fallback to text if JSON parsing fails
if hasattr(response_obj, "text"):
response_content_text = response_obj.text
print(
f"[{self.name}] Underlying API Response Text: {response_content_text}"
)
elif hasattr(response_obj, "content"): # often bytes
try:
response_content_text = (
response_obj.content.decode()
)
print(
f"""[{self.name}] Underlying API Response
Content (decoded): {response_content_text}"""
)
except Exception:
response_content_text = str(response_obj.content)
print(
f"""[{self.name}] Underlying API Response Content
(raw bytes as str): {response_content_text}"""
)
# If no json() method, try .text or .content directly
elif hasattr(response_obj, "text"):
response_content_text = response_obj.text
print(
f"[{self.name}] Underlying API Response Text: {response_content_text}"
)
elif hasattr(response_obj, "content"):
try:
response_content_text = response_obj.content.decode()
print(
f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}"
)
except Exception:
response_content_text = str(response_obj.content)
print(
f"""[{self.name}] Underlying API Response Content
(raw bytes as str): {response_content_text}"""
)
print(
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
)
print(
f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}"
)
return None, []
to_score_inputs = []
for choice in chat_completions.choices:
@ -283,24 +424,13 @@ class AccessibilityEnv(BaseEnv):
"tokens": all_tokens,
"masks": all_masks,
"scores": all_scores,
"advantages": None,
"ref_logprobs": None,
"group_overrides": None,
"messages": None,
"advantages": all_advantages,
"ref_logprobs": all_ref_logprobs,
"group_overrides": {},
"messages": all_messages_for_trainer,
"overrides": None,
}
# Add optional fields if they have values (or if you explicitly want them as None)
# The TypedDict definition uses Optional[], so if a key is missing, it's fine.
# If you want to explicitly include them as None if not populated:
if all_advantages is not None:
data_to_return["advantages"] = all_advantages
if all_ref_logprobs is not None:
data_to_return["ref_logprobs"] = all_ref_logprobs
if all_messages_for_trainer is not None:
data_to_return["messages"] = all_messages_for_trainer
# group_overrides and overrides are also optional
print(
f"""[{self.name}] Scoring complete. Data to return (first score):
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""

View file

@ -0,0 +1 @@
beautifulsoup4