mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Working POC
This commit is contained in:
parent
94038876f4
commit
8ff2b02ce0
2 changed files with 170 additions and 39 deletions
|
|
@ -2,6 +2,9 @@
|
|||
import os # For API keys, etc.
|
||||
from typing import Dict, List, Optional, Tuple # Common type hints, added Dict
|
||||
|
||||
import tenacity
|
||||
|
||||
# from bs4 import BeautifulSoup
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
# Corrected imports for Atropos types
|
||||
|
|
@ -38,23 +41,26 @@ class AccessibilityEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
|
||||
env_config = AccessibilityEnvConfig(
|
||||
tokenizer_name="NousResearch/Llama-3-8B-Instruct- যেভাবে-তুমি-বাংলা-বলো", # Placeholder
|
||||
group_size=2, # Smaller for faster testing initially
|
||||
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
|
||||
group_size=1, # Smaller for faster testing initially
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10, # For process mode, number of items to generate
|
||||
batch_size=4, # Max items in a single call to score (related to group_size)
|
||||
total_steps=3, # For process mode, number of items to generate
|
||||
batch_size=1, # Max items in a single call to score (related to group_size)
|
||||
steps_per_eval=5,
|
||||
max_token_length=2048,
|
||||
wandb_name="accessibility_env_hackathon_dev", # Dev run name
|
||||
wandb_name="accessibility_llama_dev", # Dev run name
|
||||
)
|
||||
|
||||
llama_api_key = os.environ.get("LLAMA_API_KEY")
|
||||
if not llama_api_key:
|
||||
print("WARNING: LLAMA_API_KEY environment variable not set!")
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-3.5-turbo", # Or your preferred model
|
||||
# base_url=None, # Defaults to OpenAI if None
|
||||
api_key=os.environ.get(
|
||||
"OPENAI_API_KEY", "YOUR_API_KEY_PLACEHOLDER_IF_NOT_SET"
|
||||
), # Important!
|
||||
model_name="Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
base_url="https://api.llama.com/v1", # <<<---- Llama API base URL
|
||||
api_key=llama_api_key,
|
||||
num_requests_for_eval=16,
|
||||
),
|
||||
]
|
||||
|
|
@ -62,14 +68,34 @@ class AccessibilityEnv(BaseEnv):
|
|||
|
||||
async def setup(self):
|
||||
print(f"[{self.name}] Setting up environment...")
|
||||
# Load dataset, initialize tools (e.g., HTML parser) here
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config.tokenizer_name, trust_remote_code=True
|
||||
)
|
||||
# It's good practice to set pad_token if it's not already set, common for GPT-like models
|
||||
) # tokenizer_name is 'gpt2'
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# Set a default chat template if it's not already set
|
||||
# This is crucial for tokenizers like 'gpt2' that don't have one by default.
|
||||
if self.tokenizer.chat_template is None:
|
||||
# A common, simple template. You might need to adjust based on how gpt-3.5-turbo expects chat.
|
||||
# For gpt-3.5-turbo, the actual formatting is handled by the API,
|
||||
# but for local tokenization for the trainer, we need *a* template.
|
||||
# A basic template for generic tokenization:
|
||||
self.tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{ message['role'] + ': ' + message['content'] + '\\n' }}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
# Alternatively, for many models, a more structured Jinja template like
|
||||
# the Llama or ChatML one might be used if you were training with such a format.
|
||||
# For just getting token IDs for a generic model for RL, the simple one above might suffice.
|
||||
# Or, if tokenize_for_trainer is smart, it might just concatenate.
|
||||
# Let's check if a simpler approach is needed for tokenize_for_trainer.
|
||||
print(
|
||||
f"[{self.name}] Set a default chat_template for tokenizer '{self.config.tokenizer_name}'."
|
||||
)
|
||||
|
||||
print(
|
||||
f"[{self.name}] Tokenizer '{self.config.tokenizer_name}' loaded successfully."
|
||||
)
|
||||
|
|
@ -77,10 +103,7 @@ class AccessibilityEnv(BaseEnv):
|
|||
print(
|
||||
f"[{self.name}] Error loading tokenizer '{self.config.tokenizer_name}': {e}"
|
||||
)
|
||||
# Decide how to handle this - raise error, or try to proceed without tokenization (not ideal for Atropos)
|
||||
# For now, let's allow it to proceed, but tokenization will fail later if tokenizer is None
|
||||
# A better approach might be to raise an exception here if tokenizer is critical.
|
||||
# raise RuntimeError(f"Failed to load tokenizer: {e}") from e
|
||||
raise RuntimeError(f"Failed to load tokenizer: {e}") from e
|
||||
|
||||
self.dataset = [
|
||||
{
|
||||
|
|
@ -93,7 +116,7 @@ class AccessibilityEnv(BaseEnv):
|
|||
"html": "<label>Name</label><input type='text' name='username'>",
|
||||
"issues_to_fix": ["missing_for_attribute_on_label"],
|
||||
},
|
||||
] # Placeholder for your HTML snippets
|
||||
]
|
||||
self.iter = 0
|
||||
print(f"[{self.name}] Setup complete. Loaded {len(self.dataset)} items.")
|
||||
|
||||
|
|
@ -175,12 +198,130 @@ class AccessibilityEnv(BaseEnv):
|
|||
{"role": "user", "content": user_message_content},
|
||||
]
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
# temperature=0.7, # Optional: adjust for creativity vs. determinism
|
||||
)
|
||||
try:
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size, # Number of completions
|
||||
# `max_tokens` here is for the *completion* part, not the whole context.
|
||||
# Your Llama API example used 256. Adjust as needed for HTML output.
|
||||
max_tokens=1024, # Max tokens for the LLM's response
|
||||
# temperature=0.7, # Optional: adjust for creativity vs. determinism
|
||||
# model=self.server_configs[0].model_name # This should be picked up automatically from server_configs
|
||||
# by the self.server object.
|
||||
)
|
||||
except tenacity.RetryError as retry_err: # Specifically catch RetryError
|
||||
print(
|
||||
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
print(f"[{self.name}] TENACITY RETRY ERROR during chat_completion call:")
|
||||
print(f"[{self.name}] RetryError Details: {retry_err}")
|
||||
# ... and the response details if available on 'e' ...
|
||||
original_exception = None
|
||||
if retry_err.last_attempt:
|
||||
if retry_err.last_attempt.failed:
|
||||
original_exception = retry_err.last_attempt.exception()
|
||||
print(
|
||||
f"[{self.name}] Last attempt failed. Original exception that caused retries:"
|
||||
)
|
||||
print(f"[{self.name}] Type: {type(original_exception)}")
|
||||
print(
|
||||
f"[{self.name}] Args: {original_exception.args if original_exception else 'N/A'}"
|
||||
)
|
||||
print(
|
||||
f"[{self.name}] Full Str: {str(original_exception)}"
|
||||
) # More direct string representation
|
||||
else:
|
||||
# This case is unusual for a RetryError due to failure
|
||||
print(
|
||||
f"""[{self.name}] Last attempt recorded but did not 'fail'.
|
||||
Result: {retry_err.last_attempt.result()}"""
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"""[{self.name}] Could not get 'last_attempt' details from
|
||||
RetryError object. Raw RetryError: {retry_err}"""
|
||||
)
|
||||
|
||||
# Now, if we have the original_exception, try to get more details (like HTTP response)
|
||||
if original_exception:
|
||||
# Check if the original exception is an OpenAI/HTTPX style error
|
||||
# by looking for a 'response' attribute.
|
||||
if (
|
||||
hasattr(original_exception, "response")
|
||||
and original_exception.response is not None
|
||||
):
|
||||
response_obj = original_exception.response
|
||||
status_code_text = "Status code N/A"
|
||||
response_content_text = "Response content N/A"
|
||||
|
||||
if hasattr(response_obj, "status_code"):
|
||||
status_code_text = str(response_obj.status_code)
|
||||
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Status Code: {status_code_text}"
|
||||
)
|
||||
|
||||
# Try to get JSON content first (common for API errors)
|
||||
if hasattr(response_obj, "json") and callable(response_obj.json):
|
||||
try:
|
||||
response_json_parsed = (
|
||||
response_obj.json()
|
||||
) # Note: this might need to be awaited if response_obj.json is async
|
||||
# but typically in an exception, it's already processed.
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response JSON: {response_json_parsed}"
|
||||
)
|
||||
except Exception as json_e_inner:
|
||||
print(
|
||||
f"[{self.name}] Could not parse underlying API response as JSON: {json_e_inner}"
|
||||
)
|
||||
# Fallback to text if JSON parsing fails
|
||||
if hasattr(response_obj, "text"):
|
||||
response_content_text = response_obj.text
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Text: {response_content_text}"
|
||||
)
|
||||
elif hasattr(response_obj, "content"): # often bytes
|
||||
try:
|
||||
response_content_text = (
|
||||
response_obj.content.decode()
|
||||
)
|
||||
print(
|
||||
f"""[{self.name}] Underlying API Response
|
||||
Content (decoded): {response_content_text}"""
|
||||
)
|
||||
except Exception:
|
||||
response_content_text = str(response_obj.content)
|
||||
print(
|
||||
f"""[{self.name}] Underlying API Response Content
|
||||
(raw bytes as str): {response_content_text}"""
|
||||
)
|
||||
# If no json() method, try .text or .content directly
|
||||
elif hasattr(response_obj, "text"):
|
||||
response_content_text = response_obj.text
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Text: {response_content_text}"
|
||||
)
|
||||
elif hasattr(response_obj, "content"):
|
||||
try:
|
||||
response_content_text = response_obj.content.decode()
|
||||
print(
|
||||
f"[{self.name}] Underlying API Response Content (decoded): {response_content_text}"
|
||||
)
|
||||
except Exception:
|
||||
response_content_text = str(response_obj.content)
|
||||
print(
|
||||
f"""[{self.name}] Underlying API Response Content
|
||||
(raw bytes as str): {response_content_text}"""
|
||||
)
|
||||
|
||||
print(
|
||||
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
print(
|
||||
f"[{self.name}] Messages that were sent during the attempt resulting in RetryError: {messages}"
|
||||
)
|
||||
return None, []
|
||||
|
||||
to_score_inputs = []
|
||||
for choice in chat_completions.choices:
|
||||
|
|
@ -283,24 +424,13 @@ class AccessibilityEnv(BaseEnv):
|
|||
"tokens": all_tokens,
|
||||
"masks": all_masks,
|
||||
"scores": all_scores,
|
||||
"advantages": None,
|
||||
"ref_logprobs": None,
|
||||
"group_overrides": None,
|
||||
"messages": None,
|
||||
"advantages": all_advantages,
|
||||
"ref_logprobs": all_ref_logprobs,
|
||||
"group_overrides": {},
|
||||
"messages": all_messages_for_trainer,
|
||||
"overrides": None,
|
||||
}
|
||||
|
||||
# Add optional fields if they have values (or if you explicitly want them as None)
|
||||
# The TypedDict definition uses Optional[], so if a key is missing, it's fine.
|
||||
# If you want to explicitly include them as None if not populated:
|
||||
if all_advantages is not None:
|
||||
data_to_return["advantages"] = all_advantages
|
||||
if all_ref_logprobs is not None:
|
||||
data_to_return["ref_logprobs"] = all_ref_logprobs
|
||||
if all_messages_for_trainer is not None:
|
||||
data_to_return["messages"] = all_messages_for_trainer
|
||||
# group_overrides and overrides are also optional
|
||||
|
||||
print(
|
||||
f"""[{self.name}] Scoring complete. Data to return (first score):
|
||||
{data_to_return['scores'][0] if data_to_return['scores'] else 'N/A'}"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
beautifulsoup4
|
||||
Loading…
Add table
Add a link
Reference in a new issue