mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
BaseEnvConfig subclass for experimental vars
This commit is contained in:
parent
2aa950a5a8
commit
a4cdf80e4a
1 changed files with 77 additions and 33 deletions
|
|
@ -1,14 +1,3 @@
|
|||
# Negative reward applied when the first mismatched tool-call causes early termination.
|
||||
WRONG_CALL_PENALTY = -0.2
|
||||
# Hard cap on how many new tokens the model may generate in a single turn.
|
||||
MAX_GEN_PER_TURN = 1024
|
||||
# Hard cap on how many tool-call turns we will actually roll out
|
||||
MAX_TOOL_CALL_TURNS = 2
|
||||
# Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]
|
||||
VALIDATE_THINK_BLOCKS = True
|
||||
# Turn-level advantage coefficient (λ in MT-GRPO paper)
|
||||
# Paper implementation uses 1.0, but we can experiment with different values
|
||||
TURN_LEVEL_ADVANTAGE_LAMBDA = 0.5 # Configurable: try 0.1, 0.5, 1.0
|
||||
|
||||
"""
|
||||
Multi-Turn Tool-Calling Environment with Turn-Level Advantages
|
||||
|
|
@ -40,10 +29,10 @@ import numpy as np
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from collections import Counter
|
||||
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
|
|
@ -55,6 +44,37 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Easy-to-change constants for experimentation - modify these for quick testing
|
||||
WRONG_CALL_PENALTY = -0.2
|
||||
MAX_GEN_PER_TURN = 1024
|
||||
MAX_TOOL_CALL_TURNS = 3
|
||||
VALIDATE_THINK_BLOCKS = True
|
||||
TURN_LEVEL_ADVANTAGE_LAMBDA = 0.5 # Paper uses 1.0, experiment with 0.1, 0.5, 1.0
|
||||
|
||||
|
||||
class MTGRPOEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for Multi-Turn Tool Calling with Turn-Level Advantages Environment."""
|
||||
max_tool_call_turns: int = Field(
|
||||
default=2,
|
||||
description="Hard cap on how many tool-call turns we will actually roll out"
|
||||
)
|
||||
validate_think_blocks: bool = Field(
|
||||
default=True,
|
||||
description="Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]"
|
||||
)
|
||||
max_gen_per_turn: int = Field(
|
||||
default=1024,
|
||||
description="Hard cap on how many new tokens the model may generate in a single turn"
|
||||
)
|
||||
wrong_call_penalty: float = Field(
|
||||
default=-0.2,
|
||||
description="Negative reward applied when the first mismatched tool-call causes early termination"
|
||||
)
|
||||
turn_level_advantage_lambda: float = Field(
|
||||
default=0.5,
|
||||
description="Turn-level advantage coefficient (λ in MT-GRPO paper). Paper implementation uses 1.0, but we can experiment with different values like 0.1, 0.5, 1.0"
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
||||
|
|
@ -118,7 +138,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
config: MTGRPOEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = True,
|
||||
testing: bool = False,
|
||||
|
|
@ -142,8 +162,8 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
self.iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_cfg = BaseEnvConfig(
|
||||
def config_init(cls) -> Tuple[MTGRPOEnvConfig, List[APIServerConfig]]:
|
||||
env_cfg = MTGRPOEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
|
|
@ -156,6 +176,12 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
wandb_name="multiturn_tool_use_turnlevel_advantage",
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
# Override config defaults with experimental constants
|
||||
wrong_call_penalty=WRONG_CALL_PENALTY,
|
||||
max_gen_per_turn=MAX_GEN_PER_TURN,
|
||||
max_tool_call_turns=MAX_TOOL_CALL_TURNS,
|
||||
validate_think_blocks=VALIDATE_THINK_BLOCKS,
|
||||
turn_level_advantage_lambda=TURN_LEVEL_ADVANTAGE_LAMBDA,
|
||||
)
|
||||
server_cfgs = [
|
||||
APIServerConfig(
|
||||
|
|
@ -204,7 +230,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
the answer is the list of function_call JSONs (canonical string).
|
||||
Each turn can have multiple tool calls.
|
||||
|
||||
We only keep those samples that contain = MAX_TOOL_CALL_TURNS separate messages with <tool_call>.
|
||||
We only keep those samples that contain = config.max_tool_call_turns separate messages with <tool_call>.
|
||||
"""
|
||||
target = self.train_items if is_train else self.test_items
|
||||
before_len = len(target)
|
||||
|
|
@ -227,7 +253,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
continue
|
||||
|
||||
# Optional: Validate <think> blocks in gpt messages if enabled
|
||||
if VALIDATE_THINK_BLOCKS:
|
||||
if self.config.validate_think_blocks:
|
||||
gpt_messages = [msg for msg in conv if msg["from"] in ("gpt", "assistant")]
|
||||
if not all("<think>" in msg["value"].lower() for msg in gpt_messages):
|
||||
continue
|
||||
|
|
@ -303,7 +329,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
while len(inter_turns) < max(0, len(expected_calls_by_turn) - 1):
|
||||
inter_turns.append([])
|
||||
|
||||
if tool_call_turns == MAX_TOOL_CALL_TURNS:
|
||||
if tool_call_turns == self.config.max_tool_call_turns:
|
||||
target.append((tuple(running_msgs), expected_calls_by_turn, inter_turns))
|
||||
|
||||
print(f"[prep_items] {'train' if is_train else 'test'}: added {len(target)-before_len} items.")
|
||||
|
|
@ -325,7 +351,13 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
"""
|
||||
turn_rewards = []
|
||||
|
||||
for turn_idx, (response, pred_turn, expected_turn) in enumerate(zip(responses_by_turn, pred_calls_by_turn, expected_calls_by_turn)):
|
||||
# Only iterate over the turns that this rollout actually completed
|
||||
num_actual_turns = min(len(responses_by_turn), len(pred_calls_by_turn), len(expected_calls_by_turn))
|
||||
|
||||
for turn_idx in range(num_actual_turns):
|
||||
response = responses_by_turn[turn_idx] if turn_idx < len(responses_by_turn) else ""
|
||||
pred_turn = pred_calls_by_turn[turn_idx] if turn_idx < len(pred_calls_by_turn) else []
|
||||
expected_turn = expected_calls_by_turn[turn_idx]
|
||||
# Turn-level reward components
|
||||
turn_reward = 0.0
|
||||
|
||||
|
|
@ -364,7 +396,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
|
||||
# Apply mismatch penalty if needed
|
||||
if pred_turn and pred_turn[-1] == "__MISMATCH__":
|
||||
turn_reward += WRONG_CALL_PENALTY # This is negative
|
||||
turn_reward += self.config.wrong_call_penalty # This is negative
|
||||
|
||||
turn_rewards.append(turn_reward)
|
||||
|
||||
|
|
@ -392,12 +424,22 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
Returns:
|
||||
List of advantages for each rollout [num_rollouts x num_turns]
|
||||
"""
|
||||
# Compute standardized turn advantages (A_T)
|
||||
turn_advantages_batch = []
|
||||
num_turns = len(turn_rewards_batch[0]) if turn_rewards_batch else 0
|
||||
if not turn_rewards_batch:
|
||||
return []
|
||||
|
||||
for turn_idx in range(num_turns):
|
||||
turn_rewards_for_this_turn = [rewards[turn_idx] for rewards in turn_rewards_batch]
|
||||
# Find the maximum number of turns across all rollouts
|
||||
max_turns = max(len(rewards) for rewards in turn_rewards_batch)
|
||||
|
||||
# Pad shorter reward lists with 0.0 for terminated rollouts
|
||||
padded_turn_rewards_batch = []
|
||||
for rewards in turn_rewards_batch:
|
||||
padded_rewards = rewards + [0.0] * (max_turns - len(rewards))
|
||||
padded_turn_rewards_batch.append(padded_rewards)
|
||||
|
||||
# Compute standardized turn advantages (A_T) for each turn
|
||||
turn_advantages_batch = []
|
||||
for turn_idx in range(max_turns):
|
||||
turn_rewards_for_this_turn = [rewards[turn_idx] for rewards in padded_turn_rewards_batch]
|
||||
mean_turn_reward = np.mean(turn_rewards_for_this_turn)
|
||||
std_turn_reward = np.std(turn_rewards_for_this_turn)
|
||||
if std_turn_reward == 0:
|
||||
|
|
@ -414,14 +456,16 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
|
||||
outcome_advantages = [(r - mean_outcome_reward) / std_outcome_reward for r in outcome_rewards_batch]
|
||||
|
||||
# Combine according to MT-GRPO formula
|
||||
# Combine according to MT-GRPO formula, but only for actual turns (not padded ones)
|
||||
mt_grpo_advantages = []
|
||||
for rollout_idx in range(len(turn_rewards_batch)):
|
||||
rollout_advantages = []
|
||||
for turn_idx in range(num_turns):
|
||||
if turn_idx < num_turns - 1: # Not the last turn
|
||||
actual_num_turns = len(turn_rewards_batch[rollout_idx]) # Original length before padding
|
||||
|
||||
for turn_idx in range(actual_num_turns):
|
||||
if turn_idx < actual_num_turns - 1: # Not the last turn
|
||||
# A_T_i + λ * A_O_i
|
||||
advantage = turn_advantages_batch[turn_idx][rollout_idx] + TURN_LEVEL_ADVANTAGE_LAMBDA * outcome_advantages[rollout_idx]
|
||||
advantage = turn_advantages_batch[turn_idx][rollout_idx] + self.config.turn_level_advantage_lambda * outcome_advantages[rollout_idx]
|
||||
else: # Last turn
|
||||
# A_O_i only
|
||||
advantage = outcome_advantages[rollout_idx]
|
||||
|
|
@ -524,7 +568,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
mismatch_penalty = 0.0
|
||||
if pred_calls and pred_calls[-1] == "__MISMATCH__":
|
||||
pred_calls = pred_calls[:-1]
|
||||
mismatch_penalty = WRONG_CALL_PENALTY
|
||||
mismatch_penalty = self.config.wrong_call_penalty
|
||||
correct = sum(
|
||||
1 for p, e in zip(pred_calls, exp_jsons) if _json_objects_match(p, e)
|
||||
)
|
||||
|
|
@ -744,12 +788,12 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
num_rollouts = self.config.group_size
|
||||
contexts: List[List[Dict[str, str]]] = [list(base_ctx) for _ in range(num_rollouts)]
|
||||
# Track predictions by turn
|
||||
preds_by_turn: List[List[List]] = [[[] for _ in range(MAX_TOOL_CALL_TURNS)] for _ in range(num_rollouts)]
|
||||
preds_by_turn: List[List[List]] = [[[] for _ in range(self.config.max_tool_call_turns)] for _ in range(num_rollouts)]
|
||||
# Track responses by turn for reward computation
|
||||
responses_by_turn: List[List[str]] = [[] for _ in range(num_rollouts)]
|
||||
active = [True] * num_rollouts
|
||||
|
||||
max_turns = min(len(expected_calls_by_turn), MAX_TOOL_CALL_TURNS)
|
||||
max_turns = min(len(expected_calls_by_turn), self.config.max_tool_call_turns)
|
||||
|
||||
for turn_idx in range(max_turns):
|
||||
print(f"[collect_trajectories] Beginning turn {turn_idx+1}/{max_turns} for this group")
|
||||
|
|
@ -762,7 +806,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
|
|||
|
||||
max_prompt_len = max(len(p) for p in prompts)
|
||||
max_gen = min(
|
||||
MAX_GEN_PER_TURN,
|
||||
self.config.max_gen_per_turn,
|
||||
max(1, self.config.max_token_length - max_prompt_len),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue