mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
1570a8a106
commit
85f7a0b226
1 changed files with 56 additions and 43 deletions
|
|
@ -19,13 +19,12 @@ import os
|
|||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
# Set to True to always print debug information.
|
||||
DEBUG = True # or toggle via env var if you prefer: bool(os.getenv("DEBUG_INTERLEAVED", "1"))
|
||||
|
||||
# Hard caps for generation length
|
||||
MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens
|
||||
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
|
||||
MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tokens
|
||||
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
|
||||
|
||||
import wandb
|
||||
from datasets import Dataset, load_dataset
|
||||
|
|
@ -136,7 +135,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
self.rng = random.Random()
|
||||
# Dynamic few‑shot pool: list of (user_msg, assistant_msg) tuples
|
||||
self.dynamic_pool: List[Tuple[Dict, Dict]] = []
|
||||
self.dynamic_pool_max = 4 # keep at most 4 real examples
|
||||
self.dynamic_pool_max = 4 # keep at most 4 real examples
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
|
|
@ -153,7 +152,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
wandb_name="toolcall_interleaved",
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
max_gen_per_turn = MAX_GEN_PER_TURN,
|
||||
max_gen_per_turn=MAX_GEN_PER_TURN,
|
||||
)
|
||||
servers = [
|
||||
APIServerConfig(
|
||||
|
|
@ -175,15 +174,17 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
|
||||
The env‑var SUBSET_ROWS (default 1000) controls how many rows we keep.
|
||||
"""
|
||||
import re, os
|
||||
import os
|
||||
import re
|
||||
|
||||
N = int(os.getenv("SUBSET_ROWS", "1000"))
|
||||
|
||||
stream_ds = load_dataset( # ≈50 k rows total → stream
|
||||
stream_ds = load_dataset( # ≈50 k rows total → stream
|
||||
"NVIDIA/OpenMathReasoning",
|
||||
split="cot",
|
||||
#"open-r1/OpenR1-Math-220k",
|
||||
#"nvidia/AceReason-Math",
|
||||
#split="train",
|
||||
# "open-r1/OpenR1-Math-220k",
|
||||
# "nvidia/AceReason-Math",
|
||||
# split="train",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
|
@ -210,9 +211,10 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
if DEBUG:
|
||||
print(f"[DEBUG setup] kept {len(subset)} rows from Dataset")
|
||||
|
||||
split = full.train_test_split(test_size=0.02, seed=42)
|
||||
split = full.train_test_split(test_size=0.02, seed=42)
|
||||
self.train, self.test = split["train"], split["test"]
|
||||
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
|
||||
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
|
||||
|
||||
# --------------------- helper methods --------------------------------- #
|
||||
async def _completion_until(
|
||||
self, prompt: str, max_tokens: int, stop: Optional[str] = None
|
||||
|
|
@ -422,7 +424,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
toks = self.tokenizer.encode(raw)
|
||||
if len(toks) > MAX_REPLY_TOKENS:
|
||||
toks = toks[:MAX_REPLY_TOKENS]
|
||||
raw = self.tokenizer.decode(toks)
|
||||
raw = self.tokenizer.decode(toks)
|
||||
if DEBUG:
|
||||
print(f"[DEBUG] truncated reply {idx} to {len(toks)} tokens")
|
||||
assistant_msg = {"role": "assistant", "content": raw}
|
||||
|
|
@ -430,18 +432,21 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
full_ctx = prompt_msgs + [assistant_msg]
|
||||
|
||||
# Outcome‑based reward: compare boxed answer to expected expr
|
||||
expr = expected["arguments"]["code"][6:-1] if (
|
||||
isinstance(expected, dict)
|
||||
and "arguments" in expected
|
||||
and "code" in expected["arguments"]
|
||||
and expected["arguments"]["code"].startswith("print(")
|
||||
and expected["arguments"]["code"].endswith(")")
|
||||
) else None
|
||||
expr = (
|
||||
expected["arguments"]["code"][6:-1]
|
||||
if (
|
||||
isinstance(expected, dict)
|
||||
and "arguments" in expected
|
||||
and "code" in expected["arguments"]
|
||||
and expected["arguments"]["code"].startswith("print(")
|
||||
and expected["arguments"]["code"].endswith(")")
|
||||
)
|
||||
else None
|
||||
)
|
||||
boxed = self._boxed_after_think(raw)
|
||||
|
||||
same = (
|
||||
boxed == expr or
|
||||
(boxed and expr and self._canon_num(boxed) == self._canon_num(expr))
|
||||
same = boxed == expr or (
|
||||
boxed and expr and self._canon_num(boxed) == self._canon_num(expr)
|
||||
)
|
||||
reward = 1.0 if same else -1.0
|
||||
if "</think>" not in raw:
|
||||
|
|
@ -449,9 +454,11 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
else:
|
||||
# NEW RULE: no tool_call tags are allowed *outside* the think block
|
||||
end_pos = raw.lower().find("</think>")
|
||||
if "<tool_call" in raw[end_pos + len("</think>"):].lower():
|
||||
if "<tool_call" in raw[end_pos + len("</think>") :].lower():
|
||||
if DEBUG:
|
||||
print("[DEBUG] tool_call found outside </think>; setting reward = -1")
|
||||
print(
|
||||
"[DEBUG] tool_call found outside </think>; setting reward = -1"
|
||||
)
|
||||
reward = -1.0
|
||||
|
||||
if DEBUG:
|
||||
|
|
@ -469,15 +476,17 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
# --- harvest a success for dynamic few‑shots --------------------
|
||||
for idx, sc in enumerate(scored["scores"]):
|
||||
reply_txt = completions.choices[idx].text
|
||||
has_call = "<tool_call" in reply_txt.lower() # ensure an interleaved call exists
|
||||
has_call = (
|
||||
"<tool_call" in reply_txt.lower()
|
||||
) # ensure an interleaved call exists
|
||||
if sc >= 1.0 and has_call:
|
||||
# Build (user, assistant) pair from this successful rollout
|
||||
u = {"role": "user", "content": prompt_msgs[-1]["content"]}
|
||||
a = {"role": "assistant", "content": reply_txt}
|
||||
self.dynamic_pool.append((u, a))
|
||||
if len(self.dynamic_pool) > self.dynamic_pool_max:
|
||||
self.dynamic_pool.pop(0) # FIFO
|
||||
break # only harvest one per group
|
||||
self.dynamic_pool.pop(0) # FIFO
|
||||
break # only harvest one per group
|
||||
|
||||
return scored, []
|
||||
|
||||
|
|
@ -543,17 +552,17 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
"<think>\n"
|
||||
"Let's be sure of the definite integral ∫₀¹ x² dx. It's easy by hand "
|
||||
"but I'll run SymPy to avoid mistakes.\n"
|
||||
"<tool_call>{\"name\":\"python_interpreter\", "
|
||||
"\"arguments\":{\"code\":"
|
||||
"\"import sympy as sp\\n"
|
||||
'<tool_call>{"name":"python_interpreter", '
|
||||
'"arguments":{"code":'
|
||||
'"import sympy as sp\\n'
|
||||
"x=sp.symbols('x')\\n"
|
||||
"print(sp.integrate(x**2,(x,0,1)))\"}}\n"
|
||||
'print(sp.integrate(x**2,(x,0,1)))"}}\n'
|
||||
"</tool_call>\n"
|
||||
"<tool_response>{\"result\": 1/3}</tool_response>\n"
|
||||
'<tool_response>{"result": 1/3}</tool_response>\n'
|
||||
"The interpreter returns 1/3, so the value is 0.333̅.\n"
|
||||
"</think>\n\n"
|
||||
"The integral equals \\boxed{\\tfrac{1}{3}} \\approx 0.333."
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
# --- second tiny example: simple arithmetic with calculator ---- #
|
||||
|
|
@ -564,13 +573,13 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
"<think>\n"
|
||||
"I need (2+3)*4. Quick mental math gives 5*4 = 20, "
|
||||
"but I'll confirm with the calculator tool.\n"
|
||||
"<tool_call>{\"name\":\"calculator\", "
|
||||
"\"arguments\":{\"expr\":\"(2+3)*4\"}}</tool_call>\n"
|
||||
"<tool_response>{\"value\": 20}</tool_response>\n"
|
||||
'<tool_call>{"name":"calculator", '
|
||||
'"arguments":{"expr":"(2+3)*4"}}</tool_call>\n'
|
||||
'<tool_response>{"value": 20}</tool_response>\n'
|
||||
"The tool also says 20, matching my head‑math.\n"
|
||||
"</think>\n\n"
|
||||
"Therefore the answer is \\boxed{20}."
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
# --------------- build final prompt messages ------------ #
|
||||
|
|
@ -589,11 +598,15 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
|
||||
|
||||
messages = (
|
||||
[system_msg,
|
||||
fewshot_user, fewshot_assistant,
|
||||
fewshot_user2, fewshot_assistant2] +
|
||||
dyn + # 0 or 2 msgs
|
||||
[real_user]
|
||||
[
|
||||
system_msg,
|
||||
fewshot_user,
|
||||
fewshot_assistant,
|
||||
fewshot_user2,
|
||||
fewshot_assistant2,
|
||||
]
|
||||
+ dyn # 0 or 2 msgs
|
||||
+ [real_user]
|
||||
)
|
||||
|
||||
# Freeze for hashing
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue