diff --git a/environments/tool_use_turnlevel_advantage_server.py b/environments/tool_use_turnlevel_advantage_server.py index f54cee64..c875f965 100644 --- a/environments/tool_use_turnlevel_advantage_server.py +++ b/environments/tool_use_turnlevel_advantage_server.py @@ -1,14 +1,3 @@ -# Negative reward applied when the first mismatched tool-call causes early termination. -WRONG_CALL_PENALTY = -0.2 -# Hard cap on how many new tokens the model may generate in a single turn. -MAX_GEN_PER_TURN = 1024 -# Hard cap on how many tool-call turns we will actually roll out -MAX_TOOL_CALL_TURNS = 2 -# Whether to validate that all GPT messages have blocks [useful when non-tool call gpt messages are inserted] -VALIDATE_THINK_BLOCKS = True -# Turn-level advantage coefficient (λ in MT-GRPO paper) -# Paper implementation uses 1.0, but we can experiment with different values -TURN_LEVEL_ADVANTAGE_LAMBDA = 0.5 # Configurable: try 0.1, 0.5, 1.0 """ Multi-Turn Tool-Calling Environment with Turn-Level Advantages @@ -40,10 +29,10 @@ import numpy as np from typing import Dict, List, Optional, Tuple, Union from collections import Counter - import wandb from datasets import load_dataset from tqdm.asyncio import tqdm_asyncio +from pydantic import Field from atroposlib.envs.base import ( APIServerConfig, @@ -55,6 +44,37 @@ from atroposlib.envs.base import ( ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +# Easy-to-change constants for experimentation - modify these for quick testing +WRONG_CALL_PENALTY = -0.2 +MAX_GEN_PER_TURN = 1024 +MAX_TOOL_CALL_TURNS = 3 +VALIDATE_THINK_BLOCKS = True +TURN_LEVEL_ADVANTAGE_LAMBDA = 0.5 # Paper uses 1.0, experiment with 0.1, 0.5, 1.0 + + +class MTGRPOEnvConfig(BaseEnvConfig): + """Configuration for Multi-Turn Tool Calling with Turn-Level Advantages Environment.""" + max_tool_call_turns: int = Field( + default=2, + description="Hard cap on how many tool-call turns we will actually roll out" + ) + validate_think_blocks: bool = Field( + default=True, + description="Whether to validate that all GPT messages have blocks [useful when non-tool call gpt messages are inserted]" + ) + max_gen_per_turn: int = Field( + default=1024, + description="Hard cap on how many new tokens the model may generate in a single turn" + ) + wrong_call_penalty: float = Field( + default=-0.2, + description="Negative reward applied when the first mismatched tool-call causes early termination" + ) + turn_level_advantage_lambda: float = Field( + default=0.5, + description="Turn-level advantage coefficient (λ in MT-GRPO paper). Paper implementation uses 1.0, but we can experiment with different values like 0.1, 0.5, 1.0" + ) + system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " @@ -118,7 +138,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): def __init__( self, - config: BaseEnvConfig, + config: MTGRPOEnvConfig, server_configs: List[APIServerConfig], slurm: bool = True, testing: bool = False, @@ -142,8 +162,8 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - env_cfg = BaseEnvConfig( + def config_init(cls) -> Tuple[MTGRPOEnvConfig, List[APIServerConfig]]: + env_cfg = MTGRPOEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, use_wandb=True, @@ -156,6 +176,12 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): wandb_name="multiturn_tool_use_turnlevel_advantage", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, + # Override config defaults with experimental constants + wrong_call_penalty=WRONG_CALL_PENALTY, + max_gen_per_turn=MAX_GEN_PER_TURN, + max_tool_call_turns=MAX_TOOL_CALL_TURNS, + validate_think_blocks=VALIDATE_THINK_BLOCKS, + turn_level_advantage_lambda=TURN_LEVEL_ADVANTAGE_LAMBDA, ) server_cfgs = [ APIServerConfig( @@ -204,7 +230,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): the answer is the list of function_call JSONs (canonical string). Each turn can have multiple tool calls. - We only keep those samples that contain = MAX_TOOL_CALL_TURNS separate messages with . + We only keep those samples that contain = config.max_tool_call_turns separate messages with . """ target = self.train_items if is_train else self.test_items before_len = len(target) @@ -227,7 +253,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): continue # Optional: Validate blocks in gpt messages if enabled - if VALIDATE_THINK_BLOCKS: + if self.config.validate_think_blocks: gpt_messages = [msg for msg in conv if msg["from"] in ("gpt", "assistant")] if not all("" in msg["value"].lower() for msg in gpt_messages): continue @@ -303,7 +329,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): while len(inter_turns) < max(0, len(expected_calls_by_turn) - 1): inter_turns.append([]) - if tool_call_turns == MAX_TOOL_CALL_TURNS: + if tool_call_turns == self.config.max_tool_call_turns: target.append((tuple(running_msgs), expected_calls_by_turn, inter_turns)) print(f"[prep_items] {'train' if is_train else 'test'}: added {len(target)-before_len} items.") @@ -325,7 +351,13 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): """ turn_rewards = [] - for turn_idx, (response, pred_turn, expected_turn) in enumerate(zip(responses_by_turn, pred_calls_by_turn, expected_calls_by_turn)): + # Only iterate over the turns that this rollout actually completed + num_actual_turns = min(len(responses_by_turn), len(pred_calls_by_turn), len(expected_calls_by_turn)) + + for turn_idx in range(num_actual_turns): + response = responses_by_turn[turn_idx] if turn_idx < len(responses_by_turn) else "" + pred_turn = pred_calls_by_turn[turn_idx] if turn_idx < len(pred_calls_by_turn) else [] + expected_turn = expected_calls_by_turn[turn_idx] # Turn-level reward components turn_reward = 0.0 @@ -364,7 +396,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): # Apply mismatch penalty if needed if pred_turn and pred_turn[-1] == "__MISMATCH__": - turn_reward += WRONG_CALL_PENALTY # This is negative + turn_reward += self.config.wrong_call_penalty # This is negative turn_rewards.append(turn_reward) @@ -392,12 +424,22 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): Returns: List of advantages for each rollout [num_rollouts x num_turns] """ - # Compute standardized turn advantages (A_T) - turn_advantages_batch = [] - num_turns = len(turn_rewards_batch[0]) if turn_rewards_batch else 0 + if not turn_rewards_batch: + return [] - for turn_idx in range(num_turns): - turn_rewards_for_this_turn = [rewards[turn_idx] for rewards in turn_rewards_batch] + # Find the maximum number of turns across all rollouts + max_turns = max(len(rewards) for rewards in turn_rewards_batch) + + # Pad shorter reward lists with 0.0 for terminated rollouts + padded_turn_rewards_batch = [] + for rewards in turn_rewards_batch: + padded_rewards = rewards + [0.0] * (max_turns - len(rewards)) + padded_turn_rewards_batch.append(padded_rewards) + + # Compute standardized turn advantages (A_T) for each turn + turn_advantages_batch = [] + for turn_idx in range(max_turns): + turn_rewards_for_this_turn = [rewards[turn_idx] for rewards in padded_turn_rewards_batch] mean_turn_reward = np.mean(turn_rewards_for_this_turn) std_turn_reward = np.std(turn_rewards_for_this_turn) if std_turn_reward == 0: @@ -414,14 +456,16 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): outcome_advantages = [(r - mean_outcome_reward) / std_outcome_reward for r in outcome_rewards_batch] - # Combine according to MT-GRPO formula + # Combine according to MT-GRPO formula, but only for actual turns (not padded ones) mt_grpo_advantages = [] for rollout_idx in range(len(turn_rewards_batch)): rollout_advantages = [] - for turn_idx in range(num_turns): - if turn_idx < num_turns - 1: # Not the last turn + actual_num_turns = len(turn_rewards_batch[rollout_idx]) # Original length before padding + + for turn_idx in range(actual_num_turns): + if turn_idx < actual_num_turns - 1: # Not the last turn # A_T_i + λ * A_O_i - advantage = turn_advantages_batch[turn_idx][rollout_idx] + TURN_LEVEL_ADVANTAGE_LAMBDA * outcome_advantages[rollout_idx] + advantage = turn_advantages_batch[turn_idx][rollout_idx] + self.config.turn_level_advantage_lambda * outcome_advantages[rollout_idx] else: # Last turn # A_O_i only advantage = outcome_advantages[rollout_idx] @@ -524,7 +568,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): mismatch_penalty = 0.0 if pred_calls and pred_calls[-1] == "__MISMATCH__": pred_calls = pred_calls[:-1] - mismatch_penalty = WRONG_CALL_PENALTY + mismatch_penalty = self.config.wrong_call_penalty correct = sum( 1 for p, e in zip(pred_calls, exp_jsons) if _json_objects_match(p, e) ) @@ -744,12 +788,12 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): num_rollouts = self.config.group_size contexts: List[List[Dict[str, str]]] = [list(base_ctx) for _ in range(num_rollouts)] # Track predictions by turn - preds_by_turn: List[List[List]] = [[[] for _ in range(MAX_TOOL_CALL_TURNS)] for _ in range(num_rollouts)] + preds_by_turn: List[List[List]] = [[[] for _ in range(self.config.max_tool_call_turns)] for _ in range(num_rollouts)] # Track responses by turn for reward computation responses_by_turn: List[List[str]] = [[] for _ in range(num_rollouts)] active = [True] * num_rollouts - max_turns = min(len(expected_calls_by_turn), MAX_TOOL_CALL_TURNS) + max_turns = min(len(expected_calls_by_turn), self.config.max_tool_call_turns) for turn_idx in range(max_turns): print(f"[collect_trajectories] Beginning turn {turn_idx+1}/{max_turns} for this group") @@ -762,7 +806,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv): max_prompt_len = max(len(p) for p in prompts) max_gen = min( - MAX_GEN_PER_TURN, + self.config.max_gen_per_turn, max(1, self.config.max_token_length - max_prompt_len), )