mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linting
This commit is contained in:
parent
1a7c0294fa
commit
d8ab1a6758
1 changed files with 130 additions and 78 deletions
|
|
@ -15,7 +15,7 @@ import json
|
|||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import gymnasium
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
|
@ -309,7 +309,9 @@ class BlackjackEnv(BaseEnv):
|
|||
if sim_env is not None:
|
||||
sim_env.close()
|
||||
|
||||
async def _next_step(self, ep: EpisodeState, current_turn: int, max_turns: int) -> Tuple[Optional[BlackjackScoredDataGroup], bool]:
|
||||
async def _next_step(
|
||||
self, ep: EpisodeState, current_turn: int, max_turns: int
|
||||
) -> Tuple[Optional[BlackjackScoredDataGroup], bool]:
|
||||
"""Process one step/turn of an episode.
|
||||
|
||||
This involves estimating current state value, sampling multiple (G) responses from the LLM,
|
||||
|
|
@ -348,7 +350,7 @@ class BlackjackEnv(BaseEnv):
|
|||
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error estimating V(s_t): {e_vt}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None, True # Indicate error and episode termination
|
||||
return None, True # Indicate error and episode termination
|
||||
|
||||
messages_for_llm = current_state_messages.copy()
|
||||
agent_prompt_content = "<think>\n" if self.config.thinking_active else ""
|
||||
|
|
@ -362,13 +364,13 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Expected {G} responses, got {len(responses) if responses else 0}. "
|
||||
f"Aborting step."
|
||||
)
|
||||
return None, True # Indicate error and episode termination
|
||||
return None, True # Indicate error and episode termination
|
||||
except Exception as e_sample:
|
||||
logger.error(
|
||||
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Error sampling responses: {e_sample}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None, True # Indicate error and episode termination
|
||||
return None, True # Indicate error and episode termination
|
||||
|
||||
alt_full_responses: List[str] = []
|
||||
alt_parsed_actions: List[int] = []
|
||||
|
|
@ -389,17 +391,20 @@ class BlackjackEnv(BaseEnv):
|
|||
parsed_action = self._parse_tool_call(full_agent_response)
|
||||
alt_parsed_actions.append(parsed_action)
|
||||
|
||||
env_action = parsed_action if parsed_action != -1 else 0 # Default to stick on parse error
|
||||
env_action = (
|
||||
parsed_action if parsed_action != -1 else 0
|
||||
) # Default to stick on parse error
|
||||
alt_env_actions.append(env_action)
|
||||
|
||||
sim_env_i = None
|
||||
raw_env_reward_i = 0.0
|
||||
term_i, trunc_i = False, False
|
||||
next_state_msgs_i = []
|
||||
sim_obs_next_i = None
|
||||
sim_obs_next_i = None
|
||||
|
||||
try:
|
||||
sim_env_i = gymnasium.make(self.config.env_name)
|
||||
# replay env to same state as current episode
|
||||
_, _ = sim_env_i.reset(seed=ep.seed)
|
||||
for prev_action_idx, prev_action in enumerate(ep.actions):
|
||||
_, _, term_replay, trunc_replay, _ = sim_env_i.step(prev_action)
|
||||
|
|
@ -409,11 +414,11 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Sim env for alternative {i} terminated prematurely during history replay "
|
||||
f"(action {prev_action_idx+1}/{len(ep.actions)}). State mismatch or unexpected termination."
|
||||
)
|
||||
term_i, trunc_i = True, True
|
||||
raw_env_reward_i = 0.0
|
||||
term_i, trunc_i = True, True
|
||||
raw_env_reward_i = 0.0
|
||||
break
|
||||
|
||||
if not (term_i or trunc_i):
|
||||
|
||||
if not (term_i or trunc_i):
|
||||
sim_obs_next_i, raw_env_reward_i, term_i, trunc_i, _ = (
|
||||
sim_env_i.step(env_action)
|
||||
)
|
||||
|
|
@ -429,20 +434,18 @@ class BlackjackEnv(BaseEnv):
|
|||
current_state_plus_response_i = current_state_messages + [
|
||||
{"role": "agent", "content": full_agent_response}
|
||||
]
|
||||
if sim_obs_next_i is not None and not (term_i or trunc_i):
|
||||
if sim_obs_next_i is not None and not (term_i or trunc_i):
|
||||
next_state_msgs_i = current_state_plus_response_i + [
|
||||
{
|
||||
"role": "environment",
|
||||
"content": self._format_observation(sim_obs_next_i),
|
||||
}
|
||||
]
|
||||
else:
|
||||
else:
|
||||
next_state_msgs_i = current_state_plus_response_i
|
||||
alt_next_state_msgs.append(next_state_msgs_i)
|
||||
|
||||
tokenized_i = tokenize_for_trainer(
|
||||
self.tokenizer, next_state_msgs_i
|
||||
)
|
||||
tokenized_i = tokenize_for_trainer(self.tokenizer, next_state_msgs_i)
|
||||
alt_tokens.append(tokenized_i["tokens"])
|
||||
alt_masks.append(tokenized_i["masks"])
|
||||
|
||||
|
|
@ -452,20 +455,20 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Error simulating action {env_action} for alternative: {e_sim}",
|
||||
exc_info=True,
|
||||
)
|
||||
alt_raw_rewards.append(0.0)
|
||||
alt_combined_rewards.append(-1.0)
|
||||
alt_raw_rewards.append(0.0)
|
||||
alt_combined_rewards.append(-1.0)
|
||||
alt_next_state_msgs.append(
|
||||
current_state_messages
|
||||
+ [{"role": "agent", "content": full_agent_response}]
|
||||
+ [{"role": "agent", "content": full_agent_response}]
|
||||
)
|
||||
alt_is_terminal.append(True)
|
||||
alt_tokens.append([])
|
||||
alt_is_terminal.append(True)
|
||||
alt_tokens.append([])
|
||||
alt_masks.append([])
|
||||
finally:
|
||||
if sim_env_i:
|
||||
sim_env_i.close()
|
||||
|
||||
alt_value_next: List[float] = []
|
||||
alt_value_next: List[float] = []
|
||||
for i in range(G):
|
||||
if not alt_is_terminal[i]:
|
||||
try:
|
||||
|
|
@ -481,12 +484,16 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Error estimating V(s') for alternative: {e_vn}",
|
||||
exc_info=True,
|
||||
)
|
||||
alt_value_next.append(0.0)
|
||||
alt_value_next.append(0.0)
|
||||
else:
|
||||
alt_value_next.append(0.0)
|
||||
alt_value_next.append(0.0)
|
||||
|
||||
for i in range(G):
|
||||
if i < len(alt_combined_rewards) and i < len(alt_value_next) and value_t is not None:
|
||||
if (
|
||||
i < len(alt_combined_rewards)
|
||||
and i < len(alt_value_next)
|
||||
and value_t is not None
|
||||
):
|
||||
advantage_i = alt_combined_rewards[i] + alt_value_next[i] - value_t
|
||||
alt_advantages.append(advantage_i)
|
||||
logger.debug(
|
||||
|
|
@ -500,12 +507,14 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Skipping advantage calculation due to missing data or value_t. "
|
||||
f"len(alt_combined_rewards)={len(alt_combined_rewards)}, len(alt_value_next)={len(alt_value_next)}"
|
||||
)
|
||||
alt_advantages.append(-float('inf'))
|
||||
alt_advantages.append(-float("inf"))
|
||||
|
||||
if not (
|
||||
len(alt_tokens) == G and len(alt_masks) == G and
|
||||
len(alt_advantages) == G and len(alt_next_state_msgs) == G and
|
||||
len(alt_parsed_actions) == G
|
||||
len(alt_tokens) == G
|
||||
and len(alt_masks) == G
|
||||
and len(alt_advantages) == G
|
||||
and len(alt_next_state_msgs) == G
|
||||
and len(alt_parsed_actions) == G
|
||||
):
|
||||
logger.error(
|
||||
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
|
||||
|
|
@ -520,7 +529,7 @@ class BlackjackEnv(BaseEnv):
|
|||
seed=ep.seed,
|
||||
tokens=alt_tokens,
|
||||
masks=alt_masks,
|
||||
scores=alt_advantages,
|
||||
scores=alt_advantages,
|
||||
messages=alt_next_state_msgs,
|
||||
parsed_actions=alt_parsed_actions,
|
||||
)
|
||||
|
|
@ -534,8 +543,16 @@ class BlackjackEnv(BaseEnv):
|
|||
secondary_lower_is_better=True,
|
||||
)
|
||||
|
||||
chosen_advantage_for_log = alt_advantages[best_advantage_idx] if best_advantage_idx < len(alt_advantages) else "N/A"
|
||||
chosen_token_length_for_log = alt_token_lengths[best_advantage_idx] if best_advantage_idx < len(alt_token_lengths) else "N/A"
|
||||
chosen_advantage_for_log = (
|
||||
alt_advantages[best_advantage_idx]
|
||||
if best_advantage_idx < len(alt_advantages)
|
||||
else "N/A"
|
||||
)
|
||||
chosen_token_length_for_log = (
|
||||
alt_token_lengths[best_advantage_idx]
|
||||
if best_advantage_idx < len(alt_token_lengths)
|
||||
else "N/A"
|
||||
)
|
||||
logger.debug(
|
||||
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] "
|
||||
f"Selected Alt {best_advantage_idx} "
|
||||
|
|
@ -544,10 +561,21 @@ class BlackjackEnv(BaseEnv):
|
|||
f"from {G} alternatives."
|
||||
)
|
||||
|
||||
chosen_env_action = alt_env_actions[best_advantage_idx] if best_advantage_idx < len(alt_env_actions) else 0
|
||||
chosen_full_response = alt_full_responses[best_advantage_idx] if best_advantage_idx < len(alt_full_responses) else ""
|
||||
chosen_parsed_action = alt_parsed_actions[best_advantage_idx] if best_advantage_idx < len(alt_parsed_actions) else -1
|
||||
|
||||
chosen_env_action = (
|
||||
alt_env_actions[best_advantage_idx]
|
||||
if best_advantage_idx < len(alt_env_actions)
|
||||
else 0
|
||||
)
|
||||
chosen_full_response = (
|
||||
alt_full_responses[best_advantage_idx]
|
||||
if best_advantage_idx < len(alt_full_responses)
|
||||
else ""
|
||||
)
|
||||
chosen_parsed_action = (
|
||||
alt_parsed_actions[best_advantage_idx]
|
||||
if best_advantage_idx < len(alt_parsed_actions)
|
||||
else -1
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[Next Step Seed: {ep.seed} Turn: {current_turn + 1}] Chosen action to step env: "
|
||||
|
|
@ -556,29 +584,36 @@ class BlackjackEnv(BaseEnv):
|
|||
)
|
||||
|
||||
ep.num_total_actions += 1
|
||||
if chosen_parsed_action != -1:
|
||||
if chosen_parsed_action != -1:
|
||||
ep.num_correct_actions += 1
|
||||
|
||||
|
||||
response_for_history = truncate_thinking(
|
||||
chosen_full_response,
|
||||
self.tokenizer,
|
||||
self.config.max_think_chars_history,
|
||||
)
|
||||
ep.message_history.append(
|
||||
{"role": "agent", "content": response_for_history}
|
||||
)
|
||||
ep.message_history.append({"role": "agent", "content": response_for_history})
|
||||
|
||||
main_obs_next, main_reward_this_step, main_term_this_step, main_trunc_this_step = None, 0.0, False, False
|
||||
(
|
||||
main_obs_next,
|
||||
main_reward_this_step,
|
||||
main_term_this_step,
|
||||
main_trunc_this_step,
|
||||
) = (None, 0.0, False, False)
|
||||
try:
|
||||
main_obs_next, main_reward_this_step, main_term_this_step, main_trunc_this_step, _ = ep.env.step(
|
||||
chosen_env_action
|
||||
)
|
||||
(
|
||||
main_obs_next,
|
||||
main_reward_this_step,
|
||||
main_term_this_step,
|
||||
main_trunc_this_step,
|
||||
_,
|
||||
) = ep.env.step(chosen_env_action)
|
||||
|
||||
ep.actions.append(chosen_env_action)
|
||||
ep.step_rewards.append(main_reward_this_step)
|
||||
ep.num_steps += 1
|
||||
|
||||
if main_obs_next:
|
||||
if main_obs_next:
|
||||
ep.message_history.append(
|
||||
{
|
||||
"role": "environment",
|
||||
|
|
@ -591,10 +626,10 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Error stepping MAIN environment with chosen action {chosen_env_action}: {e_main_step}",
|
||||
exc_info=True,
|
||||
)
|
||||
main_term_this_step, main_trunc_this_step = True, True
|
||||
main_term_this_step, main_trunc_this_step = True, True
|
||||
|
||||
is_episode_terminal_this_step = main_term_this_step or main_trunc_this_step
|
||||
|
||||
|
||||
return current_step_data, is_episode_terminal_this_step
|
||||
|
||||
async def score(
|
||||
|
|
@ -632,14 +667,15 @@ class BlackjackEnv(BaseEnv):
|
|||
- List of BlackjackScoredDataGroup objects: Contains the collected data for each step of the trajectory.
|
||||
- List of Tuple[int, int]: Backlog items (always empty in this implementation).
|
||||
"""
|
||||
seed, _ = item
|
||||
G_config = self.config.group_size
|
||||
seed, _ = item
|
||||
G_config = self.config.group_size
|
||||
max_turns = self.config.max_turns or 5
|
||||
|
||||
trajectory_data_for_trainer: List[BlackjackScoredDataGroup] = []
|
||||
|
||||
logger.info(
|
||||
f"[Collect Trajectories Seed: {seed}] Starting new trajectory. Group size G={G_config}, Max turns={max_turns}."
|
||||
f"[Collect Trajectories Seed: {seed}] Starting new trajectory. "
|
||||
f"Group size G={G_config}, Max turns={max_turns}."
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -649,31 +685,38 @@ class BlackjackEnv(BaseEnv):
|
|||
f"[Collect Trajectories Seed: {seed}] Fatal error creating/getting episode: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return [], []
|
||||
return [], []
|
||||
|
||||
for turn_idx in range(max_turns):
|
||||
logger.debug(f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}.")
|
||||
|
||||
step_data, is_terminal_this_step = await self._next_step(ep, turn_idx, max_turns)
|
||||
logger.debug(
|
||||
f"[Collect Trajectories Seed: {seed}] Attempting turn {turn_idx + 1}/{max_turns}."
|
||||
)
|
||||
|
||||
step_data, is_terminal_this_step = await self._next_step(
|
||||
ep, turn_idx, max_turns
|
||||
)
|
||||
|
||||
if step_data:
|
||||
trajectory_data_for_trainer.append(step_data)
|
||||
else:
|
||||
logger.error(
|
||||
f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data. Terminating episode."
|
||||
f"[Collect Trajectories Seed: {seed}] Turn {turn_idx + 1} failed to produce data."
|
||||
" Terminating episode."
|
||||
)
|
||||
is_terminal_this_step = True
|
||||
|
||||
is_terminal_this_step = True
|
||||
|
||||
if is_terminal_this_step:
|
||||
final_reward_at_termination = sum(ep.step_rewards) if ep.step_rewards else 0.0
|
||||
final_reward_at_termination = (
|
||||
sum(ep.step_rewards) if ep.step_rewards else 0.0
|
||||
)
|
||||
logger.info(
|
||||
f"[Collect Trajectories Seed: {seed}] Episode ended at turn {turn_idx + 1}. "
|
||||
f"Reason: step reported terminal. Total raw env reward: {final_reward_at_termination:.2f}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})."
|
||||
f"[Collect Trajectories Seed: {seed}] Episode reached max_turns ({max_turns})."
|
||||
)
|
||||
|
||||
final_raw_reward = sum(ep.step_rewards) if ep.step_rewards else 0.0
|
||||
|
|
@ -684,17 +727,19 @@ class BlackjackEnv(BaseEnv):
|
|||
f"Final raw reward: {final_raw_reward:.2f}"
|
||||
)
|
||||
|
||||
if ep:
|
||||
if ep:
|
||||
game_outcome = 0
|
||||
if final_raw_reward > 0: game_outcome = 1
|
||||
elif final_raw_reward < 0: game_outcome = -1
|
||||
|
||||
if final_raw_reward > 0:
|
||||
game_outcome = 1
|
||||
elif final_raw_reward < 0:
|
||||
game_outcome = -1
|
||||
|
||||
episode_summary_metrics = {
|
||||
"seed": ep.seed,
|
||||
"total_reward": final_raw_reward,
|
||||
"num_steps": ep.num_steps,
|
||||
"num_steps": ep.num_steps,
|
||||
"num_correct_actions": ep.num_correct_actions,
|
||||
"num_total_actions": ep.num_total_actions,
|
||||
"num_total_actions": ep.num_total_actions,
|
||||
"game_outcome": game_outcome,
|
||||
}
|
||||
self.completed_episode_metrics_buffer.append(episode_summary_metrics)
|
||||
|
|
@ -704,30 +749,37 @@ class BlackjackEnv(BaseEnv):
|
|||
|
||||
if seed in self.episodes:
|
||||
try:
|
||||
if hasattr(self.episodes[seed], 'env') and self.episodes[seed].env is not None:
|
||||
if (
|
||||
hasattr(self.episodes[seed], "env")
|
||||
and self.episodes[seed].env is not None
|
||||
):
|
||||
self.episodes[seed].env.close()
|
||||
except Exception as e_close:
|
||||
logger.warning(
|
||||
f"[Collect Trajectories Seed: {seed}] Exception closing environment for episode: {e_close}",
|
||||
exc_info=True
|
||||
exc_info=True,
|
||||
)
|
||||
del self.episodes[seed]
|
||||
del self.episodes[seed]
|
||||
|
||||
if not trajectory_data_for_trainer:
|
||||
logger.warning(f"[Collect Trajectories Seed: {seed}] Collected an empty trajectory (no valid steps).")
|
||||
return [], []
|
||||
logger.warning(
|
||||
f"[Collect Trajectories Seed: {seed}] Collected an empty trajectory (no valid steps)."
|
||||
)
|
||||
return [], []
|
||||
|
||||
limited_trajectory_data = ensure_trajectory_token_limit(
|
||||
trajectory_data_for_trainer,
|
||||
self.tokenizer,
|
||||
self.config.max_trajectory_tokens,
|
||||
)
|
||||
|
||||
if not limited_trajectory_data:
|
||||
logger.warning(f"[Collect Trajectories Seed: {seed}] Trajectory became empty after token limiting.")
|
||||
return [], []
|
||||
|
||||
return limited_trajectory_data, []
|
||||
if not limited_trajectory_data:
|
||||
logger.warning(
|
||||
f"[Collect Trajectories Seed: {seed}] Trajectory became empty after token limiting."
|
||||
)
|
||||
return [], []
|
||||
|
||||
return limited_trajectory_data, []
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue