no-thinking env added

This commit is contained in:
Shannon Sands 2025-05-14 11:28:39 -07:00
parent 21cc528b85
commit 54ae40840d
2 changed files with 285 additions and 1 deletions

View file

@ -0,0 +1,265 @@
import logging
from typing import Dict, List, Optional, Tuple
import gymnasium as gym
import random
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem
from atroposlib.type_definitions import Item, Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
logger = logging.getLogger(__name__)
ACTION_HIT = 1
ACTION_STICK = 0
ACTION_MAP_TO_STR = {ACTION_HIT: "hit", ACTION_STICK: "stick"}
ACTION_STR_TO_INT = {v: k for k, v in ACTION_MAP_TO_STR.items()}
class BlackjackEnvNoThinkingConfig(BaseEnvConfig):
"""
Configuration for the BlackjackEnvNoThinking environment.
"""
env_name: str = "Blackjack-v1"
max_episode_turns: int = 10
eval_episodes: int = 100
class BlackjackEnvNoThinking(BaseEnv):
name = "blackjack_no_thinking"
env_config_cls = BlackjackEnvNoThinkingConfig
def __init__(
self,
config: BlackjackEnvNoThinkingConfig,
server_configs: List[OpenaiConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: BlackjackEnvNoThinkingConfig = config
self.episode_outcomes_buffer: List[float] = []
self.eval_metrics_custom: List[Tuple[str, float]] = []
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvNoThinkingConfig, List[OpenaiConfig]]:
env_config = BlackjackEnvNoThinkingConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
max_token_length=2048,
wandb_name=cls.name,
steps_per_eval=50,
max_episode_turns=10,
eval_episodes=100,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=128,
),
]
return env_config, server_configs
def _format_observation(self, obs: Tuple[int, int, int]) -> str:
"""Converts a Blackjack observation to a human-readable string."""
player_sum, dealer_card, usable_ace = obs
return (
f"Your current hand sum is {player_sum}. "
f"The dealer is showing a {dealer_card}. "
f"You have a usable ace: {'yes' if usable_ace else 'no'}."
)
def _parse_action_from_llm(self, llm_response: str) -> Optional[int]:
"""Parses 'hit' or 'stick' from the LLM response."""
action_str = llm_response.strip().lower()
if action_str in ACTION_STR_TO_INT:
return ACTION_STR_TO_INT[action_str]
logger.warning(f"Could not parse action from LLM response: '{llm_response}'")
return None
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
"""
Collects a single trajectory (episode) for the Blackjack environment.
The LLM directly outputs 'hit' or 'stick'.
The 'score' in ScoredDataItem is the final game outcome (+1, 0, -1).
"""
seed = item["seed"]
messages: List[Message] = []
game_reward = 0.0
num_turns = 0
try:
env = gym.make(self.config.env_name)
except Exception as e:
logger.error(f"Failed to make environment {self.config.env_name}: {e}")
return None, []
try:
obs, info = env.reset(seed=seed)
except Exception as e:
logger.error(f"Failed to reset environment with seed {seed}: {e}")
env.close()
return None, []
system_prompt = (
"You are playing Blackjack. Respond with either 'hit' or 'stick'."
)
messages.append({"role": "system", "content": system_prompt})
current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str})
async with self.server.dedicated_server() as server:
for _ in range(self.config.max_episode_turns):
if (
len(self.tokenizer.apply_chat_template(messages, tokenize=False))
> self.config.max_token_length - 50
):
logger.warning(f"[Seed: {seed}] Max token length reached, truncating episode.")
break
max_tokens_for_action = 10
try:
chat_completions = await server.chat_completion(
messages=messages,
n=1,
max_tokens=max_tokens_for_action,
temperature=0.5,
)
llm_action_response = chat_completions.choices[0].message.content.strip()
except Exception as e:
logger.error(f"[Seed: {seed}] LLM API error: {e}")
break
messages.append({"role": "assistant", "content": llm_action_response})
action = self._parse_action_from_llm(llm_action_response)
if action is None:
logger.warning(f"[Seed: {seed}] Invalid action parsed. Ending episode.")
game_reward = -1.0
break
try:
obs, reward, terminated, truncated, _ = env.step(action)
game_reward = float(reward)
except Exception as e:
logger.error(f"[Seed: {seed}] Error stepping env: {e}")
break
if terminated or truncated:
break
current_obs_str = self._format_observation(obs)
messages.append({"role": "user", "content": current_obs_str})
env.close()
self.episode_outcomes_buffer.append(game_reward)
tokenization_result = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=messages,
train_on_all_assistant_turns=True
)
tokens = tokenization_result["tokens"]
masks = tokenization_result["masks"]
scored_data_item = ScoredDataItem(
messages=messages if self.config.include_messages else None,
tokens=tokens,
masks=masks,
scores=game_reward,
)
return scored_data_item, []
async def get_next_item(self) -> Item:
next_seed = random.randint(0, 1_000_000)
return {"seed": next_seed}
async def setup(self):
logger.info(f"Setting up {self.name} environment.")
async def evaluate(self, *args, **kwargs):
logger.info(f"Starting evaluation for {self.name} with {self.config.eval_episodes} episodes.")
wins = 0
losses = 0
draws = 0
eval_outcomes: List[float] = []
for i in range(self.config.eval_episodes):
seed = random.randint(1_000_001, 2_000_000)
item = {"seed": seed}
scored_item_tuple = await self.collect_trajectory(item)
if scored_item_tuple and scored_item_tuple[0]:
outcome = scored_item_tuple[0]["scores"]
eval_outcomes.append(outcome)
else:
logger.warning(f"Evaluation episode {i+1} (seed {seed}) failed to produce data.")
if not eval_outcomes:
logger.warning("No evaluation episodes completed successfully.")
self.eval_metrics_custom = []
return
for outcome in eval_outcomes:
if outcome > 0:
wins += 1
elif outcome < 0:
losses += 1
else:
draws += 1
num_completed = len(eval_outcomes)
win_rate = wins / num_completed if num_completed > 0 else 0
loss_rate = losses / num_completed if num_completed > 0 else 0
draw_rate = draws / num_completed if num_completed > 0 else 0
avg_reward = sum(eval_outcomes) / num_completed if num_completed > 0 else 0
self.eval_metrics_custom = [
(f"{self.name}_eval/win_rate", win_rate),
(f"{self.name}_eval/loss_rate", loss_rate),
(f"{self.name}_eval/draw_rate", draw_rate),
(f"{self.name}_eval/avg_reward", avg_reward),
(f"{self.name}_eval/num_completed_episodes", num_completed),
]
logger.info(f"Evaluation completed for {self.name}. Metrics: {self.eval_metrics_custom}")
async def wandb_log(self, wandb_metrics: Optional[Dict[str, float]] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.episode_outcomes_buffer:
avg_training_reward = sum(self.episode_outcomes_buffer) / len(self.episode_outcomes_buffer)
wandb_metrics[f"{self.name}_train/avg_episode_reward"] = avg_training_reward
train_wins = sum(1 for r in self.episode_outcomes_buffer if r > 0)
train_losses = sum(1 for r in self.episode_outcomes_buffer if r < 0)
train_draws = sum(1 for r in self.episode_outcomes_buffer if r == 0)
wandb_metrics[f"{self.name}_train/win_count"] = train_wins
wandb_metrics[f"{self.name}_train/loss_count"] = train_losses
wandb_metrics[f"{self.name}_train/draw_count"] = train_draws
wandb_metrics[f"{self.name}_train/num_episodes_in_batch"] = len(self.episode_outcomes_buffer)
self.episode_outcomes_buffer = []
for key, value in self.eval_metrics_custom:
wandb_metrics[key] = value
self.eval_metrics_custom = []
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
BlackjackEnvNoThinking.cli()

View file

@ -489,7 +489,7 @@ class BlackjackEnv(BaseEnv):
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
# If we pass this then instead of raw scores, implicitly, we're
# doing some credit assignment. Could maybe do bonus on a win too
# and apply with a discount factor to alts in winning trajectories
# and/or apply with a discount factor to alts in winning trajectories
alt_advantages.append(advantage_i)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
@ -663,6 +663,15 @@ class BlackjackEnv(BaseEnv):
) -> List[Optional[BlackjackScoredDataGroup]]:
"""Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
already contains the A*(s,a) advantages from the collection phase.
If you wanted to play around with additional scoring metrics, you could do so here.
Eg, bonuses for the specific winning action trajectory
Args:
rollout_group_data: List of BlackjackScoredDataGroup objects containing the collected rollout data.
Returns:
List of BlackjackScoredDataGroup objects with the scores field updated.
"""
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
return rollout_group_data
@ -670,6 +679,16 @@ class BlackjackEnv(BaseEnv):
async def collect_trajectories(
self, item: Tuple[int, int]
) -> Tuple[List[BlackjackScoredDataGroup], List[Tuple[int, int]]]:
"""Collect trajectories for training.
Args:
item: Tuple containing the seed and the group index.
Returns:
Tuple of two lists:
- List of BlackjackScoredDataGroup objects containing the collected rollout data.
- List of Tuple[int, int] objects for the backlog
"""
seed, _ = item
traj = await self._collect_trajectory(seed)
if not traj: