mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import random
|
|
from typing import Optional
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from atroposlib.envs.base import EvalHandlingEnum, OpenaiConfig, ScoredDataItem
|
|
from environments.game_environments.gymnasium.blackjack_env_no_thinking import (
|
|
BlackjackEnvNoThinking,
|
|
BlackjackEnvNoThinkingConfig,
|
|
)
|
|
|
|
load_dotenv()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def main():
|
|
logger.info(
|
|
"Starting Blackjack (No Thinking) environment local debug runner"
|
|
)
|
|
|
|
env_config = BlackjackEnvNoThinkingConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=1,
|
|
use_wandb=False,
|
|
wandb_name="blackjack_no_thinking_local_debug",
|
|
max_num_workers=1,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1,
|
|
batch_size=1,
|
|
steps_per_eval=0,
|
|
max_token_length=1024,
|
|
inference_weight=1.0,
|
|
data_path_to_save_groups=None,
|
|
eval_handling=EvalHandlingEnum.NONE,
|
|
eval_limit_ratio=0.0,
|
|
env_name="Blackjack-v1",
|
|
max_episode_turns=10,
|
|
eval_episodes=0,
|
|
)
|
|
server_configs = [
|
|
OpenaiConfig(
|
|
model_name="gpt-4.1-nano",
|
|
base_url="https://api.openai.com/v1",
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
num_requests_for_eval=0,
|
|
)
|
|
]
|
|
logger.info("Using hardcoded debug configuration for No Thinking Blackjack.")
|
|
logger.debug(f"Env Config: {env_config}")
|
|
logger.debug(f"Server Configs: {server_configs}")
|
|
|
|
try:
|
|
env = BlackjackEnvNoThinking(
|
|
config=env_config,
|
|
server_configs=server_configs,
|
|
slurm=False,
|
|
testing=False,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Failed to initialize BlackjackEnvNoThinking: {e}")
|
|
return
|
|
|
|
logger.info("Running a single trajectory directly using collect_trajectory")
|
|
try:
|
|
await env.setup()
|
|
seed = random.randint(0, 1000000)
|
|
item_for_env = {"seed": seed}
|
|
logger.info(f"Using seed: {seed} for item: {item_for_env}")
|
|
|
|
result_tuple = await env.collect_trajectory(item_for_env)
|
|
|
|
scored_data_item: Optional[ScoredDataItem] = None
|
|
if result_tuple and result_tuple[0]:
|
|
scored_data_item = result_tuple[0]
|
|
logger.info(
|
|
f"Trajectory collection complete. Score: {scored_data_item.get('scores')}"
|
|
)
|
|
if env_config.include_messages and scored_data_item.get('messages'):
|
|
logger.info("Collected Messages:")
|
|
for i, msg in enumerate(scored_data_item['messages']):
|
|
logger.info(f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'")
|
|
logger.info(f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}...")
|
|
logger.info(f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}...")
|
|
else:
|
|
logger.error("Trajectory collection did not return a ScoredDataItem.")
|
|
|
|
episode_summary_reward = None
|
|
if env.episode_outcomes_buffer:
|
|
episode_summary_reward = env.episode_outcomes_buffer[-1]
|
|
|
|
if episode_summary_reward is not None:
|
|
logger.info("\n========== Episode Summary ==========")
|
|
logger.info(f"Seed: {seed}")
|
|
logger.info(
|
|
f"Final Environment reward (Score): {episode_summary_reward:.2f}"
|
|
)
|
|
outcome_str = "Draw"
|
|
if episode_summary_reward > 0:
|
|
outcome_str = "Win"
|
|
elif episode_summary_reward < 0:
|
|
outcome_str = "Loss"
|
|
logger.info(f"Game Outcome: {outcome_str}")
|
|
logger.info("=======================================")
|
|
else:
|
|
logger.error(
|
|
f"Could not get episode summary for seed {seed} from metrics buffer."
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(
|
|
f"An error occurred during trajectory collection or summary: {e}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|