diff --git a/environments/tool_use_interleaved_thinking.py b/environments/tool_use_interleaved_thinking.py index 09657ae3..f4e61d8b 100644 --- a/environments/tool_use_interleaved_thinking.py +++ b/environments/tool_use_interleaved_thinking.py @@ -16,17 +16,28 @@ class is copied here so nothing breaks when you swap env names. from __future__ import annotations import asyncio -import itertools import json import logging +import logging import os import re from typing import Dict, List, Optional, Tuple, Union import aiohttp import httpx +from datasets import Dataset, load_dataset +import wandb +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + ScoredDataGroup, +) +from atroposlib.type_definitions import Message from atroposlib.utils.io import parse_http_response +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer logger = logging.getLogger(__name__) @@ -42,19 +53,6 @@ MAX_GEN_PER_TURN = 512 # never request more than 512 new tokens from the model MAX_ROLLOUT_TURNS = 3 -import wandb -from datasets import Dataset, load_dataset - -from atroposlib.envs.base import ( - APIServerConfig, - BaseEnv, - BaseEnvConfig, - EvalHandlingEnum, - ScoredDataGroup, -) -from atroposlib.type_definitions import Message -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer - system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " @@ -63,64 +61,66 @@ system_prompt = ( ) -TOOL_SYSTEM_PROMPT = """ -You are a function calling & reasoning AI model. You are provided with function signatures within XML tags for internal reasoning tools. After calling & executing the functions, you will be provided with function results within XML tags. Here are the available tools: - - - -[ - { - "type": "function", - "function": { - "name": "calculator", - "description": "Evaluate a numeric Python expression and return the result.", - "parameters": { - "type": "object", - "properties": { - "expr": { - "type": "string", - "description": "A pure‑Python arithmetic expression, e.g. '3*(4+5)'" - } - }, - "required": ["expr"] - } - } - }, - { - "type": "function", - "function": { - "name": "python_interpreter", - "description": "Run a short Python snippet and return stdout plus the last expression.", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python source code to execute." - } - }, - "required": ["code"] - } - } - } -] - - - -You must use reasoning tools such as python_interpreter as a tool call when available for hard problems such as math before providing your final answer. -Always provide your final numeric answer (or final result) in \\boxed{...} so it can be automatically graded right after closing tag. - - -For reasoning tools, return interleaved tool calls within tags. - -\n{'name': , 'arguments': }\n - -\n{'result': }\n - - -\n" + "{'result': }\n" + "\n" + "\n" + "\n" +) SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT @@ -164,7 +164,8 @@ class InterleavedInlineEnv(BaseEnv): # Log what the server tried to set max_token_len to if data["max_token_len"] != -1: logger.info( - f"Server tried to set max_token_len to {data['max_token_len']}, keeping our value of {self.max_token_len}" + f"Server tried to set max_token_len to {data['max_token_len']}\n" + f"keeping our value of {self.max_token_len}" ) if self.config.batch_size == -1: logging.warning("Batch size not set by config or server!") @@ -211,12 +212,15 @@ class InterleavedInlineEnv(BaseEnv): calculator / python_interpreter tools can verify them automatically. + The env‑var SUBSET_ROWS (default 1000) controls how many rows we keep. """ N = int(os.getenv("SUBSET_ROWS", "1000")) stream_ds = load_dataset( # ≈50 k rows total → stream + # "NVIDIA/OpenMathReasoning", + # split="cot", # "NVIDIA/OpenMathReasoning", # split="cot", # "open-r1/OpenR1-Math-220k", @@ -405,7 +409,8 @@ class InterleavedInlineEnv(BaseEnv): ) -> List[str]: """Handle identical prompts efficiently using n parameter.""" print( - f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m \033[92m{prompt}\033[0m" + f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m " + f"\033[92m{prompt}\033[0m" ) # Use the constant instead of config attribute @@ -603,9 +608,7 @@ class InterleavedInlineEnv(BaseEnv): print( "[DEBUG] tool_call found outside ; setting reward = -1" ) - print( - "[DEBUG] tool_call found outside ; setting reward = -1" - ) + reward = -1.0 if DEBUG: @@ -764,7 +767,8 @@ class InterleavedInlineEnv(BaseEnv): continue print( - f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']} with args: {call_json['arguments']}" + f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']}\n" + f"with args: {call_json['arguments']}" ) try: result = await self._exec_tool(call_json) @@ -939,14 +943,15 @@ class InterleavedInlineEnv(BaseEnv): if len(self.dynamic_pool) > self.dynamic_pool_max: self.dynamic_pool.pop(0) - except Exception as e: + except Exception: scored["tokens"].append([]) scored["masks"].append([]) scored["scores"].append(-1.0) self.percent_correct_buffer.append(0.0) print( - f"\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: {sum(scored['scores'])/len(scored['scores']):.3f}" + "\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: \n" + f"{sum(scored['scores'])/len(scored['scores']):.3f}" ) # -- Per-rollout score summary -- @@ -962,7 +967,7 @@ class InterleavedInlineEnv(BaseEnv): f"⚠️ [WARNING] All {len(scored['scores'])} rollouts failed with negative rewards!" ) print( - f" This may indicate a problem with the model, prompt, or token budget." + " This may indicate a problem with the model, prompt, or token budget." ) # Signal failure to the outer loop return None, [] @@ -1070,7 +1075,8 @@ class InterleavedInlineEnv(BaseEnv): real_user = { "role": "user", "content": ( - f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it." + f"{prompt_text} \n" + "This is a math problem, you must use the python_interpreter or calculator tool call to solve it." # "Before you call the tools, try to solve it step-by-step and then use the tool to verify" ), }