resoling conflicts

This commit is contained in:
interstellarninja 2025-06-24 15:25:55 -04:00
commit 1570a8a106

View file

@ -13,10 +13,11 @@ class is copied here so nothing breaks when you swap env names.
from __future__ import annotations
import json, re
from typing import Dict, List, Optional, Tuple
import os
import itertools
import json
import os
import re
from typing import Dict, List, Optional, Tuple
# Set to True to always print debug information.
@ -27,18 +28,17 @@ MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tok
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
import wandb
from datasets import load_dataset, Dataset
from datasets import Dataset, load_dataset
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
APIServerConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# -------------------------------------------------------------------------- #
# Constants
# -------------------------------------------------------------------------- #
@ -103,10 +103,9 @@ For reasoning tools, return interleaved tool calls within <think> </think> tags.
</think>
"""
SYSTEM_PROMPT = (
system_prompt
+ TOOL_SYSTEM_PROMPT
)
SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT
# -------------------------------------------------------------------------- #
# Environment
# -------------------------------------------------------------------------- #
@ -133,6 +132,7 @@ class InterleavedInlineEnv(BaseEnv):
self.rollouts_for_wandb = []
self.iter = 0
import random
self.rng = random.Random()
# Dynamic fewshot pool: list of (user_msg, assistant_msg) tuples
self.dynamic_pool: List[Tuple[Dict, Dict]] = []
@ -214,7 +214,9 @@ class InterleavedInlineEnv(BaseEnv):
self.train, self.test = split["train"], split["test"]
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
# --------------------- helper methods --------------------------------- #
async def _completion_until(self, prompt: str, max_tokens: int, stop: Optional[str] = None) -> str:
async def _completion_until(
self, prompt: str, max_tokens: int, stop: Optional[str] = None
) -> str:
comp = await self.server.completion(
prompt=prompt,
stop=stop,
@ -268,16 +270,23 @@ class InterleavedInlineEnv(BaseEnv):
args = call_json["arguments"]
if name == "python_interpreter":
import httpx, asyncio
import asyncio
import httpx
async with httpx.AsyncClient(timeout=10.0) as client:
payload = {"code": args["code"], "input": ""}
resp = await client.post("http://localhost:5002/execute", json=payload)
data = resp.json()
if DEBUG:
print(f"[DEBUG _exec_tool] {name} result → {data}")
return {"stdout": data.get("output", ""), "result": data.get("output", "").strip()}
return {
"stdout": data.get("output", ""),
"result": data.get("output", "").strip(),
}
elif name == "calculator":
import math
expr = args["expr"]
val = eval(expr, {"__builtins__": {}}, {"math": math})
if DEBUG:
@ -287,9 +296,7 @@ class InterleavedInlineEnv(BaseEnv):
raise ValueError(f"Unknown tool name {name}")
# --------------------- rollout logic (interleaved) ------------------- #
async def _run_one_episode(
self, ctx: List[Dict]
) -> Tuple[List[Dict], List[Dict]]:
async def _run_one_episode(self, ctx: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""
Generateexecuteresume loop:
@ -346,9 +353,9 @@ class InterleavedInlineEnv(BaseEnv):
executed.append(call_json)
# Append tool_response inline
assistant_msg["content"] += (
f"\n<tool_response>{json.dumps(result)}</tool_response>\n"
)
assistant_msg[
"content"
] += f"\n<tool_response>{json.dumps(result)}</tool_response>\n"
# continue loop (model will keep thinking)
continue
@ -375,16 +382,16 @@ class InterleavedInlineEnv(BaseEnv):
except Exception:
return False
async def collect_trajectories(
self, item
) -> Tuple[ScoredDataGroup, List]:
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
"""
One prompt `n = group_size` sampled assistant completions in
parallel (single OpenAI request with n completions). Mirrors the
logic in SingleToolCallingEnv.
"""
messages_tuple, expected_raw = item
expected = json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
expected = (
json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
)
# Reinflate frozensets to normal dicts
prompt_msgs = [dict(r) for r in messages_tuple]
@ -486,7 +493,7 @@ class InterleavedInlineEnv(BaseEnv):
total, correct = 0, 0
for sample in self.test:
# Build prompt exactly like get_next_item but without mutating self.iter
prompt_text = sample["problem"]
prompt_text = sample["problem"]
expr = sample["expected_answer"].strip()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
@ -519,16 +526,16 @@ class InterleavedInlineEnv(BaseEnv):
sample = self.train[idx]
prompt_text = sample["problem"]
expr = sample["expected_answer"].strip()
expr = sample["expected_answer"].strip()
answer_call = {
"name": "python_interpreter",
"arguments": {"code": f"print({expr})"}
"arguments": {"code": f"print({expr})"},
}
# ---------------- fewshot demonstration ---------------- #
fewshot_user = {
"role": "user",
"content": "Compute the integral of x^2 from 0 to 1."
"content": "Compute the integral of x^2 from 0 to 1.",
}
fewshot_assistant = {
"role": "assistant",
@ -550,10 +557,7 @@ class InterleavedInlineEnv(BaseEnv):
}
# --- second tiny example: simple arithmetic with calculator ---- #
fewshot_user2 = {
"role": "user",
"content": "What is (2 + 3) * 4 ?"
}
fewshot_user2 = {"role": "user", "content": "What is (2 + 3) * 4 ?"}
fewshot_assistant2 = {
"role": "assistant",
"content": (
@ -571,16 +575,19 @@ class InterleavedInlineEnv(BaseEnv):
# --------------- build final prompt messages ------------ #
system_msg = {"role": "system", "content": SYSTEM_PROMPT}
real_user = {
real_user = {
"role": "user",
"content": (
f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it."
)
),
}
# Optionally insert one real demo from dynamic_pool
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
# Optionally insert one real demo from dynamic_pool
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
messages = (
[system_msg,
fewshot_user, fewshot_assistant,
@ -620,4 +627,4 @@ class InterleavedInlineEnv(BaseEnv):
# -------------------------------------------------------------------------- #
if __name__ == "__main__":
InterleavedInlineEnv.cli()
InterleavedInlineEnv.cli()