diff --git a/environments/tool_use_interleaved_thinking.py b/environments/tool_use_interleaved_thinking.py
index deacc928..a0e1790f 100644
--- a/environments/tool_use_interleaved_thinking.py
+++ b/environments/tool_use_interleaved_thinking.py
@@ -13,10 +13,11 @@ class is copied here so nothing breaks when you swap env names.
from __future__ import annotations
-import json, re
-from typing import Dict, List, Optional, Tuple
-import os
import itertools
+import json
+import os
+import re
+from typing import Dict, List, Optional, Tuple
# Set to True to always print debug information.
@@ -27,18 +28,17 @@ MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tok
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
import wandb
-from datasets import load_dataset, Dataset
+from datasets import Dataset, load_dataset
from atroposlib.envs.base import (
+ APIServerConfig,
BaseEnv,
BaseEnvConfig,
- APIServerConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
-
# -------------------------------------------------------------------------- #
# Constants
# -------------------------------------------------------------------------- #
@@ -103,10 +103,9 @@ For reasoning tools, return interleaved tool calls within tags.
"""
-SYSTEM_PROMPT = (
- system_prompt
- + TOOL_SYSTEM_PROMPT
-)
+SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT
+
+
# -------------------------------------------------------------------------- #
# Environment
# -------------------------------------------------------------------------- #
@@ -133,6 +132,7 @@ class InterleavedInlineEnv(BaseEnv):
self.rollouts_for_wandb = []
self.iter = 0
import random
+
self.rng = random.Random()
# Dynamic few‑shot pool: list of (user_msg, assistant_msg) tuples
self.dynamic_pool: List[Tuple[Dict, Dict]] = []
@@ -214,7 +214,9 @@ class InterleavedInlineEnv(BaseEnv):
self.train, self.test = split["train"], split["test"]
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
# --------------------- helper methods --------------------------------- #
- async def _completion_until(self, prompt: str, max_tokens: int, stop: Optional[str] = None) -> str:
+ async def _completion_until(
+ self, prompt: str, max_tokens: int, stop: Optional[str] = None
+ ) -> str:
comp = await self.server.completion(
prompt=prompt,
stop=stop,
@@ -268,16 +270,23 @@ class InterleavedInlineEnv(BaseEnv):
args = call_json["arguments"]
if name == "python_interpreter":
- import httpx, asyncio
+ import asyncio
+
+ import httpx
+
async with httpx.AsyncClient(timeout=10.0) as client:
payload = {"code": args["code"], "input": ""}
resp = await client.post("http://localhost:5002/execute", json=payload)
data = resp.json()
if DEBUG:
print(f"[DEBUG _exec_tool] {name} result → {data}")
- return {"stdout": data.get("output", ""), "result": data.get("output", "").strip()}
+ return {
+ "stdout": data.get("output", ""),
+ "result": data.get("output", "").strip(),
+ }
elif name == "calculator":
import math
+
expr = args["expr"]
val = eval(expr, {"__builtins__": {}}, {"math": math})
if DEBUG:
@@ -287,9 +296,7 @@ class InterleavedInlineEnv(BaseEnv):
raise ValueError(f"Unknown tool name {name}")
# --------------------- rollout logic (interleaved) ------------------- #
- async def _run_one_episode(
- self, ctx: List[Dict]
- ) -> Tuple[List[Dict], List[Dict]]:
+ async def _run_one_episode(self, ctx: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""
Generate–execute–resume loop:
@@ -346,9 +353,9 @@ class InterleavedInlineEnv(BaseEnv):
executed.append(call_json)
# Append tool_response inline
- assistant_msg["content"] += (
- f"\n{json.dumps(result)}\n"
- )
+ assistant_msg[
+ "content"
+ ] += f"\n{json.dumps(result)}\n"
# continue loop (model will keep thinking)
continue
@@ -375,16 +382,16 @@ class InterleavedInlineEnv(BaseEnv):
except Exception:
return False
- async def collect_trajectories(
- self, item
- ) -> Tuple[ScoredDataGroup, List]:
+ async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
"""
One prompt → `n = group_size` sampled assistant completions in
parallel (single OpenAI request with n completions). Mirrors the
logic in SingleToolCallingEnv.
"""
messages_tuple, expected_raw = item
- expected = json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
+ expected = (
+ json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
+ )
# Re‑inflate frozensets to normal dicts
prompt_msgs = [dict(r) for r in messages_tuple]
@@ -486,7 +493,7 @@ class InterleavedInlineEnv(BaseEnv):
total, correct = 0, 0
for sample in self.test:
# Build prompt exactly like get_next_item but without mutating self.iter
- prompt_text = sample["problem"]
+ prompt_text = sample["problem"]
expr = sample["expected_answer"].strip()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
@@ -519,16 +526,16 @@ class InterleavedInlineEnv(BaseEnv):
sample = self.train[idx]
prompt_text = sample["problem"]
- expr = sample["expected_answer"].strip()
+ expr = sample["expected_answer"].strip()
answer_call = {
"name": "python_interpreter",
- "arguments": {"code": f"print({expr})"}
+ "arguments": {"code": f"print({expr})"},
}
# ---------------- few‑shot demonstration ---------------- #
fewshot_user = {
"role": "user",
- "content": "Compute the integral of x^2 from 0 to 1."
+ "content": "Compute the integral of x^2 from 0 to 1.",
}
fewshot_assistant = {
"role": "assistant",
@@ -550,10 +557,7 @@ class InterleavedInlineEnv(BaseEnv):
}
# --- second tiny example: simple arithmetic with calculator ---- #
- fewshot_user2 = {
- "role": "user",
- "content": "What is (2 + 3) * 4 ?"
- }
+ fewshot_user2 = {"role": "user", "content": "What is (2 + 3) * 4 ?"}
fewshot_assistant2 = {
"role": "assistant",
"content": (
@@ -571,16 +575,19 @@ class InterleavedInlineEnv(BaseEnv):
# --------------- build final prompt messages ------------ #
system_msg = {"role": "system", "content": SYSTEM_PROMPT}
- real_user = {
+ real_user = {
"role": "user",
"content": (
f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it."
- )
+ ),
}
# Optionally insert one real demo from dynamic_pool
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
+ # Optionally insert one real demo from dynamic_pool
+ dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
+
messages = (
[system_msg,
fewshot_user, fewshot_assistant,
@@ -620,4 +627,4 @@ class InterleavedInlineEnv(BaseEnv):
# -------------------------------------------------------------------------- #
if __name__ == "__main__":
- InterleavedInlineEnv.cli()
\ No newline at end of file
+ InterleavedInlineEnv.cli()