diff --git a/environments/tool_use_interleaved_thinking.py b/environments/tool_use_interleaved_thinking.py
index 09657ae3..f4e61d8b 100644
--- a/environments/tool_use_interleaved_thinking.py
+++ b/environments/tool_use_interleaved_thinking.py
@@ -16,17 +16,28 @@ class is copied here so nothing breaks when you swap env names.
from __future__ import annotations
import asyncio
-import itertools
import json
import logging
+import logging
import os
import re
from typing import Dict, List, Optional, Tuple, Union
import aiohttp
import httpx
+from datasets import Dataset, load_dataset
+import wandb
+from atroposlib.envs.base import (
+ APIServerConfig,
+ BaseEnv,
+ BaseEnvConfig,
+ EvalHandlingEnum,
+ ScoredDataGroup,
+)
+from atroposlib.type_definitions import Message
from atroposlib.utils.io import parse_http_response
+from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
logger = logging.getLogger(__name__)
@@ -42,19 +53,6 @@ MAX_GEN_PER_TURN = 512 # never request more than 512 new tokens from the model
MAX_ROLLOUT_TURNS = 3
-import wandb
-from datasets import Dataset, load_dataset
-
-from atroposlib.envs.base import (
- APIServerConfig,
- BaseEnv,
- BaseEnvConfig,
- EvalHandlingEnum,
- ScoredDataGroup,
-)
-from atroposlib.type_definitions import Message
-from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
-
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
@@ -63,64 +61,66 @@ system_prompt = (
)
-TOOL_SYSTEM_PROMPT = """
-You are a function calling & reasoning AI model. You are provided with function signatures within XML tags for internal reasoning tools. After calling & executing the functions, you will be provided with function results within XML tags. Here are the available tools:
-
-
-
-[
- {
- "type": "function",
- "function": {
- "name": "calculator",
- "description": "Evaluate a numeric Python expression and return the result.",
- "parameters": {
- "type": "object",
- "properties": {
- "expr": {
- "type": "string",
- "description": "A pure‑Python arithmetic expression, e.g. '3*(4+5)'"
- }
- },
- "required": ["expr"]
- }
- }
- },
- {
- "type": "function",
- "function": {
- "name": "python_interpreter",
- "description": "Run a short Python snippet and return stdout plus the last expression.",
- "parameters": {
- "type": "object",
- "properties": {
- "code": {
- "type": "string",
- "description": "Python source code to execute."
- }
- },
- "required": ["code"]
- }
- }
- }
-]
-
-
-
-You must use reasoning tools such as python_interpreter as a tool call when available for hard problems such as math before providing your final answer.
-Always provide your final numeric answer (or final result) in \\boxed{...} so it can be automatically graded right after closing tag.
-
-
-For reasoning tools, return interleaved tool calls within tags.
-
-\n{'name': , 'arguments': }\n
-
-\n{'result': }\n
-
-
-\n"
+ "{'result': }\n"
+ "\n"
+ "\n"
+ "\n"
+)
SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT
@@ -164,7 +164,8 @@ class InterleavedInlineEnv(BaseEnv):
# Log what the server tried to set max_token_len to
if data["max_token_len"] != -1:
logger.info(
- f"Server tried to set max_token_len to {data['max_token_len']}, keeping our value of {self.max_token_len}"
+ f"Server tried to set max_token_len to {data['max_token_len']}\n"
+ f"keeping our value of {self.max_token_len}"
)
if self.config.batch_size == -1:
logging.warning("Batch size not set by config or server!")
@@ -211,12 +212,15 @@ class InterleavedInlineEnv(BaseEnv):
calculator / python_interpreter tools can verify them automatically.
+
The env‑var SUBSET_ROWS (default 1000) controls how many rows we keep.
"""
N = int(os.getenv("SUBSET_ROWS", "1000"))
stream_ds = load_dataset( # ≈50 k rows total → stream
+ # "NVIDIA/OpenMathReasoning",
+ # split="cot",
# "NVIDIA/OpenMathReasoning",
# split="cot",
# "open-r1/OpenR1-Math-220k",
@@ -405,7 +409,8 @@ class InterleavedInlineEnv(BaseEnv):
) -> List[str]:
"""Handle identical prompts efficiently using n parameter."""
print(
- f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m \033[92m{prompt}\033[0m"
+ f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m "
+ f"\033[92m{prompt}\033[0m"
)
# Use the constant instead of config attribute
@@ -603,9 +608,7 @@ class InterleavedInlineEnv(BaseEnv):
print(
"[DEBUG] tool_call found outside ; setting reward = -1"
)
- print(
- "[DEBUG] tool_call found outside ; setting reward = -1"
- )
+
reward = -1.0
if DEBUG:
@@ -764,7 +767,8 @@ class InterleavedInlineEnv(BaseEnv):
continue
print(
- f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']} with args: {call_json['arguments']}"
+ f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']}\n"
+ f"with args: {call_json['arguments']}"
)
try:
result = await self._exec_tool(call_json)
@@ -939,14 +943,15 @@ class InterleavedInlineEnv(BaseEnv):
if len(self.dynamic_pool) > self.dynamic_pool_max:
self.dynamic_pool.pop(0)
- except Exception as e:
+ except Exception:
scored["tokens"].append([])
scored["masks"].append([])
scored["scores"].append(-1.0)
self.percent_correct_buffer.append(0.0)
print(
- f"\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: {sum(scored['scores'])/len(scored['scores']):.3f}"
+ "\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: \n"
+ f"{sum(scored['scores'])/len(scored['scores']):.3f}"
)
# -- Per-rollout score summary --
@@ -962,7 +967,7 @@ class InterleavedInlineEnv(BaseEnv):
f"⚠️ [WARNING] All {len(scored['scores'])} rollouts failed with negative rewards!"
)
print(
- f" This may indicate a problem with the model, prompt, or token budget."
+ " This may indicate a problem with the model, prompt, or token budget."
)
# Signal failure to the outer loop
return None, []
@@ -1070,7 +1075,8 @@ class InterleavedInlineEnv(BaseEnv):
real_user = {
"role": "user",
"content": (
- f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it."
+ f"{prompt_text} \n"
+ "This is a math problem, you must use the python_interpreter or calculator tool call to solve it."
# "Before you call the tools, try to solve it step-by-step and then use the tool to verify"
),
}