diff --git a/environments/tool_use_interleaved_thinking.py b/environments/tool_use_interleaved_thinking.py index deacc928..a0e1790f 100644 --- a/environments/tool_use_interleaved_thinking.py +++ b/environments/tool_use_interleaved_thinking.py @@ -13,10 +13,11 @@ class is copied here so nothing breaks when you swap env names. from __future__ import annotations -import json, re -from typing import Dict, List, Optional, Tuple -import os import itertools +import json +import os +import re +from typing import Dict, List, Optional, Tuple # Set to True to always print debug information. @@ -27,18 +28,17 @@ MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tok MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model import wandb -from datasets import load_dataset, Dataset +from datasets import Dataset, load_dataset from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, - APIServerConfig, EvalHandlingEnum, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer - # -------------------------------------------------------------------------- # # Constants # -------------------------------------------------------------------------- # @@ -103,10 +103,9 @@ For reasoning tools, return interleaved tool calls within tags. """ -SYSTEM_PROMPT = ( - system_prompt - + TOOL_SYSTEM_PROMPT -) +SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT + + # -------------------------------------------------------------------------- # # Environment # -------------------------------------------------------------------------- # @@ -133,6 +132,7 @@ class InterleavedInlineEnv(BaseEnv): self.rollouts_for_wandb = [] self.iter = 0 import random + self.rng = random.Random() # Dynamic few‑shot pool: list of (user_msg, assistant_msg) tuples self.dynamic_pool: List[Tuple[Dict, Dict]] = [] @@ -214,7 +214,9 @@ class InterleavedInlineEnv(BaseEnv): self.train, self.test = split["train"], split["test"] self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big")) # --------------------- helper methods --------------------------------- # - async def _completion_until(self, prompt: str, max_tokens: int, stop: Optional[str] = None) -> str: + async def _completion_until( + self, prompt: str, max_tokens: int, stop: Optional[str] = None + ) -> str: comp = await self.server.completion( prompt=prompt, stop=stop, @@ -268,16 +270,23 @@ class InterleavedInlineEnv(BaseEnv): args = call_json["arguments"] if name == "python_interpreter": - import httpx, asyncio + import asyncio + + import httpx + async with httpx.AsyncClient(timeout=10.0) as client: payload = {"code": args["code"], "input": ""} resp = await client.post("http://localhost:5002/execute", json=payload) data = resp.json() if DEBUG: print(f"[DEBUG _exec_tool] {name} result → {data}") - return {"stdout": data.get("output", ""), "result": data.get("output", "").strip()} + return { + "stdout": data.get("output", ""), + "result": data.get("output", "").strip(), + } elif name == "calculator": import math + expr = args["expr"] val = eval(expr, {"__builtins__": {}}, {"math": math}) if DEBUG: @@ -287,9 +296,7 @@ class InterleavedInlineEnv(BaseEnv): raise ValueError(f"Unknown tool name {name}") # --------------------- rollout logic (interleaved) ------------------- # - async def _run_one_episode( - self, ctx: List[Dict] - ) -> Tuple[List[Dict], List[Dict]]: + async def _run_one_episode(self, ctx: List[Dict]) -> Tuple[List[Dict], List[Dict]]: """ Generate–execute–resume loop: @@ -346,9 +353,9 @@ class InterleavedInlineEnv(BaseEnv): executed.append(call_json) # Append tool_response inline - assistant_msg["content"] += ( - f"\n{json.dumps(result)}\n" - ) + assistant_msg[ + "content" + ] += f"\n{json.dumps(result)}\n" # continue loop (model will keep thinking) continue @@ -375,16 +382,16 @@ class InterleavedInlineEnv(BaseEnv): except Exception: return False - async def collect_trajectories( - self, item - ) -> Tuple[ScoredDataGroup, List]: + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: """ One prompt → `n = group_size` sampled assistant completions in parallel (single OpenAI request with n completions). Mirrors the logic in SingleToolCallingEnv. """ messages_tuple, expected_raw = item - expected = json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw + expected = ( + json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw + ) # Re‑inflate frozensets to normal dicts prompt_msgs = [dict(r) for r in messages_tuple] @@ -486,7 +493,7 @@ class InterleavedInlineEnv(BaseEnv): total, correct = 0, 0 for sample in self.test: # Build prompt exactly like get_next_item but without mutating self.iter - prompt_text = sample["problem"] + prompt_text = sample["problem"] expr = sample["expected_answer"].strip() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, @@ -519,16 +526,16 @@ class InterleavedInlineEnv(BaseEnv): sample = self.train[idx] prompt_text = sample["problem"] - expr = sample["expected_answer"].strip() + expr = sample["expected_answer"].strip() answer_call = { "name": "python_interpreter", - "arguments": {"code": f"print({expr})"} + "arguments": {"code": f"print({expr})"}, } # ---------------- few‑shot demonstration ---------------- # fewshot_user = { "role": "user", - "content": "Compute the integral of x^2 from 0 to 1." + "content": "Compute the integral of x^2 from 0 to 1.", } fewshot_assistant = { "role": "assistant", @@ -550,10 +557,7 @@ class InterleavedInlineEnv(BaseEnv): } # --- second tiny example: simple arithmetic with calculator ---- # - fewshot_user2 = { - "role": "user", - "content": "What is (2 + 3) * 4 ?" - } + fewshot_user2 = {"role": "user", "content": "What is (2 + 3) * 4 ?"} fewshot_assistant2 = { "role": "assistant", "content": ( @@ -571,16 +575,19 @@ class InterleavedInlineEnv(BaseEnv): # --------------- build final prompt messages ------------ # system_msg = {"role": "system", "content": SYSTEM_PROMPT} - real_user = { + real_user = { "role": "user", "content": ( f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it." - ) + ), } # Optionally insert one real demo from dynamic_pool dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else [] + # Optionally insert one real demo from dynamic_pool + dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else [] + messages = ( [system_msg, fewshot_user, fewshot_assistant, @@ -620,4 +627,4 @@ class InterleavedInlineEnv(BaseEnv): # -------------------------------------------------------------------------- # if __name__ == "__main__": - InterleavedInlineEnv.cli() \ No newline at end of file + InterleavedInlineEnv.cli()