mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
no-thinking env added
This commit is contained in:
parent
21cc528b85
commit
54ae40840d
2 changed files with 285 additions and 1 deletions
|
|
@ -0,0 +1,265 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import random
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
|
||||
from atroposlib.type_definitions import Item, Message
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACTION_HIT = 1
|
||||
ACTION_STICK = 0
|
||||
ACTION_MAP_TO_STR = {ACTION_HIT: "hit", ACTION_STICK: "stick"}
|
||||
ACTION_STR_TO_INT = {v: k for k, v in ACTION_MAP_TO_STR.items()}
|
||||
|
||||
|
||||
class BlackjackEnvNoThinkingConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for the BlackjackEnvNoThinking environment.
|
||||
"""
|
||||
|
||||
env_name: str = "Blackjack-v1"
|
||||
max_episode_turns: int = 10
|
||||
eval_episodes: int = 100
|
||||
|
||||
|
||||
class BlackjackEnvNoThinking(BaseEnv):
|
||||
name = "blackjack_no_thinking"
|
||||
env_config_cls = BlackjackEnvNoThinkingConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BlackjackEnvNoThinkingConfig,
|
||||
server_configs: List[OpenaiConfig],
|
||||
slurm: bool = True,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: BlackjackEnvNoThinkingConfig = config
|
||||
self.episode_outcomes_buffer: List[float] = []
|
||||
self.eval_metrics_custom: List[Tuple[str, float]] = []
|
||||
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
|
||||
env_config = BlackjackEnvNoThinkingConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
max_token_length=2048,
|
||||
wandb_name=cls.name,
|
||||
steps_per_eval=50,
|
||||
max_episode_turns=10,
|
||||
eval_episodes=100,
|
||||
)
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=128,
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
def _format_observation(self, obs: Tuple[int, int, int]) -> str:
|
||||
"""Converts a Blackjack observation to a human-readable string."""
|
||||
player_sum, dealer_card, usable_ace = obs
|
||||
return (
|
||||
f"Your current hand sum is {player_sum}. "
|
||||
f"The dealer is showing a {dealer_card}. "
|
||||
f"You have a usable ace: {'yes' if usable_ace else 'no'}."
|
||||
)
|
||||
|
||||
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
|
||||
"""Parses 'hit' or 'stick' from the LLM response."""
|
||||
action_str = llm_response.strip().lower()
|
||||
if action_str in ACTION_STR_TO_INT:
|
||||
return ACTION_STR_TO_INT[action_str]
|
||||
logger.warning(f"Could not parse action from LLM response: '{llm_response}'")
|
||||
return None
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
"""
|
||||
Collects a single trajectory (episode) for the Blackjack environment.
|
||||
The LLM directly outputs 'hit' or 'stick'.
|
||||
The 'score' in ScoredDataItem is the final game outcome (+1, 0, -1).
|
||||
"""
|
||||
seed = item["seed"]
|
||||
messages: List[Message] = []
|
||||
game_reward = 0.0
|
||||
num_turns = 0
|
||||
|
||||
try:
|
||||
env = gym.make(self.config.env_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to make environment {self.config.env_name}: {e}")
|
||||
return None, []
|
||||
|
||||
try:
|
||||
obs, info = env.reset(seed=seed)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset environment with seed {seed}: {e}")
|
||||
env.close()
|
||||
return None, []
|
||||
|
||||
system_prompt = (
|
||||
"You are playing Blackjack. Respond with either 'hit' or 'stick'."
|
||||
)
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
|
||||
async with self.server.dedicated_server() as server:
|
||||
for _ in range(self.config.max_episode_turns):
|
||||
if (
|
||||
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
|
||||
> self.config.max_token_length - 50
|
||||
):
|
||||
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
|
||||
break
|
||||
|
||||
max_tokens_for_action = 10
|
||||
|
||||
try:
|
||||
chat_completions = await server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=max_tokens_for_action,
|
||||
temperature=0.5,
|
||||
)
|
||||
llm_action_response = chat_completions.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] LLM API error: {e}")
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": llm_action_response})
|
||||
|
||||
action = self._parse_action_from_llm(llm_action_response)
|
||||
if action is None:
|
||||
logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.")
|
||||
game_reward = -1.0
|
||||
break
|
||||
|
||||
try:
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
game_reward = float(reward)
|
||||
except Exception as e:
|
||||
logger.error(f"[Seed: {seed}] Error stepping env: {e}")
|
||||
break
|
||||
|
||||
if terminated or truncated:
|
||||
break
|
||||
|
||||
current_obs_str = self._format_observation(obs)
|
||||
messages.append({"role": "user", "content": current_obs_str})
|
||||
|
||||
env.close()
|
||||
self.episode_outcomes_buffer.append(game_reward)
|
||||
|
||||
tokenization_result = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=messages,
|
||||
train_on_all_assistant_turns=True
|
||||
)
|
||||
|
||||
tokens = tokenization_result["tokens"]
|
||||
masks = tokenization_result["masks"]
|
||||
|
||||
scored_data_item = ScoredDataItem(
|
||||
messages=messages if self.config.include_messages else None,
|
||||
tokens=tokens,
|
||||
masks=masks,
|
||||
scores=game_reward,
|
||||
)
|
||||
return scored_data_item, []
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
next_seed = random.randint(0, 1_000_000)
|
||||
return {"seed": next_seed}
|
||||
|
||||
async def setup(self):
|
||||
logger.info(f"Setting up {self.name} environment.")
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.")
|
||||
|
||||
wins = 0
|
||||
losses = 0
|
||||
draws = 0
|
||||
|
||||
eval_outcomes: List[float] = []
|
||||
|
||||
for i in range(self.config.eval_episodes):
|
||||
seed = random.randint(1_000_001, 2_000_000)
|
||||
item = {"seed": seed}
|
||||
scored_item_tuple = await self.collect_trajectory(item)
|
||||
if scored_item_tuple and scored_item_tuple[0]:
|
||||
outcome = scored_item_tuple[0]["scores"]
|
||||
eval_outcomes.append(outcome)
|
||||
else:
|
||||
logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.")
|
||||
|
||||
|
||||
if not eval_outcomes:
|
||||
logger.warning("No evaluation episodes completed successfully.")
|
||||
self.eval_metrics_custom = []
|
||||
return
|
||||
|
||||
for outcome in eval_outcomes:
|
||||
if outcome > 0:
|
||||
wins += 1
|
||||
elif outcome < 0:
|
||||
losses += 1
|
||||
else:
|
||||
draws += 1
|
||||
|
||||
num_completed = len(eval_outcomes)
|
||||
win_rate = wins / num_completed if num_completed > 0 else 0
|
||||
loss_rate = losses / num_completed if num_completed > 0 else 0
|
||||
draw_rate = draws / num_completed if num_completed > 0 else 0
|
||||
avg_reward = sum(eval_outcomes) / num_completed if num_completed > 0 else 0
|
||||
|
||||
self.eval_metrics_custom = [
|
||||
(f"{self.name}_eval/win_rate", win_rate),
|
||||
(f"{self.name}_eval/loss_rate", loss_rate),
|
||||
(f"{self.name}_eval/draw_rate", draw_rate),
|
||||
(f"{self.name}_eval/avg_reward", avg_reward),
|
||||
(f"{self.name}_eval/num_completed_episodes", num_completed),
|
||||
]
|
||||
logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}")
|
||||
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.episode_outcomes_buffer:
|
||||
avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
|
||||
wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward
|
||||
train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0)
|
||||
train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0)
|
||||
train_draws = sum(1 for r in self.episode_outcomes_buffer if r == 0)
|
||||
wandb_metrics[f"{self.name}_train/win_count"] = train_wins
|
||||
wandb_metrics[f"{self.name}_train/loss_count"] = train_losses
|
||||
wandb_metrics[f"{self.name}_train/draw_count"] = train_draws
|
||||
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer)
|
||||
|
||||
self.episode_outcomes_buffer = []
|
||||
|
||||
for key, value in self.eval_metrics_custom:
|
||||
wandb_metrics[key] = value
|
||||
self.eval_metrics_custom = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
BlackjackEnvNoThinking.cli()
|
||||
|
|
@ -489,7 +489,7 @@ class BlackjackEnv(BaseEnv):
|
|||
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
|
||||
# If we pass this then instead of raw scores, implicitly, we're
|
||||
# doing some credit assignment. Could maybe do bonus on a win too
|
||||
# and apply with a discount factor to alts in winning trajectories
|
||||
# and/or apply with a discount factor to alts in winning trajectories
|
||||
alt_advantages.append(advantage_i)
|
||||
logger.debug(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
||||
|
|
@ -663,6 +663,15 @@ class BlackjackEnv(BaseEnv):
|
|||
) -> List[Optional[BlackjackScoredDataGroup]]:
|
||||
"""Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
|
||||
already contains the A*(s,a) advantages from the collection phase.
|
||||
|
||||
If you wanted to play around with additional scoring metrics, you could do so here.
|
||||
Eg, bonuses for the specific winning action trajectory
|
||||
|
||||
Args:
|
||||
rollout_group_data: List of BlackjackScoredDataGroup objects containing the collected rollout data.
|
||||
|
||||
Returns:
|
||||
List of BlackjackScoredDataGroup objects with the scores field updated.
|
||||
"""
|
||||
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
|
||||
return rollout_group_data
|
||||
|
|
@ -670,6 +679,16 @@ class BlackjackEnv(BaseEnv):
|
|||
async def collect_trajectories(
|
||||
self, item: Tuple[int, int]
|
||||
) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]:
|
||||
"""Collect trajectories for training.
|
||||
|
||||
Args:
|
||||
item: Tuple containing the seed and the group index.
|
||||
|
||||
Returns:
|
||||
Tuple of two lists:
|
||||
- List of BlackjackScoredDataGroup objects containing the collected rollout data.
|
||||
- List of Tuple[int, int] objects for the backlog
|
||||
"""
|
||||
seed, _ = item
|
||||
traj = await self._collect_trajectory(seed)
|
||||
if not traj:
|
||||
Loading…
Add table
Add a link
Reference in a new issue