This commit is contained in:
Shannon Sands 2025-05-14 14:20:54 -07:00
parent 1a7c0294fa
commit d8ab1a6758

View file

@ -15,7 +15,7 @@ import json
import logging
import random
import re
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import gymnasium
from tqdm.asyncio import tqdm_asyncio
@ -309,7 +309,9 @@ class BlackjackEnv(BaseEnv):
if sim_env is not None:
sim_env.close()
async def _next_step(self, ep: EpisodeState, current_turn: int, max_turns: int) -> Tuple[Optional[BlackjackScoredDataGroup], bool]:
async def _next_step(
self, ep: EpisodeState, current_turn: int, max_turns: int
) -> Tuple[Optional[BlackjackScoredDataGroup], bool]:
"""Process one step/turn of an episode.
This involves estimating current state value, sampling multiple (G) responses from the LLM,
@ -348,7 +350,7 @@ class BlackjackEnv(BaseEnv):
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error estimating V(s_t): {e_vt}",
exc_info=True,
)
return None, True # Indicate error and episode termination
return None, True # Indicate error and episode termination
messages_for_llm = current_state_messages.copy()
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
@ -362,13 +364,13 @@ class BlackjackEnv(BaseEnv):
f"Expected {G} responses, got {len(responses) if responses else 0}. "
f"Aborting step."
)
return None, True # Indicate error and episode termination
return None, True # Indicate error and episode termination
except Exception as e_sample:
logger.error(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error sampling responses: {e_sample}",
exc_info=True,
)
return None, True # Indicate error and episode termination
return None, True # Indicate error and episode termination
alt_full_responses: List[str] = []
alt_parsed_actions: List[int] = []
@ -389,17 +391,20 @@ class BlackjackEnv(BaseEnv):
parsed_action = self._parse_tool_call(full_agent_response)
alt_parsed_actions.append(parsed_action)
env_action = parsed_action if parsed_action != -1 else 0 # Default to stick on parse error
env_action = (
parsed_action if parsed_action != -1 else 0
) # Default to stick on parse error
alt_env_actions.append(env_action)
sim_env_i = None
raw_env_reward_i = 0.0
term_i, trunc_i = False, False
next_state_msgs_i = []
sim_obs_next_i = None
sim_obs_next_i = None
try:
sim_env_i = gymnasium.make(self.config.env_name)
# replay env to same state as current episode
_, _ = sim_env_i.reset(seed=ep.seed)
for prev_action_idx, prev_action in enumerate(ep.actions):
_, _, term_replay, trunc_replay, _ = sim_env_i.step(prev_action)
@ -409,11 +414,11 @@ class BlackjackEnv(BaseEnv):
f"Sim env for alternative {i} terminated prematurely during history replay "
f"(action {prev_action_idx+1}/{len(ep.actions)}). State mismatch or unexpected termination."
)
term_i, trunc_i = True, True
raw_env_reward_i = 0.0
term_i, trunc_i = True, True
raw_env_reward_i = 0.0
break
if not (term_i or trunc_i):
if not (term_i or trunc_i):
sim_obs_next_i, raw_env_reward_i, term_i, trunc_i, _ = (
sim_env_i.step(env_action)
)
@ -429,20 +434,18 @@ class BlackjackEnv(BaseEnv):
current_state_plus_response_i = current_state_messages + [
{"role": "agent", "content": full_agent_response}
]
if sim_obs_next_i is not None and not (term_i or trunc_i):
if sim_obs_next_i is not None and not (term_i or trunc_i):
next_state_msgs_i = current_state_plus_response_i + [
{
"role": "environment",
"content": self._format_observation(sim_obs_next_i),
}
]
else:
else:
next_state_msgs_i = current_state_plus_response_i
alt_next_state_msgs.append(next_state_msgs_i)
tokenized_i = tokenize_for_trainer(
self.tokenizer, next_state_msgs_i
)
tokenized_i = tokenize_for_trainer(self.tokenizer, next_state_msgs_i)
alt_tokens.append(tokenized_i["tokens"])
alt_masks.append(tokenized_i["masks"])
@ -452,20 +455,20 @@ class BlackjackEnv(BaseEnv):
f"Error simulating action {env_action} for alternative: {e_sim}",
exc_info=True,
)
alt_raw_rewards.append(0.0)
alt_combined_rewards.append(-1.0)
alt_raw_rewards.append(0.0)
alt_combined_rewards.append(-1.0)
alt_next_state_msgs.append(
current_state_messages
+ [{"role": "agent", "content": full_agent_response}]
+ [{"role": "agent", "content": full_agent_response}]
)
alt_is_terminal.append(True)
alt_tokens.append([])
alt_is_terminal.append(True)
alt_tokens.append([])
alt_masks.append([])
finally:
if sim_env_i:
sim_env_i.close()
alt_value_next: List[float] = []
alt_value_next: List[float] = []
for i in range(G):
if not alt_is_terminal[i]:
try:
@ -481,12 +484,16 @@ class BlackjackEnv(BaseEnv):
f"Error estimating V(s') for alternative: {e_vn}",
exc_info=True,
)
alt_value_next.append(0.0)
alt_value_next.append(0.0)
else:
alt_value_next.append(0.0)
alt_value_next.append(0.0)
for i in range(G):
if i < len(alt_combined_rewards) and i < len(alt_value_next) and value_t is not None:
if (
i < len(alt_combined_rewards)
and i < len(alt_value_next)
and value_t is not None
):
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
alt_advantages.append(advantage_i)
logger.debug(
@ -500,12 +507,14 @@ class BlackjackEnv(BaseEnv):
f"Skipping advantage calculation due to missing data or value_t. "
f"len(alt_combined_rewards)={len(alt_combined_rewards)}, len(alt_value_next)={len(alt_value_next)}"
)
alt_advantages.append(-float('inf'))
alt_advantages.append(-float("inf"))
if not (
len(alt_tokens) == G and len(alt_masks) == G and
len(alt_advantages) == G and len(alt_next_state_msgs) == G and
len(alt_parsed_actions) == G
len(alt_tokens) == G
and len(alt_masks) == G
and len(alt_advantages) == G
and len(alt_next_state_msgs) == G
and len(alt_parsed_actions) == G
):
logger.error(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
@ -520,7 +529,7 @@ class BlackjackEnv(BaseEnv):
seed=ep.seed,
tokens=alt_tokens,
masks=alt_masks,
scores=alt_advantages,
scores=alt_advantages,
messages=alt_next_state_msgs,
parsed_actions=alt_parsed_actions,
)
@ -534,8 +543,16 @@ class BlackjackEnv(BaseEnv):
secondary_lower_is_better=True,
)
chosen_advantage_for_log = alt_advantages[best_advantage_idx] if best_advantage_idx < len(alt_advantages) else "N/A"
chosen_token_length_for_log = alt_token_lengths[best_advantage_idx] if best_advantage_idx < len(alt_token_lengths) else "N/A"
chosen_advantage_for_log = (
alt_advantages[best_advantage_idx]
if best_advantage_idx < len(alt_advantages)
else "N/A"
)
chosen_token_length_for_log = (
alt_token_lengths[best_advantage_idx]
if best_advantage_idx < len(alt_token_lengths)
else "N/A"
)
logger.debug(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
f"Selected Alt {best_advantage_idx} "
@ -544,10 +561,21 @@ class BlackjackEnv(BaseEnv):
f"from {G} alternatives."
)
chosen_env_action = alt_env_actions[best_advantage_idx] if best_advantage_idx < len(alt_env_actions) else 0
chosen_full_response = alt_full_responses[best_advantage_idx] if best_advantage_idx < len(alt_full_responses) else ""
chosen_parsed_action = alt_parsed_actions[best_advantage_idx] if best_advantage_idx < len(alt_parsed_actions) else -1
chosen_env_action = (
alt_env_actions[best_advantage_idx]
if best_advantage_idx < len(alt_env_actions)
else 0
)
chosen_full_response = (
alt_full_responses[best_advantage_idx]
if best_advantage_idx < len(alt_full_responses)
else ""
)
chosen_parsed_action = (
alt_parsed_actions[best_advantage_idx]
if best_advantage_idx < len(alt_parsed_actions)
else -1
)
logger.info(
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Chosen action to step env: "
@ -556,29 +584,36 @@ class BlackjackEnv(BaseEnv):
)
ep.num_total_actions += 1
if chosen_parsed_action != -1:
if chosen_parsed_action != -1:
ep.num_correct_actions += 1
response_for_history = truncate_thinking(
chosen_full_response,
self.tokenizer,
self.config.max_think_chars_history,
)
ep.message_history.append(
{"role": "agent", "content": response_for_history}
)
ep.message_history.append({"role": "agent", "content": response_for_history})
main_obs_next, main_reward_this_step, main_term_this_step, main_trunc_this_step = None, 0.0, False, False
(
main_obs_next,
main_reward_this_step,
main_term_this_step,
main_trunc_this_step,
) = (None, 0.0, False, False)
try:
main_obs_next, main_reward_this_step, main_term_this_step, main_trunc_this_step, _ = ep.env.step(
chosen_env_action
)
(
main_obs_next,
main_reward_this_step,
main_term_this_step,
main_trunc_this_step,
_,
) = ep.env.step(chosen_env_action)
ep.actions.append(chosen_env_action)
ep.step_rewards.append(main_reward_this_step)
ep.num_steps += 1
if main_obs_next:
if main_obs_next:
ep.message_history.append(
{
"role": "environment",
@ -591,10 +626,10 @@ class BlackjackEnv(BaseEnv):
f"Error stepping MAIN environment with chosen action {chosen_env_action}: {e_main_step}",
exc_info=True,
)
main_term_this_step, main_trunc_this_step = True, True
main_term_this_step, main_trunc_this_step = True, True
is_episode_terminal_this_step = main_term_this_step or main_trunc_this_step
return current_step_data, is_episode_terminal_this_step
async def score(
@ -632,14 +667,15 @@ class BlackjackEnv(BaseEnv):
- List of BlackjackScoredDataGroup objects: Contains the collected data for each step of the trajectory.
- List of Tuple[int, int]: Backlog items (always empty in this implementation).
"""
seed, _ = item
G_config = self.config.group_size
seed, _ = item
G_config = self.config.group_size
max_turns = self.config.max_turns or 5
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
logger.info(
f"[Collect Trajectories Seed: {seed}] Starting new trajectory. Group size G={G_config}, Max turns={max_turns}."
f"[Collect Trajectories Seed: {seed}] Starting new trajectory. "
f"Group size G={G_config}, Max turns={max_turns}."
)
try:
@ -649,31 +685,38 @@ class BlackjackEnv(BaseEnv):
f"[Collect Trajectories Seed: {seed}] Fatal error creating/getting episode: {e}",
exc_info=True,
)
return [], []
return [], []
for turn_idx in range(max_turns):
logger.debug(f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}.")
step_data, is_terminal_this_step = await self._next_step(ep, turn_idx, max_turns)
logger.debug(
f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}."
)
step_data, is_terminal_this_step = await self._next_step(
ep, turn_idx, max_turns
)
if step_data:
trajectory_data_for_trainer.append(step_data)
else:
logger.error(
f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data. Terminating episode."
f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data."
" Terminating episode."
)
is_terminal_this_step = True
is_terminal_this_step = True
if is_terminal_this_step:
final_reward_at_termination = sum(ep.step_rewards) if ep.step_rewards else 0.0
final_reward_at_termination = (
sum(ep.step_rewards) if ep.step_rewards else 0.0
)
logger.info(
f"[Collect Trajectories Seed: {seed}] Episode ended at turn {turn_idx + 1}. "
f"Reason: step reported terminal. Total raw env reward: {final_reward_at_termination:.2f}"
)
break
else:
break
else:
logger.info(
f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})."
f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})."
)
final_raw_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0
@ -684,17 +727,19 @@ class BlackjackEnv(BaseEnv):
f"Final raw reward: {final_raw_reward:.2f}"
)
if ep:
if ep:
game_outcome = 0
if final_raw_reward > 0: game_outcome = 1
elif final_raw_reward < 0: game_outcome = -1
if final_raw_reward > 0:
game_outcome = 1
elif final_raw_reward < 0:
game_outcome = -1
episode_summary_metrics = {
"seed": ep.seed,
"total_reward": final_raw_reward,
"num_steps": ep.num_steps,
"num_steps": ep.num_steps,
"num_correct_actions": ep.num_correct_actions,
"num_total_actions": ep.num_total_actions,
"num_total_actions": ep.num_total_actions,
"game_outcome": game_outcome,
}
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
@ -704,30 +749,37 @@ class BlackjackEnv(BaseEnv):
if seed in self.episodes:
try:
if hasattr(self.episodes[seed], 'env') and self.episodes[seed].env is not None:
if (
hasattr(self.episodes[seed], "env")
and self.episodes[seed].env is not None
):
self.episodes[seed].env.close()
except Exception as e_close:
logger.warning(
f"[Collect Trajectories Seed: {seed}] Exception closing environment for episode: {e_close}",
exc_info=True
exc_info=True,
)
del self.episodes[seed]
del self.episodes[seed]
if not trajectory_data_for_trainer:
logger.warning(f"[Collect Trajectories Seed: {seed}] Collected an empty trajectory (no valid steps).")
return [], []
logger.warning(
f"[Collect Trajectories Seed: {seed}] Collected an empty trajectory (no valid steps)."
)
return [], []
limited_trajectory_data = ensure_trajectory_token_limit(
trajectory_data_for_trainer,
self.tokenizer,
self.config.max_trajectory_tokens,
)
if not limited_trajectory_data:
logger.warning(f"[Collect Trajectories Seed: {seed}] Trajectory became empty after token limiting.")
return [], []
return limited_trajectory_data, []
if not limited_trajectory_data:
logger.warning(
f"[Collect Trajectories Seed: {seed}] Trajectory became empty after token limiting."
)
return [], []
return limited_trajectory_data, []
async def setup(self):
pass