This commit is contained in:
Shannon Sands 2025-05-10 08:44:35 +10:00
parent 06c4a9e65c
commit 9efd8c1529
2 changed files with 40 additions and 86 deletions

View file

@ -11,7 +11,7 @@ Uses Monte Carlo sampling to estimate the value of the current state, similar to
import json
import logging
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
import gymnasium
import numpy as np
@ -1025,7 +1025,8 @@ class BlackjackEnv(BaseEnv):
if num_alternatives == 0:
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping."
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives"
" after copying. Skipping."
)
continue
@ -1109,7 +1110,8 @@ class BlackjackEnv(BaseEnv):
working_masks = temp_new_alt_masks
max_current_tokens = max_tokens_after_this_trunc
logger.debug(
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, "
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, "
f"after uniform pop of {min_pop_this_round}, "
f"max tokens: {max_current_tokens}"
)

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python3
"""
BlackjackEnv: Trainer environment for Gymnasium Blackjack
@ -13,7 +12,7 @@ but may not be as effective at learning correct strategy (it's effectively a ser
import json
import logging
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
import gymnasium
from tqdm.asyncio import tqdm_asyncio
@ -48,7 +47,6 @@ class BlackjackEnvConfig(BaseEnvConfig):
batch_size: int = 1024
max_think_chars_history: int = 3000
# Should be higher than the max tokens to allow for multiple turns
max_trajectory_tokens: int = 24576
debug_mode: bool = False
@ -259,7 +257,6 @@ class BlackjackEnv(BaseEnv):
continue
if action == -1:
# Penalty for parsing error is applied within _score_response.
env_reward_sim = 0.0
else:
_obs_sim, env_reward_sim, term_sim, trunc_sim, _info_sim = (
@ -720,7 +717,6 @@ class BlackjackEnv(BaseEnv):
)
return []
# Ensure all elements are at least dictionaries before proceeding
if not all(
isinstance(rgd, dict) for rgd in rollout_group_data if rgd is not None
):
@ -728,10 +724,8 @@ class BlackjackEnv(BaseEnv):
"score: rollout_group_data contains non-dictionary elements. "
"Cannot proceed."
)
# Return a list of Nones matching input length or handle as error
return [None] * len(rollout_group_data)
# 1. Determine Overall Game Outcome
final_env_reward_for_outcome = 0.0
is_win = False
seed_for_outcome = None
@ -776,7 +770,7 @@ class BlackjackEnv(BaseEnv):
f"Invalid action ({action_for_outcome_step}) found at step {step_idx_outcome} "
f"during game outcome replay. Assuming non-win outcome for scoring."
)
final_env_reward_for_outcome = 0.0 # Treat as non-win
final_env_reward_for_outcome = 0.0
break
logger.debug(
@ -800,7 +794,7 @@ class BlackjackEnv(BaseEnv):
f"Final env reward for outcome: {final_env_reward_for_outcome}"
)
break
else: # Loop completed without break
else:
logger.info(
f"score [Seed: {seed_for_outcome}]: "
f"Game outcome replay completed all steps. "
@ -840,25 +834,19 @@ class BlackjackEnv(BaseEnv):
processed_rollout_data.append(None)
continue
# Make a copy to modify scores
current_step_group: BlackjackScoredDataGroup = (
original_step_group_untyped.copy()
)
step_seed = current_step_group.get(
"seed", "N/A"
) # Use N/A if seed is somehow missing
step_seed = current_step_group.get("seed", "N/A")
if current_step_group.get("scores") is None:
logger.warning(
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Scores are missing. Cannot apply win bonus or tie-breaking."
)
processed_rollout_data.append(
current_step_group
) # Append original or a copy
processed_rollout_data.append(current_step_group)
continue
# Ensure scores is a list of numbers
original_scores = current_step_group["scores"]
if not isinstance(original_scores, list) or not all(
isinstance(s, (int, float)) for s in original_scores
@ -871,13 +859,11 @@ class BlackjackEnv(BaseEnv):
processed_rollout_data.append(current_step_group)
continue
modified_scores = original_scores.copy() # Work on a copy
modified_scores = original_scores.copy()
# 2. Apply Win Bonus (if applicable)
if is_win and modified_scores: # Ensure scores list is not empty
if is_win and modified_scores:
try:
max_score_in_step = -float("inf")
# Find max score correctly, even with Nones, though previous check should handle Nones in list
valid_scores_for_max = [
s for s in modified_scores if isinstance(s, (int, float))
]
@ -888,7 +874,6 @@ class BlackjackEnv(BaseEnv):
)
else:
max_score_in_step = max(valid_scores_for_max)
# Find first index of max_score_in_step. If multiple, bonus applies to first.
best_alternative_idx_this_step = -1
for idx, score_val in enumerate(modified_scores):
if score_val == max_score_in_step:
@ -913,10 +898,7 @@ class BlackjackEnv(BaseEnv):
f"Could not find index of max score {max_score_in_step} for win bonus. "
f"This should not happen if scores exist."
)
except (
ValueError
): # Should be caught by empty list check or valid_scores_for_max
# Split into two lines to avoid line length issues
except ValueError:
score_msg = (
f"score [Seed: {step_seed}, Step: {step_idx}]: "
f"Error finding max score for win bonus."
@ -929,7 +911,6 @@ class BlackjackEnv(BaseEnv):
f"Unexpected error applying win bonus: {e_bonus}"
)
# 3. Apply Tie-Breaking Logic (to potentially bonus-adjusted scores)
step_messages = current_step_group.get("messages")
if not isinstance(step_messages, list) or len(modified_scores) != len(
step_messages
@ -940,7 +921,7 @@ class BlackjackEnv(BaseEnv):
f"({len(step_messages) if isinstance(step_messages, list) else 'not a list'}) "
f"lengths, or messages missing. Skipping tie-breaking for this step."
)
elif modified_scores: # Ensure scores list is not empty for tie-breaking
elif modified_scores:
token_lengths = []
valid_messages_for_tiebreak = True
for alt_msg_list_idx, alt_msg_list in enumerate(step_messages):
@ -966,16 +947,12 @@ class BlackjackEnv(BaseEnv):
f"Tokenization error for tie-breaking on alt {alt_msg_list_idx}: {e_tok}. "
f"Defaulting token length to large value."
)
token_lengths.append(
float("inf")
) # Penalize if tokenization fails
token_lengths.append(float("inf"))
if valid_messages_for_tiebreak:
score_groups = {} # Maps score_value to list of indices
score_groups = {}
for idx, score_val in enumerate(modified_scores):
if not isinstance(
score_val, (int, float)
): # Skip non-numeric scores
if not isinstance(score_val, (int, float)):
continue
if score_val not in score_groups:
score_groups[score_val] = []
@ -983,8 +960,7 @@ class BlackjackEnv(BaseEnv):
scores_after_tiebreak = modified_scores.copy()
for score_val, indices_with_this_score in score_groups.items():
if len(indices_with_this_score) > 1: # A tie is found
# Check if token_lengths are available for all tied indices
if len(indices_with_this_score) > 1:
if not all(
idx < len(token_lengths)
for idx in indices_with_this_score
@ -999,16 +975,14 @@ class BlackjackEnv(BaseEnv):
continue
try:
# Sort tied indices by their token_lengths
sorted_tied_indices = sorted(
indices_with_this_score,
key=lambda i: token_lengths[i],
)
# Apply penalty to all but the first (shortest token length)
for rank, tied_idx in enumerate(
sorted_tied_indices[1:], 1
): # Start rank from 1 for 2nd shortest
):
penalty = 0.0001 * rank
scores_after_tiebreak[tied_idx] -= penalty
logger.debug(
@ -1375,39 +1349,36 @@ class BlackjackEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
env_config = BlackjackEnvConfig(
# Fields from fundamental_prediction_environment.py's BaseEnvConfig init:
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16, # From Base, as not in BJ no_mc config's direct definition
group_size=16,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024, # Matches BlackjackEnvConfig (no_mc) default as well
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction
wandb_name="fundamental_metric_prediction",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
# BlackjackEnvConfig (no_mc version) specific fields (those NOT in BaseEnvConfig from fundamental_prediction)
# using their defined defaults from BlackjackEnvConfig (no_mc):
env_name="Blackjack-v1", # Default from BlackjackEnvConfig (no_mc)
temperature=0.7, # Default from BlackjackEnvConfig (no_mc)
top_p=0.9, # Default from BlackjackEnvConfig (no_mc)
max_turns=5, # Default from BlackjackEnvConfig (no_mc)
thinking_active=True, # Default from BlackjackEnvConfig (no_mc)
eval_episodes=100, # Default from BlackjackEnvConfig (no_mc)
max_think_chars_history=3000, # Default from BlackjackEnvConfig (no_mc)
max_trajectory_tokens=24576, # Default from BlackjackEnvConfig (no_mc)
debug_mode=False, # Default from BlackjackEnvConfig (no_mc)
env_name="Blackjack-v1",
temperature=0.7,
top_p=0.9,
max_turns=5,
thinking_active=True,
eval_episodes=100,
max_think_chars_history=3000,
max_trajectory_tokens=24576,
debug_mode=False,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256, # From fundamental_prediction_environment.py
num_requests_for_eval=256,
)
]
return env_config, server_configs
@ -1514,8 +1485,7 @@ class BlackjackEnv(BaseEnv):
original_step_data.get("messages")
and original_step_data.get("tokens")
and original_step_data.get("masks")
and original_step_data.get("seed")
is not None # seed is mandatory for new group
and original_step_data.get("seed") is not None
):
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} "
@ -1523,8 +1493,6 @@ class BlackjackEnv(BaseEnv):
)
continue
# Initial token calculation from original data to see if truncation is needed
# Ensure tokens are lists of integers before calling len
max_initial_tokens = 0
if original_step_data["tokens"]:
max_initial_tokens = (
@ -1553,8 +1521,6 @@ class BlackjackEnv(BaseEnv):
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
)
# Prepare working copies for modification
# Ensure deep copies for lists of dicts if dicts are modified, but here we pop from list of dicts.
working_messages = [
msgs_list.copy() for msgs_list in original_step_data["messages"] or []
]
@ -1567,7 +1533,7 @@ class BlackjackEnv(BaseEnv):
max_current_tokens = max_initial_tokens
num_alternatives = len(working_messages)
if num_alternatives == 0: # Should not happen if initial checks passed
if num_alternatives == 0:
logger.warning(
f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping."
)
@ -1579,60 +1545,47 @@ class BlackjackEnv(BaseEnv):
for alt_idx in range(num_alternatives):
alt_msg_list = working_messages[alt_idx]
# Calculate how many initial messages (after system prompt) can be popped.
# Preserving: system prompt (index 0), last agent response, and its preceding env observation.
num_preserved_at_end = 0
if (
len(alt_msg_list) > 1
and alt_msg_list[-1]["role"] in UNMASKED_ROLES
):
num_preserved_at_end = 1 # Last agent response
num_preserved_at_end = 1
if (
len(alt_msg_list) > 2
and alt_msg_list[-2]["role"] == "environment"
):
num_preserved_at_end = (
2 # Agent response + preceding env observation
)
num_preserved_at_end = 2
# Number of messages available for popping (between system prompt and preserved end messages)
# Subtract 1 for the system prompt itself (which is never popped from index 0).
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
if available_to_pop <= 0:
target_pop_counts_per_alt.append(0)
else:
# Try to pop a pair (environment, agent) if they are at list[1] and list[2]
can_pop_pair = (
available_to_pop >= 2
and len(alt_msg_list) > 2
and alt_msg_list[ # Ensure messages at index 1 and 2 exist
1
]["role"]
== "environment"
and alt_msg_list[1]["role"] == "environment"
and alt_msg_list[2]["role"] in UNMASKED_ROLES
)
if can_pop_pair:
target_pop_counts_per_alt.append(2)
else: # Can pop at least 1 since available_to_pop > 0
else:
target_pop_counts_per_alt.append(1)
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
if not positive_pop_counts:
break # No alternative can be truncated further
break
min_pop_this_round = min(positive_pop_counts)
# Pop messages and re-tokenize
temp_new_alt_tokens = []
temp_new_alt_masks = []
max_tokens_after_this_trunc = 0
for alt_idx in range(num_alternatives):
for _ in range(min_pop_this_round):
if (
len(working_messages[alt_idx]) > 1
): # Ensure there's something to pop after system
if len(working_messages[alt_idx]) > 1:
working_messages[alt_idx].pop(1)
else:
logger.error(
@ -1671,7 +1624,6 @@ class BlackjackEnv(BaseEnv):
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, "
f"max tokens: {max_current_tokens}"
)
# End of while loop for truncation attempts
if (
not retokenization_error_this_step