fixing pre-commit errors

This commit is contained in:
interstellarninja 2025-07-08 00:25:46 -04:00
parent ab06a1ed52
commit 90c1b703e6

View file

@ -16,17 +16,28 @@ class is copied here so nothing breaks when you swap env names.
from __future__ import annotations
import asyncio
import itertools
import json
import logging
import logging
import os
import re
from typing import Dict, List, Optional, Tuple, Union
import aiohttp
import httpx
from datasets import Dataset, load_dataset
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.type_definitions import Message
from atroposlib.utils.io import parse_http_response
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
logger = logging.getLogger(__name__)
@ -42,19 +53,6 @@ MAX_GEN_PER_TURN = 512 # never request more than 512 new tokens from the model
MAX_ROLLOUT_TURNS = 3
import wandb
from datasets import Dataset, load_dataset
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.type_definitions import Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
@ -63,64 +61,66 @@ system_prompt = (
)
TOOL_SYSTEM_PROMPT = """
You are a function calling & reasoning AI model. You are provided with function signatures within <reasoning_tools> </reasoning_tools> XML tags for internal reasoning tools. After calling & executing the functions, you will be provided with function results within <tool_response> </tool_response> XML tags. Here are the available tools:
<reasoning_tools>
[
{
"type": "function",
"function": {
"name": "calculator",
"description": "Evaluate a numeric Python expression and return the result.",
"parameters": {
"type": "object",
"properties": {
"expr": {
"type": "string",
"description": "A purePython arithmetic expression, e.g. '3*(4+5)'"
}
},
"required": ["expr"]
}
}
},
{
"type": "function",
"function": {
"name": "python_interpreter",
"description": "Run a short Python snippet and return stdout plus the last expression.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python source code to execute."
}
},
"required": ["code"]
}
}
}
]
</reasoning_tools>
You must use reasoning tools such as python_interpreter as a tool call when available for hard problems such as math before providing your final answer.
Always provide your final numeric answer (or final result) in \\boxed{...} so it can be automatically graded right after closing </think> tag.
For reasoning tools, return interleaved tool calls within <think> </think> tags.
<think>
<tool_call>\n{'name': <function-name>, 'arguments': <args-dict>}\n</tool_call>
<!-- system pauses runtime for execution -->
<tool_response>\n{'result': <result>}\n</tool_response>
<!-- assistant resumes within same think -->
</think>
<!-- plain text answer with \\boxed{...}
"""
TOOL_SYSTEM_PROMPT = (
"You are a function-calling & reasoning AI model. You are provided with "
"function signatures inside <reasoning_tools> … XML tags. After calling & "
"executing the functions, you will get results inside <tool_response> … "
"Here are the available tools:\n\n"
"<reasoning_tools>\n"
"[\n"
" {\n"
' "type": "function",\n'
' "function": {\n'
' "name": "calculator",\n'
' "description": '
' "Evaluate a numeric Python expression and return the result.",\n'
' "parameters": {\n'
' "type": "object",\n'
' "properties": {\n'
' "expr": {\n'
' "type": "string",\n'
' "description": '
' "A pure-Python arithmetic expression\'"\n'
" }\n"
" },\n"
' "required": ["expr"]\n'
" }\n"
" }\n"
" },\n"
" {\n"
' "type": "function",\n'
' "function": {\n'
' "name": "python_interpreter",\n'
' "description": '
' "Run a short Python snippet and return stdout plus the last '
'expression.",\n'
' "parameters": {\n'
' "type": "object",\n'
' "properties": {\n'
' "code": {\n'
' "type": "string",\n'
' "description": "Python source code to execute."\n'
" }\n"
" },\n"
' "required": ["code"]\n'
" }\n"
" }\n"
" }\n"
"]\n"
"</reasoning_tools>\n\n"
"You must use reasoning tools such as python_interpreter as a tool call when available "
"for hard problems such as math before providing your final answer.\n"
"Always provide your final numeric answer (or final result) in \\\\boxed{...} so it "
"can be automatically graded right after closing </think> tag.\n\n"
"For reasoning tools, return interleaved tool calls within <think> </think> tags.\n"
"<think>\n"
"<tool_call>{'name': <function-name>, 'arguments': <args-dict>}</tool_call>\n"
"<!-- system pauses runtime for execution -->\n"
"<tool_response>{'result': <result>}</tool_response>\n"
"<!-- assistant resumes within same think -->\n"
"</think>\n"
"<!-- plain text answer with \\\\boxed{...} -->\n"
)
SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT
@ -164,7 +164,8 @@ class InterleavedInlineEnv(BaseEnv):
# Log what the server tried to set max_token_len to
if data["max_token_len"] != -1:
logger.info(
f"Server tried to set max_token_len to {data['max_token_len']}, keeping our value of {self.max_token_len}"
f"Server tried to set max_token_len to {data['max_token_len']}\n"
f"keeping our value of {self.max_token_len}"
)
if self.config.batch_size == -1:
logging.warning("Batch size not set by config or server!")
@ -211,12 +212,15 @@ class InterleavedInlineEnv(BaseEnv):
calculator / python_interpreter tools can verify them automatically.
The envvar SUBSET_ROWS (default 1000) controls how many rows we keep.
"""
N = int(os.getenv("SUBSET_ROWS", "1000"))
stream_ds = load_dataset( # ≈50k rows total → stream
# "NVIDIA/OpenMathReasoning",
# split="cot",
# "NVIDIA/OpenMathReasoning",
# split="cot",
# "open-r1/OpenR1-Math-220k",
@ -405,7 +409,8 @@ class InterleavedInlineEnv(BaseEnv):
) -> List[str]:
"""Handle identical prompts efficiently using n parameter."""
print(
f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m \033[92m{prompt}\033[0m"
f" \033[93m→ TURN {turn_idx+1} prompt full:\033[0m "
f"\033[92m{prompt}\033[0m"
)
# Use the constant instead of config attribute
@ -603,9 +608,7 @@ class InterleavedInlineEnv(BaseEnv):
print(
"[DEBUG] tool_call found outside </think>; setting reward = -1"
)
print(
"[DEBUG] tool_call found outside </think>; setting reward = -1"
)
reward = -1.0
if DEBUG:
@ -764,7 +767,8 @@ class InterleavedInlineEnv(BaseEnv):
continue
print(
f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']} with args: {call_json['arguments']}"
f"🔧 [ROLLOUT {rollout_idx}] Executing {call_json['name']}\n"
f"with args: {call_json['arguments']}"
)
try:
result = await self._exec_tool(call_json)
@ -939,14 +943,15 @@ class InterleavedInlineEnv(BaseEnv):
if len(self.dynamic_pool) > self.dynamic_pool_max:
self.dynamic_pool.pop(0)
except Exception as e:
except Exception:
scored["tokens"].append([])
scored["masks"].append([])
scored["scores"].append(-1.0)
self.percent_correct_buffer.append(0.0)
print(
f"\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: {sum(scored['scores'])/len(scored['scores']):.3f}"
"\n🏁 [EXECUTION MODE] Completed all rollouts. Average reward: \n"
f"{sum(scored['scores'])/len(scored['scores']):.3f}"
)
# -- Per-rollout score summary --
@ -962,7 +967,7 @@ class InterleavedInlineEnv(BaseEnv):
f"⚠️ [WARNING] All {len(scored['scores'])} rollouts failed with negative rewards!"
)
print(
f" This may indicate a problem with the model, prompt, or token budget."
" This may indicate a problem with the model, prompt, or token budget."
)
# Signal failure to the outer loop
return None, []
@ -1070,7 +1075,8 @@ class InterleavedInlineEnv(BaseEnv):
real_user = {
"role": "user",
"content": (
f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it."
f"{prompt_text} \n"
"This is a math problem, you must use the python_interpreter or calculator tool call to solve it."
# "Before you call the tools, try to solve it step-by-step and then use the tool to verify"
),
}