mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linting
This commit is contained in:
parent
06c4a9e65c
commit
9efd8c1529
2 changed files with 40 additions and 86 deletions
|
|
@ -11,7 +11,7 @@ Uses Monte Carlo sampling to estimate the value of the current state, similar to
|
|||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
|
|
@ -1025,7 +1025,8 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
if num_alternatives == 0:
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives after copying. Skipping."
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} in MC env has no alternatives"
|
||||
" after copying. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -1109,7 +1110,8 @@ class BlackjackEnv(BaseEnv):
|
|||
working_masks = temp_new_alt_masks
|
||||
max_current_tokens = max_tokens_after_this_trunc
|
||||
logger.debug(
|
||||
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, after uniform pop of {min_pop_this_round}, "
|
||||
f"[_ensure_trajectory_token_limit] MC env: Step {step_idx}, "
|
||||
f"after uniform pop of {min_pop_this_round}, "
|
||||
f"max tokens: {max_current_tokens}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
BlackjackEnv: Trainer environment for Gymnasium Blackjack
|
||||
|
||||
|
|
@ -13,7 +12,7 @@ but may not be as effective at learning correct strategy (it's effectively a ser
|
|||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
|
@ -48,7 +47,6 @@ class BlackjackEnvConfig(BaseEnvConfig):
|
|||
|
||||
batch_size: int = 1024
|
||||
max_think_chars_history: int = 3000
|
||||
# Should be higher than the max tokens to allow for multiple turns
|
||||
max_trajectory_tokens: int = 24576
|
||||
debug_mode: bool = False
|
||||
|
||||
|
|
@ -259,7 +257,6 @@ class BlackjackEnv(BaseEnv):
|
|||
continue
|
||||
|
||||
if action == -1:
|
||||
# Penalty for parsing error is applied within _score_response.
|
||||
env_reward_sim = 0.0
|
||||
else:
|
||||
_obs_sim, env_reward_sim, term_sim, trunc_sim, _info_sim = (
|
||||
|
|
@ -720,7 +717,6 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
return []
|
||||
|
||||
# Ensure all elements are at least dictionaries before proceeding
|
||||
if not all(
|
||||
isinstance(rgd, dict) for rgd in rollout_group_data if rgd is not None
|
||||
):
|
||||
|
|
@ -728,10 +724,8 @@ class BlackjackEnv(BaseEnv):
|
|||
"score: rollout_group_data contains non-dictionary elements. "
|
||||
"Cannot proceed."
|
||||
)
|
||||
# Return a list of Nones matching input length or handle as error
|
||||
return [None] * len(rollout_group_data)
|
||||
|
||||
# 1. Determine Overall Game Outcome
|
||||
final_env_reward_for_outcome = 0.0
|
||||
is_win = False
|
||||
seed_for_outcome = None
|
||||
|
|
@ -776,7 +770,7 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Invalid action ({action_for_outcome_step}) found at step {step_idx_outcome} "
|
||||
f"during game outcome replay. Assuming non-win outcome for scoring."
|
||||
)
|
||||
final_env_reward_for_outcome = 0.0 # Treat as non-win
|
||||
final_env_reward_for_outcome = 0.0
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
|
|
@ -800,7 +794,7 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Final env reward for outcome: {final_env_reward_for_outcome}"
|
||||
)
|
||||
break
|
||||
else: # Loop completed without break
|
||||
else:
|
||||
logger.info(
|
||||
f"score [Seed: {seed_for_outcome}]: "
|
||||
f"Game outcome replay completed all steps. "
|
||||
|
|
@ -840,25 +834,19 @@ class BlackjackEnv(BaseEnv):
|
|||
processed_rollout_data.append(None)
|
||||
continue
|
||||
|
||||
# Make a copy to modify scores
|
||||
current_step_group: BlackjackScoredDataGroup = (
|
||||
original_step_group_untyped.copy()
|
||||
)
|
||||
step_seed = current_step_group.get(
|
||||
"seed", "N/A"
|
||||
) # Use N/A if seed is somehow missing
|
||||
step_seed = current_step_group.get("seed", "N/A")
|
||||
|
||||
if current_step_group.get("scores") is None:
|
||||
logger.warning(
|
||||
f"score [Seed: {step_seed}, Step: {step_idx}]: "
|
||||
f"Scores are missing. Cannot apply win bonus or tie-breaking."
|
||||
)
|
||||
processed_rollout_data.append(
|
||||
current_step_group
|
||||
) # Append original or a copy
|
||||
processed_rollout_data.append(current_step_group)
|
||||
continue
|
||||
|
||||
# Ensure scores is a list of numbers
|
||||
original_scores = current_step_group["scores"]
|
||||
if not isinstance(original_scores, list) or not all(
|
||||
isinstance(s, (int, float)) for s in original_scores
|
||||
|
|
@ -871,13 +859,11 @@ class BlackjackEnv(BaseEnv):
|
|||
processed_rollout_data.append(current_step_group)
|
||||
continue
|
||||
|
||||
modified_scores = original_scores.copy() # Work on a copy
|
||||
modified_scores = original_scores.copy()
|
||||
|
||||
# 2. Apply Win Bonus (if applicable)
|
||||
if is_win and modified_scores: # Ensure scores list is not empty
|
||||
if is_win and modified_scores:
|
||||
try:
|
||||
max_score_in_step = -float("inf")
|
||||
# Find max score correctly, even with Nones, though previous check should handle Nones in list
|
||||
valid_scores_for_max = [
|
||||
s for s in modified_scores if isinstance(s, (int, float))
|
||||
]
|
||||
|
|
@ -888,7 +874,6 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
else:
|
||||
max_score_in_step = max(valid_scores_for_max)
|
||||
# Find first index of max_score_in_step. If multiple, bonus applies to first.
|
||||
best_alternative_idx_this_step = -1
|
||||
for idx, score_val in enumerate(modified_scores):
|
||||
if score_val == max_score_in_step:
|
||||
|
|
@ -913,10 +898,7 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Could not find index of max score {max_score_in_step} for win bonus. "
|
||||
f"This should not happen if scores exist."
|
||||
)
|
||||
except (
|
||||
ValueError
|
||||
): # Should be caught by empty list check or valid_scores_for_max
|
||||
# Split into two lines to avoid line length issues
|
||||
except ValueError:
|
||||
score_msg = (
|
||||
f"score [Seed: {step_seed}, Step: {step_idx}]: "
|
||||
f"Error finding max score for win bonus."
|
||||
|
|
@ -929,7 +911,6 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Unexpected error applying win bonus: {e_bonus}"
|
||||
)
|
||||
|
||||
# 3. Apply Tie-Breaking Logic (to potentially bonus-adjusted scores)
|
||||
step_messages = current_step_group.get("messages")
|
||||
if not isinstance(step_messages, list) or len(modified_scores) != len(
|
||||
step_messages
|
||||
|
|
@ -940,7 +921,7 @@ class BlackjackEnv(BaseEnv):
|
|||
f"({len(step_messages) if isinstance(step_messages, list) else 'not a list'}) "
|
||||
f"lengths, or messages missing. Skipping tie-breaking for this step."
|
||||
)
|
||||
elif modified_scores: # Ensure scores list is not empty for tie-breaking
|
||||
elif modified_scores:
|
||||
token_lengths = []
|
||||
valid_messages_for_tiebreak = True
|
||||
for alt_msg_list_idx, alt_msg_list in enumerate(step_messages):
|
||||
|
|
@ -966,16 +947,12 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Tokenization error for tie-breaking on alt {alt_msg_list_idx}: {e_tok}. "
|
||||
f"Defaulting token length to large value."
|
||||
)
|
||||
token_lengths.append(
|
||||
float("inf")
|
||||
) # Penalize if tokenization fails
|
||||
token_lengths.append(float("inf"))
|
||||
|
||||
if valid_messages_for_tiebreak:
|
||||
score_groups = {} # Maps score_value to list of indices
|
||||
score_groups = {}
|
||||
for idx, score_val in enumerate(modified_scores):
|
||||
if not isinstance(
|
||||
score_val, (int, float)
|
||||
): # Skip non-numeric scores
|
||||
if not isinstance(score_val, (int, float)):
|
||||
continue
|
||||
if score_val not in score_groups:
|
||||
score_groups[score_val] = []
|
||||
|
|
@ -983,8 +960,7 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
scores_after_tiebreak = modified_scores.copy()
|
||||
for score_val, indices_with_this_score in score_groups.items():
|
||||
if len(indices_with_this_score) > 1: # A tie is found
|
||||
# Check if token_lengths are available for all tied indices
|
||||
if len(indices_with_this_score) > 1:
|
||||
if not all(
|
||||
idx < len(token_lengths)
|
||||
for idx in indices_with_this_score
|
||||
|
|
@ -999,16 +975,14 @@ class BlackjackEnv(BaseEnv):
|
|||
continue
|
||||
|
||||
try:
|
||||
# Sort tied indices by their token_lengths
|
||||
sorted_tied_indices = sorted(
|
||||
indices_with_this_score,
|
||||
key=lambda i: token_lengths[i],
|
||||
)
|
||||
|
||||
# Apply penalty to all but the first (shortest token length)
|
||||
for rank, tied_idx in enumerate(
|
||||
sorted_tied_indices[1:], 1
|
||||
): # Start rank from 1 for 2nd shortest
|
||||
):
|
||||
penalty = 0.0001 * rank
|
||||
scores_after_tiebreak[tied_idx] -= penalty
|
||||
logger.debug(
|
||||
|
|
@ -1375,39 +1349,36 @@ class BlackjackEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(cls) -> Tuple[BlackjackEnvConfig, List[OpenaiConfig]]:
|
||||
env_config = BlackjackEnvConfig(
|
||||
# Fields from fundamental_prediction_environment.py's BaseEnvConfig init:
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16, # From Base, as not in BJ no_mc config's direct definition
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
max_num_workers=128,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=2000,
|
||||
batch_size=1024, # Matches BlackjackEnvConfig (no_mc) default as well
|
||||
batch_size=1024,
|
||||
steps_per_eval=20,
|
||||
max_token_length=1024 * 16,
|
||||
inference_weight=1.0,
|
||||
wandb_name="fundamental_metric_prediction", # Strict: Use value from fundamental_prediction
|
||||
wandb_name="fundamental_metric_prediction",
|
||||
data_path_to_save_groups=None,
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
# BlackjackEnvConfig (no_mc version) specific fields (those NOT in BaseEnvConfig from fundamental_prediction)
|
||||
# using their defined defaults from BlackjackEnvConfig (no_mc):
|
||||
env_name="Blackjack-v1", # Default from BlackjackEnvConfig (no_mc)
|
||||
temperature=0.7, # Default from BlackjackEnvConfig (no_mc)
|
||||
top_p=0.9, # Default from BlackjackEnvConfig (no_mc)
|
||||
max_turns=5, # Default from BlackjackEnvConfig (no_mc)
|
||||
thinking_active=True, # Default from BlackjackEnvConfig (no_mc)
|
||||
eval_episodes=100, # Default from BlackjackEnvConfig (no_mc)
|
||||
max_think_chars_history=3000, # Default from BlackjackEnvConfig (no_mc)
|
||||
max_trajectory_tokens=24576, # Default from BlackjackEnvConfig (no_mc)
|
||||
debug_mode=False, # Default from BlackjackEnvConfig (no_mc)
|
||||
env_name="Blackjack-v1",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_turns=5,
|
||||
thinking_active=True,
|
||||
eval_episodes=100,
|
||||
max_think_chars_history=3000,
|
||||
max_trajectory_tokens=24576,
|
||||
debug_mode=False,
|
||||
)
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256, # From fundamental_prediction_environment.py
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
|
@ -1514,8 +1485,7 @@ class BlackjackEnv(BaseEnv):
|
|||
original_step_data.get("messages")
|
||||
and original_step_data.get("tokens")
|
||||
and original_step_data.get("masks")
|
||||
and original_step_data.get("seed")
|
||||
is not None # seed is mandatory for new group
|
||||
and original_step_data.get("seed") is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} "
|
||||
|
|
@ -1523,8 +1493,6 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
continue
|
||||
|
||||
# Initial token calculation from original data to see if truncation is needed
|
||||
# Ensure tokens are lists of integers before calling len
|
||||
max_initial_tokens = 0
|
||||
if original_step_data["tokens"]:
|
||||
max_initial_tokens = (
|
||||
|
|
@ -1553,8 +1521,6 @@ class BlackjackEnv(BaseEnv):
|
|||
f"exceeds limit ({self.config.max_trajectory_tokens}). Attempting truncation."
|
||||
)
|
||||
|
||||
# Prepare working copies for modification
|
||||
# Ensure deep copies for lists of dicts if dicts are modified, but here we pop from list of dicts.
|
||||
working_messages = [
|
||||
msgs_list.copy() for msgs_list in original_step_data["messages"] or []
|
||||
]
|
||||
|
|
@ -1567,7 +1533,7 @@ class BlackjackEnv(BaseEnv):
|
|||
max_current_tokens = max_initial_tokens
|
||||
num_alternatives = len(working_messages)
|
||||
|
||||
if num_alternatives == 0: # Should not happen if initial checks passed
|
||||
if num_alternatives == 0:
|
||||
logger.warning(
|
||||
f"[_ensure_trajectory_token_limit] Step {step_idx} has no alternatives after copying. Skipping."
|
||||
)
|
||||
|
|
@ -1579,60 +1545,47 @@ class BlackjackEnv(BaseEnv):
|
|||
for alt_idx in range(num_alternatives):
|
||||
alt_msg_list = working_messages[alt_idx]
|
||||
|
||||
# Calculate how many initial messages (after system prompt) can be popped.
|
||||
# Preserving: system prompt (index 0), last agent response, and its preceding env observation.
|
||||
num_preserved_at_end = 0
|
||||
if (
|
||||
len(alt_msg_list) > 1
|
||||
and alt_msg_list[-1]["role"] in UNMASKED_ROLES
|
||||
):
|
||||
num_preserved_at_end = 1 # Last agent response
|
||||
num_preserved_at_end = 1
|
||||
if (
|
||||
len(alt_msg_list) > 2
|
||||
and alt_msg_list[-2]["role"] == "environment"
|
||||
):
|
||||
num_preserved_at_end = (
|
||||
2 # Agent response + preceding env observation
|
||||
)
|
||||
num_preserved_at_end = 2
|
||||
|
||||
# Number of messages available for popping (between system prompt and preserved end messages)
|
||||
# Subtract 1 for the system prompt itself (which is never popped from index 0).
|
||||
available_to_pop = len(alt_msg_list) - 1 - num_preserved_at_end
|
||||
|
||||
if available_to_pop <= 0:
|
||||
target_pop_counts_per_alt.append(0)
|
||||
else:
|
||||
# Try to pop a pair (environment, agent) if they are at list[1] and list[2]
|
||||
can_pop_pair = (
|
||||
available_to_pop >= 2
|
||||
and len(alt_msg_list) > 2
|
||||
and alt_msg_list[ # Ensure messages at index 1 and 2 exist
|
||||
1
|
||||
]["role"]
|
||||
== "environment"
|
||||
and alt_msg_list[1]["role"] == "environment"
|
||||
and alt_msg_list[2]["role"] in UNMASKED_ROLES
|
||||
)
|
||||
if can_pop_pair:
|
||||
target_pop_counts_per_alt.append(2)
|
||||
else: # Can pop at least 1 since available_to_pop > 0
|
||||
else:
|
||||
target_pop_counts_per_alt.append(1)
|
||||
|
||||
positive_pop_counts = [c for c in target_pop_counts_per_alt if c > 0]
|
||||
if not positive_pop_counts:
|
||||
break # No alternative can be truncated further
|
||||
break
|
||||
|
||||
min_pop_this_round = min(positive_pop_counts)
|
||||
|
||||
# Pop messages and re-tokenize
|
||||
temp_new_alt_tokens = []
|
||||
temp_new_alt_masks = []
|
||||
max_tokens_after_this_trunc = 0
|
||||
|
||||
for alt_idx in range(num_alternatives):
|
||||
for _ in range(min_pop_this_round):
|
||||
if (
|
||||
len(working_messages[alt_idx]) > 1
|
||||
): # Ensure there's something to pop after system
|
||||
if len(working_messages[alt_idx]) > 1:
|
||||
working_messages[alt_idx].pop(1)
|
||||
else:
|
||||
logger.error(
|
||||
|
|
@ -1671,7 +1624,6 @@ class BlackjackEnv(BaseEnv):
|
|||
f"[_ensure_trajectory_token_limit] Step {step_idx}, after uniform pop of {min_pop_this_round}, "
|
||||
f"max tokens: {max_current_tokens}"
|
||||
)
|
||||
# End of while loop for truncation attempts
|
||||
|
||||
if (
|
||||
not retokenization_error_this_step
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue