Linting and cleanup

This commit is contained in:
Shannon Sands 2025-05-10 21:15:00 +10:00
parent 6617d402b3
commit 220b92be47
2 changed files with 44 additions and 1711 deletions

View file

@ -5,7 +5,9 @@ BlackjackEnv: Trainer environment for Gymnasium Blackjack
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
using function-call style actions. Extends BaseEnv.
Uses Monte Carlo sampling to estimate the value of the current state, similar to VinePPO
Sort of inspired by VinePPO, but uses a recursive exact value calculation
instead of Monte Carlo sampling (because the action space is so small and
the environment is deterministic, plus short episode lengths).
"""
import copy
@ -16,7 +18,6 @@ import re
from typing import Any, Dict, List, Optional, Tuple
import gymnasium
import numpy as np
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
@ -44,7 +45,6 @@ class BlackjackEnvConfig(BaseEnvConfig):
max_trajectory_tokens: int = 24576
debug_mode: bool = False
group_size: int = 16
mc_samples: int = 3
tiebreak_token_factor: float = 0.01
@ -166,11 +166,10 @@ class BlackjackEnv(BaseEnv):
current_env_reward -= 0.2
# Calculate the number of tokens in the agent's response
num_tokens = len(
tokenize_for_trainer_multistep(
self.tokenizer, [{"role": "agent", "content": response_text}]
)["tokens"]
)
if response_text:
num_tokens = len(self.tokenizer.encode(response_text))
else:
num_tokens = 0
# tiebreak & small length penalty
if self.config.max_token_length > 0:
@ -232,7 +231,7 @@ class BlackjackEnv(BaseEnv):
episode_seed_for_sim: int,
env_actions_to_replay: List[int],
) -> float:
"""Calculate exact state value V*(s) using recursive calls and memoization.
"""Calculate exact state value V*(s)
Args:
episode_seed_for_sim: The seed of the original episode for deterministic env creation.
@ -240,7 +239,9 @@ class BlackjackEnv(BaseEnv):
"""
v_star_cache: Dict[Tuple[int, int, int], float] = {}
def _get_v_star_recursive(obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env) -> float:
def _get_v_star_recursive(
obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env
) -> float:
player_sum, dealer_card, usable_ace = obs_tuple
# Base Case 1: Bust
@ -259,13 +260,13 @@ class BlackjackEnv(BaseEnv):
# Q-value for HIT (action 1)
env_for_hit = copy.deepcopy(current_env)
obs_hit, reward_hit, term_hit, trunc_hit, _ = env_for_hit.step(1)
if term_hit or trunc_hit: # Game ended after hitting
if term_hit or trunc_hit: # Game ended after hitting
q_star_hit = reward_hit
else: # Game continues, recursively find V* of next state
else: # Game continues, recursively find V* of next state
# reward_hit is typically 0 if the game didn't end
q_star_hit = reward_hit + _get_v_star_recursive(obs_hit, env_for_hit)
v_star = max(q_star_stick, q_star_hit)
v_star_cache[obs_tuple] = v_star
return v_star
@ -286,7 +287,7 @@ class BlackjackEnv(BaseEnv):
)
is_terminal_after_replay = True
break
if is_terminal_after_replay:
return 0.0
final_v_star = _get_v_star_recursive(current_obs, sim_env)
@ -294,10 +295,12 @@ class BlackjackEnv(BaseEnv):
except Exception as e:
logger.error(
f"[_estimate_value] Error during exact value calculation for seed {episode_seed_for_sim}, actions {env_actions_to_replay}: {e}",
exc_info=True
f"[_estimate_value] Error during exact value"
f" calculation for seed {episode_seed_for_sim}, "
f"actions {env_actions_to_replay}: {e}",
exc_info=True,
)
return 0.0 # Return a default value on error
return 0.0
finally:
if sim_env is not None:
sim_env.close()
@ -305,14 +308,13 @@ class BlackjackEnv(BaseEnv):
async def collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
"""Collect data for ONE trajectory, evaluating G alternatives per step using MC advantages."""
G = self.config.group_size
K = self.config.mc_samples
max_turns = self.config.max_turns or 5
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
episode_summary_metrics: Optional[Dict[str, Any]] = None
logger.info(
f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}, MC samples K={K}."
f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}."
)
try:
@ -333,8 +335,7 @@ class BlackjackEnv(BaseEnv):
try:
value_t = await self._estimate_value(
episode_seed_for_sim=ep.seed,
env_actions_to_replay=ep.actions
episode_seed_for_sim=ep.seed, env_actions_to_replay=ep.actions
)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Estimated V(s_t) = {value_t:.4f}"
@ -395,11 +396,9 @@ class BlackjackEnv(BaseEnv):
next_state_msgs_i = []
try:
sim_env = gymnasium.make(self.config.env_name)
sim_obs, _ = sim_env.reset(seed=ep.seed)
_, _ = sim_env.reset(seed=ep.seed)
for prev_action in ep.actions:
sim_obs, _, term_replay, trunc_replay, _ = sim_env.step(
prev_action
)
_, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
if term_replay or trunc_replay:
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
@ -468,7 +467,7 @@ class BlackjackEnv(BaseEnv):
actions_to_reach_s_prime = ep.actions + [alt_env_actions[i]]
value_next_i = await self._estimate_value(
episode_seed_for_sim=ep.seed,
env_actions_to_replay=actions_to_reach_s_prime
env_actions_to_replay=actions_to_reach_s_prime,
)
alt_value_next.append(value_next_i)
except Exception as e_vn:
@ -483,6 +482,9 @@ class BlackjackEnv(BaseEnv):
for i in range(G):
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
# If we pass this then instead of raw scores, implicitly, we're
# doing some credit assignment. Could maybe do bonus on a win too
# and apply with a discount factor to alts in winning trajectories
alt_advantages.append(advantage_i)
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
@ -497,6 +499,7 @@ class BlackjackEnv(BaseEnv):
or len(alt_next_state_msgs) != G
or len(alt_parsed_actions) != G
):
# sanity check
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Mismatch in alternative list lengths after processing. "
@ -509,38 +512,31 @@ class BlackjackEnv(BaseEnv):
seed=ep.seed,
tokens=alt_tokens,
masks=alt_masks,
scores=alt_combined_rewards,
scores=alt_advantages,
messages=alt_next_state_msgs,
parsed_actions=alt_parsed_actions,
)
)
# Determine the best alternative: highest advantage, tie-broken by shortest token length.
if G == 0: # Should ideally not occur with G = self.config.group_size > 0
logger.error(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"No alternatives to choose from (G=0). Aborting turn."
)
break # Exit the current turn processing.
# Prepare items for sorting: (-advantage, token_length, original_index)
# Sorting this list will place the best alternative (highest advantage, then shortest tokens) first.
sortable_alternatives = []
for i in range(G):
# alt_tokens[i] is expected to be a list (possibly empty)
token_len = len(alt_tokens[i])
sortable_alternatives.append((-alt_advantages[i], token_len, i))
sortable_alternatives.sort() # Sorts in-place
sortable_alternatives.sort()
best_advantage_idx = sortable_alternatives[0][2]
# Log details of the selected alternative based on the sort
chosen_advantage_for_log = -sortable_alternatives[0][0] # Revert sign for logging
chosen_advantage_for_log = -sortable_alternatives[0][0]
chosen_token_length_for_log = sortable_alternatives[0][1]
logger.debug(
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
f"Selected Alt {best_advantage_idx} (Adv: {chosen_advantage_for_log:.2f}, Tokens: {chosen_token_length_for_log}) "
f"Selected Alt {best_advantage_idx} "
f"(Adv: {chosen_advantage_for_log:.2f}, "
f"Tokens: {chosen_token_length_for_log}) "
f"from {G} alternatives using sort."
)
@ -570,7 +566,7 @@ class BlackjackEnv(BaseEnv):
)
try:
main_obs, main_reward, main_term, main_trunc, main_info = ep.env.step(
main_obs, main_reward, main_term, main_trunc, _ = ep.env.step(
chosen_env_action
)
if abs(main_reward - chosen_raw_env_reward) > 1e-6:
@ -630,7 +626,7 @@ class BlackjackEnv(BaseEnv):
game_outcome = 1
elif final_raw_reward < 0:
game_outcome = -1
# debugging
episode_summary_metrics = {
"seed": ep.seed,
"total_reward": final_raw_reward,
@ -658,7 +654,9 @@ class BlackjackEnv(BaseEnv):
async def score(
self, rollout_group_data: List[BlackjackScoredDataGroup]
) -> List[Optional[BlackjackScoredDataGroup]]:
"""Return rollout data with advantages as scores."""
"""Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
already contains the A*(s,a) advantages from the collection phase.
"""
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
return rollout_group_data