mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
new env runs locally
This commit is contained in:
parent
54ae40840d
commit
d6f9d58606
4 changed files with 204 additions and 45 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import json
|
||||
|
||||
import gymnasium as gym
|
||||
import random
|
||||
|
|
@ -7,6 +8,7 @@ import random
|
|||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item, Message
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from atroposlib.utils.tool_call_parser import parse_tool_call
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,6 +44,39 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
self.episode_outcomes_buffer: List[float] = []
|
||||
self.eval_metrics_custom: List[Tuple[str, float]] = []
|
||||
|
||||
# Define tools available to the LLM
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "take_action",
|
||||
"description": "Choose to 'hit' or 'stick' in Blackjack.",
|
||||
"parameters": {
|
||||
# Parameters are implicitly defined by the arguments of the function call
|
||||
# For this simple case, let's assume the LLM will provide arguments.action
|
||||
# based on the prompt. A more robust schema would define 'action' here.
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["hit", "stick"]}
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
tools_json = json.dumps(self.tools)
|
||||
# Updated system prompt for tool calling
|
||||
self.system_prompt = (
|
||||
"You are an AI agent playing Blackjack. "
|
||||
"You need to decide whether to hit or stick based on your current hand and the dealer's showing card.\n\n"
|
||||
f"<tools>\n{tools_json}\n</tools>\n\n"
|
||||
"For your function call, return a JSON object with function name and arguments "
|
||||
"within <tool_call> </tool_call> tags with the following schema:\n"
|
||||
'<tool_call>\n{"arguments": {"action": "hit"}, "name": "take_action"}\n</tool_call>\n\n'
|
||||
"Your full answer format should be (NO THINKING BLOCK):\n"
|
||||
'<tool_call>\n{"arguments": {"action": "stick"}, "name": "take_action"}\n</tool_call>\n'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
|
||||
|
|
@ -76,12 +111,45 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
)
|
||||
|
||||
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
|
||||
"""Parses 'hit' or 'stick' from the LLM response."""
|
||||
action_str = llm_response.strip().lower()
|
||||
if action_str in ACTION_STR_TO_INT:
|
||||
return ACTION_STR_TO_INT[action_str]
|
||||
logger.warning(f"Could not parse action from LLM response: '{llm_response}'")
|
||||
return None
|
||||
"""Parses the action from the LLM's tool_call response."""
|
||||
if not llm_response:
|
||||
logger.warning(
|
||||
"Attempted to parse an empty LLM response. Returning invalid action (None)."
|
||||
)
|
||||
return None
|
||||
|
||||
parsed_name, parsed_args, is_error = parse_tool_call(
|
||||
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
|
||||
)
|
||||
|
||||
if is_error:
|
||||
error_detail = (
|
||||
str(parsed_name) # Error message is in parsed_name if is_error
|
||||
if parsed_name
|
||||
else "Parser indicated error, but no specific message was returned."
|
||||
)
|
||||
logger.warning(
|
||||
f"Failed to parse tool call. Full response: '{llm_response}'. Error: {error_detail}"
|
||||
)
|
||||
return None
|
||||
|
||||
if parsed_name != "take_action":
|
||||
logger.warning(
|
||||
f"Expected tool call name 'take_action', but got '{parsed_name}'. Response: '{llm_response}'"
|
||||
)
|
||||
return None
|
||||
|
||||
action_str = parsed_args.get("action", "").lower()
|
||||
if action_str == "hit":
|
||||
return ACTION_HIT
|
||||
elif action_str == "stick":
|
||||
return ACTION_STICK
|
||||
else:
|
||||
logger.warning(
|
||||
f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
|
||||
f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
|
|
@ -109,10 +177,8 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
env.close()
|
||||
return None, []
|
||||
|
||||
system_prompt = (
|
||||
"You are playing Blackjack. Respond with either 'hit' or 'stick'."
|
||||
)
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
# Use the class system_prompt
|
||||
messages.append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
|
|
@ -126,7 +192,7 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
|
||||
break
|
||||
|
||||
max_tokens_for_action = 10
|
||||
max_tokens_for_action = 512
|
||||
|
||||
try:
|
||||
chat_completions = await server.chat_completion(
|
||||
|
|
@ -136,6 +202,7 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
temperature=0.5,
|
||||
)
|
||||
llm_action_response = chat_completions.choices[0].message.content.strip()
|
||||
logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
||||
break
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue