diff --git a/environments/hack0/accessibility_env/accessibility_env.py b/environments/hack0/accessibility_env/accessibility_env.py index b3095680..d46aebbb 100644 --- a/environments/hack0/accessibility_env/accessibility_env.py +++ b/environments/hack0/accessibility_env/accessibility_env.py @@ -2,6 +2,9 @@ import os # For API keys, etc. from typing import Dict, List, Optional, Tuple # Common type hints, added Dict +import tenacity + +# from bs4 import BeautifulSoup from transformers.models.auto.tokenization_auto import AutoTokenizer # Corrected imports for Atropos types @@ -38,23 +41,26 @@ class AccessibilityEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]: env_config = AccessibilityEnvConfig( - tokenizer_name="NousResearch/Llama-3-8B-Instruct- যেভাবে-তুমি-বাংলা-বলো", # Placeholder - group_size=2, # Smaller for faster testing initially + tokenizer_name="meta-llama/Llama-2-7b-chat-hf", + group_size=1, # Smaller for faster testing initially use_wandb=True, rollout_server_url="http://localhost:8000", - total_steps=10, # For process mode, number of items to generate - batch_size=4, # Max items in a single call to score (related to group_size) + total_steps=3, # For process mode, number of items to generate + batch_size=1, # Max items in a single call to score (related to group_size) steps_per_eval=5, max_token_length=2048, - wandb_name="accessibility_env_hackathon_dev", # Dev run name + wandb_name="accessibility_llama_dev", # Dev run name ) + + llama_api_key = os.environ.get("LLAMA_API_KEY") + if not llama_api_key: + print("WARNING: LLAMA_API_KEY environment variable not set!") + server_configs = [ APIServerConfig( - model_name="gpt-3.5-turbo", # Or your preferred model - # base_url=None, # Defaults to OpenAI if None - api_key=os.environ.get( - "OPENAI_API_KEY", "YOUR_API_KEY_PLACEHOLDER_IF_NOT_SET" - ), # Important! + model_name="Llama-4-Maverick-17B-128E-Instruct-FP8", + base_url="https://api.llama.com/v1", # <<<---- Llama API base URL + api_key=llama_api_key, num_requests_for_eval=16, ), ] @@ -62,14 +68,34 @@ class AccessibilityEnv(BaseEnv): async def setup(self): print(f"[{self.name}] Setting up environment...") - # Load dataset, initialize tools (e.g., HTML parser) here try: self.tokenizer = AutoTokenizer.from_pretrained( self.config.tokenizer_name, trust_remote_code=True - ) - # It's good practice to set pad_token if it's not already set, common for GPT-like models + ) # tokenizer_name is 'gpt2' if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token + + # Set a default chat template if it's not already set + # This is crucial for tokenizers like 'gpt2' that don't have one by default. + if self.tokenizer.chat_template is None: + # A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat. + # For gpt-3.5-turbo, the actual formatting is handled by the API, + # but for local tokenization for the trainer, we need *a* template. + # A basic template for generic tokenization: + self.tokenizer.chat_template = ( + "{% for message in messages %}" + "{{ message['role'] + ': ' + message['content'] + '\\n' }}" + "{% endfor %}" + ) + # Alternatively, for many models, a more structured Jinja template like + # the Llama or ChatML one might be used if you were training with such a format. + # For just getting token IDs for a generic model for RL, the simple one above might suffice. + # Or, if tokenize_for_trainer is smart, it might just concatenate. + # Let's check if a simpler approach is needed for tokenize_for_trainer. + print( + f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'." + ) + print( f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully." ) @@ -77,10 +103,7 @@ class AccessibilityEnv(BaseEnv): print( f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}" ) - # Decide how to handle this - raise error, or try to proceed without tokenization (not ideal for Atropos) - # For now, let's allow it to proceed, but tokenization will fail later if tokenizer is None - # A better approach might be to raise an exception here if tokenizer is critical. - # raise RuntimeError(f"Failed to load tokenizer: {e}") from e + raise RuntimeError(f"Failed to load tokenizer: {e}") from e self.dataset = [ { @@ -93,7 +116,7 @@ class AccessibilityEnv(BaseEnv): "html": "", "issues_to_fix": ["missing_for_attribute_on_label"], }, - ] # Placeholder for your HTML snippets + ] self.iter = 0 print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.") @@ -175,12 +198,130 @@ class AccessibilityEnv(BaseEnv): {"role": "user", "content": user_message_content}, ] - chat_completions = await self.server.chat_completion( - messages=messages, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - # temperature=0.7, # Optional: adjust for creativity vs. determinism - ) + try: + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, # Number of completions + # `max_tokens` here is for the *completion* part, not the whole context. + # Your Llama API example used 256. Adjust as needed for HTML output. + max_tokens=1024, # Max tokens for the LLM's response + # temperature=0.7, # Optional: adjust for creativity vs. determinism + # model=self.server_configs[0].model_name # This should be picked up automatically from server_configs + # by the self.server object. + ) + except tenacity.RetryError as retry_err: # Specifically catch RetryError + print( + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + ) + print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:") + print(f"[{self.name}] RetryError Details: {retry_err}") + # ... and the response details if available on 'e' ... + original_exception = None + if retry_err.last_attempt: + if retry_err.last_attempt.failed: + original_exception = retry_err.last_attempt.exception() + print( + f"[{self.name}] Last attempt failed. Original exception that caused retries:" + ) + print(f"[{self.name}] Type: {type(original_exception)}") + print( + f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}" + ) + print( + f"[{self.name}] Full Str: {str(original_exception)}" + ) # More direct string representation + else: + # This case is unusual for a RetryError due to failure + print( + f"""[{self.name}] Last attempt recorded but did not 'fail'. + Result: {retry_err.last_attempt.result()}""" + ) + else: + print( + f"""[{self.name}] Could not get 'last_attempt' details from + RetryError object. Raw RetryError: {retry_err}""" + ) + + # Now, if we have the original_exception, try to get more details (like HTTP response) + if original_exception: + # Check if the original exception is an OpenAI/HTTPX style error + # by looking for a 'response' attribute. + if ( + hasattr(original_exception, "response") + and original_exception.response is not None + ): + response_obj = original_exception.response + status_code_text = "Status code N/A" + response_content_text = "Response content N/A" + + if hasattr(response_obj, "status_code"): + status_code_text = str(response_obj.status_code) + + print( + f"[{self.name}] Underlying API Response Status Code: {status_code_text}" + ) + + # Try to get JSON content first (common for API errors) + if hasattr(response_obj, "json") and callable(response_obj.json): + try: + response_json_parsed = ( + response_obj.json() + ) # Note: this might need to be awaited if response_obj.json is async + # but typically in an exception, it's already processed. + print( + f"[{self.name}] Underlying API Response JSON: {response_json_parsed}" + ) + except Exception as json_e_inner: + print( + f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}" + ) + # Fallback to text if JSON parsing fails + if hasattr(response_obj, "text"): + response_content_text = response_obj.text + print( + f"[{self.name}] Underlying API Response Text: {response_content_text}" + ) + elif hasattr(response_obj, "content"): # often bytes + try: + response_content_text = ( + response_obj.content.decode() + ) + print( + f"""[{self.name}] Underlying API Response + Content (decoded): {response_content_text}""" + ) + except Exception: + response_content_text = str(response_obj.content) + print( + f"""[{self.name}] Underlying API Response Content + (raw bytes as str): {response_content_text}""" + ) + # If no json() method, try .text or .content directly + elif hasattr(response_obj, "text"): + response_content_text = response_obj.text + print( + f"[{self.name}] Underlying API Response Text: {response_content_text}" + ) + elif hasattr(response_obj, "content"): + try: + response_content_text = response_obj.content.decode() + print( + f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}" + ) + except Exception: + response_content_text = str(response_obj.content) + print( + f"""[{self.name}] Underlying API Response Content + (raw bytes as str): {response_content_text}""" + ) + + print( + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + ) + print( + f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}" + ) + return None, [] to_score_inputs = [] for choice in chat_completions.choices: @@ -283,24 +424,13 @@ class AccessibilityEnv(BaseEnv): "tokens": all_tokens, "masks": all_masks, "scores": all_scores, - "advantages": None, - "ref_logprobs": None, - "group_overrides": None, - "messages": None, + "advantages": all_advantages, + "ref_logprobs": all_ref_logprobs, + "group_overrides": {}, + "messages": all_messages_for_trainer, "overrides": None, } - # Add optional fields if they have values (or if you explicitly want them as None) - # The TypedDict definition uses Optional[], so if a key is missing, it's fine. - # If you want to explicitly include them as None if not populated: - if all_advantages is not None: - data_to_return["advantages"] = all_advantages - if all_ref_logprobs is not None: - data_to_return["ref_logprobs"] = all_ref_logprobs - if all_messages_for_trainer is not None: - data_to_return["messages"] = all_messages_for_trainer - # group_overrides and overrides are also optional - print( f"""[{self.name}] Scoring complete. Data to return (first score): {data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}""" diff --git a/environments/hack0/accessibility_env/requirements.txt b/environments/hack0/accessibility_env/requirements.txt index e69de29b..c1f5f713 100644 --- a/environments/hack0/accessibility_env/requirements.txt +++ b/environments/hack0/accessibility_env/requirements.txt @@ -0,0 +1 @@ +beautifulsoup4