option to generate all gpt turns

This commit is contained in:
interstellarninja 2025-06-24 08:14:14 -04:00
parent 60be1bbbe8
commit 45bc484931

View file

@ -37,13 +37,27 @@ from atroposlib.envs.base import (
Item,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# ------------------------------------------------------------------
# Filter: skip tasks that were already processed in a finished SFT dataset
# ------------------------------------------------------------------
COMPLETED_DATASET_ID = "interstellarninja/hermes_reasoning_tool_use"
try:
_done_ds = load_dataset(COMPLETED_DATASET_ID, split="train")
COMPLETED_TASKS: set[str] = set(_done_ds["task"])
print(f"[filter] Loaded {len(COMPLETED_TASKS):,} completed tasks from {COMPLETED_DATASET_ID}")
except Exception as _exc:
COMPLETED_TASKS = set()
print(f"[filter] Could not load completed-task dataset: {_exc}. No skipping will occur.")
# Easy-to-change constants for experimentation - modify these for quick testing
WRONG_CALL_PENALTY = -0.2
MAX_GEN_PER_TURN = 1024
MAX_TOOL_CALL_TURNS = 2
VALIDATE_THINK_BLOCKS = True
VALIDATE_THINK_BLOCKS = False
GENERATE_ALL_GPT_TURNS = True
class MultiTurnEnvConfig(BaseEnvConfig):
"""Configuration for Multi-Turn Tool Calling Environment."""
@ -55,6 +69,16 @@ class MultiTurnEnvConfig(BaseEnvConfig):
default=True,
description="Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]"
)
generate_all_gpt_turns: bool = Field(
default=False,
description=(
"If True, the environment will emit a GPT turn *after each* <tool_response>. "
"That reply **must begin** with one <think> … </think> block. "
"If the dataset expects toolcalls for that turn, the reply must also contain "
"the same number of <tool_call> … </tool_call> blocks; otherwise it must contain "
"no <tool_call> blocks."
),
)
max_gen_per_turn: int = Field(
default=1024,
description="Hard cap on how many new tokens the model may generate in a single turn"
@ -63,6 +87,10 @@ class MultiTurnEnvConfig(BaseEnvConfig):
default=-0.2,
description="Negative reward applied when the first mismatched tool-call causes early termination"
)
skip_completed: bool = Field(
default=True,
description="Skip any conversation whose first user prompt appears in COMPLETED_TASKS.",
)
system_prompt = (
@ -73,34 +101,63 @@ system_prompt = (
)
def _validate_reply_and_extract(txt: str):
# ------------------------------------------------------------------
# Helper: validate a GPT reply that *may* include tool calls
# ------------------------------------------------------------------
def _validate_think_only(txt: str) -> bool:
"""
Validates that the reply matches the allowed structure:
- exactly one mandatory <think></think> block at the top
- one or more <tool_call></tool_call> blocks
- nothing else except whitespace/newlines
Returns list of tool-call JSONs if valid, else None.
A narration / summary turn must:
start with exactly one <think> </think> block
contain **no** <tool_call> tags anywhere
Anything after the </think> (uservisible answer) is allowed.
"""
_allowed_re = re.compile(
r"""^\s*
<think>[\s\S]*?</think>\s*
(?:
<tool_call>[\s\S]*?</tool_call>\s*
)+
\s*$""",
re.IGNORECASE | re.VERBOSE,
if not isinstance(txt, str):
return False
# Must begin with one think block
begins_with_think = re.match(
r"^\s*<think>[\s\S]*?</think>", txt, flags=re.IGNORECASE
)
if not isinstance(txt, str) or not _allowed_re.match(txt):
if not begins_with_think:
return False
# Must not contain any <tool_call>
if re.search(r"<tool_call\s*>", txt, flags=re.IGNORECASE):
return False
return True
def _validate_think_plus_calls(txt: str):
"""
Validate a GPT reply that should contain <think> </think> followed by
one or more <tool_call> </tool_call> blocks.
Returns:
list[dict] Parsed toolcall JSONs (1) valid
None Any structural / JSON error invalid
"""
pat = re.compile(
r"^\s*<think>[\s\S]*?</think>\s*((?:<tool_call>[\s\S]*?</tool_call>\s*)+)\s*$",
flags=re.IGNORECASE,
)
m = pat.match(txt)
if not m:
return None
# Extract tool_call JSONs
matches = re.findall(r"<tool_call>\s*(.*?)\s*</tool_call>", txt, re.DOTALL | re.IGNORECASE)
jsons = []
for m in matches:
calls_section = m.group(1)
tool_jsons = []
for raw in re.findall(
r"<tool_call>\s*(.*?)\s*</tool_call>", calls_section, flags=re.DOTALL | re.IGNORECASE
):
try:
jsons.append(json.loads(m))
tool_jsons.append(json.loads(raw))
except Exception:
pass
return jsons
return None
if not tool_jsons:
return None
return tool_jsons
def _json_objects_match(model_json, expected_json):
@ -135,7 +192,8 @@ class MultiTurnToolCallingEnv(BaseEnv):
):
super().__init__(config, server_configs, slurm, testing)
# Load dataset once and cache on this instance
self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
#self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
self.ds = load_dataset("interstellarninja/toolace_hermes_sequential_tool_use", split="train")
self.percent_correct_buffer: List[float] = []
self.raw_score_buffer: List[float] = []
@ -155,7 +213,7 @@ class MultiTurnToolCallingEnv(BaseEnv):
def config_init(cls) -> Tuple[MultiTurnEnvConfig, List[APIServerConfig]]:
env_cfg = MultiTurnEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=2000,
@ -171,6 +229,8 @@ class MultiTurnToolCallingEnv(BaseEnv):
max_gen_per_turn=MAX_GEN_PER_TURN,
max_tool_call_turns=MAX_TOOL_CALL_TURNS,
validate_think_blocks=VALIDATE_THINK_BLOCKS,
generate_all_gpt_turns=GENERATE_ALL_GPT_TURNS,
skip_completed=True,
)
server_cfgs = [
APIServerConfig(
@ -228,6 +288,14 @@ class MultiTurnToolCallingEnv(BaseEnv):
running_msgs: List[frozenset] = []
conv = row["conversations"]
# --- Skip conversation if its task already completed ---
if self.config.skip_completed and COMPLETED_TASKS:
first_user_msg = next(
(m["value"].strip() for m in conv if m["from"] == "human"),
None,
)
if first_user_msg and first_user_msg in COMPLETED_TASKS:
continue
if len(conv) < 3:
continue
if conv[0]["from"] != "system" or conv[1]["from"] != "human":
@ -502,16 +570,23 @@ class MultiTurnToolCallingEnv(BaseEnv):
for txt, r in zip(choices, ridx_map):
txt = txt or ""
contexts[r].append({"role": "assistant", "content": txt})
calls = _validate_reply_and_extract(txt)
if calls is None:
preds[r].append("__MISMATCH__")
active[r] = False
expected_turn_calls = expected_calls_by_turn[turn_idx]
# ------------------------------------------------------------
# Decide validation strategy: narration vs. toolcalling turn
# ------------------------------------------------------------
if expected_turn_calls: # Turn SHOULD have tool calls
calls = _validate_think_plus_calls(txt)
if calls is None:
preds[r].append("__MISMATCH__")
active[r] = False
continue
else: # Narration / summary turn
if not _validate_think_only(txt):
active[r] = False
# Narration turns produce no predictions to score
continue
# Get expected calls for this specific turn
expected_turn_calls = expected_calls_by_turn[turn_idx]
# Check if number of calls matches
if len(calls) != len(expected_turn_calls):
preds[r].append("__MISMATCH__")
@ -584,6 +659,45 @@ class MultiTurnToolCallingEnv(BaseEnv):
# Process and validate responses
await self._process_turn_responses(turn_idx, choices, ridx_map, contexts, preds, active, expected_calls_by_turn)
# ───────────────────────────────────────────────────────────────
# Optionally emit a GPT narration/summary turn after tool_response
# ───────────────────────────────────────────────────────────────
if self.config.generate_all_gpt_turns and any(active):
extra_prompts, extra_ridx = [], []
for r in range(len(contexts)):
if not active[r]:
continue
ptxt = self.tokenizer.apply_chat_template(
contexts[r],
add_generation_prompt=True,
tokenize=False,
)
extra_prompts.append(ptxt)
extra_ridx.append(r)
async def _infer_one(prompt_str: str) -> str:
try:
comp = await self.server.completion(
prompt=prompt_str,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.7,
)
return comp.choices[0].text
except Exception as exc:
print(f" → extra GPT turn inference error: {exc}")
return ""
extra_replies = await asyncio.gather(*[_infer_one(p) for p in extra_prompts])
for txt, r in zip(extra_replies, extra_ridx):
txt = txt or ""
contexts[r].append({"role": "assistant", "content": txt})
# Narration turn MUST be thinkonly. If not, terminate rollout r.
if not _validate_think_only(txt):
active[r] = False
if not any(active):
print(" → All roll-outs terminated; stopping further turns.")