[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-06-24 19:27:12 +00:00
parent 1570a8a106
commit 85f7a0b226

View file

@ -19,13 +19,12 @@ import os
import re
from typing import Dict, List, Optional, Tuple
# Set to True to always print debug information.
DEBUG = True # or toggle via env var if you prefer: bool(os.getenv("DEBUG_INTERLEAVED", "1"))
# Hard caps for generation length
MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
import wandb
from datasets import Dataset, load_dataset
@ -136,7 +135,7 @@ class InterleavedInlineEnv(BaseEnv):
self.rng = random.Random()
# Dynamic fewshot pool: list of (user_msg, assistant_msg) tuples
self.dynamic_pool: List[Tuple[Dict, Dict]] = []
self.dynamic_pool_max = 4 # keep at most 4 real examples
self.dynamic_pool_max = 4 # keep at most 4 real examples
@classmethod
def config_init(cls):
@ -153,7 +152,7 @@ class InterleavedInlineEnv(BaseEnv):
wandb_name="toolcall_interleaved",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
max_gen_per_turn = MAX_GEN_PER_TURN,
max_gen_per_turn=MAX_GEN_PER_TURN,
)
servers = [
APIServerConfig(
@ -175,15 +174,17 @@ class InterleavedInlineEnv(BaseEnv):
The envvar SUBSET_ROWS (default 1000) controls how many rows we keep.
"""
import re, os
import os
import re
N = int(os.getenv("SUBSET_ROWS", "1000"))
stream_ds = load_dataset( # ≈50k rows total → stream
stream_ds = load_dataset( # ≈50k rows total → stream
"NVIDIA/OpenMathReasoning",
split="cot",
#"open-r1/OpenR1-Math-220k",
#"nvidia/AceReason-Math",
#split="train",
# "open-r1/OpenR1-Math-220k",
# "nvidia/AceReason-Math",
# split="train",
streaming=True,
)
@ -210,9 +211,10 @@ class InterleavedInlineEnv(BaseEnv):
if DEBUG:
print(f"[DEBUG setup] kept {len(subset)} rows from Dataset")
split = full.train_test_split(test_size=0.02, seed=42)
split = full.train_test_split(test_size=0.02, seed=42)
self.train, self.test = split["train"], split["test"]
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
# --------------------- helper methods --------------------------------- #
async def _completion_until(
self, prompt: str, max_tokens: int, stop: Optional[str] = None
@ -422,7 +424,7 @@ class InterleavedInlineEnv(BaseEnv):
toks = self.tokenizer.encode(raw)
if len(toks) > MAX_REPLY_TOKENS:
toks = toks[:MAX_REPLY_TOKENS]
raw = self.tokenizer.decode(toks)
raw = self.tokenizer.decode(toks)
if DEBUG:
print(f"[DEBUG] truncated reply {idx} to {len(toks)} tokens")
assistant_msg = {"role": "assistant", "content": raw}
@ -430,18 +432,21 @@ class InterleavedInlineEnv(BaseEnv):
full_ctx = prompt_msgs + [assistant_msg]
# Outcomebased reward: compare boxed answer to expected expr
expr = expected["arguments"]["code"][6:-1] if (
isinstance(expected, dict)
and "arguments" in expected
and "code" in expected["arguments"]
and expected["arguments"]["code"].startswith("print(")
and expected["arguments"]["code"].endswith(")")
) else None
expr = (
expected["arguments"]["code"][6:-1]
if (
isinstance(expected, dict)
and "arguments" in expected
and "code" in expected["arguments"]
and expected["arguments"]["code"].startswith("print(")
and expected["arguments"]["code"].endswith(")")
)
else None
)
boxed = self._boxed_after_think(raw)
same = (
boxed == expr or
(boxed and expr and self._canon_num(boxed) == self._canon_num(expr))
same = boxed == expr or (
boxed and expr and self._canon_num(boxed) == self._canon_num(expr)
)
reward = 1.0 if same else -1.0
if "</think>" not in raw:
@ -449,9 +454,11 @@ class InterleavedInlineEnv(BaseEnv):
else:
# NEW RULE: no tool_call tags are allowed *outside* the think block
end_pos = raw.lower().find("</think>")
if "<tool_call" in raw[end_pos + len("</think>"):].lower():
if "<tool_call" in raw[end_pos + len("</think>") :].lower():
if DEBUG:
print("[DEBUG] tool_call found outside </think>; setting reward = -1")
print(
"[DEBUG] tool_call found outside </think>; setting reward = -1"
)
reward = -1.0
if DEBUG:
@ -469,15 +476,17 @@ class InterleavedInlineEnv(BaseEnv):
# --- harvest a success for dynamic fewshots --------------------
for idx, sc in enumerate(scored["scores"]):
reply_txt = completions.choices[idx].text
has_call = "<tool_call" in reply_txt.lower() # ensure an interleaved call exists
has_call = (
"<tool_call" in reply_txt.lower()
) # ensure an interleaved call exists
if sc >= 1.0 and has_call:
# Build (user, assistant) pair from this successful rollout
u = {"role": "user", "content": prompt_msgs[-1]["content"]}
a = {"role": "assistant", "content": reply_txt}
self.dynamic_pool.append((u, a))
if len(self.dynamic_pool) > self.dynamic_pool_max:
self.dynamic_pool.pop(0) # FIFO
break # only harvest one per group
self.dynamic_pool.pop(0) # FIFO
break # only harvest one per group
return scored, []
@ -543,17 +552,17 @@ class InterleavedInlineEnv(BaseEnv):
"<think>\n"
"Let's be sure of the definite integral ∫₀¹ x² dx. It's easy by hand "
"but I'll run SymPy to avoid mistakes.\n"
"<tool_call>{\"name\":\"python_interpreter\", "
"\"arguments\":{\"code\":"
"\"import sympy as sp\\n"
'<tool_call>{"name":"python_interpreter", '
'"arguments":{"code":'
'"import sympy as sp\\n'
"x=sp.symbols('x')\\n"
"print(sp.integrate(x**2,(x,0,1)))\"}}\n"
'print(sp.integrate(x**2,(x,0,1)))"}}\n'
"</tool_call>\n"
"<tool_response>{\"result\": 1/3}</tool_response>\n"
'<tool_response>{"result": 1/3}</tool_response>\n'
"The interpreter returns 1/3, so the value is 0.333̅.\n"
"</think>\n\n"
"The integral equals \\boxed{\\tfrac{1}{3}} \\approx 0.333."
)
),
}
# --- second tiny example: simple arithmetic with calculator ---- #
@ -564,13 +573,13 @@ class InterleavedInlineEnv(BaseEnv):
"<think>\n"
"I need (2+3)*4. Quick mental math gives 5*4 = 20, "
"but I'll confirm with the calculator tool.\n"
"<tool_call>{\"name\":\"calculator\", "
"\"arguments\":{\"expr\":\"(2+3)*4\"}}</tool_call>\n"
"<tool_response>{\"value\": 20}</tool_response>\n"
'<tool_call>{"name":"calculator", '
'"arguments":{"expr":"(2+3)*4"}}</tool_call>\n'
'<tool_response>{"value": 20}</tool_response>\n'
"The tool also says 20, matching my headmath.\n"
"</think>\n\n"
"Therefore the answer is \\boxed{20}."
)
),
}
# --------------- build final prompt messages ------------ #
@ -589,11 +598,15 @@ class InterleavedInlineEnv(BaseEnv):
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
messages = (
[system_msg,
fewshot_user, fewshot_assistant,
fewshot_user2, fewshot_assistant2] +
dyn + # 0 or 2 msgs
[real_user]
[
system_msg,
fewshot_user,
fewshot_assistant,
fewshot_user2,
fewshot_assistant2,
]
+ dyn # 0 or 2 msgs
+ [real_user]
)
# Freeze for hashing