diff --git a/environments/hack0/accessibility_env/accessibility_env.py b/environments/hack0/accessibility_env/accessibility_env.py
index b3095680..d46aebbb 100644
--- a/environments/hack0/accessibility_env/accessibility_env.py
+++ b/environments/hack0/accessibility_env/accessibility_env.py
@@ -2,6 +2,9 @@
import os # For API keys, etc.
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
+import tenacity
+
+# from bs4 import BeautifulSoup
from transformers.models.auto.tokenization_auto import AutoTokenizer
# Corrected imports for Atropos types
@@ -38,23 +41,26 @@ class AccessibilityEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
env_config = AccessibilityEnvConfig(
- tokenizer_name="NousResearch/Llama-3-8B-Instruct- যেভাবে-তুমি-বাংলা-বলো", # Placeholder
- group_size=2, # Smaller for faster testing initially
+ tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
+ group_size=1, # Smaller for faster testing initially
use_wandb=True,
rollout_server_url="http://localhost:8000",
- total_steps=10, # For process mode, number of items to generate
- batch_size=4, # Max items in a single call to score (related to group_size)
+ total_steps=3, # For process mode, number of items to generate
+ batch_size=1, # Max items in a single call to score (related to group_size)
steps_per_eval=5,
max_token_length=2048,
- wandb_name="accessibility_env_hackathon_dev", # Dev run name
+ wandb_name="accessibility_llama_dev", # Dev run name
)
+
+ llama_api_key = os.environ.get("LLAMA_API_KEY")
+ if not llama_api_key:
+ print("WARNING: LLAMA_API_KEY environment variable not set!")
+
server_configs = [
APIServerConfig(
- model_name="gpt-3.5-turbo", # Or your preferred model
- # base_url=None, # Defaults to OpenAI if None
- api_key=os.environ.get(
- "OPENAI_API_KEY", "YOUR_API_KEY_PLACEHOLDER_IF_NOT_SET"
- ), # Important!
+ model_name="Llama-4-Maverick-17B-128E-Instruct-FP8",
+ base_url="https://api.llama.com/v1", # <<<---- Llama API base URL
+ api_key=llama_api_key,
num_requests_for_eval=16,
),
]
@@ -62,14 +68,34 @@ class AccessibilityEnv(BaseEnv):
async def setup(self):
print(f"[{self.name}] Setting up environment...")
- # Load dataset, initialize tools (e.g., HTML parser) here
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_name, trust_remote_code=True
- )
- # It's good practice to set pad_token if it's not already set, common for GPT-like models
+ ) # tokenizer_name is 'gpt2'
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # Set a default chat template if it's not already set
+ # This is crucial for tokenizers like 'gpt2' that don't have one by default.
+ if self.tokenizer.chat_template is None:
+ # A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat.
+ # For gpt-3.5-turbo, the actual formatting is handled by the API,
+ # but for local tokenization for the trainer, we need *a* template.
+ # A basic template for generic tokenization:
+ self.tokenizer.chat_template = (
+ "{% for message in messages %}"
+ "{{ message['role'] + ': ' + message['content'] + '\\n' }}"
+ "{% endfor %}"
+ )
+ # Alternatively, for many models, a more structured Jinja template like
+ # the Llama or ChatML one might be used if you were training with such a format.
+ # For just getting token IDs for a generic model for RL, the simple one above might suffice.
+ # Or, if tokenize_for_trainer is smart, it might just concatenate.
+ # Let's check if a simpler approach is needed for tokenize_for_trainer.
+ print(
+ f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'."
+ )
+
print(
f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully."
)
@@ -77,10 +103,7 @@ class AccessibilityEnv(BaseEnv):
print(
f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}"
)
- # Decide how to handle this - raise error, or try to proceed without tokenization (not ideal for Atropos)
- # For now, let's allow it to proceed, but tokenization will fail later if tokenizer is None
- # A better approach might be to raise an exception here if tokenizer is critical.
- # raise RuntimeError(f"Failed to load tokenizer: {e}") from e
+ raise RuntimeError(f"Failed to load tokenizer: {e}") from e
self.dataset = [
{
@@ -93,7 +116,7 @@ class AccessibilityEnv(BaseEnv):
"html": "",
"issues_to_fix": ["missing_for_attribute_on_label"],
},
- ] # Placeholder for your HTML snippets
+ ]
self.iter = 0
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
@@ -175,12 +198,130 @@ class AccessibilityEnv(BaseEnv):
{"role": "user", "content": user_message_content},
]
- chat_completions = await self.server.chat_completion(
- messages=messages,
- n=self.config.group_size,
- max_tokens=self.config.max_token_length,
- # temperature=0.7, # Optional: adjust for creativity vs. determinism
- )
+ try:
+ chat_completions = await self.server.chat_completion(
+ messages=messages,
+ n=self.config.group_size, # Number of completions
+ # `max_tokens` here is for the *completion* part, not the whole context.
+ # Your Llama API example used 256. Adjust as needed for HTML output.
+ max_tokens=1024, # Max tokens for the LLM's response
+ # temperature=0.7, # Optional: adjust for creativity vs. determinism
+ # model=self.server_configs[0].model_name # This should be picked up automatically from server_configs
+ # by the self.server object.
+ )
+ except tenacity.RetryError as retry_err: # Specifically catch RetryError
+ print(
+ "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
+ )
+ print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:")
+ print(f"[{self.name}] RetryError Details: {retry_err}")
+ # ... and the response details if available on 'e' ...
+ original_exception = None
+ if retry_err.last_attempt:
+ if retry_err.last_attempt.failed:
+ original_exception = retry_err.last_attempt.exception()
+ print(
+ f"[{self.name}] Last attempt failed. Original exception that caused retries:"
+ )
+ print(f"[{self.name}] Type: {type(original_exception)}")
+ print(
+ f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}"
+ )
+ print(
+ f"[{self.name}] Full Str: {str(original_exception)}"
+ ) # More direct string representation
+ else:
+ # This case is unusual for a RetryError due to failure
+ print(
+ f"""[{self.name}] Last attempt recorded but did not 'fail'.
+ Result: {retry_err.last_attempt.result()}"""
+ )
+ else:
+ print(
+ f"""[{self.name}] Could not get 'last_attempt' details from
+ RetryError object. Raw RetryError: {retry_err}"""
+ )
+
+ # Now, if we have the original_exception, try to get more details (like HTTP response)
+ if original_exception:
+ # Check if the original exception is an OpenAI/HTTPX style error
+ # by looking for a 'response' attribute.
+ if (
+ hasattr(original_exception, "response")
+ and original_exception.response is not None
+ ):
+ response_obj = original_exception.response
+ status_code_text = "Status code N/A"
+ response_content_text = "Response content N/A"
+
+ if hasattr(response_obj, "status_code"):
+ status_code_text = str(response_obj.status_code)
+
+ print(
+ f"[{self.name}] Underlying API Response Status Code: {status_code_text}"
+ )
+
+ # Try to get JSON content first (common for API errors)
+ if hasattr(response_obj, "json") and callable(response_obj.json):
+ try:
+ response_json_parsed = (
+ response_obj.json()
+ ) # Note: this might need to be awaited if response_obj.json is async
+ # but typically in an exception, it's already processed.
+ print(
+ f"[{self.name}] Underlying API Response JSON: {response_json_parsed}"
+ )
+ except Exception as json_e_inner:
+ print(
+ f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}"
+ )
+ # Fallback to text if JSON parsing fails
+ if hasattr(response_obj, "text"):
+ response_content_text = response_obj.text
+ print(
+ f"[{self.name}] Underlying API Response Text: {response_content_text}"
+ )
+ elif hasattr(response_obj, "content"): # often bytes
+ try:
+ response_content_text = (
+ response_obj.content.decode()
+ )
+ print(
+ f"""[{self.name}] Underlying API Response
+ Content (decoded): {response_content_text}"""
+ )
+ except Exception:
+ response_content_text = str(response_obj.content)
+ print(
+ f"""[{self.name}] Underlying API Response Content
+ (raw bytes as str): {response_content_text}"""
+ )
+ # If no json() method, try .text or .content directly
+ elif hasattr(response_obj, "text"):
+ response_content_text = response_obj.text
+ print(
+ f"[{self.name}] Underlying API Response Text: {response_content_text}"
+ )
+ elif hasattr(response_obj, "content"):
+ try:
+ response_content_text = response_obj.content.decode()
+ print(
+ f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}"
+ )
+ except Exception:
+ response_content_text = str(response_obj.content)
+ print(
+ f"""[{self.name}] Underlying API Response Content
+ (raw bytes as str): {response_content_text}"""
+ )
+
+ print(
+ "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
+ )
+ print(
+ f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}"
+ )
+ return None, []
to_score_inputs = []
for choice in chat_completions.choices:
@@ -283,24 +424,13 @@ class AccessibilityEnv(BaseEnv):
"tokens": all_tokens,
"masks": all_masks,
"scores": all_scores,
- "advantages": None,
- "ref_logprobs": None,
- "group_overrides": None,
- "messages": None,
+ "advantages": all_advantages,
+ "ref_logprobs": all_ref_logprobs,
+ "group_overrides": {},
+ "messages": all_messages_for_trainer,
"overrides": None,
}
- # Add optional fields if they have values (or if you explicitly want them as None)
- # The TypedDict definition uses Optional[], so if a key is missing, it's fine.
- # If you want to explicitly include them as None if not populated:
- if all_advantages is not None:
- data_to_return["advantages"] = all_advantages
- if all_ref_logprobs is not None:
- data_to_return["ref_logprobs"] = all_ref_logprobs
- if all_messages_for_trainer is not None:
- data_to_return["messages"] = all_messages_for_trainer
- # group_overrides and overrides are also optional
-
print(
f"""[{self.name}] Scoring complete. Data to return (first score):
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""
diff --git a/environments/hack0/accessibility_env/requirements.txt b/environments/hack0/accessibility_env/requirements.txt
index e69de29b..c1f5f713 100644
--- a/environments/hack0/accessibility_env/requirements.txt
+++ b/environments/hack0/accessibility_env/requirements.txt
@@ -0,0 +1 @@
+beautifulsoup4