atropos/environments/tool_use_multiturn_server.py
2025-06-24 08:14:14 -04:00

790 lines
No EOL
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Multi-Turn Tool-Calling Environment
==================================
Extends the single-turn tool-calling environment to conversations that
contain **multiple** function-call / observation pairs. Each training
item corresponds to *one* episode consisting of all tool-calls in the conversation:
• We locate every message where `msg["from"] == "gpt" and has <tool_call>`.
• For each such conversation, we create an item whose **context** is all
conversation messages upto the next function call turn.
• Rewards are *episodic*: dense + sparse reward for matching all tool-calls.
Dataset columns expected
------------------------
* `conversations` list[dict] with keys `from` and `value`
"""
import json
import random
import re
import asyncio
import ast
from typing import Dict, List, Optional, Tuple, Union
from collections import Counter
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
from pydantic import Field
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# ------------------------------------------------------------------
# Filter: skip tasks that were already processed in a finished SFT dataset
# ------------------------------------------------------------------
COMPLETED_DATASET_ID = "interstellarninja/hermes_reasoning_tool_use"
try:
_done_ds = load_dataset(COMPLETED_DATASET_ID, split="train")
COMPLETED_TASKS: set[str] = set(_done_ds["task"])
print(f"[filter] Loaded {len(COMPLETED_TASKS):,} completed tasks from {COMPLETED_DATASET_ID}")
except Exception as _exc:
COMPLETED_TASKS = set()
print(f"[filter] Could not load completed-task dataset: {_exc}. No skipping will occur.")
# Easy-to-change constants for experimentation - modify these for quick testing
WRONG_CALL_PENALTY = -0.2
MAX_GEN_PER_TURN = 1024
MAX_TOOL_CALL_TURNS = 2
VALIDATE_THINK_BLOCKS = False
GENERATE_ALL_GPT_TURNS = True
class MultiTurnEnvConfig(BaseEnvConfig):
"""Configuration for Multi-Turn Tool Calling Environment."""
max_tool_call_turns: int = Field(
default=2,
description="Hard cap on how many tool-call turns we will actually roll out"
)
validate_think_blocks: bool = Field(
default=True,
description="Whether to validate that all GPT messages have <think> blocks [useful when non-tool call gpt messages are inserted]"
)
generate_all_gpt_turns: bool = Field(
default=False,
description=(
"If True, the environment will emit a GPT turn *after each* <tool_response>. "
"That reply **must begin** with one <think> … </think> block. "
"If the dataset expects toolcalls for that turn, the reply must also contain "
"the same number of <tool_call> … </tool_call> blocks; otherwise it must contain "
"no <tool_call> blocks."
),
)
max_gen_per_turn: int = Field(
default=1024,
description="Hard cap on how many new tokens the model may generate in a single turn"
)
wrong_call_penalty: float = Field(
default=-0.2,
description="Negative reward applied when the first mismatched tool-call causes early termination"
)
skip_completed: bool = Field(
default=True,
description="Skip any conversation whose first user prompt appears in COMPLETED_TASKS.",
)
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
# ------------------------------------------------------------------
# Helper: validate a GPT reply that *may* include tool calls
# ------------------------------------------------------------------
def _validate_think_only(txt: str) -> bool:
"""
A narration / summary turn must:
• start with exactly one <think> … </think> block
• contain **no** <tool_call> tags anywhere
Anything after the </think> (uservisible answer) is allowed.
"""
if not isinstance(txt, str):
return False
# Must begin with one think block
begins_with_think = re.match(
r"^\s*<think>[\s\S]*?</think>", txt, flags=re.IGNORECASE
)
if not begins_with_think:
return False
# Must not contain any <tool_call>
if re.search(r"<tool_call\s*>", txt, flags=re.IGNORECASE):
return False
return True
def _validate_think_plus_calls(txt: str):
"""
Validate a GPT reply that should contain <think> … </think> followed by
one or more <tool_call> … </tool_call> blocks.
Returns:
list[dict] Parsed toolcall JSONs (≥1) → valid
None Any structural / JSON error → invalid
"""
pat = re.compile(
r"^\s*<think>[\s\S]*?</think>\s*((?:<tool_call>[\s\S]*?</tool_call>\s*)+)\s*$",
flags=re.IGNORECASE,
)
m = pat.match(txt)
if not m:
return None
calls_section = m.group(1)
tool_jsons = []
for raw in re.findall(
r"<tool_call>\s*(.*?)\s*</tool_call>", calls_section, flags=re.DOTALL | re.IGNORECASE
):
try:
tool_jsons.append(json.loads(raw))
except Exception:
return None
if not tool_jsons:
return None
return tool_jsons
def _json_objects_match(model_json, expected_json):
"""
True when every key/value in expected_json appears exactly in model_json.
Nested dicts handled recursively.
"""
if not isinstance(model_json, dict) or not isinstance(expected_json, dict):
return False
for k, v in expected_json.items():
if k not in model_json:
return False
if isinstance(v, dict):
if not _json_objects_match(model_json[k], v):
return False
else:
if model_json[k] != v:
return False
return True
class MultiTurnToolCallingEnv(BaseEnv):
name = "multiturn_tool_use"
def __init__(
self,
config: MultiTurnEnvConfig,
server_configs: List[APIServerConfig],
slurm: bool = True,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
# Load dataset once and cache on this instance
#self.ds = load_dataset("interstellarninja/salesforce_hermes_thinking", split="train")
self.ds = load_dataset("interstellarninja/toolace_hermes_sequential_tool_use", split="train")
self.percent_correct_buffer: List[float] = []
self.raw_score_buffer: List[float] = []
self.eval_metrics: List[Tuple[str, float]] = []
self.rollouts_for_wandb: List = []
# List of (messages_tuple, expected_calls_by_turn, inter_turns) triples
self.train_items: List[
Tuple[Tuple[frozenset, ...], List[List[str]], List[List[Dict[str, str]]]]
] = []
self.test_items: List[
Tuple[Tuple[frozenset, ...], List[List[str]], List[List[Dict[str, str]]]]
] = []
self.iter = 0
@classmethod
def config_init(cls) -> Tuple[MultiTurnEnvConfig, List[APIServerConfig]]:
env_cfg = MultiTurnEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 64,
inference_weight=1.0,
wandb_name="multiturn_tool_use",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
# Override config defaults with experimental constants
wrong_call_penalty=WRONG_CALL_PENALTY,
max_gen_per_turn=MAX_GEN_PER_TURN,
max_tool_call_turns=MAX_TOOL_CALL_TURNS,
validate_think_blocks=VALIDATE_THINK_BLOCKS,
generate_all_gpt_turns=GENERATE_ALL_GPT_TURNS,
skip_completed=True,
)
server_cfgs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
)
]
return env_cfg, server_cfgs
async def setup(self):
ds = self.ds.shuffle()
counts = Counter()
for row in ds:
conv = row["conversations"]
num_turns = 0
for msg in conv:
if msg["from"] in ("gpt", "assistant") and re.search(
r"<tool_call>", msg["value"], re.IGNORECASE
):
num_turns += 1
counts[num_turns] += 1
print("Tool-call distribution (tool_calls_per_convo → examples):")
for k in sorted(counts):
print(f" {k:2d}{counts[k]}")
split = ds.train_test_split(0.02)
split["train"] = split["train"].shuffle()
split["test"] = split["test"].shuffle()
self._prep_items(split["train"], is_train=True)
self._prep_items(split["test"], is_train=False)
random.shuffle(self.train_items)
random.shuffle(self.test_items)
if not self.train_items:
raise ValueError("No training items prepared: check dataset formatting.")
def _prep_items(self, dataset, *, is_train: bool):
"""
For each conversation, collect all function_calls as a single episode.
The context is all messages up to (but not including) the first function_call;
the answer is the list of function_call JSONs (canonical string).
Each turn can have multiple tool calls.
We only keep those samples that contain = config.max_tool_call_turns separate messages with <tool_call>.
"""
target = self.train_items if is_train else self.test_items
before_len = len(target)
for row in dataset:
running_msgs: List[frozenset] = []
conv = row["conversations"]
# --- Skip conversation if its task already completed ---
if self.config.skip_completed and COMPLETED_TASKS:
first_user_msg = next(
(m["value"].strip() for m in conv if m["from"] == "human"),
None,
)
if first_user_msg and first_user_msg in COMPLETED_TASKS:
continue
if len(conv) < 3:
continue
if conv[0]["from"] != "system" or conv[1]["from"] != "human":
continue
# Check if conversation has ANY tool calling turns
has_tool_calls = any(
msg["from"] in ("gpt", "assistant") and "<tool_call>" in msg["value"].lower()
for msg in conv
)
if not has_tool_calls:
continue
# Optional: Validate <think> blocks in gpt messages if enabled
if self.config.validate_think_blocks:
gpt_messages = [msg for msg in conv if msg["from"] in ("gpt", "assistant")]
if not all("<think>" in msg["value"].lower() for msg in gpt_messages):
continue
if conv and conv[0]["from"] == "system":
combined_system = system_prompt + "\n\n" + conv[0]["value"]
running_msgs.append(
frozenset({"role": "system", "content": combined_system}.items())
)
conv = conv[1:]
else:
running_msgs.append(
frozenset({"role": "system", "content": system_prompt}.items())
)
inter_turns: List[List[Dict[str, str]]] = []
expected_calls_by_turn: List[List[str]] = []
buffer: List[Dict[str, str]] = []
tool_call_turns = 0
for msg in conv:
m_from, m_val = msg["from"], msg["value"]
is_tool_call = (
m_from in ("gpt", "assistant")
and "<tool_call>" in m_val.lower()
)
if is_tool_call:
tool_call_turns += 1
if expected_calls_by_turn: # If we have previous turns, save the buffer
inter_turns.append(buffer)
buffer = []
matches = re.findall(
r"<tool_call>\s*(.*?)\s*</tool_call>",
m_val,
re.DOTALL | re.IGNORECASE,
)
if not matches:
continue
# Group all tool calls from this message as one turn
turn_calls = []
for raw in matches:
try:
obj = json.loads(raw)
turn_calls.append(json.dumps(obj, separators=(",", ":")))
except Exception:
turn_calls.append(raw)
expected_calls_by_turn.append(turn_calls)
continue
elif m_from in ("human", "gpt", "assistant"):
role = "user" if m_from == "human" else "assistant"
if not expected_calls_by_turn:
running_msgs.append(
frozenset({"role": role, "content": m_val}.items())
)
else:
buffer.append({"role": role, "content": m_val})
elif m_from == "tool":
if not expected_calls_by_turn:
running_msgs.append(
frozenset({"role": "tool", "content": m_val}.items())
)
else:
buffer.append({"role": "tool", "content": m_val})
if buffer and expected_calls_by_turn:
inter_turns.append(buffer)
while len(inter_turns) < max(0, len(expected_calls_by_turn) - 1):
inter_turns.append([])
if tool_call_turns == self.config.max_tool_call_turns:
target.append((tuple(running_msgs), expected_calls_by_turn, inter_turns))
print(f"[prep_items] {'train' if is_train else 'test'}: added {len(target)-before_len} items.")
@staticmethod
def _score_episode(pred_calls: list, exp_calls: list, lam: float = 0.5, wrong_call_penalty: float = -0.2) -> float:
"""
pred_calls : list of JSON objects (already parsed)
exp_calls : list of *canonical* JSON strings from dataset
Returns dense + sparse reward:
r = (#correct / N) + lam * 1{all correct} + penalty (if mismatch)
"""
exp_jsons: List[dict] = []
for raw in exp_calls:
try:
exp_jsons.append(json.loads(raw))
except json.JSONDecodeError:
exp_jsons.append(ast.literal_eval(raw))
mismatch_penalty = 0.0
if pred_calls and pred_calls[-1] == "__MISMATCH__":
pred_calls = pred_calls[:-1]
mismatch_penalty = wrong_call_penalty
correct = sum(
1 for p, e in zip(pred_calls, exp_jsons) if _json_objects_match(p, e)
)
dense = correct / max(1, len(exp_jsons))
bonus = 1.0 if correct == len(exp_jsons) else 0.0
return dense + lam * bonus + mismatch_penalty
async def rollout_and_score_eval(self, item) -> float:
messages_tuple, expected_calls_by_turn, inter_turns = item
base_ctx = [dict(m) for m in messages_tuple]
ctx = list(base_ctx)
preds = []
# Iterate through turns instead of individual calls
for turn_idx, expected_turn_calls in enumerate(expected_calls_by_turn):
if turn_idx > 0 and turn_idx - 1 < len(inter_turns):
ctx.extend(inter_turns[turn_idx - 1])
prompt = self.tokenizer.apply_chat_template(ctx, add_generation_prompt=True, tokenize=False)
max_toks = max(1, self.config.max_token_length - len(prompt))
comp = await self.server.completion(
prompt=prompt, n=1, max_tokens=self.config.max_token_length, temperature=0.0, split="eval"
)
reply = comp.choices[0].text
ctx.append({"role": "assistant", "content": reply})
tool_jsons = _validate_reply_and_extract(reply)
if tool_jsons is None:
break
preds.extend(tool_jsons)
# Check if we've processed enough turns
if turn_idx >= len(expected_calls_by_turn) - 1:
break
# Flatten expected calls for scoring
expected_calls_flat = [call for turn_calls in expected_calls_by_turn for call in turn_calls]
score = self._score_episode(preds, expected_calls_flat, wrong_call_penalty=self.config.wrong_call_penalty)
return score
async def evaluate(self, *_, **__):
subset = self.test_items[: min(128, len(self.test_items))]
scores = await tqdm_asyncio.gather(*[self.rollout_and_score_eval(it) for it in subset])
avg_reward = sum(scores) / len(scores)
pct_exact = sum(1 for s in scores if s >= 1.0) / len(scores)
self.eval_metrics.append(("eval/avg_reward", avg_reward))
self.eval_metrics.append(("eval/percent_correct", pct_exact))
async def get_next_item(self):
"""
Return the next training item in a strictly sequential (nonwrapping) order.
Once we've gone through all items, reshuffle and start over.
"""
if not self.train_items:
raise ValueError("train_items is empty dataset preparation failed.")
if self.iter >= len(self.train_items):
random.shuffle(self.train_items)
self.iter = 0
itm = self.train_items[self.iter]
self.iter += 1
return itm
async def _build_turn_contexts(self, turn_idx: int, contexts: List[List[Dict[str, str]]],
inter_turns: List[List[Dict[str, str]]], active: List[bool]) -> Tuple[List[str], List[int]]:
"""Build prompts for the current turn from active rollout contexts."""
# Add inter-turn context if not the first turn
if turn_idx > 0 and turn_idx - 1 < len(inter_turns):
filler = inter_turns[turn_idx - 1]
for r in range(len(contexts)):
if active[r]:
contexts[r].extend(filler)
# Build prompts for active rollouts
prompts, ridx_map = [], []
for r in range(len(contexts)):
if not active[r]:
continue
ptxt = self.tokenizer.apply_chat_template(
contexts[r],
add_generation_prompt=True,
tokenize=False,
)
prompts.append(ptxt)
ridx_map.append(r)
return prompts, ridx_map
async def _execute_turn_inference(self, turn_idx: int, prompts: List[str], ridx_map: List[int]) -> List[str]:
"""Execute inference for a turn using optimal batching strategy."""
if turn_idx == 0:
# Turn 1: Use n parameter for identical prompts
return await self._batch_identical_prompts(prompts[0], len(ridx_map), turn_idx)
else:
# Later turns: Use parallel requests for heterogeneous prompts
return await self._batch_heterogeneous_prompts(prompts, turn_idx)
async def _batch_identical_prompts(self, prompt: str, count: int, turn_idx: int) -> List[str]:
"""Handle identical prompts efficiently using n parameter."""
print(f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m \033[92m{prompt}\033[0m")
resp = await self.server.completion(
prompt=prompt,
n=count,
max_tokens=self.config.max_token_length,
temperature=0.8,
)
choices = [c.text for c in resp.choices]
# Debug: print each rollout
for i, raw in enumerate(choices):
print(f" \033[93m· turn {turn_idx+1} rollout raw [{i}]:\033[0m \033[94m{raw}\033[0m")
if not raw.strip():
print(f" → (empty or error string returned for rollout {i})")
print(" → All turn 1 rollouts printed; moving on.\n" + "-"*48)
return choices
async def _batch_heterogeneous_prompts(self, prompts: List[str], turn_idx: int) -> List[str]:
"""Handle heterogeneous prompts using parallel requests."""
if turn_idx == 1:
print("=== DEBUG: Now parallelizing Turn 2 prompts ===")
print(f" → Parallelizing {len(prompts)} prompts at turn {turn_idx+1}")
# Print each prompt
for idx_p, p_str in enumerate(prompts):
print(f" \033[93m→ TURN-{turn_idx+1} prompt[{idx_p}] full:\033[0m \033[92m{p_str}\033[0m")
async def _call_single(prompt_str: str) -> str:
try:
comp = await self.server.completion(
prompt=prompt_str,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.8,
)
return comp.choices[0].text
except Exception as e:
print(f" → Turn {turn_idx+1} _call_single exception: {e}")
return ""
tasks = [_call_single(p) for p in prompts]
results = await asyncio.gather(*tasks)
# Debug: print results for all turns
choices = []
for i, rtext in enumerate(results):
raw = rtext or ""
print(f" \033[93m· rollout {i} (Turn {turn_idx+1}) full reply:\033[0m \033[94m{raw}\033[0m\n" + "-"*48)
if not raw:
print(f" → Rollout {i} returned empty or error string")
choices.append(raw)
return choices
async def _process_turn_responses(self, turn_idx: int, choices: List[str], ridx_map: List[int],
contexts: List[List[Dict[str, str]]], preds: List[List],
active: List[bool], expected_calls_by_turn: List[List[str]]) -> None:
"""Process and validate responses for a single turn."""
for txt, r in zip(choices, ridx_map):
txt = txt or ""
contexts[r].append({"role": "assistant", "content": txt})
expected_turn_calls = expected_calls_by_turn[turn_idx]
# ------------------------------------------------------------
# Decide validation strategy: narration vs. toolcalling turn
# ------------------------------------------------------------
if expected_turn_calls: # Turn SHOULD have tool calls
calls = _validate_think_plus_calls(txt)
if calls is None:
preds[r].append("__MISMATCH__")
active[r] = False
continue
else: # Narration / summary turn
if not _validate_think_only(txt):
active[r] = False
# Narration turns produce no predictions to score
continue
# Check if number of calls matches
if len(calls) != len(expected_turn_calls):
preds[r].append("__MISMATCH__")
active[r] = False
continue
mismatch = False
for mdl, exp_raw in zip(calls, expected_turn_calls):
try:
exp_obj = json.loads(exp_raw)
except Exception:
exp_obj = ast.literal_eval(exp_raw)
if not _json_objects_match(mdl, exp_obj):
mismatch = True
break
if mismatch:
preds[r].append("__MISMATCH__")
active[r] = False
else:
preds[r].extend(calls)
async def collect_trajectories(
self,
item: Tuple[
Tuple[frozenset, ...],
List[List[str]],
List[List[Dict[str, str]]],
],
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
"""
Roll-out one *tool-call turn* for every rollout in the group.
─ Round 0 ────────────────────────────────────────────────────────────
All roll-outs share an *identical* prompt → send a **single** request
with `n = group_size`.
─ Later rounds ───────────────────────────────────────────────────────
Prompts are heterogeneous, so we always issue `group_size` independent
requests in parallel via ``asyncio.gather``.
"""
messages_tuple, expected_calls_by_turn, inter_turns = item
base_ctx = [dict(m) for m in messages_tuple]
num_rollouts = self.config.group_size
contexts: List[List[Dict[str, str]]] = [list(base_ctx) for _ in range(num_rollouts)]
preds: List[List] = [[] for _ in range(num_rollouts)]
active = [True] * num_rollouts
max_turns = min(len(expected_calls_by_turn), self.config.max_tool_call_turns)
for turn_idx in range(max_turns):
print(f"[collect_trajectories] Beginning turn {turn_idx+1}/{max_turns} for this group")
# Build contexts and prompts for this turn
prompts, ridx_map = await self._build_turn_contexts(turn_idx, contexts, inter_turns, active)
if not prompts:
break
max_prompt_len = max(len(p) for p in prompts)
max_gen = min(
self.config.max_gen_per_turn,
max(1, self.config.max_token_length - max_prompt_len),
)
# Execute inference for this turn
choices = await self._execute_turn_inference(turn_idx, prompts, ridx_map)
# Process and validate responses
await self._process_turn_responses(turn_idx, choices, ridx_map, contexts, preds, active, expected_calls_by_turn)
# ───────────────────────────────────────────────────────────────
# Optionally emit a GPT narration/summary turn after tool_response
# ───────────────────────────────────────────────────────────────
if self.config.generate_all_gpt_turns and any(active):
extra_prompts, extra_ridx = [], []
for r in range(len(contexts)):
if not active[r]:
continue
ptxt = self.tokenizer.apply_chat_template(
contexts[r],
add_generation_prompt=True,
tokenize=False,
)
extra_prompts.append(ptxt)
extra_ridx.append(r)
async def _infer_one(prompt_str: str) -> str:
try:
comp = await self.server.completion(
prompt=prompt_str,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.7,
)
return comp.choices[0].text
except Exception as exc:
print(f" → extra GPT turn inference error: {exc}")
return ""
extra_replies = await asyncio.gather(*[_infer_one(p) for p in extra_prompts])
for txt, r in zip(extra_replies, extra_ridx):
txt = txt or ""
contexts[r].append({"role": "assistant", "content": txt})
# Narration turn MUST be thinkonly. If not, terminate rollout r.
if not _validate_think_only(txt):
active[r] = False
if not any(active):
print(" → All roll-outs terminated; stopping further turns.")
break
survivors = sum(1 for a in active if a)
print(f" → DEBUG: finished turn {turn_idx+1}; {survivors}/{num_rollouts} rollouts still active")
scored = ScoredDataGroup(tokens=[], masks=[], scores=[])
# Flatten expected calls for scoring (since _score_episode expects flat list)
expected_calls_flat = [call for turn_calls in expected_calls_by_turn for call in turn_calls]
for r in range(num_rollouts):
reward = self._score_episode(preds[r], expected_calls_flat, wrong_call_penalty=self.config.wrong_call_penalty)
out = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=contexts[r],
include_messages=self.config.include_messages,
)
scored["tokens"].append(out["tokens"])
scored["masks"].append(out["masks"])
scored["scores"].append(reward if reward > 0 else -1.0)
if scored["scores"] and all(s > 0.99 for s in scored["scores"]):
cutoff = self.config.max_token_length * 0.5
for i, ln in enumerate([len(t) for t in scored["tokens"]]):
if ln > cutoff:
frac = min((ln - cutoff) / (self.config.max_token_length - cutoff), 1.0)
scored["scores"][i] = max(0.0, scored["scores"][i] - frac)
for s in scored["scores"]:
self.raw_score_buffer.append(s)
self.percent_correct_buffer.append(1.0 if s >= 1.0 else 0.0)
if len(scored["tokens"]) < self.config.group_size or scored["scores"].count(
scored["scores"][0]
) == len(scored["scores"]):
return None, []
await self.add_rollouts_for_wandb(scored, item)
return scored, []
async def create_rollout_table(self, wdict):
if self.rollouts_for_wandb:
table = wandb.Table(columns=["generation", "score", "expected_tool_call"])
for grp in self.rollouts_for_wandb:
for g, sc, exp in grp:
exp_str = json.dumps(exp, separators=(",", ":"))
table.add_data(g, sc, exp_str)
wdict["train/rollouts"] = table
self.rollouts_for_wandb = []
return wdict
async def add_rollouts_for_wandb(
self, scored: ScoredDataGroup, item: Item, *, num_keep: int = 1
):
num_keep = min(num_keep, len(scored["tokens"]))
# Flatten expected calls for wandb logging
expected_calls_flat = [call for turn_calls in item[1] for call in turn_calls]
self.rollouts_for_wandb.append(
[
(
self.tokenizer.decode(scored["tokens"][i]),
scored["scores"][i],
expected_calls_flat,
)
for i in range(num_keep)
]
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def wandb_log(self, metrics: Optional[Dict] = None):
metrics = metrics or {}
metrics = await self.create_rollout_table(metrics)
if self.raw_score_buffer:
avg_reward = sum(self.raw_score_buffer) / len(self.raw_score_buffer)
pct_correct = (
sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
)
metrics["train/avg_reward"] = avg_reward
metrics["train/percent_correct"] = pct_correct
self.raw_score_buffer.clear()
self.percent_correct_buffer.clear()
for k, v in self.eval_metrics:
metrics[k] = v
await super().wandb_log(metrics)
if __name__ == "__main__":
MultiTurnToolCallingEnv.cli()