new env runs locally

This commit is contained in:
Shannon Sands 2025-05-14 11:57:45 -07:00
parent 54ae40840d
commit d6f9d58606
4 changed files with 204 additions and 45 deletions

View file

@ -3,31 +3,5 @@ Utility functions and classes for the atroposlib package.
""" """
from .config_handler import ConfigHandler from .config_handler import ConfigHandler
from .message_history_utils import (
strip_thinking,
truncate_thinking,
ensure_trajectory_token_limit,
)
from .tokenize_for_trainer import tokenize_for_trainer
from .tool_call_parser import parse_tool_call
from .advantages import (
allclose_to_first,
compute_stats,
compute_discounted_returns,
compute_grpo_process_supervision_advantages,
)
from .best_of_n_selection import select_best_index
__all__ = [ __all__ = ["ConfigHandler"]
"ConfigHandler",
"strip_thinking",
"truncate_thinking",
"tokenize_for_trainer",
"parse_tool_call",
"allclose_to_first",
"compute_stats",
"compute_discounted_returns",
"compute_grpo_process_supervision_advantages",
"ensure_trajectory_token_limit",
"select_best_index",
]

View file

@ -1,5 +1,6 @@
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import json
import gymnasium as gym import gymnasium as gym
import random import random
@ -7,6 +8,7 @@ import random
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
from atroposlib.type_definitions import Item, Message from atroposlib.type_definitions import Item, Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from atroposlib.utils.tool_call_parser import parse_tool_call
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,6 +44,39 @@ class BlackjackEnvNoThinking(BaseEnv):
self.episode_outcomes_buffer: List[float] = [] self.episode_outcomes_buffer: List[float] = []
self.eval_metrics_custom: List[Tuple[str, float]] = [] self.eval_metrics_custom: List[Tuple[str, float]] = []
# Define tools available to the LLM
self.tools = [
{
"type": "function",
"function": {
"name": "take_action",
"description": "Choose to 'hit' or 'stick' in Blackjack.",
"parameters": {
# Parameters are implicitly defined by the arguments of the function call
# For this simple case, let's assume the LLM will provide arguments.action
# based on the prompt. A more robust schema would define 'action' here.
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["hit", "stick"]}
},
"required": ["action"],
},
},
}
]
tools_json = json.dumps(self.tools)
# Updated system prompt for tool calling
self.system_prompt = (
"You are an AI agent playing Blackjack. "
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
f"<tools>\n{tools_json}\n</tools>\n\n"
"For your function call, return a JSON object with function name and arguments "
"within <tool_call> </tool_call> tags with the following schema:\n"
'<tool_call>\n{"arguments": {"action": "hit"}, "name": "take_action"}\n</tool_call>\n\n'
"Your full answer format should be (NO THINKING BLOCK):\n"
'<tool_call>\n{"arguments": {"action": "stick"}, "name": "take_action"}\n</tool_call>\n'
)
@classmethod @classmethod
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]: def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
@ -76,12 +111,45 @@ class BlackjackEnvNoThinking(BaseEnv):
) )
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]: def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
"""Parses 'hit' or 'stick' from the LLM response.""" """Parses the action from the LLM's tool_call response."""
action_str = llm_response.strip().lower() if not llm_response:
if action_str in ACTION_STR_TO_INT: logger.warning(
return ACTION_STR_TO_INT[action_str] "Attempted to parse an empty LLM response. Returning invalid action (None)."
logger.warning(f"Could not parse action from LLM response: '{llm_response}'") )
return None return None
parsed_name, parsed_args, is_error = parse_tool_call(
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
)
if is_error:
error_detail = (
str(parsed_name) # Error message is in parsed_name if is_error
if parsed_name
else "Parser indicated error, but no specific message was returned."
)
logger.warning(
f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}"
)
return None
if parsed_name != "take_action":
logger.warning(
f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'"
)
return None
action_str = parsed_args.get("action", "").lower()
if action_str == "hit":
return ACTION_HIT
elif action_str == "stick":
return ACTION_STICK
else:
logger.warning(
f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
)
return None
async def collect_trajectory( async def collect_trajectory(
self, item: Item self, item: Item
@ -109,10 +177,8 @@ class BlackjackEnvNoThinking(BaseEnv):
env.close() env.close()
return None, [] return None, []
system_prompt = ( # Use the class system_prompt
"You are playing Blackjack. Respond with either 'hit' or 'stick'." messages.append({"role": "system", "content": self.system_prompt})
)
messages.append({"role": "system", "content": system_prompt})
current_obs_str = self._format_observation(obs) current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str}) messages.append({"role": "user", "content": current_obs_str})
@ -126,7 +192,7 @@ class BlackjackEnvNoThinking(BaseEnv):
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.") logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
break break
max_tokens_for_action = 10 max_tokens_for_action = 512
try: try:
chat_completions = await server.chat_completion( chat_completions = await server.chat_completion(
@ -136,6 +202,7 @@ class BlackjackEnvNoThinking(BaseEnv):
temperature=0.5, temperature=0.5,
) )
llm_action_response = chat_completions.choices[0].message.content.strip() llm_action_response = chat_completions.choices[0].message.content.strip()
logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
except Exception as e: except Exception as e:
logger.error(f"[Seed: {seed}] LLM API error: {e}") logger.error(f"[Seed: {seed}] LLM API error: {e}")
break break

View file

@ -27,13 +27,10 @@ from atroposlib.envs.base import (
OpenaiConfig, OpenaiConfig,
ScoredDataGroup, ScoredDataGroup,
) )
from atroposlib.utils import ( from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
tokenize_for_trainer, from atroposlib.utils.message_history_utils import truncate_thinking
parse_tool_call, from atroposlib.utils.tool_call_parser import parse_tool_call
truncate_thinking, from atroposlib.utils.best_of_n_selection import select_best_index
ensure_trajectory_token_limit,
select_best_index
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -0,0 +1,121 @@
import asyncio
import logging
import os
import random
from typing import Optional
from dotenv import load_dotenv
from atroposlib.envs.base import EvalHandlingEnum, OpenaiConfig, ScoredDataItem
from environments.game_environments.gymnasium.blackjack_env_no_thinking import (
BlackjackEnvNoThinking,
BlackjackEnvNoThinkingConfig,
)
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def main():
logger.info(
"Starting Blackjack (No Thinking) environment local debug runner"
)
env_config = BlackjackEnvNoThinkingConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=1,
use_wandb=False,
wandb_name="blackjack_no_thinking_local_debug",
max_num_workers=1,
rollout_server_url="http://localhost:8000",
total_steps=1,
batch_size=1,
steps_per_eval=0,
max_token_length=1024,
inference_weight=1.0,
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.NONE,
eval_limit_ratio=0.0,
env_name="Blackjack-v1",
max_episode_turns=10,
eval_episodes=0,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4.1-nano",
base_url="https://api.openai.com/v1",
api_key=os.getenv("OPENAI_API_KEY"),
num_requests_for_eval=0,
)
]
logger.info("Using hardcoded debug configuration for No Thinking Blackjack.")
logger.debug(f"Env Config: {env_config}")
logger.debug(f"Server Configs: {server_configs}")
try:
env = BlackjackEnvNoThinking(
config=env_config,
server_configs=server_configs,
slurm=False,
testing=False,
)
except Exception as e:
logger.exception(f"Failed to initialize BlackjackEnvNoThinking: {e}")
return
logger.info("Running a single trajectory directly using collect_trajectory")
try:
await env.setup()
seed = random.randint(0, 1000000)
item_for_env = {"seed": seed}
logger.info(f"Using seed: {seed} for item: {item_for_env}")
result_tuple = await env.collect_trajectory(item_for_env)
scored_data_item: Optional[ScoredDataItem] = None
if result_tuple and result_tuple[0]:
scored_data_item = result_tuple[0]
logger.info(
f"Trajectory collection complete. Score: {scored_data_item.get('scores')}"
)
if env_config.include_messages and scored_data_item.get('messages'):
logger.info("Collected Messages:")
for i, msg in enumerate(scored_data_item['messages']):
logger.info(f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'")
logger.info(f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}...")
logger.info(f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}...")
else:
logger.error("Trajectory collection did not return a ScoredDataItem.")
episode_summary_reward = None
if env.episode_outcomes_buffer:
episode_summary_reward = env.episode_outcomes_buffer[-1]
if episode_summary_reward is not None:
logger.info("\n========== Episode Summary ==========")
logger.info(f"Seed: {seed}")
logger.info(
f"Final Environment reward (Score): {episode_summary_reward:.2f}"
)
outcome_str = "Draw"
if episode_summary_reward > 0:
outcome_str = "Win"
elif episode_summary_reward < 0:
outcome_str = "Loss"
logger.info(f"Game Outcome: {outcome_str}")
logger.info("=======================================")
else:
logger.error(
f"Could not get episode summary for seed {seed} from metrics buffer."
)
except Exception as e:
logger.exception(
f"An error occurred during trajectory collection or summary: {e}"
)
if __name__ == "__main__":
asyncio.run(main())