mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Linting and cleanup
This commit is contained in:
parent
6617d402b3
commit
220b92be47
2 changed files with 44 additions and 1711 deletions
|
|
@ -5,7 +5,9 @@ BlackjackEnv: Trainer environment for Gymnasium Blackjack
|
|||
This wraps Gymnasium's Blackjack-v1 environment to train an LLM via a best-of-n pattern
|
||||
using function-call style actions. Extends BaseEnv.
|
||||
|
||||
Uses Monte Carlo sampling to estimate the value of the current state, similar to VinePPO
|
||||
Sort of inspired by VinePPO, but uses a recursive exact value calculation
|
||||
instead of Monte Carlo sampling (because the action space is so small and
|
||||
the environment is deterministic, plus short episode lengths).
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
|
@ -16,7 +18,6 @@ import re
|
|||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
|
|
@ -44,7 +45,6 @@ class BlackjackEnvConfig(BaseEnvConfig):
|
|||
max_trajectory_tokens: int = 24576
|
||||
debug_mode: bool = False
|
||||
group_size: int = 16
|
||||
mc_samples: int = 3
|
||||
tiebreak_token_factor: float = 0.01
|
||||
|
||||
|
||||
|
|
@ -166,11 +166,10 @@ class BlackjackEnv(BaseEnv):
|
|||
current_env_reward -= 0.2
|
||||
|
||||
# Calculate the number of tokens in the agent's response
|
||||
num_tokens = len(
|
||||
tokenize_for_trainer_multistep(
|
||||
self.tokenizer, [{"role": "agent", "content": response_text}]
|
||||
)["tokens"]
|
||||
)
|
||||
if response_text:
|
||||
num_tokens = len(self.tokenizer.encode(response_text))
|
||||
else:
|
||||
num_tokens = 0
|
||||
|
||||
# tiebreak & small length penalty
|
||||
if self.config.max_token_length > 0:
|
||||
|
|
@ -232,7 +231,7 @@ class BlackjackEnv(BaseEnv):
|
|||
episode_seed_for_sim: int,
|
||||
env_actions_to_replay: List[int],
|
||||
) -> float:
|
||||
"""Calculate exact state value V*(s) using recursive calls and memoization.
|
||||
"""Calculate exact state value V*(s)
|
||||
|
||||
Args:
|
||||
episode_seed_for_sim: The seed of the original episode for deterministic env creation.
|
||||
|
|
@ -240,7 +239,9 @@ class BlackjackEnv(BaseEnv):
|
|||
"""
|
||||
v_star_cache: Dict[Tuple[int, int, int], float] = {}
|
||||
|
||||
def _get_v_star_recursive(obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env) -> float:
|
||||
def _get_v_star_recursive(
|
||||
obs_tuple: Tuple[int, int, int], current_env: gymnasium.Env
|
||||
) -> float:
|
||||
player_sum, dealer_card, usable_ace = obs_tuple
|
||||
|
||||
# Base Case 1: Bust
|
||||
|
|
@ -259,13 +260,13 @@ class BlackjackEnv(BaseEnv):
|
|||
# Q-value for HIT (action 1)
|
||||
env_for_hit = copy.deepcopy(current_env)
|
||||
obs_hit, reward_hit, term_hit, trunc_hit, _ = env_for_hit.step(1)
|
||||
|
||||
if term_hit or trunc_hit: # Game ended after hitting
|
||||
|
||||
if term_hit or trunc_hit: # Game ended after hitting
|
||||
q_star_hit = reward_hit
|
||||
else: # Game continues, recursively find V* of next state
|
||||
else: # Game continues, recursively find V* of next state
|
||||
# reward_hit is typically 0 if the game didn't end
|
||||
q_star_hit = reward_hit + _get_v_star_recursive(obs_hit, env_for_hit)
|
||||
|
||||
|
||||
v_star = max(q_star_stick, q_star_hit)
|
||||
v_star_cache[obs_tuple] = v_star
|
||||
return v_star
|
||||
|
|
@ -286,7 +287,7 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
is_terminal_after_replay = True
|
||||
break
|
||||
|
||||
|
||||
if is_terminal_after_replay:
|
||||
return 0.0
|
||||
final_v_star = _get_v_star_recursive(current_obs, sim_env)
|
||||
|
|
@ -294,10 +295,12 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[_estimate_value] Error during exact value calculation for seed {episode_seed_for_sim}, actions {env_actions_to_replay}: {e}",
|
||||
exc_info=True
|
||||
f"[_estimate_value] Error during exact value"
|
||||
f" calculation for seed {episode_seed_for_sim}, "
|
||||
f"actions {env_actions_to_replay}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return 0.0 # Return a default value on error
|
||||
return 0.0
|
||||
finally:
|
||||
if sim_env is not None:
|
||||
sim_env.close()
|
||||
|
|
@ -305,14 +308,13 @@ class BlackjackEnv(BaseEnv):
|
|||
async def collect_trajectory(self, seed: int) -> List[BlackjackScoredDataGroup]:
|
||||
"""Collect data for ONE trajectory, evaluating G alternatives per step using MC advantages."""
|
||||
G = self.config.group_size
|
||||
K = self.config.mc_samples
|
||||
max_turns = self.config.max_turns or 5
|
||||
|
||||
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
|
||||
episode_summary_metrics: Optional[Dict[str, Any]] = None
|
||||
|
||||
logger.info(
|
||||
f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}, MC samples K={K}."
|
||||
f"[Collect Trajectory Seed: {seed}] Starting trajectory. Group size G={G}."
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -333,8 +335,7 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
try:
|
||||
value_t = await self._estimate_value(
|
||||
episode_seed_for_sim=ep.seed,
|
||||
env_actions_to_replay=ep.actions
|
||||
episode_seed_for_sim=ep.seed, env_actions_to_replay=ep.actions
|
||||
)
|
||||
logger.debug(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] Estimated V(s_t) = {value_t:.4f}"
|
||||
|
|
@ -395,11 +396,9 @@ class BlackjackEnv(BaseEnv):
|
|||
next_state_msgs_i = []
|
||||
try:
|
||||
sim_env = gymnasium.make(self.config.env_name)
|
||||
sim_obs, _ = sim_env.reset(seed=ep.seed)
|
||||
_, _ = sim_env.reset(seed=ep.seed)
|
||||
for prev_action in ep.actions:
|
||||
sim_obs, _, term_replay, trunc_replay, _ = sim_env.step(
|
||||
prev_action
|
||||
)
|
||||
_, _, term_replay, trunc_replay, _ = sim_env.step(prev_action)
|
||||
if term_replay or trunc_replay:
|
||||
logger.error(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
||||
|
|
@ -468,7 +467,7 @@ class BlackjackEnv(BaseEnv):
|
|||
actions_to_reach_s_prime = ep.actions + [alt_env_actions[i]]
|
||||
value_next_i = await self._estimate_value(
|
||||
episode_seed_for_sim=ep.seed,
|
||||
env_actions_to_replay=actions_to_reach_s_prime
|
||||
env_actions_to_replay=actions_to_reach_s_prime,
|
||||
)
|
||||
alt_value_next.append(value_next_i)
|
||||
except Exception as e_vn:
|
||||
|
|
@ -483,6 +482,9 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
for i in range(G):
|
||||
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
|
||||
# If we pass this then instead of raw scores, implicitly, we're
|
||||
# doing some credit assignment. Could maybe do bonus on a win too
|
||||
# and apply with a discount factor to alts in winning trajectories
|
||||
alt_advantages.append(advantage_i)
|
||||
logger.debug(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1} Alt: {i}] "
|
||||
|
|
@ -497,6 +499,7 @@ class BlackjackEnv(BaseEnv):
|
|||
or len(alt_next_state_msgs) != G
|
||||
or len(alt_parsed_actions) != G
|
||||
):
|
||||
# sanity check
|
||||
logger.error(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
||||
f"Mismatch in alternative list lengths after processing. "
|
||||
|
|
@ -509,38 +512,31 @@ class BlackjackEnv(BaseEnv):
|
|||
seed=ep.seed,
|
||||
tokens=alt_tokens,
|
||||
masks=alt_masks,
|
||||
scores=alt_combined_rewards,
|
||||
scores=alt_advantages,
|
||||
messages=alt_next_state_msgs,
|
||||
parsed_actions=alt_parsed_actions,
|
||||
)
|
||||
)
|
||||
|
||||
# Determine the best alternative: highest advantage, tie-broken by shortest token length.
|
||||
if G == 0: # Should ideally not occur with G = self.config.group_size > 0
|
||||
logger.error(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
||||
f"No alternatives to choose from (G=0). Aborting turn."
|
||||
)
|
||||
break # Exit the current turn processing.
|
||||
|
||||
# Prepare items for sorting: (-advantage, token_length, original_index)
|
||||
# Sorting this list will place the best alternative (highest advantage, then shortest tokens) first.
|
||||
sortable_alternatives = []
|
||||
for i in range(G):
|
||||
# alt_tokens[i] is expected to be a list (possibly empty)
|
||||
token_len = len(alt_tokens[i])
|
||||
sortable_alternatives.append((-alt_advantages[i], token_len, i))
|
||||
|
||||
sortable_alternatives.sort() # Sorts in-place
|
||||
|
||||
|
||||
sortable_alternatives.sort()
|
||||
|
||||
best_advantage_idx = sortable_alternatives[0][2]
|
||||
|
||||
|
||||
# Log details of the selected alternative based on the sort
|
||||
chosen_advantage_for_log = -sortable_alternatives[0][0] # Revert sign for logging
|
||||
chosen_advantage_for_log = -sortable_alternatives[0][0]
|
||||
chosen_token_length_for_log = sortable_alternatives[0][1]
|
||||
logger.debug(
|
||||
f"[Collect Trajectory Seed: {seed} Turn: {turn+1}] "
|
||||
f"Selected Alt {best_advantage_idx} (Adv: {chosen_advantage_for_log:.2f}, Tokens: {chosen_token_length_for_log}) "
|
||||
f"Selected Alt {best_advantage_idx} "
|
||||
f"(Adv: {chosen_advantage_for_log:.2f}, "
|
||||
f"Tokens: {chosen_token_length_for_log}) "
|
||||
f"from {G} alternatives using sort."
|
||||
)
|
||||
|
||||
|
|
@ -570,7 +566,7 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
|
||||
try:
|
||||
main_obs, main_reward, main_term, main_trunc, main_info = ep.env.step(
|
||||
main_obs, main_reward, main_term, main_trunc, _ = ep.env.step(
|
||||
chosen_env_action
|
||||
)
|
||||
if abs(main_reward - chosen_raw_env_reward) > 1e-6:
|
||||
|
|
@ -630,7 +626,7 @@ class BlackjackEnv(BaseEnv):
|
|||
game_outcome = 1
|
||||
elif final_raw_reward < 0:
|
||||
game_outcome = -1
|
||||
|
||||
# debugging
|
||||
episode_summary_metrics = {
|
||||
"seed": ep.seed,
|
||||
"total_reward": final_raw_reward,
|
||||
|
|
@ -658,7 +654,9 @@ class BlackjackEnv(BaseEnv):
|
|||
async def score(
|
||||
self, rollout_group_data: List[BlackjackScoredDataGroup]
|
||||
) -> List[Optional[BlackjackScoredDataGroup]]:
|
||||
"""Return rollout data with advantages as scores."""
|
||||
"""Pass through rollout data. The 'scores' field in BlackjackScoredDataGroup
|
||||
already contains the A*(s,a) advantages from the collection phase.
|
||||
"""
|
||||
logger.info(f"[Score] Processing {len(rollout_group_data)} steps.")
|
||||
return rollout_group_data
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue