diff --git a/environments/tool_use_multiturn_server.py b/environments/tool_use_multiturn_server.py index edd41882..d3614514 100644 --- a/environments/tool_use_multiturn_server.py +++ b/environments/tool_use_multiturn_server.py @@ -37,13 +37,27 @@ from atroposlib.envs.base import ( Item, ScoredDataGroup, ) + from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer +# ------------------------------------------------------------------ +# Filter: skip tasks that were already processed in a finished SFT dataset +# ------------------------------------------------------------------ +COMPLETED_DATASET_ID = "interstellarninja/hermes_reasoning_tool_use" +try: + _done_ds = load_dataset(COMPLETED_DATASET_ID, split="train") + COMPLETED_TASKS: set[str] = set(_done_ds["task"]) + print(f"[filter] Loaded {len(COMPLETED_TASKS):,} completed tasks from {COMPLETED_DATASET_ID}") +except Exception as _exc: + COMPLETED_TASKS = set() + print(f"[filter] Could not load completed-task dataset: {_exc}. No skipping will occur.") + # Easy-to-change constants for experimentation - modify these for quick testing WRONG_CALL_PENALTY = -0.2 MAX_GEN_PER_TURN = 1024 MAX_TOOL_CALL_TURNS = 2 -VALIDATE_THINK_BLOCKS = True +VALIDATE_THINK_BLOCKS = False +GENERATE_ALL_GPT_TURNS = True class MultiTurnEnvConfig(BaseEnvConfig): """Configuration for Multi-Turn Tool Calling Environment.""" @@ -55,6 +69,16 @@ class MultiTurnEnvConfig(BaseEnvConfig): default=True, description="Whether to validate that all GPT messages have blocks [useful when non-tool call gpt messages are inserted]" ) + generate_all_gpt_turns: bool = Field( + default=False, + description=( + "If True, the environment will emit a GPT turn *after each* . " + "That reply **must begin** with one block. " + "If the dataset expects tool‑calls for that turn, the reply must also contain " + "the same number of blocks; otherwise it must contain " + "no blocks." + ), + ) max_gen_per_turn: int = Field( default=1024, description="Hard cap on how many new tokens the model may generate in a single turn" @@ -63,6 +87,10 @@ class MultiTurnEnvConfig(BaseEnvConfig): default=-0.2, description="Negative reward applied when the first mismatched tool-call causes early termination" ) + skip_completed: bool = Field( + default=True, + description="Skip any conversation whose first user prompt appears in COMPLETED_TASKS.", + ) system_prompt = ( @@ -73,34 +101,63 @@ system_prompt = ( ) -def _validate_reply_and_extract(txt: str): +# ------------------------------------------------------------------ +# Helper: validate a GPT reply that *may* include tool calls +# ------------------------------------------------------------------ + +def _validate_think_only(txt: str) -> bool: """ - Validates that the reply matches the allowed structure: - - exactly one mandatory block at the top - - one or more blocks - - nothing else except whitespace/newlines - Returns list of tool-call JSONs if valid, else None. + A narration / summary turn must: + • start with exactly one block + • contain **no** tags anywhere + Anything after the (user‑visible answer) is allowed. """ - _allowed_re = re.compile( - r"""^\s* - [\s\S]*?\s* - (?: - [\s\S]*?\s* - )+ - \s*$""", - re.IGNORECASE | re.VERBOSE, + if not isinstance(txt, str): + return False + + # Must begin with one think block + begins_with_think = re.match( + r"^\s*[\s\S]*?", txt, flags=re.IGNORECASE ) - if not isinstance(txt, str) or not _allowed_re.match(txt): + if not begins_with_think: + return False + + # Must not contain any + if re.search(r"", txt, flags=re.IGNORECASE): + return False + + return True + + +def _validate_think_plus_calls(txt: str): + """ + Validate a GPT reply that should contain followed by + one or more blocks. + + Returns: + list[dict] – Parsed tool‑call JSONs (≥1) → valid + None – Any structural / JSON error → invalid + """ + pat = re.compile( + r"^\s*[\s\S]*?\s*((?:[\s\S]*?\s*)+)\s*$", + flags=re.IGNORECASE, + ) + m = pat.match(txt) + if not m: return None - # Extract tool_call JSONs - matches = re.findall(r"\s*(.*?)\s*", txt, re.DOTALL | re.IGNORECASE) - jsons = [] - for m in matches: + + calls_section = m.group(1) + tool_jsons = [] + for raw in re.findall( + r"\s*(.*?)\s*", calls_section, flags=re.DOTALL | re.IGNORECASE + ): try: - jsons.append(json.loads(m)) + tool_jsons.append(json.loads(raw)) except Exception: - pass - return jsons + return None + if not tool_jsons: + return None + return tool_jsons def _json_objects_match(model_json, expected_json): @@ -135,7 +192,8 @@ class MultiTurnToolCallingEnv(BaseEnv): ): super().__init__(config, server_configs, slurm, testing) # Load dataset once and cache on this instance - self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train") + #self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train") + self.ds = load_dataset("interstellarninja/toolace_hermes_sequential_tool_use", split="train") self.percent_correct_buffer: List[float] = [] self.raw_score_buffer: List[float] = [] @@ -155,7 +213,7 @@ class MultiTurnToolCallingEnv(BaseEnv): def config_init(cls) -> Tuple[MultiTurnEnvConfig, List[APIServerConfig]]: env_cfg = MultiTurnEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - group_size=16, + group_size=8, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=2000, @@ -171,6 +229,8 @@ class MultiTurnToolCallingEnv(BaseEnv): max_gen_per_turn=MAX_GEN_PER_TURN, max_tool_call_turns=MAX_TOOL_CALL_TURNS, validate_think_blocks=VALIDATE_THINK_BLOCKS, + generate_all_gpt_turns=GENERATE_ALL_GPT_TURNS, + skip_completed=True, ) server_cfgs = [ APIServerConfig( @@ -228,6 +288,14 @@ class MultiTurnToolCallingEnv(BaseEnv): running_msgs: List[frozenset] = [] conv = row["conversations"] + # --- Skip conversation if its task already completed --- + if self.config.skip_completed and COMPLETED_TASKS: + first_user_msg = next( + (m["value"].strip() for m in conv if m["from"] == "human"), + None, + ) + if first_user_msg and first_user_msg in COMPLETED_TASKS: + continue if len(conv) < 3: continue if conv[0]["from"] != "system" or conv[1]["from"] != "human": @@ -502,16 +570,23 @@ class MultiTurnToolCallingEnv(BaseEnv): for txt, r in zip(choices, ridx_map): txt = txt or "" contexts[r].append({"role": "assistant", "content": txt}) - calls = _validate_reply_and_extract(txt) - - if calls is None: - preds[r].append("__MISMATCH__") - active[r] = False + expected_turn_calls = expected_calls_by_turn[turn_idx] + + # ------------------------------------------------------------ + # Decide validation strategy: narration vs. tool‑calling turn + # ------------------------------------------------------------ + if expected_turn_calls: # Turn SHOULD have tool calls + calls = _validate_think_plus_calls(txt) + if calls is None: + preds[r].append("__MISMATCH__") + active[r] = False + continue + else: # Narration / summary turn + if not _validate_think_only(txt): + active[r] = False + # Narration turns produce no predictions to score continue - # Get expected calls for this specific turn - expected_turn_calls = expected_calls_by_turn[turn_idx] - # Check if number of calls matches if len(calls) != len(expected_turn_calls): preds[r].append("__MISMATCH__") @@ -584,6 +659,45 @@ class MultiTurnToolCallingEnv(BaseEnv): # Process and validate responses await self._process_turn_responses(turn_idx, choices, ridx_map, contexts, preds, active, expected_calls_by_turn) + + # ─────────────────────────────────────────────────────────────── + # Optionally emit a GPT narration/summary turn after tool_response + # ─────────────────────────────────────────────────────────────── + if self.config.generate_all_gpt_turns and any(active): + extra_prompts, extra_ridx = [], [] + for r in range(len(contexts)): + if not active[r]: + continue + ptxt = self.tokenizer.apply_chat_template( + contexts[r], + add_generation_prompt=True, + tokenize=False, + ) + extra_prompts.append(ptxt) + extra_ridx.append(r) + + async def _infer_one(prompt_str: str) -> str: + try: + comp = await self.server.completion( + prompt=prompt_str, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.7, + ) + return comp.choices[0].text + except Exception as exc: + print(f" → extra GPT turn inference error: {exc}") + return "" + + extra_replies = await asyncio.gather(*[_infer_one(p) for p in extra_prompts]) + + for txt, r in zip(extra_replies, extra_ridx): + txt = txt or "" + contexts[r].append({"role": "assistant", "content": txt}) + + # Narration turn MUST be think‑only. If not, terminate rollout r. + if not _validate_think_only(txt): + active[r] = False if not any(active): print(" → All roll-outs terminated; stopping further turns.")