diff --git a/environments/tool_use_multiturn_server.py b/environments/tool_use_multiturn_server.py
index edd41882..d3614514 100644
--- a/environments/tool_use_multiturn_server.py
+++ b/environments/tool_use_multiturn_server.py
@@ -37,13 +37,27 @@ from atroposlib.envs.base import (
Item,
ScoredDataGroup,
)
+
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
+# ------------------------------------------------------------------
+# Filter: skip tasks that were already processed in a finished SFT dataset
+# ------------------------------------------------------------------
+COMPLETED_DATASET_ID = "interstellarninja/hermes_reasoning_tool_use"
+try:
+ _done_ds = load_dataset(COMPLETED_DATASET_ID, split="train")
+ COMPLETED_TASKS: set[str] = set(_done_ds["task"])
+ print(f"[filter] Loaded {len(COMPLETED_TASKS):,} completed tasks from {COMPLETED_DATASET_ID}")
+except Exception as _exc:
+ COMPLETED_TASKS = set()
+ print(f"[filter] Could not load completed-task dataset: {_exc}. No skipping will occur.")
+
# Easy-to-change constants for experimentation - modify these for quick testing
WRONG_CALL_PENALTY = -0.2
MAX_GEN_PER_TURN = 1024
MAX_TOOL_CALL_TURNS = 2
-VALIDATE_THINK_BLOCKS = True
+VALIDATE_THINK_BLOCKS = False
+GENERATE_ALL_GPT_TURNS = True
class MultiTurnEnvConfig(BaseEnvConfig):
"""Configuration for Multi-Turn Tool Calling Environment."""
@@ -55,6 +69,16 @@ class MultiTurnEnvConfig(BaseEnvConfig):
default=True,
description="Whether to validate that all GPT messages have blocks [useful when non-tool call gpt messages are inserted]"
)
+ generate_all_gpt_turns: bool = Field(
+ default=False,
+ description=(
+ "If True, the environment will emit a GPT turn *after each* . "
+ "That reply **must begin** with one … block. "
+ "If the dataset expects tool‑calls for that turn, the reply must also contain "
+ "the same number of … blocks; otherwise it must contain "
+ "no blocks."
+ ),
+ )
max_gen_per_turn: int = Field(
default=1024,
description="Hard cap on how many new tokens the model may generate in a single turn"
@@ -63,6 +87,10 @@ class MultiTurnEnvConfig(BaseEnvConfig):
default=-0.2,
description="Negative reward applied when the first mismatched tool-call causes early termination"
)
+ skip_completed: bool = Field(
+ default=True,
+ description="Skip any conversation whose first user prompt appears in COMPLETED_TASKS.",
+ )
system_prompt = (
@@ -73,34 +101,63 @@ system_prompt = (
)
-def _validate_reply_and_extract(txt: str):
+# ------------------------------------------------------------------
+# Helper: validate a GPT reply that *may* include tool calls
+# ------------------------------------------------------------------
+
+def _validate_think_only(txt: str) -> bool:
"""
- Validates that the reply matches the allowed structure:
- - exactly one mandatory … block at the top
- - one or more … blocks
- - nothing else except whitespace/newlines
- Returns list of tool-call JSONs if valid, else None.
+ A narration / summary turn must:
+ • start with exactly one … block
+ • contain **no** tags anywhere
+ Anything after the (user‑visible answer) is allowed.
"""
- _allowed_re = re.compile(
- r"""^\s*
- [\s\S]*?\s*
- (?:
- [\s\S]*?\s*
- )+
- \s*$""",
- re.IGNORECASE | re.VERBOSE,
+ if not isinstance(txt, str):
+ return False
+
+ # Must begin with one think block
+ begins_with_think = re.match(
+ r"^\s*[\s\S]*?", txt, flags=re.IGNORECASE
)
- if not isinstance(txt, str) or not _allowed_re.match(txt):
+ if not begins_with_think:
+ return False
+
+ # Must not contain any
+ if re.search(r"", txt, flags=re.IGNORECASE):
+ return False
+
+ return True
+
+
+def _validate_think_plus_calls(txt: str):
+ """
+ Validate a GPT reply that should contain … followed by
+ one or more … blocks.
+
+ Returns:
+ list[dict] – Parsed tool‑call JSONs (≥1) → valid
+ None – Any structural / JSON error → invalid
+ """
+ pat = re.compile(
+ r"^\s*[\s\S]*?\s*((?:[\s\S]*?\s*)+)\s*$",
+ flags=re.IGNORECASE,
+ )
+ m = pat.match(txt)
+ if not m:
return None
- # Extract tool_call JSONs
- matches = re.findall(r"\s*(.*?)\s*", txt, re.DOTALL | re.IGNORECASE)
- jsons = []
- for m in matches:
+
+ calls_section = m.group(1)
+ tool_jsons = []
+ for raw in re.findall(
+ r"\s*(.*?)\s*", calls_section, flags=re.DOTALL | re.IGNORECASE
+ ):
try:
- jsons.append(json.loads(m))
+ tool_jsons.append(json.loads(raw))
except Exception:
- pass
- return jsons
+ return None
+ if not tool_jsons:
+ return None
+ return tool_jsons
def _json_objects_match(model_json, expected_json):
@@ -135,7 +192,8 @@ class MultiTurnToolCallingEnv(BaseEnv):
):
super().__init__(config, server_configs, slurm, testing)
# Load dataset once and cache on this instance
- self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
+ #self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
+ self.ds = load_dataset("interstellarninja/toolace_hermes_sequential_tool_use", split="train")
self.percent_correct_buffer: List[float] = []
self.raw_score_buffer: List[float] = []
@@ -155,7 +213,7 @@ class MultiTurnToolCallingEnv(BaseEnv):
def config_init(cls) -> Tuple[MultiTurnEnvConfig, List[APIServerConfig]]:
env_cfg = MultiTurnEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
- group_size=16,
+ group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=2000,
@@ -171,6 +229,8 @@ class MultiTurnToolCallingEnv(BaseEnv):
max_gen_per_turn=MAX_GEN_PER_TURN,
max_tool_call_turns=MAX_TOOL_CALL_TURNS,
validate_think_blocks=VALIDATE_THINK_BLOCKS,
+ generate_all_gpt_turns=GENERATE_ALL_GPT_TURNS,
+ skip_completed=True,
)
server_cfgs = [
APIServerConfig(
@@ -228,6 +288,14 @@ class MultiTurnToolCallingEnv(BaseEnv):
running_msgs: List[frozenset] = []
conv = row["conversations"]
+ # --- Skip conversation if its task already completed ---
+ if self.config.skip_completed and COMPLETED_TASKS:
+ first_user_msg = next(
+ (m["value"].strip() for m in conv if m["from"] == "human"),
+ None,
+ )
+ if first_user_msg and first_user_msg in COMPLETED_TASKS:
+ continue
if len(conv) < 3:
continue
if conv[0]["from"] != "system" or conv[1]["from"] != "human":
@@ -502,16 +570,23 @@ class MultiTurnToolCallingEnv(BaseEnv):
for txt, r in zip(choices, ridx_map):
txt = txt or ""
contexts[r].append({"role": "assistant", "content": txt})
- calls = _validate_reply_and_extract(txt)
-
- if calls is None:
- preds[r].append("__MISMATCH__")
- active[r] = False
+ expected_turn_calls = expected_calls_by_turn[turn_idx]
+
+ # ------------------------------------------------------------
+ # Decide validation strategy: narration vs. tool‑calling turn
+ # ------------------------------------------------------------
+ if expected_turn_calls: # Turn SHOULD have tool calls
+ calls = _validate_think_plus_calls(txt)
+ if calls is None:
+ preds[r].append("__MISMATCH__")
+ active[r] = False
+ continue
+ else: # Narration / summary turn
+ if not _validate_think_only(txt):
+ active[r] = False
+ # Narration turns produce no predictions to score
continue
- # Get expected calls for this specific turn
- expected_turn_calls = expected_calls_by_turn[turn_idx]
-
# Check if number of calls matches
if len(calls) != len(expected_turn_calls):
preds[r].append("__MISMATCH__")
@@ -584,6 +659,45 @@ class MultiTurnToolCallingEnv(BaseEnv):
# Process and validate responses
await self._process_turn_responses(turn_idx, choices, ridx_map, contexts, preds, active, expected_calls_by_turn)
+
+ # ───────────────────────────────────────────────────────────────
+ # Optionally emit a GPT narration/summary turn after tool_response
+ # ───────────────────────────────────────────────────────────────
+ if self.config.generate_all_gpt_turns and any(active):
+ extra_prompts, extra_ridx = [], []
+ for r in range(len(contexts)):
+ if not active[r]:
+ continue
+ ptxt = self.tokenizer.apply_chat_template(
+ contexts[r],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+ extra_prompts.append(ptxt)
+ extra_ridx.append(r)
+
+ async def _infer_one(prompt_str: str) -> str:
+ try:
+ comp = await self.server.completion(
+ prompt=prompt_str,
+ n=1,
+ max_tokens=self.config.max_token_length,
+ temperature=0.7,
+ )
+ return comp.choices[0].text
+ except Exception as exc:
+ print(f" → extra GPT turn inference error: {exc}")
+ return ""
+
+ extra_replies = await asyncio.gather(*[_infer_one(p) for p in extra_prompts])
+
+ for txt, r in zip(extra_replies, extra_ridx):
+ txt = txt or ""
+ contexts[r].append({"role": "assistant", "content": txt})
+
+ # Narration turn MUST be think‑only. If not, terminate rollout r.
+ if not _validate_think_only(txt):
+ active[r] = False
if not any(active):
print(" → All roll-outs terminated; stopping further turns.")