[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-06-24 12:26:55 +00:00
parent 569a8303f3
commit 5ee01a7911

View file

@ -13,27 +13,27 @@ class is copied here so nothing breaks when you swap env names.
from __future__ import annotations
import json, re
from typing import Dict, List, Optional, Tuple
import os
import itertools
import json
import os
import re
from typing import Dict, List, Optional, Tuple
# Set to True to always print debug information.
DEBUG = True # or toggle via env var if you prefer: bool(os.getenv("DEBUG_INTERLEAVED", "1"))
import wandb
from datasets import load_dataset, Dataset
from datasets import Dataset, load_dataset
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
APIServerConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# -------------------------------------------------------------------------- #
# Constants
# -------------------------------------------------------------------------- #
@ -43,11 +43,11 @@ system_prompt = (
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
#TOOL_SYSTEM_PROMPT = """
#You are a function calling & reasoning AI model. You are provided with function signatures within <tools> </tools> XML tags for user facing tools and <reasoning_tools> </reasoning_tools> XML tags for internal reasoning tools. You may call one or more functions to assist with the user query. If available tools are not relevant in assisting with user query, just respond in natural conversational language. Don't make assumptions about what values to plug into functions. After calling & executing the functions, you will be provided with function results within <tool_response> </tool_response> XML tags. Here are the available tools:
# TOOL_SYSTEM_PROMPT = """
# You are a function calling & reasoning AI model. You are provided with function signatures within <tools> </tools> XML tags for user facing tools and <reasoning_tools> </reasoning_tools> XML tags for internal reasoning tools. You may call one or more functions to assist with the user query. If available tools are not relevant in assisting with user query, just respond in natural conversational language. Don't make assumptions about what values to plug into functions. After calling & executing the functions, you will be provided with function results within <tool_response> </tool_response> XML tags. Here are the available tools:
#
#<reasoning_tools>
#[
# <reasoning_tools>
# [
# {
# "type": "function",
# "function": {
@ -82,24 +82,24 @@ system_prompt = (
# }
# }
# }
#]
#</reasoning_tools>
# ]
# </reasoning_tools>
#
#For each function call return a JSON object, with the following pydantic model json schema:\n{'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}
# For each function call return a JSON object, with the following pydantic model json schema:\n{'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}
#
#Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n<tool_call>\n{'name': <function-name>, 'arguments': <args-dict>}\n</tool_call>
# Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n<tool_call>\n{'name': <function-name>, 'arguments': <args-dict>}\n</tool_call>
#
#For reasoning tools, return interleaved tool calls within <think> </think> tags.
#<think>
#<tool_call>\n{'name': <function-name>, 'arguments': <args-dict>}\n</tool_call>
#<!-- system pauses runtime for execution -->
#<tool_response>\n{'result': <result>}\n</tool_response>
#<!-- assistant resumes within same think -->
#</think>
# For reasoning tools, return interleaved tool calls within <think> </think> tags.
# <think>
# <tool_call>\n{'name': <function-name>, 'arguments': <args-dict>}\n</tool_call>
# <!-- system pauses runtime for execution -->
# <tool_response>\n{'result': <result>}\n</tool_response>
# <!-- assistant resumes within same think -->
# </think>
#
#You must use reasoning tools such as python_interpreter when available for hard problems such as math before providing your final answer.
#Always wrap your final numeric answer (or final result) in \\boxed{...} so it can be automatically graded.
#"""
# You must use reasoning tools such as python_interpreter when available for hard problems such as math before providing your final answer.
# Always wrap your final numeric answer (or final result) in \\boxed{...} so it can be automatically graded.
# """
TOOL_SYSTEM_PROMPT = """
You are a function calling & reasoning AI model. You are provided with function signatures within <reasoning_tools> </reasoning_tools> XML tags for internal reasoning tools. After calling & executing the functions, you will be provided with function results within <tool_response> </tool_response> XML tags. Here are the available tools:
@ -155,10 +155,9 @@ For reasoning tools, return interleaved tool calls within <think> </think> tags.
</think>
"""
SYSTEM_PROMPT = (
system_prompt
+ TOOL_SYSTEM_PROMPT
)
SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT
# -------------------------------------------------------------------------- #
# Environment
# -------------------------------------------------------------------------- #
@ -185,6 +184,7 @@ class InterleavedInlineEnv(BaseEnv):
self.rollouts_for_wandb = []
self.iter = 0
import random
self.rng = random.Random()
@classmethod
@ -230,9 +230,7 @@ class InterleavedInlineEnv(BaseEnv):
# streaming=True => HF downloads shardbyshard and stops after N rows
stream_ds = load_dataset(
"NVIDIA/OpenMathReasoning",
split="cot",
streaming=True
"NVIDIA/OpenMathReasoning", split="cot", streaming=True
)
# take first N rows
@ -253,8 +251,11 @@ class InterleavedInlineEnv(BaseEnv):
self.train, self.test = split["train"], split["test"]
# Shuffle training samples so each run starts at a random order
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
# --------------------- helper methods --------------------------------- #
async def _completion_until(self, prompt: str, max_tokens: int, stop: Optional[str] = None) -> str:
async def _completion_until(
self, prompt: str, max_tokens: int, stop: Optional[str] = None
) -> str:
comp = await self.server.completion(
prompt=prompt,
stop=stop,
@ -302,16 +303,23 @@ class InterleavedInlineEnv(BaseEnv):
args = call_json["arguments"]
if name == "python_interpreter":
import httpx, asyncio
import asyncio
import httpx
async with httpx.AsyncClient(timeout=10.0) as client:
payload = {"code": args["code"], "input": ""}
resp = await client.post("http://localhost:5002/execute", json=payload)
data = resp.json()
if DEBUG:
print(f"[DEBUG _exec_tool] {name} result → {data}")
return {"stdout": data.get("output", ""), "result": data.get("output", "").strip()}
return {
"stdout": data.get("output", ""),
"result": data.get("output", "").strip(),
}
elif name == "calculator":
import math
expr = args["expr"]
val = eval(expr, {"__builtins__": {}}, {"math": math})
if DEBUG:
@ -321,9 +329,7 @@ class InterleavedInlineEnv(BaseEnv):
raise ValueError(f"Unknown tool name {name}")
# --------------------- rollout logic (interleaved) ------------------- #
async def _run_one_episode(
self, ctx: List[Dict]
) -> Tuple[List[Dict], List[Dict]]:
async def _run_one_episode(self, ctx: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""
Generateexecuteresume loop:
@ -380,9 +386,9 @@ class InterleavedInlineEnv(BaseEnv):
executed.append(call_json)
# Append tool_response inline
assistant_msg["content"] += (
f"\n<tool_response>{json.dumps(result)}</tool_response>\n"
)
assistant_msg[
"content"
] += f"\n<tool_response>{json.dumps(result)}</tool_response>\n"
# continue loop (model will keep thinking)
continue
@ -409,16 +415,16 @@ class InterleavedInlineEnv(BaseEnv):
except Exception:
return False
async def collect_trajectories(
self, item
) -> Tuple[ScoredDataGroup, List]:
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
"""
One prompt `n = group_size` sampled assistant completions in
parallel (single OpenAI request with n completions). Mirrors the
logic in SingleToolCallingEnv.
"""
messages_tuple, expected_raw = item
expected = json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
expected = (
json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
)
# Reinflate frozensets to normal dicts
prompt_msgs = [dict(r) for r in messages_tuple]
@ -450,13 +456,17 @@ class InterleavedInlineEnv(BaseEnv):
full_ctx = prompt_msgs + [assistant_msg]
# Outcomebased reward: compare boxed answer to expected expr
expr = expected["arguments"]["code"][6:-1] if (
isinstance(expected, dict)
and "arguments" in expected
and "code" in expected["arguments"]
and expected["arguments"]["code"].startswith("print(")
and expected["arguments"]["code"].endswith(")")
) else None
expr = (
expected["arguments"]["code"][6:-1]
if (
isinstance(expected, dict)
and "arguments" in expected
and "code" in expected["arguments"]
and expected["arguments"]["code"].startswith("print(")
and expected["arguments"]["code"].endswith(")")
)
else None
)
boxed = self._boxed_after_think(choice.text)
reward = 1.0 if (boxed and boxed == expr) else -1.0
if "</think>" not in choice.text:
@ -488,7 +498,7 @@ class InterleavedInlineEnv(BaseEnv):
total, correct = 0, 0
for sample in self.test:
# Build prompt exactly like get_next_item but without mutating self.iter
prompt_text = sample["problem"]
prompt_text = sample["problem"]
expr = sample["expected_answer"].strip()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
@ -521,67 +531,66 @@ class InterleavedInlineEnv(BaseEnv):
sample = self.train[idx]
prompt_text = sample["problem"]
expr = sample["expected_answer"].strip()
expr = sample["expected_answer"].strip()
answer_call = {
"name": "python_interpreter",
"arguments": {"code": f"print({expr})"}
"arguments": {"code": f"print({expr})"},
}
# ---------------- fewshot demonstration ---------------- #
fewshot_user = {
"role": "user",
"content": "Compute the integral of x^2 from 0 to 1."
"content": "Compute the integral of x^2 from 0 to 1.",
}
fewshot_assistant = {
"role": "assistant",
"content": (
"<think>\n"
"Goal: integrate x² from 0 to 1 with python_interpreter.\n"
"<tool_call>{\"name\":\"python_interpreter\", "
"\"arguments\":{\"code\":\"import sympy as sp\\n"
'<tool_call>{"name":"python_interpreter", '
'"arguments":{"code":"import sympy as sp\\n'
"x=sp.symbols('x'); print(sp.integrate(x**2,(x,0,1)))\"}}\n"
"</tool_call>\n"
"<tool_response>{\"result\": 1/3}</tool_response>\n"
'<tool_response>{"result": 1/3}</tool_response>\n'
"Observation: result 0.333333.\n"
"Reflection: ready.\n"
"</think>\n\n"
"The integral ≈ 0.333333."
)
),
}
# --- second tiny example: simple arithmetic with calculator ---- #
fewshot_user2 = {
"role": "user",
"content": "What is (2 + 3) * 4 ?"
}
fewshot_user2 = {"role": "user", "content": "What is (2 + 3) * 4 ?"}
fewshot_assistant2 = {
"role": "assistant",
"content": (
"<think>\n"
"Goal: evaluate (2+3)*4 with the calculator tool.\n"
"<tool_call>{\"name\":\"calculator\", "
"\"arguments\":{\"expr\":\"(2+3)*4\"}}</tool_call>\n"
"<tool_response>{\"value\": 20}</tool_response>\n"
'<tool_call>{"name":"calculator", '
'"arguments":{"expr":"(2+3)*4"}}</tool_call>\n'
'<tool_response>{"value": 20}</tool_response>\n'
"Observation: result 20.\n"
"Reflection: ready.\n"
"</think>\n\n"
"The answer is \\boxed{20}."
)
),
}
# --------------- build final prompt messages ------------ #
system_msg = {"role": "system", "content": SYSTEM_PROMPT}
real_user = {
real_user = {
"role": "user",
"content": (
f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it."
)
),
}
messages = [
system_msg,
fewshot_user, fewshot_assistant,
fewshot_user2, fewshot_assistant2,
fewshot_user,
fewshot_assistant,
fewshot_user2,
fewshot_assistant2,
real_user,
]
@ -616,4 +625,4 @@ class InterleavedInlineEnv(BaseEnv):
# -------------------------------------------------------------------------- #
if __name__ == "__main__":
InterleavedInlineEnv.cli()
InterleavedInlineEnv.cli()