Fix environment issues. Safely ran python3 accessibility_env.py --help

This commit is contained in:
Josh 2025-05-18 13:07:08 -07:00
parent 737139994a
commit 659247fc00

View file

@ -1,11 +1,17 @@
# environments/hack0/accessibility_env/accessibility_env.py
from typing import List, Optional, Tuple # Common type hints
import os # For API keys, etc.
from typing import List, Optional, Tuple # Common type hints, added Dict
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig
from atroposlib.type_definitions import ( # Assuming you'll need these
Item,
# Corrected imports for Atropos types
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import ( # GameHistory might not be needed yet, Item is common
Item,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
@ -30,23 +36,24 @@ class AccessibilityEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[AccessibilityEnvConfig, List[APIServerConfig]]:
env_config = AccessibilityEnvConfig(
tokenizer_name="NousResearch/Llama-3-8B-Instruct- যেভাবে-তুমি-বাংলা-বলো", # Placeholder, change later
group_size=4, # Example, adjust as needed
use_wandb=True, # Recommended for hackathon
rollout_server_url="http://localhost:8000", # Standard Atropos default
total_steps=100, # For process mode, this is more like num_items_to_process
batch_size=8, # Example
steps_per_eval=20, # Less relevant for process-only
max_token_length=2048, # LLM context window
wandb_name="accessibility_env_hackathon", # Your Wandb run name
# data_path_to_save_groups="accessibility_rollouts.jsonl" # Often set via CLI for process
tokenizer_name="NousResearch/Llama-3-8B-Instruct- যেভাবে-তুমি-বাংলা-বলো", # Placeholder
group_size=2, # Smaller for faster testing initially
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=10, # For process mode, number of items to generate
batch_size=4, # Max items in a single call to score (related to group_size)
steps_per_eval=5,
max_token_length=2048,
wandb_name="accessibility_env_hackathon_dev", # Dev run name
)
server_configs = [
APIServerConfig(
model_name="gpt-3.5-turbo", # Placeholder, use your desired model
# base_url="YOUR_LLM_PROVIDER_BASE_URL_IF_NOT_OPENAI_DEFAULT", # e.g., for vLLM
# api_key="YOUR_API_KEY_HERE_OR_USE_ENV_VAR", # Best to use os.environ.get("OPENAI_API_KEY")
num_requests_for_eval=32, # Example
model_name="gpt-3.5-turbo", # Or your preferred model
# base_url=None, # Defaults to OpenAI if None
api_key=os.environ.get(
"OPENAI_API_KEY", "YOUR_API_KEY_PLACEHOLDER_IF_NOT_SET"
), # Important!
num_requests_for_eval=16,
),
]
return env_config, server_configs