mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
linting
This commit is contained in:
parent
826de9e283
commit
67cfd961c5
6 changed files with 111 additions and 85 deletions
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import random
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item, Message
|
||||
|
|
@ -119,13 +119,13 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
return None
|
||||
|
||||
parsed_name, parsed_args, is_error = parse_tool_call(
|
||||
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
|
||||
llm_response, self.tools, ["tool_call"] # Expecting <tool_call>
|
||||
)
|
||||
|
||||
if is_error:
|
||||
error_detail = (
|
||||
str(parsed_name) # Error message is in parsed_name if is_error
|
||||
if parsed_name
|
||||
str(parsed_name) # Error message is in parsed_name if is_error
|
||||
if parsed_name
|
||||
else "Parser indicated error, but no specific message was returned."
|
||||
)
|
||||
logger.warning(
|
||||
|
|
@ -146,7 +146,8 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
return ACTION_STICK
|
||||
else:
|
||||
logger.warning(
|
||||
f"Successfully parsed tool call '{parsed_name}', but action argument is invalid. Action: '{action_str}'. "
|
||||
f"Successfully parsed tool call '{parsed_name}', "
|
||||
f"but action argument is invalid. Action: '{action_str}'. "
|
||||
f"Full response: '{llm_response}'. Parsed args: {parsed_args}"
|
||||
)
|
||||
return None
|
||||
|
|
@ -162,14 +163,13 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
seed = item["seed"]
|
||||
messages: List[Message] = []
|
||||
game_reward = 0.0
|
||||
num_turns = 0
|
||||
|
||||
try:
|
||||
env = gym.make(self.config.env_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to make environment {self.config.env_name}: {e}")
|
||||
return None, []
|
||||
|
||||
|
||||
try:
|
||||
obs, info = env.reset(seed=seed)
|
||||
except Exception as e:
|
||||
|
|
@ -189,7 +189,9 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
|
||||
> self.config.max_token_length - 50
|
||||
):
|
||||
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
|
||||
logger.warning(
|
||||
f"[Seed: {seed}] Max token length reached, truncating episode."
|
||||
)
|
||||
break
|
||||
|
||||
max_tokens_for_action = 512
|
||||
|
|
@ -201,19 +203,25 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
max_tokens=max_tokens_for_action,
|
||||
temperature=0.5,
|
||||
)
|
||||
llm_action_response = chat_completions.choices[0].message.content.strip()
|
||||
logger.info(f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'") # Log raw response
|
||||
llm_action_response = chat_completions.choices[
|
||||
0
|
||||
].message.content.strip()
|
||||
logger.info(
|
||||
f"[Seed: {seed}] LLM Raw Response: '{llm_action_response}'"
|
||||
) # Log raw response
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": llm_action_response})
|
||||
|
||||
|
||||
action = self._parse_action_from_llm(llm_action_response)
|
||||
if action is None:
|
||||
logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.")
|
||||
game_reward = -1.0
|
||||
break
|
||||
logger.warning(
|
||||
f"[Seed: {seed}] Invalid action parsed. Ending episode."
|
||||
)
|
||||
game_reward = -1.0
|
||||
break
|
||||
|
||||
try:
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
|
|
@ -224,19 +232,17 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
||||
|
||||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
|
||||
|
||||
env.close()
|
||||
self.episode_outcomes_buffer.append(game_reward)
|
||||
|
||||
tokenization_result = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=messages,
|
||||
train_on_all_assistant_turns=True
|
||||
tokenizer=self.tokenizer, chat=messages, train_on_all_assistant_turns=True
|
||||
)
|
||||
|
||||
|
||||
tokens = tokenization_result["tokens"]
|
||||
masks = tokenization_result["masks"]
|
||||
|
||||
|
|
@ -256,24 +262,27 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
logger.info(f"Setting up {self.name} environment.")
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.")
|
||||
|
||||
logger.info(
|
||||
f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes."
|
||||
)
|
||||
|
||||
wins = 0
|
||||
losses = 0
|
||||
draws = 0
|
||||
|
||||
|
||||
eval_outcomes: List[float] = []
|
||||
|
||||
for i in range(self.config.eval_episodes):
|
||||
seed = random.randint(1_000_001, 2_000_000)
|
||||
seed = random.randint(1_000_001, 2_000_000)
|
||||
item = {"seed": seed}
|
||||
scored_item_tuple = await self.collect_trajectory(item)
|
||||
if scored_item_tuple and scored_item_tuple[0]:
|
||||
outcome = scored_item_tuple[0]["scores"]
|
||||
eval_outcomes.append(outcome)
|
||||
else:
|
||||
logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.")
|
||||
|
||||
logger.warning(
|
||||
f"Evaluation episode {i+1} (seed {seed}) failed to produce data."
|
||||
)
|
||||
|
||||
if not eval_outcomes:
|
||||
logger.warning("No evaluation episodes completed successfully.")
|
||||
|
|
@ -287,7 +296,7 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
losses += 1
|
||||
else:
|
||||
draws += 1
|
||||
|
||||
|
||||
num_completed = len(eval_outcomes)
|
||||
win_rate = wins / num_completed if num_completed > 0 else 0
|
||||
loss_rate = losses / num_completed if num_completed > 0 else 0
|
||||
|
|
@ -301,15 +310,18 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
(f"{self.name}_eval/avg_reward", avg_reward),
|
||||
(f"{self.name}_eval/num_completed_episodes", num_completed),
|
||||
]
|
||||
logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}")
|
||||
|
||||
logger.info(
|
||||
f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}"
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.episode_outcomes_buffer:
|
||||
avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
|
||||
avg_training_reward = sum(self.episode_outcomes_buffer) / len(
|
||||
self.episode_outcomes_buffer
|
||||
)
|
||||
wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward
|
||||
train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0)
|
||||
train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0)
|
||||
|
|
@ -317,7 +329,9 @@ class BlackjackEnvNoThinking(BaseEnv):
|
|||
wandb_metrics[f"{self.name}_train/win_count"] = train_wins
|
||||
wandb_metrics[f"{self.name}_train/loss_count"] = train_losses
|
||||
wandb_metrics[f"{self.name}_train/draw_count"] = train_draws
|
||||
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer)
|
||||
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(
|
||||
self.episode_outcomes_buffer
|
||||
)
|
||||
|
||||
self.episode_outcomes_buffer = []
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue