BaseEnvConfig subclass for experimental vars

This commit is contained in:
interstellarninja 2025-06-06 04:22:07 -04:00
parent 2aa950a5a8
commit a4cdf80e4a

View file

@ -1,14 +1,3 @@
# Negative reward applied when the first mismatched tool-call causes early termination.
WRONG_CALL_PENALTY = -0.2
# Hard cap on how many new tokens the model may generate in a single turn.
MAX_GEN_PER_TURN = 1024
# Hard cap on how many tool-call turns we will actually roll out
MAX_TOOL_CALL_TURNS = 2
# Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]
VALIDATE_THINK_BLOCKS = True
# Turn-level advantage coefficient (λ in MT-GRPO paper)
# Paper implementation uses 1.0, but we can experiment with different values
TURN_LEVEL_ADVANTAGE_LAMBDA = 0.5 # Configurable: try 0.1, 0.5, 1.0
"""
Multi-Turn Tool-Calling Environment with Turn-Level Advantages
@ -40,10 +29,10 @@ import numpy as np
from typing import Dict, List, Optional, Tuple, Union
from collections import Counter
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
from pydantic import Field
from atroposlib.envs.base import (
APIServerConfig,
@ -55,6 +44,37 @@ from atroposlib.envs.base import (
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Easy-to-change constants for experimentation - modify these for quick testing
WRONG_CALL_PENALTY = -0.2
MAX_GEN_PER_TURN = 1024
MAX_TOOL_CALL_TURNS = 3
VALIDATE_THINK_BLOCKS = True
TURN_LEVEL_ADVANTAGE_LAMBDA = 0.5 # Paper uses 1.0, experiment with 0.1, 0.5, 1.0
class MTGRPOEnvConfig(BaseEnvConfig):
"""Configuration for Multi-Turn Tool Calling with Turn-Level Advantages Environment."""
max_tool_call_turns: int = Field(
default=2,
description="Hard cap on how many tool-call turns we will actually roll out"
)
validate_think_blocks: bool = Field(
default=True,
description="Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]"
)
max_gen_per_turn: int = Field(
default=1024,
description="Hard cap on how many new tokens the model may generate in a single turn"
)
wrong_call_penalty: float = Field(
default=-0.2,
description="Negative reward applied when the first mismatched tool-call causes early termination"
)
turn_level_advantage_lambda: float = Field(
default=0.5,
description="Turn-level advantage coefficient (λ in MT-GRPO paper). Paper implementation uses 1.0, but we can experiment with different values like 0.1, 0.5, 1.0"
)
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
@ -118,7 +138,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
def __init__(
self,
config: BaseEnvConfig,
config: MTGRPOEnvConfig,
server_configs: List[APIServerConfig],
slurm: bool = True,
testing: bool = False,
@ -142,8 +162,8 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
self.iter = 0
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_cfg = BaseEnvConfig(
def config_init(cls) -> Tuple[MTGRPOEnvConfig, List[APIServerConfig]]:
env_cfg = MTGRPOEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
@ -156,6 +176,12 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
wandb_name="multiturn_tool_use_turnlevel_advantage",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
# Override config defaults with experimental constants
wrong_call_penalty=WRONG_CALL_PENALTY,
max_gen_per_turn=MAX_GEN_PER_TURN,
max_tool_call_turns=MAX_TOOL_CALL_TURNS,
validate_think_blocks=VALIDATE_THINK_BLOCKS,
turn_level_advantage_lambda=TURN_LEVEL_ADVANTAGE_LAMBDA,
)
server_cfgs = [
APIServerConfig(
@ -204,7 +230,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
the answer is the list of function_call JSONs (canonical string).
Each turn can have multiple tool calls.
We only keep those samples that contain = MAX_TOOL_CALL_TURNS separate messages with <tool_call>.
We only keep those samples that contain = config.max_tool_call_turns separate messages with <tool_call>.
"""
target = self.train_items if is_train else self.test_items
before_len = len(target)
@ -227,7 +253,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
continue
# Optional: Validate <think> blocks in gpt messages if enabled
if VALIDATE_THINK_BLOCKS:
if self.config.validate_think_blocks:
gpt_messages = [msg for msg in conv if msg["from"] in ("gpt", "assistant")]
if not all("<think>" in msg["value"].lower() for msg in gpt_messages):
continue
@ -303,7 +329,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
while len(inter_turns) < max(0, len(expected_calls_by_turn) - 1):
inter_turns.append([])
if tool_call_turns == MAX_TOOL_CALL_TURNS:
if tool_call_turns == self.config.max_tool_call_turns:
target.append((tuple(running_msgs), expected_calls_by_turn, inter_turns))
print(f"[prep_items] {'train' if is_train else 'test'}: added {len(target)-before_len} items.")
@ -325,7 +351,13 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
"""
turn_rewards = []
for turn_idx, (response, pred_turn, expected_turn) in enumerate(zip(responses_by_turn, pred_calls_by_turn, expected_calls_by_turn)):
# Only iterate over the turns that this rollout actually completed
num_actual_turns = min(len(responses_by_turn), len(pred_calls_by_turn), len(expected_calls_by_turn))
for turn_idx in range(num_actual_turns):
response = responses_by_turn[turn_idx] if turn_idx < len(responses_by_turn) else ""
pred_turn = pred_calls_by_turn[turn_idx] if turn_idx < len(pred_calls_by_turn) else []
expected_turn = expected_calls_by_turn[turn_idx]
# Turn-level reward components
turn_reward = 0.0
@ -364,7 +396,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
# Apply mismatch penalty if needed
if pred_turn and pred_turn[-1] == "__MISMATCH__":
turn_reward += WRONG_CALL_PENALTY # This is negative
turn_reward += self.config.wrong_call_penalty # This is negative
turn_rewards.append(turn_reward)
@ -392,12 +424,22 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
Returns:
List of advantages for each rollout [num_rollouts x num_turns]
"""
# Compute standardized turn advantages (A_T)
turn_advantages_batch = []
num_turns = len(turn_rewards_batch[0]) if turn_rewards_batch else 0
if not turn_rewards_batch:
return []
for turn_idx in range(num_turns):
turn_rewards_for_this_turn = [rewards[turn_idx] for rewards in turn_rewards_batch]
# Find the maximum number of turns across all rollouts
max_turns = max(len(rewards) for rewards in turn_rewards_batch)
# Pad shorter reward lists with 0.0 for terminated rollouts
padded_turn_rewards_batch = []
for rewards in turn_rewards_batch:
padded_rewards = rewards + [0.0] * (max_turns - len(rewards))
padded_turn_rewards_batch.append(padded_rewards)
# Compute standardized turn advantages (A_T) for each turn
turn_advantages_batch = []
for turn_idx in range(max_turns):
turn_rewards_for_this_turn = [rewards[turn_idx] for rewards in padded_turn_rewards_batch]
mean_turn_reward = np.mean(turn_rewards_for_this_turn)
std_turn_reward = np.std(turn_rewards_for_this_turn)
if std_turn_reward == 0:
@ -414,14 +456,16 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
outcome_advantages = [(r - mean_outcome_reward) / std_outcome_reward for r in outcome_rewards_batch]
# Combine according to MT-GRPO formula
# Combine according to MT-GRPO formula, but only for actual turns (not padded ones)
mt_grpo_advantages = []
for rollout_idx in range(len(turn_rewards_batch)):
rollout_advantages = []
for turn_idx in range(num_turns):
if turn_idx < num_turns - 1: # Not the last turn
actual_num_turns = len(turn_rewards_batch[rollout_idx]) # Original length before padding
for turn_idx in range(actual_num_turns):
if turn_idx < actual_num_turns - 1: # Not the last turn
# A_T_i + λ * A_O_i
advantage = turn_advantages_batch[turn_idx][rollout_idx] + TURN_LEVEL_ADVANTAGE_LAMBDA * outcome_advantages[rollout_idx]
advantage = turn_advantages_batch[turn_idx][rollout_idx] + self.config.turn_level_advantage_lambda * outcome_advantages[rollout_idx]
else: # Last turn
# A_O_i only
advantage = outcome_advantages[rollout_idx]
@ -524,7 +568,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
mismatch_penalty = 0.0
if pred_calls and pred_calls[-1] == "__MISMATCH__":
pred_calls = pred_calls[:-1]
mismatch_penalty = WRONG_CALL_PENALTY
mismatch_penalty = self.config.wrong_call_penalty
correct = sum(
1 for p, e in zip(pred_calls, exp_jsons) if _json_objects_match(p, e)
)
@ -744,12 +788,12 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
num_rollouts = self.config.group_size
contexts: List[List[Dict[str, str]]] = [list(base_ctx) for _ in range(num_rollouts)]
# Track predictions by turn
preds_by_turn: List[List[List]] = [[[] for _ in range(MAX_TOOL_CALL_TURNS)] for _ in range(num_rollouts)]
preds_by_turn: List[List[List]] = [[[] for _ in range(self.config.max_tool_call_turns)] for _ in range(num_rollouts)]
# Track responses by turn for reward computation
responses_by_turn: List[List[str]] = [[] for _ in range(num_rollouts)]
active = [True] * num_rollouts
max_turns = min(len(expected_calls_by_turn), MAX_TOOL_CALL_TURNS)
max_turns = min(len(expected_calls_by_turn), self.config.max_tool_call_turns)
for turn_idx in range(max_turns):
print(f"[collect_trajectories] Beginning turn {turn_idx+1}/{max_turns} for this group")
@ -762,7 +806,7 @@ class MultiTurnToolCallingTurnLevelAdvantageEnv(BaseEnv):
max_prompt_len = max(len(p) for p in prompts)
max_gen = min(
MAX_GEN_PER_TURN,
self.config.max_gen_per_turn,
max(1, self.config.max_token_length - max_prompt_len),
)