This commit is contained in:
Shannon Sands 2025-05-14 14:01:31 -07:00
parent 826de9e283
commit 67cfd961c5
6 changed files with 111 additions and 85 deletions

View file

@ -1,9 +1,9 @@
import logging
from typing import Dict, List, Optional, Tuple
import json
import logging
import random
from typing import Dict, List, Optional, Tuple
import gymnasium as gym
import random
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
from atroposlib.type_definitions import Item, Message
@ -119,13 +119,13 @@ class BlackjackEnvNoThinking(BaseEnv):
return None
parsed_name, parsed_args, is_error = parse_tool_call(
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
)
if is_error:
error_detail = (
str(parsed_name) # Error message is in parsed_name if is_error
if parsed_name
str(parsed_name) # Error message is in parsed_name if is_error
if parsed_name
else "Parser indicated error, but no specific message was returned."
)
logger.warning(
@ -146,7 +146,8 @@ class BlackjackEnvNoThinking(BaseEnv):
return ACTION_STICK
else:
logger.warning(
f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
f"Successfully parsed tool call '{parsed_name}', "
f"but action argument is invalid. Action: '{action_str}'. "
f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
)
return None
@ -162,14 +163,13 @@ class BlackjackEnvNoThinking(BaseEnv):
seed = item["seed"]
messages: List[Message] = []
game_reward = 0.0
num_turns = 0
try:
env = gym.make(self.config.env_name)
except Exception as e:
logger.error(f"Failed to make environment {self.config.env_name}: {e}")
return None, []
try:
obs, info = env.reset(seed=seed)
except Exception as e:
@ -189,7 +189,9 @@ class BlackjackEnvNoThinking(BaseEnv):
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
> self.config.max_token_length - 50
):
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
logger.warning(
f"[Seed: {seed}] Max token length reached, truncating episode."
)
break
max_tokens_for_action = 512
@ -201,19 +203,25 @@ class BlackjackEnvNoThinking(BaseEnv):
max_tokens=max_tokens_for_action,
temperature=0.5,
)
llm_action_response = chat_completions.choices[0].message.content.strip()
logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
llm_action_response = chat_completions.choices[
0
].message.content.strip()
logger.info(
f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'"
) # Log raw response
except Exception as e:
logger.error(f"[Seed: {seed}] LLM API error: {e}")
break
messages.append({"role": "assistant", "content": llm_action_response})
action = self._parse_action_from_llm(llm_action_response)
if action is None:
logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.")
game_reward = -1.0
break
logger.warning(
f"[Seed: {seed}] Invalid action parsed. Ending episode."
)
game_reward = -1.0
break
try:
obs, reward, terminated, truncated, _ = env.step(action)
@ -224,19 +232,17 @@ class BlackjackEnvNoThinking(BaseEnv):
if terminated or truncated:
break
current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str})
env.close()
self.episode_outcomes_buffer.append(game_reward)
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True
tokenizer=self.tokenizer, chat=messages, train_on_all_assistant_turns=True
)
tokens = tokenization_result["tokens"]
masks = tokenization_result["masks"]
@ -256,24 +262,27 @@ class BlackjackEnvNoThinking(BaseEnv):
logger.info(f"Setting up {self.name} environment.")
async def evaluate(self, *args, **kwargs):
logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.")
logger.info(
f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes."
)
wins = 0
losses = 0
draws = 0
eval_outcomes: List[float] = []
for i in range(self.config.eval_episodes):
seed = random.randint(1_000_001, 2_000_000)
seed = random.randint(1_000_001, 2_000_000)
item = {"seed": seed}
scored_item_tuple = await self.collect_trajectory(item)
if scored_item_tuple and scored_item_tuple[0]:
outcome = scored_item_tuple[0]["scores"]
eval_outcomes.append(outcome)
else:
logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.")
logger.warning(
f"Evaluation episode {i+1} (seed {seed}) failed to produce data."
)
if not eval_outcomes:
logger.warning("No evaluation episodes completed successfully.")
@ -287,7 +296,7 @@ class BlackjackEnvNoThinking(BaseEnv):
losses += 1
else:
draws += 1
num_completed = len(eval_outcomes)
win_rate = wins / num_completed if num_completed > 0 else 0
loss_rate = losses / num_completed if num_completed > 0 else 0
@ -301,15 +310,18 @@ class BlackjackEnvNoThinking(BaseEnv):
(f"{self.name}_eval/avg_reward", avg_reward),
(f"{self.name}_eval/num_completed_episodes", num_completed),
]
logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}")
logger.info(
f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}"
)
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.episode_outcomes_buffer:
avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
avg_training_reward = sum(self.episode_outcomes_buffer) / len(
self.episode_outcomes_buffer
)
wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward
train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0)
train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0)
@ -317,7 +329,9 @@ class BlackjackEnvNoThinking(BaseEnv):
wandb_metrics[f"{self.name}_train/win_count"] = train_wins
wandb_metrics[f"{self.name}_train/loss_count"] = train_losses
wandb_metrics[f"{self.name}_train/draw_count"] = train_draws
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer)
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(
self.episode_outcomes_buffer
)
self.episode_outcomes_buffer = []