diff --git a/environments/game_environments/gymnasium/blackjack_env_thinking.py b/environments/game_environments/gymnasium/blackjack_env_thinking.py index 85396a41..bb49a64a 100644 --- a/environments/game_environments/gymnasium/blackjack_env_thinking.py +++ b/environments/game_environments/gymnasium/blackjack_env_thinking.py @@ -28,7 +28,7 @@ from atroposlib.envs.base import ( ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer -from atroposlib.utils.message_history_utils import truncate_thinking +from atroposlib.utils.message_history_utils import truncate_thinking, ensure_trajectory_token_limit from atroposlib.utils.tool_call_parser import parse_tool_call from atroposlib.utils.best_of_n_selection import select_best_index @@ -142,7 +142,6 @@ class BlackjackEnv(BaseEnv): env_reward: float, response_text: str, parsed_action: int, - episode_seed: int, ) -> float: """ Calculates a score for a single agent response based purely on environment reward @@ -419,7 +418,7 @@ class BlackjackEnv(BaseEnv): alt_is_terminal.append(term_i or trunc_i) combined_reward_i = self._score_response( - raw_env_reward_i, full_agent_response, parsed_action, ep.seed + raw_env_reward_i, full_agent_response, parsed_action ) alt_combined_rewards.append(combined_reward_i) diff --git a/environments/game_environments/gymnasium/blackjack_local_server.py b/environments/game_environments/gymnasium/blackjack_local_server.py index 9a6b37a0..53f3606c 100644 --- a/environments/game_environments/gymnasium/blackjack_local_server.py +++ b/environments/game_environments/gymnasium/blackjack_local_server.py @@ -6,7 +6,7 @@ import random from dotenv import load_dotenv from atroposlib.envs.base import EvalHandlingEnum, OpenaiConfig -from environments.game_environments.gymnasium.blackjack_env import ( +from environments.game_environments.gymnasium.blackjack_env_thinking import ( BlackjackEnv, BlackjackEnvConfig, ) @@ -76,9 +76,11 @@ async def main(): _ = env._get_or_create_episode(seed) - result_trajectory = await env.collect_trajectory(seed) + result_trajectories_tuple = await env.collect_trajectories((seed, 0)) + result_trajectory = result_trajectories_tuple[0] + logger.info( - f"Trajectory collection complete with {len(result_trajectory)} steps." + f"Trajectory collection complete with {len(result_trajectory)} groups/steps." ) episode_summary = None