""" Interleaved‑Thinking Single‑Block Environment ============================================ This environment lets a model emit *multiple* / pairs **inside one still‑open block**, then close and write the final answer – all within a single assistant turn. Unlike the first draft, this version is **stand‑alone**: it does **NOT** inherit from SingleToolCallingEnv. All required boiler‑plate from that class is copied here so nothing breaks when you swap env names. """ from __future__ import annotations import asyncio import json import logging import os import re from typing import Dict, List, Optional, Tuple, Union import aiohttp import httpx import wandb from datasets import Dataset, load_dataset from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, ScoredDataGroup, ) from atroposlib.type_definitions import Message from atroposlib.utils.io import parse_http_response from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer logger = logging.getLogger(__name__) # Set to True to always print debug information. DEBUG = True EXECUTION_FEEDBACK = True TOOL_USAGE_BONUS = 0.2 # Hard caps for generation length MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens MAX_GEN_PER_TURN = 512 # never request more than 512 new tokens from the model # Maximum number of thinking/tool-use turns per rollout MAX_ROLLOUT_TURNS = 3 system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) TOOL_SYSTEM_PROMPT = ( "You are a function-calling & reasoning AI model. You are provided with " "function signatures inside … XML tags. After calling & " "executing the functions, you will get results inside … " "Here are the available tools:\n\n" "\n" "[\n" " {\n" ' "type": "function",\n' ' "function": {\n' ' "name": "calculator",\n' ' "description": ' ' "Evaluate a numeric Python expression and return the result.",\n' ' "parameters": {\n' ' "type": "object",\n' ' "properties": {\n' ' "expr": {\n' ' "type": "string",\n' ' "description": ' ' "A pure-Python arithmetic expression\'"\n' " }\n" " },\n" ' "required": ["expr"]\n' " }\n" " }\n" " },\n" " {\n" ' "type": "function",\n' ' "function": {\n' ' "name": "python_interpreter",\n' ' "description": ' ' "Run a short Python snippet and return stdout plus the last ' 'expression.",\n' ' "parameters": {\n' ' "type": "object",\n' ' "properties": {\n' ' "code": {\n' ' "type": "string",\n' ' "description": "Python source code to execute."\n' " }\n" " },\n" ' "required": ["code"]\n' " }\n" " }\n" " }\n" "]\n" "\n\n" "You must use reasoning tools such as python_interpreter as a tool call when available " "for hard problems such as math before providing your final answer.\n" "Always provide your final numeric answer (or final result) in \\\\boxed{...} so it " "can be automatically graded right after closing tag.\n\n" "For reasoning tools, return interleaved tool calls within tags.\n" "\n" "{'name': , 'arguments': }\n" "\n" "{'result': }\n" "\n" "\n" "\n" ) SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT class InterleavedInlineEnv(BaseEnv): """ One episode = user prompt → single assistant message with inline tool calls inside a still‑open block. """ name = "interleaved_inline" _re_last_call = re.compile(r"\s*(.*?)\s*\s*$", re.S) def __init__( self, config: BaseEnvConfig, server_configs: List[APIServerConfig], slurm: bool = True, testing: bool = False, ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer: List[float] = [] self.eval_metrics: List[Tuple[str, float]] = [] self.rollouts_for_wandb = [] self.iter = 0 import random self.rng = random.Random() # Dynamic few‑shot pool: list of (user_msg, assistant_msg) tuples self.dynamic_pool: List[Tuple[Dict, Dict]] = [] self.dynamic_pool_max = 0 # keep at most 4 real examples self.max_token_len = 8192 async def get_server_info(self): """Override to prevent server from overwriting our max_token_len""" async with aiohttp.ClientSession() as session: async with session.get(f"{self.config.rollout_server_url}/info") as resp: data = await parse_http_response(resp, logger) if data["batch_size"] != -1: self.config.batch_size = data["batch_size"] # Log what the server tried to set max_token_len to if data["max_token_len"] != -1: logger.info( f"Server tried to set max_token_len to {data['max_token_len']}\n" f"keeping our value of {self.max_token_len}" ) if self.config.batch_size == -1: logging.warning("Batch size not set by config or server!") if self.config.group_size > self.config.batch_size: raise ValueError( f"group_size ({self.config.group_size}) " f"must be less than batch_size ({self.config.batch_size})" ) @classmethod def config_init(cls): cfg = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=2000, batch_size=1024, steps_per_eval=20, max_token_length=16 * 8192, inference_weight=1.0, wandb_name="toolcall_interleaved", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, max_gen_per_turn=MAX_GEN_PER_TURN, ) servers = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ) ] return cfg, servers async def setup(self): """ Load a streamed subset of **nvidia/AceReason-Math**. We keep only rows whose *answer* looks purely numeric so the calculator / python_interpreter tools can verify them automatically. The env‑var SUBSET_ROWS (default 1000) controls how many rows we keep. """ N = int(os.getenv("SUBSET_ROWS", "1000")) stream_ds = load_dataset( # ≈50 k rows total → stream # "NVIDIA/OpenMathReasoning", # split="cot", # "NVIDIA/OpenMathReasoning", # split="cot", # "open-r1/OpenR1-Math-220k", "nvidia/AceReason-Math", split="train", streaming=True, ) _numeric = re.compile(r"^[0-9+\-*/(). %√\\\\sqrt{}]+$").fullmatch subset = [] for ex in stream_ds: if len(subset) >= N: break # some datasets use "answer", others "expected_answer" ans_raw = ex.get("answer", ex.get("expected_answer")) if ans_raw is None: continue ans = str(ans_raw).strip() if _numeric(ans): subset.append( { "problem": ex["problem"], "expected_answer": ans, } ) full = Dataset.from_list(subset) if DEBUG: print(f"[DEBUG setup] kept {len(subset)} rows from Dataset") split = full.train_test_split(test_size=0.02, seed=42) split = full.train_test_split(test_size=0.02, seed=42) self.train, self.test = split["train"], split["test"] self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big")) async def _completion_until( self, prompt: str, max_tokens: int, stop: Optional[Union[str, List[str]]] = None ) -> str: comp = await self.server.completion( prompt=prompt, stop=stop, max_tokens=max_tokens, temperature=0.8, ) return comp.choices[0].text def _extract_last_call(self, chunk: str): """ Return the JSON dict for the *last* block in `chunk`, or **None** if no such block exists. Also handles incomplete tool calls (missing tag). """ # First try to find complete tool calls matches = self._re_last_call.findall(chunk) if matches: try: return json.loads(matches[-1]) except Exception: pass # If no complete tool calls, look for incomplete ones (missing ) last_tool_call_pos = chunk.rfind("") if last_tool_call_pos != -1: json_start = last_tool_call_pos + len("") json_text = chunk[json_start:].strip() try: return json.loads(json_text) except json.JSONDecodeError: # Try partial JSON extraction brace_count = 0 json_end = 0 for i, char in enumerate(json_text): if char == "{": brace_count += 1 elif char == "}": brace_count -= 1 if brace_count == 0: json_end = i + 1 break if json_end > 0: try: return json.loads(json_text[:json_end]) except json.JSONDecodeError: pass return None def _is_new_tool_call(self, raw: str) -> bool: """ Return True if there's an unresponded in raw (i.e., open call without matching ). """ pos = raw.rfind("") if pos == -1: return False return "" not in raw[pos:] @staticmethod def _canon_num(txt: str) -> str: """Return number string without commas / spaces; keep leading sign.""" return txt.strip().replace(",", "").replace(" ", "") # boxed{answer} pattern for final numeric result _re_box = re.compile(r"\\boxed\{([^}]*)\}") def _boxed_after_think(self, text: str) -> Optional[str]: """ Return the first \\boxed{…} that appears *after* the closing tag. Returns None if is missing or no boxed answer exists. """ think_pos = text.find("") if think_pos == -1: return None m = self._re_box.search(text, pos=think_pos) return m.group(1).strip() if m else None async def _exec_tool(self, call_json: Dict): """ Execute reasoning‑time tools. • python_interpreter → POST code to the local coding server running at localhost:5002/execute and return {"stdout":..., "result":...} • calculator → eval(expr) in a math‑only sandbox and return the number. """ name = call_json["name"] args = call_json["arguments"] if name == "python_interpreter": try: async with httpx.AsyncClient(timeout=10.0) as client: payload = {"code": args["code"], "input": ""} resp = await client.post( "http://localhost:5002/execute", json=payload ) data = resp.json() except httpx.ConnectError: print( "❌ [CRITICAL] Python interpreter server not available at localhost:5002" ) print("Please ensure the code_exec_server Docker container is running") raise RuntimeError( "Python interpreter server not available - cannot continue without verification" ) if DEBUG: print(f"[DEBUG _exec_tool] {name} result → {data}") return { "stdout": data.get("output", ""), "result": data.get("output", "").strip(), } elif name == "calculator": import math expr = args["expr"] val = eval(expr, {"__builtins__": {}}, {"math": math}) if DEBUG: print(f"[DEBUG _exec_tool] {name} result → {val}") return {"value": val} else: raise ValueError(f"Unknown tool name {name}") async def _execute_turn_inference( self, turn_idx: int, prompts: List[str], ridx_map: List[int], expected_calls_by_turn: List[List[str]], ) -> List[str]: """Execute inference for a turn using optimal batching strategy.""" print(f"\n\033[95m=== Expected Tool Calls for Turn {turn_idx+1} ===\033[0m") print(f"\033[95m{expected_calls_by_turn[turn_idx]}\033[0m\n") # Always use batched identical prompts for turn 0, heterogeneous for others if turn_idx == 0: choices = await self._batch_identical_prompts( prompts[0], len(ridx_map), turn_idx ) else: choices = await self._batch_heterogeneous_prompts(prompts, turn_idx) return choices async def _batch_identical_prompts( self, prompt: str, count: int, turn_idx: int ) -> List[str]: """Handle identical prompts efficiently using n parameter.""" print( f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m " f"\033[92m{prompt}\033[0m" ) # Use the constant instead of config attribute resp = await self.server.completion( prompt=prompt, n=count, max_tokens=MAX_GEN_PER_TURN, temperature=0.8, stop="", ) choices = [c.text for c in resp.choices] # Debug: print each rollout for i, raw in enumerate(choices): print( f" \033[93m· turn {turn_idx+1} rollout raw [{i}]:\033[0m \033[94m{raw}\033[0m" ) if not raw.strip(): print(f" → (empty or error string returned for rollout {i})") print(" → All turn 1 rollouts printed; moving on.\n" + "-" * 48) return choices async def _batch_heterogeneous_prompts( self, prompts: List[str], turn_idx: int ) -> List[str]: """Handle heterogeneous prompts using parallel requests.""" if turn_idx == 1: print("=== DEBUG: Now parallelizing Turn 2 prompts ===") print(f" → Parallelizing {len(prompts)} prompts at turn {turn_idx+1}") # Print each prompt for idx_p, p_str in enumerate(prompts): print( f" \033[93m→ TURN-{turn_idx+1} prompt[{idx_p}] full:\033[0m \033[92m{p_str}\033[0m" ) async def _call_single(prompt_str: str) -> str: try: # Use the constant instead of config attribute comp = await self.server.completion( prompt=prompt_str, n=1, max_tokens=MAX_GEN_PER_TURN, temperature=0.8, stop="", ) return comp.choices[0].text except Exception as e: print(f" → Turn {turn_idx+1} _call_single exception: {e}") return "" tasks = [_call_single(p) for p in prompts] results = await asyncio.gather(*tasks) # Debug: print results for all turns choices = [] for i, rtext in enumerate(results): raw = rtext or "" print( f" \033[93m· rollout {i} (Turn {turn_idx+1}) full reply:\033[0m \033[94m{raw}\033[0m\n" + "-" * 48 ) if not raw: print(f" → Rollout {i} returned empty or error string") choices.append(raw) return choices def _json_objects_match(self, j1, j2): try: for k in j2: if k not in j1: return False if isinstance(j2[k], dict): if not self._json_objects_match(j1[k], j2[k]): return False elif j1[k] != j2[k]: return False return True except Exception: return False async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: """ One prompt → `n = group_size` sampled assistant completions in parallel (single OpenAI request with n completions). Mirrors the logic in SingleToolCallingEnv. """ messages_tuple, expected_raw = item expected = ( json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw ) # Re‑inflate frozensets to normal dicts prompt_msgs = [dict(r) for r in messages_tuple] if EXECUTION_FEEDBACK: # MODE: Real interleaved tool execution return await self._collect_trajectories_with_execution( prompt_msgs, expected ) else: # MODE: Static generation for data collection (current behavior) return await self._collect_trajectories_static(prompt_msgs, expected) async def _collect_trajectories_static( self, prompt_msgs: List[Dict], expected ) -> Tuple[ScoredDataGroup, List]: """ Original static generation mode - no tool execution, just data collection. """ # Convert to text prompt prompt_txt = self.tokenizer.apply_chat_template( prompt_msgs, add_generation_prompt=True, tokenize=False ) if DEBUG: clean_prompt = prompt_txt.replace("<|eot_id|>", "") print( f"\n\033[93m▶ BATCH PROMPT (tokens {len(prompt_txt)}):\033[0m " f"\033[92m{clean_prompt}\033[0m\n{'-'*60}" ) # One API call → many completions completions = await self.server.completion( prompt=prompt_txt, n=self.config.group_size, max_tokens=MAX_GEN_PER_TURN, temperature=0.8, ) scored: ScoredDataGroup = { "tokens": [], "masks": [], "scores": [], "advantages": None, "ref_logprobs": None, "messages": None, "group_overrides": {}, "overrides": None, "images": None, } for idx, choice in enumerate(completions.choices): raw = choice.text or "" toks = self.tokenizer.encode(raw) if len(toks) > MAX_REPLY_TOKENS: toks = toks[:MAX_REPLY_TOKENS] raw = self.tokenizer.decode(toks) raw = self.tokenizer.decode(toks) if DEBUG: print(f"[DEBUG] truncated reply {idx} to {len(toks)} tokens") assistant_msg = {"role": "assistant", "content": raw} # Create the full context for tokenization - cast to Message type full_ctx: List[Message] = prompt_msgs + [assistant_msg] # Outcome‑based reward: compare boxed answer to expected expr expr = ( expected["arguments"]["code"][6:-1] if ( isinstance(expected, dict) and "arguments" in expected and "code" in expected["arguments"] and expected["arguments"]["code"].startswith("print(") and expected["arguments"]["code"].endswith(")") ) else None ) expr = ( expected["arguments"]["code"][6:-1] if ( isinstance(expected, dict) and "arguments" in expected and "code" in expected["arguments"] and expected["arguments"]["code"].startswith("print(") and expected["arguments"]["code"].endswith(")") ) else None ) boxed = self._boxed_after_think(raw) same = boxed == expr or ( boxed and expr and self._canon_num(boxed) == self._canon_num(expr) ) reward = 1.0 if same else -1.0 if "" not in raw: reward = -1.0 # invalid – did not close think block else: # no tool_call tags are allowed *outside* the think block end_pos = raw.lower().find("") if "") :].lower(): if DEBUG: print( "[DEBUG] tool_call found outside ; setting reward = -1" ) reward = -1.0 if DEBUG: print( f"\033[95m--- COMPLETION {idx+1}/{self.config.group_size} ---\033[0m\n" f"\033[94m{raw}\033[0m\nreward={reward}\n{'='*60}" ) tok = tokenize_for_trainer(self.tokenizer, full_ctx) scored["tokens"].append(tok["tokens"]) scored["masks"].append(tok["masks"]) scored["scores"].append(reward) self.percent_correct_buffer.append(max(reward, 0)) # --- harvest a success for dynamic few‑shots -------------------- for idx, sc in enumerate(scored["scores"]): reply_txt = completions.choices[idx].text has_call = "= 1.0 and has_call: # Build (user, assistant) pair from this successful rollout u = {"role": "user", "content": prompt_msgs[-1]["content"]} a = {"role": "assistant", "content": reply_txt} self.dynamic_pool.append((u, a)) if len(self.dynamic_pool) > self.dynamic_pool_max: self.dynamic_pool.pop(0) break return scored, [] async def _collect_trajectories_with_execution( self, prompt_msgs: List[Dict], expected ) -> Tuple[ScoredDataGroup, List]: """ Real interleaved tool execution mode - stops at tool calls, executes them, and continues. Uses turn-based parallel execution for maximum efficiency. """ print( f"\n🚀 [EXECUTION MODE] Running {self.config.group_size} rollouts with parallel turn-based execution" ) scored: ScoredDataGroup = { "tokens": [], "masks": [], "scores": [], "advantages": None, "ref_logprobs": None, "messages": None, "group_overrides": {}, "overrides": None, "images": None, } # Initialize per-rollout state num_rollouts = self.config.group_size rollout_ctxs = [prompt_msgs.copy() for _ in range(num_rollouts)] assistant_msgs = [ {"role": "assistant", "content": ""} for _ in range(num_rollouts) ] done = [False] * num_rollouts final_results = [None] * num_rollouts executed_tools = [[] for _ in range(num_rollouts)] # Track the most recent generation chunk for each rollout last_turns: List[str] = [""] * num_rollouts turn_idx = 0 max_turns = MAX_ROLLOUT_TURNS while not all(done) and turn_idx < max_turns: print( f"\n[TURN {turn_idx + 1}] Processing {sum(1 for d in done if not d)} active rollouts" ) # Build prompts for active rollouts only active_prompts = [] active_indices = [] for i in range(num_rollouts): if not done[i]: prompt_txt = self.tokenizer.apply_chat_template( rollout_ctxs[i], add_generation_prompt=True, tokenize=False, ) prompt_txt += assistant_msgs[i]["content"] active_prompts.append(prompt_txt) active_indices.append(i) if not active_prompts: break # Execute inference for this turn if turn_idx == 0: # First turn: all prompts are identical, use batched inference print( f"[TURN {turn_idx + 1}] Batching {len(active_prompts)} identical prompts" ) replies = await self._batch_identical_prompts( active_prompts[0], len(active_prompts), turn_idx ) else: # Subsequent turns: prompts may be heterogeneous, use parallel inference print( f"⚡ [TURN {turn_idx + 1}] Parallelizing {len(active_prompts)} heterogeneous prompts" ) replies = await self._batch_heterogeneous_prompts( active_prompts, turn_idx ) # Process each active rollout's reply for prompt_idx, rollout_idx in enumerate(active_indices): if done[rollout_idx]: continue reply = replies[prompt_idx] # Save this turn's delta for summary last_turns[rollout_idx] = reply assistant_msgs[rollout_idx]["content"] += reply raw = assistant_msgs[rollout_idx]["content"] if "" in raw: # Think block closed boxed = self._boxed_after_think(raw) if boxed: # Boxed answer found after print( f"🎯 [ROLLOUT {rollout_idx}] Found boxed answer after - marking complete" ) done[rollout_idx] = True rollout_ctxs[rollout_idx].append(assistant_msgs[rollout_idx]) final_results[rollout_idx] = raw continue else: # Think block closed but no boxed answer print( f"❌ [ROLLOUT {rollout_idx}] closed but no boxed answer - marking failed" ) done[rollout_idx] = True final_results[rollout_idx] = raw continue else: # Think block not closed if self._is_new_tool_call(raw): # Tool call present, continue to next turn after executing tool print( f"🔧 [ROLLOUT {rollout_idx}] Tool call detected - extracting and executing" ) call_json = self._extract_last_call(raw) if call_json is None: print( f"❌ [ROLLOUT {rollout_idx}] Failed to parse tool call JSON - marking inactive" ) done[rollout_idx] = True final_results[rollout_idx] = raw continue print( f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']}\n" f"with args: {call_json['arguments']}" ) try: result = await self._exec_tool(call_json) executed_tools[rollout_idx].append(call_json) print(f"✅ [ROLLOUT {rollout_idx}] Tool result: {result}") # Clean up any malformed/partial closing tags before appending content = assistant_msgs[rollout_idx]["content"] content = re.sub( r"{json.dumps(result)}\n" ) print( f"📝 [ROLLOUT {rollout_idx}] Added tool response to context" ) continue except Exception as e: print( f"❌ [ROLLOUT {rollout_idx}] Tool execution failed: {e}" ) done[rollout_idx] = True final_results[rollout_idx] = raw continue else: # No new tool call or boxed answer yet if turn_idx + 1 < max_turns: print( f"🔄 [ROLLOUT {rollout_idx}] Still thinking—continuing to next turn" ) continue # max turns reached, fail print( f"⚠️ [ROLLOUT {rollout_idx}] Max turns reached without completion—marking failed" ) done[rollout_idx] = True final_results[rollout_idx] = raw continue turn_idx += 1 # Process final results and score print(f"\n🏁 [EXECUTION COMPLETE] Processed {turn_idx} turns") # -- Summary of all rollouts before scoring -- expr = None if ( isinstance(expected, dict) and "arguments" in expected and "code" in expected["arguments"] ): code_str = expected["arguments"]["code"] if code_str.startswith("print(") and code_str.endswith(")"): expr = code_str[6:-1] print("\n\033[96m🔎 Final rollout results:\033[0m") any_success = False for i in range(num_rollouts): # Get full text and boxed value raw_full = ( final_results[i] if final_results[i] is not None else assistant_msgs[i]["content"] ) boxed_val = self._boxed_after_think(raw_full) # Determine correctness against expected is_correct = False if expr is not None and boxed_val is not None: is_correct = boxed_val == expr or self._canon_num( boxed_val ) == self._canon_num(expr) # Choose color and label if is_correct: label = "CORRECT" lbl_color = "\033[92m" any_success = True elif boxed_val is not None: label = "WRONG" lbl_color = "\033[93m" else: label = "NO_BOX" lbl_color = "\033[91m" reset = "\033[0m" last = last_turns[i] print() print(f"\033[93m--- ROLLOUT {i} ---\033[0m") print(f"Result: {lbl_color}{label}{reset}") # Last turn content print("Last turn output:") print(f"\033[96m{last}\033[0m") # Boxed vs expected print(f"Boxed answer: {boxed_val!r}") print(f"Expected answer: {expr!r}") if not any_success: print( f"⚠️ All {num_rollouts} rollouts failed to produce a boxed answer. Invalidating group." ) return None, [] # -- End summary -- for rollout_idx in range(num_rollouts): try: raw = ( final_results[rollout_idx] if final_results[rollout_idx] is not None else assistant_msgs[rollout_idx]["content"] ) toks = self.tokenizer.encode(raw) if len(toks) > MAX_REPLY_TOKENS: toks = toks[:MAX_REPLY_TOKENS] raw = self.tokenizer.decode(toks) final_assistant_msg = {"role": "assistant", "content": raw} full_ctx: List[Message] = prompt_msgs + [final_assistant_msg] expr = ( expected["arguments"]["code"][6:-1] if ( isinstance(expected, dict) and "arguments" in expected and "code" in expected["arguments"] and expected["arguments"]["code"].startswith("print(") and expected["arguments"]["code"].endswith(")") ) else None ) boxed = self._boxed_after_think(raw) same = boxed == expr or ( boxed and expr and self._canon_num(boxed) == self._canon_num(expr) ) reward = 1.0 if same else -1.0 if "" not in raw: reward = -1.0 else: end_pos = raw.lower().find("") # Check for tool calls or responses after if ( "") :].lower() or "") :].lower() ): reward = -1.0 # Add bonus for tool usage if the completion was successful elif reward > 0 and len(executed_tools[rollout_idx]) > 0: print( f"🌟 [ROLLOUT {rollout_idx}] Adding tool usage bonus (+{TOOL_USAGE_BONUS})" ) reward += TOOL_USAGE_BONUS tok = tokenize_for_trainer(self.tokenizer, full_ctx) scored["tokens"].append(tok["tokens"]) scored["masks"].append(tok["masks"]) scored["scores"].append(reward) self.percent_correct_buffer.append(max(reward, 0)) # Add successful completions to dynamic pool regardless of number of turns if ( reward >= 1.0 ): # This will now include both 1.0 and 1.0 + TOOL_USAGE_BONUS u = {"role": "user", "content": prompt_msgs[-1]["content"]} a = {"role": "assistant", "content": raw} self.dynamic_pool.append((u, a)) if len(self.dynamic_pool) > self.dynamic_pool_max: self.dynamic_pool.pop(0) except Exception: scored["tokens"].append([]) scored["masks"].append([]) scored["scores"].append(-1.0) self.percent_correct_buffer.append(0.0) print( "\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: \n" f"{sum(scored['scores'])/len(scored['scores']):.3f}" ) # -- Per-rollout score summary -- print("\n\033[96m📊 Rollout score summary:\033[0m") reset = "\033[0m" for i, score in enumerate(scored["scores"]): color = "\033[92m" if score > 0 else "\033[91m" print(f" \033[93m[ROLLOUT {i}]\033[0m Score: {color}{score}{reset}") # Add warning if all rollouts failed if all(score < 0 for score in scored["scores"]): print( f"⚠️ [WARNING] All {len(scored['scores'])} rollouts failed with negative rewards!" ) print( " This may indicate a problem with the model, prompt, or token budget." ) # Signal failure to the outer loop return None, [] return scored, [] # --------------------- evaluation loop -------------------------------- # async def evaluate(self, *_, **__): """ Simple eval: run one rollout per test item, compute binary correctness based on the boxed answer. Adds a metric 'eval/percent_correct'. """ if not hasattr(self, "test"): return total, correct = 0, 0 for sample in self.test: # Build prompt exactly like get_next_item but without mutating self.iter prompt_text = sample["problem"] expr = sample["expected_answer"].strip() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt_text}, ] prompt = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) comp = await self.server.completion( prompt=prompt, n=1, max_tokens=1024, temperature=0.0, split="eval", ) model_reply = comp.choices[0].text boxed = self._boxed_after_think(model_reply) if boxed and boxed == expr: correct += 1 total += 1 accuracy = correct / max(total, 1) self.eval_metrics.append(("eval/percent_correct", accuracy)) # --------------------- dataset iterator ------------------------------- # async def get_next_item(self): idx = self.rng.randint(0, len(self.train) - 1) sample = self.train[idx] prompt_text = sample["problem"] expr = sample["expected_answer"].strip() answer_call = { "name": "python_interpreter", "arguments": {"code": f"print({expr})"}, } # ---------------- few‑shot demonstration ---------------- # fewshot_user = { "role": "user", "content": "Compute the integral of x^2 from 0 to 1.", } fewshot_assistant = { "role": "assistant", "content": ( "\n" "Let's evaluate the definite integral ∫₀¹ x² dx. This is a basic power rule integral.\n" "We know:\n" "∫ xⁿ dx from a to b = [xⁿ⁺¹ / (n+1)] from a to b.\n" "So for x²:\n" "= [x³ / 3] from 0 to 1\n" "= (1³ / 3) - (0³ / 3) = 1/3 - 0 = 1/3\n" "That checks out, but let's confirm with SymPy just to be sure.\n" '{"name":"python_interpreter", ' '"arguments":{"code":' '"import sympy as sp\\n' "x=sp.symbols('x')\\n" 'print(sp.integrate(x**2,(x,0,1)))"}}\n' "\n" '{"result": 1/3}\n' "The interpreter returns 1/3, so the value is 0.333̅.\n" "\n\n" "The integral equals \\boxed{\\tfrac{1}{3}} \\approx 0.333." ), } # --- second tiny example: simple arithmetic with calculator ---- # fewshot_user2 = {"role": "user", "content": "What is (2 + 3) * 4 ?"} fewshot_assistant2 = { "role": "assistant", "content": ( "\n" "I need (2+3)*4. Quick mental math gives 5*4 = 20, " "but I'll confirm with the calculator tool.\n" '{"name":"calculator", ' '"arguments":{"expr":"(2+3)*4"}}\n' '{"value": 20}\n' "The tool also says 20, matching my head‑math.\n" "\n\n" "Therefore the answer is \\boxed{20}." ), } # --------------- build final prompt messages ------------ # system_msg = {"role": "system", "content": SYSTEM_PROMPT} real_user = { "role": "user", "content": ( f"{prompt_text} \n" "This is a math problem, you must use the python_interpreter or calculator tool call to solve it." # "Before you call the tools, try to solve it step-by-step and then use the tool to verify" ), } # Optionally insert one real demo from dynamic_pool dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else [] messages = ( [ system_msg, fewshot_user, fewshot_assistant, fewshot_user2, fewshot_assistant2, ] + dyn + [real_user] ) # Freeze for hashing frozen = tuple(frozenset(m.items()) for m in messages) return (frozen, answer_call) # --------------------- wandb logging ---------------------------------- # async def create_rollout_table(self, metrics): if self.rollouts_for_wandb: table = wandb.Table(columns=["text", "score"]) for grp in self.rollouts_for_wandb: for txt, sc in grp: table.add_data(txt, sc) metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return metrics async def wandb_log(self, metrics: Dict = None): metrics = metrics or {} if self.percent_correct_buffer: metrics["train/percent_correct"] = sum(self.percent_correct_buffer) / len( self.percent_correct_buffer ) self.percent_correct_buffer = [] for k, v in self.eval_metrics: metrics[k] = v self.eval_metrics = [] await super().wandb_log(metrics) if __name__ == "__main__": InterleavedInlineEnv.cli()