mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
option to generate all gpt turns
This commit is contained in:
parent
60be1bbbe8
commit
45bc484931
1 changed files with 147 additions and 33 deletions
|
|
@ -37,13 +37,27 @@ from atroposlib.envs.base import (
|
|||
Item,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Filter: skip tasks that were already processed in a finished SFT dataset
|
||||
# ------------------------------------------------------------------
|
||||
COMPLETED_DATASET_ID = "interstellarninja/hermes_reasoning_tool_use"
|
||||
try:
|
||||
_done_ds = load_dataset(COMPLETED_DATASET_ID, split="train")
|
||||
COMPLETED_TASKS: set[str] = set(_done_ds["task"])
|
||||
print(f"[filter] Loaded {len(COMPLETED_TASKS):,} completed tasks from {COMPLETED_DATASET_ID}")
|
||||
except Exception as _exc:
|
||||
COMPLETED_TASKS = set()
|
||||
print(f"[filter] Could not load completed-task dataset: {_exc}. No skipping will occur.")
|
||||
|
||||
# Easy-to-change constants for experimentation - modify these for quick testing
|
||||
WRONG_CALL_PENALTY = -0.2
|
||||
MAX_GEN_PER_TURN = 1024
|
||||
MAX_TOOL_CALL_TURNS = 2
|
||||
VALIDATE_THINK_BLOCKS = True
|
||||
VALIDATE_THINK_BLOCKS = False
|
||||
GENERATE_ALL_GPT_TURNS = True
|
||||
|
||||
class MultiTurnEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for Multi-Turn Tool Calling Environment."""
|
||||
|
|
@ -55,6 +69,16 @@ class MultiTurnEnvConfig(BaseEnvConfig):
|
|||
default=True,
|
||||
description="Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]"
|
||||
)
|
||||
generate_all_gpt_turns: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If True, the environment will emit a GPT turn *after each* <tool_response>. "
|
||||
"That reply **must begin** with one <think> … </think> block. "
|
||||
"If the dataset expects tool‑calls for that turn, the reply must also contain "
|
||||
"the same number of <tool_call> … </tool_call> blocks; otherwise it must contain "
|
||||
"no <tool_call> blocks."
|
||||
),
|
||||
)
|
||||
max_gen_per_turn: int = Field(
|
||||
default=1024,
|
||||
description="Hard cap on how many new tokens the model may generate in a single turn"
|
||||
|
|
@ -63,6 +87,10 @@ class MultiTurnEnvConfig(BaseEnvConfig):
|
|||
default=-0.2,
|
||||
description="Negative reward applied when the first mismatched tool-call causes early termination"
|
||||
)
|
||||
skip_completed: bool = Field(
|
||||
default=True,
|
||||
description="Skip any conversation whose first user prompt appears in COMPLETED_TASKS.",
|
||||
)
|
||||
|
||||
|
||||
system_prompt = (
|
||||
|
|
@ -73,34 +101,63 @@ system_prompt = (
|
|||
)
|
||||
|
||||
|
||||
def _validate_reply_and_extract(txt: str):
|
||||
# ------------------------------------------------------------------
|
||||
# Helper: validate a GPT reply that *may* include tool calls
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_think_only(txt: str) -> bool:
|
||||
"""
|
||||
Validates that the reply matches the allowed structure:
|
||||
- exactly one mandatory <think>…</think> block at the top
|
||||
- one or more <tool_call>…</tool_call> blocks
|
||||
- nothing else except whitespace/newlines
|
||||
Returns list of tool-call JSONs if valid, else None.
|
||||
A narration / summary turn must:
|
||||
• start with exactly one <think> … </think> block
|
||||
• contain **no** <tool_call> tags anywhere
|
||||
Anything after the </think> (user‑visible answer) is allowed.
|
||||
"""
|
||||
_allowed_re = re.compile(
|
||||
r"""^\s*
|
||||
<think>[\s\S]*?</think>\s*
|
||||
(?:
|
||||
<tool_call>[\s\S]*?</tool_call>\s*
|
||||
)+
|
||||
\s*$""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
if not isinstance(txt, str):
|
||||
return False
|
||||
|
||||
# Must begin with one think block
|
||||
begins_with_think = re.match(
|
||||
r"^\s*<think>[\s\S]*?</think>", txt, flags=re.IGNORECASE
|
||||
)
|
||||
if not isinstance(txt, str) or not _allowed_re.match(txt):
|
||||
if not begins_with_think:
|
||||
return False
|
||||
|
||||
# Must not contain any <tool_call>
|
||||
if re.search(r"<tool_call\s*>", txt, flags=re.IGNORECASE):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _validate_think_plus_calls(txt: str):
|
||||
"""
|
||||
Validate a GPT reply that should contain <think> … </think> followed by
|
||||
one or more <tool_call> … </tool_call> blocks.
|
||||
|
||||
Returns:
|
||||
list[dict] – Parsed tool‑call JSONs (≥1) → valid
|
||||
None – Any structural / JSON error → invalid
|
||||
"""
|
||||
pat = re.compile(
|
||||
r"^\s*<think>[\s\S]*?</think>\s*((?:<tool_call>[\s\S]*?</tool_call>\s*)+)\s*$",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
m = pat.match(txt)
|
||||
if not m:
|
||||
return None
|
||||
# Extract tool_call JSONs
|
||||
matches = re.findall(r"<tool_call>\s*(.*?)\s*</tool_call>", txt, re.DOTALL | re.IGNORECASE)
|
||||
jsons = []
|
||||
for m in matches:
|
||||
|
||||
calls_section = m.group(1)
|
||||
tool_jsons = []
|
||||
for raw in re.findall(
|
||||
r"<tool_call>\s*(.*?)\s*</tool_call>", calls_section, flags=re.DOTALL | re.IGNORECASE
|
||||
):
|
||||
try:
|
||||
jsons.append(json.loads(m))
|
||||
tool_jsons.append(json.loads(raw))
|
||||
except Exception:
|
||||
pass
|
||||
return jsons
|
||||
return None
|
||||
if not tool_jsons:
|
||||
return None
|
||||
return tool_jsons
|
||||
|
||||
|
||||
def _json_objects_match(model_json, expected_json):
|
||||
|
|
@ -135,7 +192,8 @@ class MultiTurnToolCallingEnv(BaseEnv):
|
|||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
# Load dataset once and cache on this instance
|
||||
self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
|
||||
#self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
|
||||
self.ds = load_dataset("interstellarninja/toolace_hermes_sequential_tool_use", split="train")
|
||||
|
||||
self.percent_correct_buffer: List[float] = []
|
||||
self.raw_score_buffer: List[float] = []
|
||||
|
|
@ -155,7 +213,7 @@ class MultiTurnToolCallingEnv(BaseEnv):
|
|||
def config_init(cls) -> Tuple[MultiTurnEnvConfig, List[APIServerConfig]]:
|
||||
env_cfg = MultiTurnEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=2000,
|
||||
|
|
@ -171,6 +229,8 @@ class MultiTurnToolCallingEnv(BaseEnv):
|
|||
max_gen_per_turn=MAX_GEN_PER_TURN,
|
||||
max_tool_call_turns=MAX_TOOL_CALL_TURNS,
|
||||
validate_think_blocks=VALIDATE_THINK_BLOCKS,
|
||||
generate_all_gpt_turns=GENERATE_ALL_GPT_TURNS,
|
||||
skip_completed=True,
|
||||
)
|
||||
server_cfgs = [
|
||||
APIServerConfig(
|
||||
|
|
@ -228,6 +288,14 @@ class MultiTurnToolCallingEnv(BaseEnv):
|
|||
running_msgs: List[frozenset] = []
|
||||
|
||||
conv = row["conversations"]
|
||||
# --- Skip conversation if its task already completed ---
|
||||
if self.config.skip_completed and COMPLETED_TASKS:
|
||||
first_user_msg = next(
|
||||
(m["value"].strip() for m in conv if m["from"] == "human"),
|
||||
None,
|
||||
)
|
||||
if first_user_msg and first_user_msg in COMPLETED_TASKS:
|
||||
continue
|
||||
if len(conv) < 3:
|
||||
continue
|
||||
if conv[0]["from"] != "system" or conv[1]["from"] != "human":
|
||||
|
|
@ -502,16 +570,23 @@ class MultiTurnToolCallingEnv(BaseEnv):
|
|||
for txt, r in zip(choices, ridx_map):
|
||||
txt = txt or ""
|
||||
contexts[r].append({"role": "assistant", "content": txt})
|
||||
calls = _validate_reply_and_extract(txt)
|
||||
|
||||
if calls is None:
|
||||
preds[r].append("__MISMATCH__")
|
||||
active[r] = False
|
||||
expected_turn_calls = expected_calls_by_turn[turn_idx]
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Decide validation strategy: narration vs. tool‑calling turn
|
||||
# ------------------------------------------------------------
|
||||
if expected_turn_calls: # Turn SHOULD have tool calls
|
||||
calls = _validate_think_plus_calls(txt)
|
||||
if calls is None:
|
||||
preds[r].append("__MISMATCH__")
|
||||
active[r] = False
|
||||
continue
|
||||
else: # Narration / summary turn
|
||||
if not _validate_think_only(txt):
|
||||
active[r] = False
|
||||
# Narration turns produce no predictions to score
|
||||
continue
|
||||
|
||||
# Get expected calls for this specific turn
|
||||
expected_turn_calls = expected_calls_by_turn[turn_idx]
|
||||
|
||||
# Check if number of calls matches
|
||||
if len(calls) != len(expected_turn_calls):
|
||||
preds[r].append("__MISMATCH__")
|
||||
|
|
@ -584,6 +659,45 @@ class MultiTurnToolCallingEnv(BaseEnv):
|
|||
|
||||
# Process and validate responses
|
||||
await self._process_turn_responses(turn_idx, choices, ridx_map, contexts, preds, active, expected_calls_by_turn)
|
||||
|
||||
# ───────────────────────────────────────────────────────────────
|
||||
# Optionally emit a GPT narration/summary turn after tool_response
|
||||
# ───────────────────────────────────────────────────────────────
|
||||
if self.config.generate_all_gpt_turns and any(active):
|
||||
extra_prompts, extra_ridx = [], []
|
||||
for r in range(len(contexts)):
|
||||
if not active[r]:
|
||||
continue
|
||||
ptxt = self.tokenizer.apply_chat_template(
|
||||
contexts[r],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
extra_prompts.append(ptxt)
|
||||
extra_ridx.append(r)
|
||||
|
||||
async def _infer_one(prompt_str: str) -> str:
|
||||
try:
|
||||
comp = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.7,
|
||||
)
|
||||
return comp.choices[0].text
|
||||
except Exception as exc:
|
||||
print(f" → extra GPT turn inference error: {exc}")
|
||||
return ""
|
||||
|
||||
extra_replies = await asyncio.gather(*[_infer_one(p) for p in extra_prompts])
|
||||
|
||||
for txt, r in zip(extra_replies, extra_ridx):
|
||||
txt = txt or ""
|
||||
contexts[r].append({"role": "assistant", "content": txt})
|
||||
|
||||
# Narration turn MUST be think‑only. If not, terminate rollout r.
|
||||
if not _validate_think_only(txt):
|
||||
active[r] = False
|
||||
|
||||
if not any(active):
|
||||
print(" → All roll-outs terminated; stopping further turns.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue