diff --git a/environments/tool_use_interleaved_thinking.py b/environments/tool_use_interleaved_thinking.py index a0e1790f..62fc8b5b 100644 --- a/environments/tool_use_interleaved_thinking.py +++ b/environments/tool_use_interleaved_thinking.py @@ -19,13 +19,12 @@ import os import re from typing import Dict, List, Optional, Tuple - # Set to True to always print debug information. DEBUG = True # or toggle via env var if you prefer: bool(os.getenv("DEBUG_INTERLEAVED", "1")) # Hard caps for generation length -MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens -MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model +MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens +MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model import wandb from datasets import Dataset, load_dataset @@ -136,7 +135,7 @@ class InterleavedInlineEnv(BaseEnv): self.rng = random.Random() # Dynamic few‑shot pool: list of (user_msg, assistant_msg) tuples self.dynamic_pool: List[Tuple[Dict, Dict]] = [] - self.dynamic_pool_max = 4 # keep at most 4 real examples + self.dynamic_pool_max = 4 # keep at most 4 real examples @classmethod def config_init(cls): @@ -153,7 +152,7 @@ class InterleavedInlineEnv(BaseEnv): wandb_name="toolcall_interleaved", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, - max_gen_per_turn = MAX_GEN_PER_TURN, + max_gen_per_turn=MAX_GEN_PER_TURN, ) servers = [ APIServerConfig( @@ -175,15 +174,17 @@ class InterleavedInlineEnv(BaseEnv): The env‑var SUBSET_ROWS (default 1000) controls how many rows we keep. """ - import re, os + import os + import re + N = int(os.getenv("SUBSET_ROWS", "1000")) - stream_ds = load_dataset( # ≈50 k rows total → stream + stream_ds = load_dataset( # ≈50 k rows total → stream "NVIDIA/OpenMathReasoning", split="cot", - #"open-r1/OpenR1-Math-220k", - #"nvidia/AceReason-Math", - #split="train", + # "open-r1/OpenR1-Math-220k", + # "nvidia/AceReason-Math", + # split="train", streaming=True, ) @@ -210,9 +211,10 @@ class InterleavedInlineEnv(BaseEnv): if DEBUG: print(f"[DEBUG setup] kept {len(subset)} rows from Dataset") - split = full.train_test_split(test_size=0.02, seed=42) + split = full.train_test_split(test_size=0.02, seed=42) self.train, self.test = split["train"], split["test"] - self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big")) + self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big")) + # --------------------- helper methods --------------------------------- # async def _completion_until( self, prompt: str, max_tokens: int, stop: Optional[str] = None @@ -422,7 +424,7 @@ class InterleavedInlineEnv(BaseEnv): toks = self.tokenizer.encode(raw) if len(toks) > MAX_REPLY_TOKENS: toks = toks[:MAX_REPLY_TOKENS] - raw = self.tokenizer.decode(toks) + raw = self.tokenizer.decode(toks) if DEBUG: print(f"[DEBUG] truncated reply {idx} to {len(toks)} tokens") assistant_msg = {"role": "assistant", "content": raw} @@ -430,18 +432,21 @@ class InterleavedInlineEnv(BaseEnv): full_ctx = prompt_msgs + [assistant_msg] # Outcome‑based reward: compare boxed answer to expected expr - expr = expected["arguments"]["code"][6:-1] if ( - isinstance(expected, dict) - and "arguments" in expected - and "code" in expected["arguments"] - and expected["arguments"]["code"].startswith("print(") - and expected["arguments"]["code"].endswith(")") - ) else None + expr = ( + expected["arguments"]["code"][6:-1] + if ( + isinstance(expected, dict) + and "arguments" in expected + and "code" in expected["arguments"] + and expected["arguments"]["code"].startswith("print(") + and expected["arguments"]["code"].endswith(")") + ) + else None + ) boxed = self._boxed_after_think(raw) - same = ( - boxed == expr or - (boxed and expr and self._canon_num(boxed) == self._canon_num(expr)) + same = boxed == expr or ( + boxed and expr and self._canon_num(boxed) == self._canon_num(expr) ) reward = 1.0 if same else -1.0 if "" not in raw: @@ -449,9 +454,11 @@ class InterleavedInlineEnv(BaseEnv): else: # NEW RULE: no tool_call tags are allowed *outside* the think block end_pos = raw.lower().find("") - if ""):].lower(): + if "") :].lower(): if DEBUG: - print("[DEBUG] tool_call found outside ; setting reward = -1") + print( + "[DEBUG] tool_call found outside ; setting reward = -1" + ) reward = -1.0 if DEBUG: @@ -469,15 +476,17 @@ class InterleavedInlineEnv(BaseEnv): # --- harvest a success for dynamic few‑shots -------------------- for idx, sc in enumerate(scored["scores"]): reply_txt = completions.choices[idx].text - has_call = "= 1.0 and has_call: # Build (user, assistant) pair from this successful rollout u = {"role": "user", "content": prompt_msgs[-1]["content"]} a = {"role": "assistant", "content": reply_txt} self.dynamic_pool.append((u, a)) if len(self.dynamic_pool) > self.dynamic_pool_max: - self.dynamic_pool.pop(0) # FIFO - break # only harvest one per group + self.dynamic_pool.pop(0) # FIFO + break # only harvest one per group return scored, [] @@ -543,17 +552,17 @@ class InterleavedInlineEnv(BaseEnv): "\n" "Let's be sure of the definite integral ∫₀¹ x² dx. It's easy by hand " "but I'll run SymPy to avoid mistakes.\n" - "{\"name\":\"python_interpreter\", " - "\"arguments\":{\"code\":" - "\"import sympy as sp\\n" + '{"name":"python_interpreter", ' + '"arguments":{"code":' + '"import sympy as sp\\n' "x=sp.symbols('x')\\n" - "print(sp.integrate(x**2,(x,0,1)))\"}}\n" + 'print(sp.integrate(x**2,(x,0,1)))"}}\n' "\n" - "{\"result\": 1/3}\n" + '{"result": 1/3}\n' "The interpreter returns 1/3, so the value is 0.333̅.\n" "\n\n" "The integral equals \\boxed{\\tfrac{1}{3}} \\approx 0.333." - ) + ), } # --- second tiny example: simple arithmetic with calculator ---- # @@ -564,13 +573,13 @@ class InterleavedInlineEnv(BaseEnv): "\n" "I need (2+3)*4. Quick mental math gives 5*4 = 20, " "but I'll confirm with the calculator tool.\n" - "{\"name\":\"calculator\", " - "\"arguments\":{\"expr\":\"(2+3)*4\"}}\n" - "{\"value\": 20}\n" + '{"name":"calculator", ' + '"arguments":{"expr":"(2+3)*4"}}\n' + '{"value": 20}\n' "The tool also says 20, matching my head‑math.\n" "\n\n" "Therefore the answer is \\boxed{20}." - ) + ), } # --------------- build final prompt messages ------------ # @@ -589,11 +598,15 @@ class InterleavedInlineEnv(BaseEnv): dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else [] messages = ( - [system_msg, - fewshot_user, fewshot_assistant, - fewshot_user2, fewshot_assistant2] + - dyn + # 0 or 2 msgs - [real_user] + [ + system_msg, + fewshot_user, + fewshot_assistant, + fewshot_user2, + fewshot_assistant2, + ] + + dyn # 0 or 2 msgs + + [real_user] ) # Freeze for hashing