mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
new env runs locally
This commit is contained in:
parent
54ae40840d
commit
d6f9d58606
4 changed files with 204 additions and 45 deletions
|
|
@ -3,31 +3,5 @@ Utility functions and classes for the atroposlib package.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .config_handler import ConfigHandler
|
from .config_handler import ConfigHandler
|
||||||
from .message_history_utils import (
|
|
||||||
strip_thinking,
|
|
||||||
truncate_thinking,
|
|
||||||
ensure_trajectory_token_limit,
|
|
||||||
)
|
|
||||||
from .tokenize_for_trainer import tokenize_for_trainer
|
|
||||||
from .tool_call_parser import parse_tool_call
|
|
||||||
from .advantages import (
|
|
||||||
allclose_to_first,
|
|
||||||
compute_stats,
|
|
||||||
compute_discounted_returns,
|
|
||||||
compute_grpo_process_supervision_advantages,
|
|
||||||
)
|
|
||||||
from .best_of_n_selection import select_best_index
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["ConfigHandler"]
|
||||||
"ConfigHandler",
|
|
||||||
"strip_thinking",
|
|
||||||
"truncate_thinking",
|
|
||||||
"tokenize_for_trainer",
|
|
||||||
"parse_tool_call",
|
|
||||||
"allclose_to_first",
|
|
||||||
"compute_stats",
|
|
||||||
"compute_discounted_returns",
|
|
||||||
"compute_grpo_process_supervision_advantages",
|
|
||||||
"ensure_trajectory_token_limit",
|
|
||||||
"select_best_index",
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
import json
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import random
|
import random
|
||||||
|
|
@ -7,6 +8,7 @@ import random
|
||||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
|
||||||
from atroposlib.type_definitions import Item, Message
|
from atroposlib.type_definitions import Item, Message
|
||||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
|
from atroposlib.utils.tool_call_parser import parse_tool_call
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -42,6 +44,39 @@ class BlackjackEnvNoThinking(BaseEnv):
|
||||||
self.episode_outcomes_buffer: List[float] = []
|
self.episode_outcomes_buffer: List[float] = []
|
||||||
self.eval_metrics_custom: List[Tuple[str, float]] = []
|
self.eval_metrics_custom: List[Tuple[str, float]] = []
|
||||||
|
|
||||||
|
# Define tools available to the LLM
|
||||||
|
self.tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "take_action",
|
||||||
|
"description": "Choose to 'hit' or 'stick' in Blackjack.",
|
||||||
|
"parameters": {
|
||||||
|
# Parameters are implicitly defined by the arguments of the function call
|
||||||
|
# For this simple case, let's assume the LLM will provide arguments.action
|
||||||
|
# based on the prompt. A more robust schema would define 'action' here.
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {"type": "string", "enum": ["hit", "stick"]}
|
||||||
|
},
|
||||||
|
"required": ["action"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
tools_json = json.dumps(self.tools)
|
||||||
|
# Updated system prompt for tool calling
|
||||||
|
self.system_prompt = (
|
||||||
|
"You are an AI agent playing Blackjack. "
|
||||||
|
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
|
||||||
|
f"<tools>\n{tools_json}\n</tools>\n\n"
|
||||||
|
"For your function call, return a JSON object with function name and arguments "
|
||||||
|
"within <tool_call> </tool_call> tags with the following schema:\n"
|
||||||
|
'<tool_call>\n{"arguments": {"action": "hit"}, "name": "take_action"}\n</tool_call>\n\n'
|
||||||
|
"Your full answer format should be (NO THINKING BLOCK):\n"
|
||||||
|
'<tool_call>\n{"arguments": {"action": "stick"}, "name": "take_action"}\n</tool_call>\n'
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
|
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
|
||||||
|
|
@ -76,12 +111,45 @@ class BlackjackEnvNoThinking(BaseEnv):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
|
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
|
||||||
"""Parses 'hit' or 'stick' from the LLM response."""
|
"""Parses the action from the LLM's tool_call response."""
|
||||||
action_str = llm_response.strip().lower()
|
if not llm_response:
|
||||||
if action_str in ACTION_STR_TO_INT:
|
logger.warning(
|
||||||
return ACTION_STR_TO_INT[action_str]
|
"Attempted to parse an empty LLM response. Returning invalid action (None)."
|
||||||
logger.warning(f"Could not parse action from LLM response: '{llm_response}'")
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
parsed_name, parsed_args, is_error = parse_tool_call(
|
||||||
|
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_error:
|
||||||
|
error_detail = (
|
||||||
|
str(parsed_name) # Error message is in parsed_name if is_error
|
||||||
|
if parsed_name
|
||||||
|
else "Parser indicated error, but no specific message was returned."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if parsed_name != "take_action":
|
||||||
|
logger.warning(
|
||||||
|
f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
action_str = parsed_args.get("action", "").lower()
|
||||||
|
if action_str == "hit":
|
||||||
|
return ACTION_HIT
|
||||||
|
elif action_str == "stick":
|
||||||
|
return ACTION_STICK
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
|
||||||
|
f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
async def collect_trajectory(
|
async def collect_trajectory(
|
||||||
self, item: Item
|
self, item: Item
|
||||||
|
|
@ -109,10 +177,8 @@ class BlackjackEnvNoThinking(BaseEnv):
|
||||||
env.close()
|
env.close()
|
||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
system_prompt = (
|
# Use the class system_prompt
|
||||||
"You are playing Blackjack. Respond with either 'hit' or 'stick'."
|
messages.append({"role": "system", "content": self.system_prompt})
|
||||||
)
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
|
|
||||||
current_obs_str = self._format_observation(obs)
|
current_obs_str = self._format_observation(obs)
|
||||||
messages.append({"role": "user", "content": current_obs_str})
|
messages.append({"role": "user", "content": current_obs_str})
|
||||||
|
|
@ -126,7 +192,7 @@ class BlackjackEnvNoThinking(BaseEnv):
|
||||||
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
|
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
|
||||||
break
|
break
|
||||||
|
|
||||||
max_tokens_for_action = 10
|
max_tokens_for_action = 512
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chat_completions = await server.chat_completion(
|
chat_completions = await server.chat_completion(
|
||||||
|
|
@ -136,6 +202,7 @@ class BlackjackEnvNoThinking(BaseEnv):
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
)
|
)
|
||||||
llm_action_response = chat_completions.choices[0].message.content.strip()
|
llm_action_response = chat_completions.choices[0].message.content.strip()
|
||||||
|
logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -27,13 +27,10 @@ from atroposlib.envs.base import (
|
||||||
OpenaiConfig,
|
OpenaiConfig,
|
||||||
ScoredDataGroup,
|
ScoredDataGroup,
|
||||||
)
|
)
|
||||||
from atroposlib.utils import (
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
tokenize_for_trainer,
|
from atroposlib.utils.message_history_utils import truncate_thinking
|
||||||
parse_tool_call,
|
from atroposlib.utils.tool_call_parser import parse_tool_call
|
||||||
truncate_thinking,
|
from atroposlib.utils.best_of_n_selection import select_best_index
|
||||||
ensure_trajectory_token_limit,
|
|
||||||
select_best_index
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from atroposlib.envs.base import EvalHandlingEnum, OpenaiConfig, ScoredDataItem
|
||||||
|
from environments.game_environments.gymnasium.blackjack_env_no_thinking import (
|
||||||
|
BlackjackEnvNoThinking,
|
||||||
|
BlackjackEnvNoThinkingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
logger.info(
|
||||||
|
"Starting Blackjack (No Thinking) environment local debug runner"
|
||||||
|
)
|
||||||
|
|
||||||
|
env_config = BlackjackEnvNoThinkingConfig(
|
||||||
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||||
|
group_size=1,
|
||||||
|
use_wandb=False,
|
||||||
|
wandb_name="blackjack_no_thinking_local_debug",
|
||||||
|
max_num_workers=1,
|
||||||
|
rollout_server_url="http://localhost:8000",
|
||||||
|
total_steps=1,
|
||||||
|
batch_size=1,
|
||||||
|
steps_per_eval=0,
|
||||||
|
max_token_length=1024,
|
||||||
|
inference_weight=1.0,
|
||||||
|
data_path_to_save_groups=None,
|
||||||
|
eval_handling=EvalHandlingEnum.NONE,
|
||||||
|
eval_limit_ratio=0.0,
|
||||||
|
env_name="Blackjack-v1",
|
||||||
|
max_episode_turns=10,
|
||||||
|
eval_episodes=0,
|
||||||
|
)
|
||||||
|
server_configs = [
|
||||||
|
OpenaiConfig(
|
||||||
|
model_name="gpt-4.1-nano",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
num_requests_for_eval=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
logger.info("Using hardcoded debug configuration for No Thinking Blackjack.")
|
||||||
|
logger.debug(f"Env Config: {env_config}")
|
||||||
|
logger.debug(f"Server Configs: {server_configs}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = BlackjackEnvNoThinking(
|
||||||
|
config=env_config,
|
||||||
|
server_configs=server_configs,
|
||||||
|
slurm=False,
|
||||||
|
testing=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to initialize BlackjackEnvNoThinking: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Running a single trajectory directly using collect_trajectory")
|
||||||
|
try:
|
||||||
|
await env.setup()
|
||||||
|
seed = random.randint(0, 1000000)
|
||||||
|
item_for_env = {"seed": seed}
|
||||||
|
logger.info(f"Using seed: {seed} for item: {item_for_env}")
|
||||||
|
|
||||||
|
result_tuple = await env.collect_trajectory(item_for_env)
|
||||||
|
|
||||||
|
scored_data_item: Optional[ScoredDataItem] = None
|
||||||
|
if result_tuple and result_tuple[0]:
|
||||||
|
scored_data_item = result_tuple[0]
|
||||||
|
logger.info(
|
||||||
|
f"Trajectory collection complete. Score: {scored_data_item.get('scores')}"
|
||||||
|
)
|
||||||
|
if env_config.include_messages and scored_data_item.get('messages'):
|
||||||
|
logger.info("Collected Messages:")
|
||||||
|
for i, msg in enumerate(scored_data_item['messages']):
|
||||||
|
logger.info(f" {i}. Role: {msg['role']}, Content: '{str(msg['content'])[:150]}...'")
|
||||||
|
logger.info(f"Tokens ({len(scored_data_item.get('tokens', []))}): {str(scored_data_item.get('tokens'))[:100]}...")
|
||||||
|
logger.info(f"Masks ({len(scored_data_item.get('masks', []))}): {str(scored_data_item.get('masks'))[:100]}...")
|
||||||
|
else:
|
||||||
|
logger.error("Trajectory collection did not return a ScoredDataItem.")
|
||||||
|
|
||||||
|
episode_summary_reward = None
|
||||||
|
if env.episode_outcomes_buffer:
|
||||||
|
episode_summary_reward = env.episode_outcomes_buffer[-1]
|
||||||
|
|
||||||
|
if episode_summary_reward is not None:
|
||||||
|
logger.info("\n========== Episode Summary ==========")
|
||||||
|
logger.info(f"Seed: {seed}")
|
||||||
|
logger.info(
|
||||||
|
f"Final Environment reward (Score): {episode_summary_reward:.2f}"
|
||||||
|
)
|
||||||
|
outcome_str = "Draw"
|
||||||
|
if episode_summary_reward > 0:
|
||||||
|
outcome_str = "Win"
|
||||||
|
elif episode_summary_reward < 0:
|
||||||
|
outcome_str = "Loss"
|
||||||
|
logger.info(f"Game Outcome: {outcome_str}")
|
||||||
|
logger.info("=======================================")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Could not get episode summary for seed {seed} from metrics buffer."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"An error occurred during trajectory collection or summary: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Add a link
Reference in a new issue