mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
resoling conflicts
This commit is contained in:
commit
1570a8a106
1 changed files with 41 additions and 34 deletions
|
|
@ -13,10 +13,11 @@ class is copied here so nothing breaks when you swap env names.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json, re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import os
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
# Set to True to always print debug information.
|
||||
|
|
@ -27,18 +28,17 @@ MAX_REPLY_TOKENS = 2048 # truncate any single assistant reply to ≤1024 tok
|
|||
MAX_GEN_PER_TURN = 1024 # never request more than 512 new tokens from the model
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset, Dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
APIServerConfig,
|
||||
EvalHandlingEnum,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------- #
|
||||
# Constants
|
||||
# -------------------------------------------------------------------------- #
|
||||
|
|
@ -103,10 +103,9 @@ For reasoning tools, return interleaved tool calls within <think> </think> tags.
|
|||
</think>
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
system_prompt
|
||||
+ TOOL_SYSTEM_PROMPT
|
||||
)
|
||||
SYSTEM_PROMPT = system_prompt + TOOL_SYSTEM_PROMPT
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------- #
|
||||
# Environment
|
||||
# -------------------------------------------------------------------------- #
|
||||
|
|
@ -133,6 +132,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
self.rollouts_for_wandb = []
|
||||
self.iter = 0
|
||||
import random
|
||||
|
||||
self.rng = random.Random()
|
||||
# Dynamic few‑shot pool: list of (user_msg, assistant_msg) tuples
|
||||
self.dynamic_pool: List[Tuple[Dict, Dict]] = []
|
||||
|
|
@ -214,7 +214,9 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
self.train, self.test = split["train"], split["test"]
|
||||
self.train = self.train.shuffle(seed=int.from_bytes(os.urandom(2), "big"))
|
||||
# --------------------- helper methods --------------------------------- #
|
||||
async def _completion_until(self, prompt: str, max_tokens: int, stop: Optional[str] = None) -> str:
|
||||
async def _completion_until(
|
||||
self, prompt: str, max_tokens: int, stop: Optional[str] = None
|
||||
) -> str:
|
||||
comp = await self.server.completion(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
|
|
@ -268,16 +270,23 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
args = call_json["arguments"]
|
||||
|
||||
if name == "python_interpreter":
|
||||
import httpx, asyncio
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
payload = {"code": args["code"], "input": ""}
|
||||
resp = await client.post("http://localhost:5002/execute", json=payload)
|
||||
data = resp.json()
|
||||
if DEBUG:
|
||||
print(f"[DEBUG _exec_tool] {name} result → {data}")
|
||||
return {"stdout": data.get("output", ""), "result": data.get("output", "").strip()}
|
||||
return {
|
||||
"stdout": data.get("output", ""),
|
||||
"result": data.get("output", "").strip(),
|
||||
}
|
||||
elif name == "calculator":
|
||||
import math
|
||||
|
||||
expr = args["expr"]
|
||||
val = eval(expr, {"__builtins__": {}}, {"math": math})
|
||||
if DEBUG:
|
||||
|
|
@ -287,9 +296,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
raise ValueError(f"Unknown tool name {name}")
|
||||
|
||||
# --------------------- rollout logic (interleaved) ------------------- #
|
||||
async def _run_one_episode(
|
||||
self, ctx: List[Dict]
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
async def _run_one_episode(self, ctx: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Generate–execute–resume loop:
|
||||
|
||||
|
|
@ -346,9 +353,9 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
executed.append(call_json)
|
||||
|
||||
# Append tool_response inline
|
||||
assistant_msg["content"] += (
|
||||
f"\n<tool_response>{json.dumps(result)}</tool_response>\n"
|
||||
)
|
||||
assistant_msg[
|
||||
"content"
|
||||
] += f"\n<tool_response>{json.dumps(result)}</tool_response>\n"
|
||||
# continue loop (model will keep thinking)
|
||||
continue
|
||||
|
||||
|
|
@ -375,16 +382,16 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item
|
||||
) -> Tuple[ScoredDataGroup, List]:
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
|
||||
"""
|
||||
One prompt → `n = group_size` sampled assistant completions in
|
||||
parallel (single OpenAI request with n completions). Mirrors the
|
||||
logic in SingleToolCallingEnv.
|
||||
"""
|
||||
messages_tuple, expected_raw = item
|
||||
expected = json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
|
||||
expected = (
|
||||
json.loads(expected_raw) if isinstance(expected_raw, str) else expected_raw
|
||||
)
|
||||
|
||||
# Re‑inflate frozensets to normal dicts
|
||||
prompt_msgs = [dict(r) for r in messages_tuple]
|
||||
|
|
@ -486,7 +493,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
total, correct = 0, 0
|
||||
for sample in self.test:
|
||||
# Build prompt exactly like get_next_item but without mutating self.iter
|
||||
prompt_text = sample["problem"]
|
||||
prompt_text = sample["problem"]
|
||||
expr = sample["expected_answer"].strip()
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
|
|
@ -519,16 +526,16 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
sample = self.train[idx]
|
||||
|
||||
prompt_text = sample["problem"]
|
||||
expr = sample["expected_answer"].strip()
|
||||
expr = sample["expected_answer"].strip()
|
||||
answer_call = {
|
||||
"name": "python_interpreter",
|
||||
"arguments": {"code": f"print({expr})"}
|
||||
"arguments": {"code": f"print({expr})"},
|
||||
}
|
||||
|
||||
# ---------------- few‑shot demonstration ---------------- #
|
||||
fewshot_user = {
|
||||
"role": "user",
|
||||
"content": "Compute the integral of x^2 from 0 to 1."
|
||||
"content": "Compute the integral of x^2 from 0 to 1.",
|
||||
}
|
||||
fewshot_assistant = {
|
||||
"role": "assistant",
|
||||
|
|
@ -550,10 +557,7 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
}
|
||||
|
||||
# --- second tiny example: simple arithmetic with calculator ---- #
|
||||
fewshot_user2 = {
|
||||
"role": "user",
|
||||
"content": "What is (2 + 3) * 4 ?"
|
||||
}
|
||||
fewshot_user2 = {"role": "user", "content": "What is (2 + 3) * 4 ?"}
|
||||
fewshot_assistant2 = {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
|
|
@ -571,16 +575,19 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
|
||||
# --------------- build final prompt messages ------------ #
|
||||
system_msg = {"role": "system", "content": SYSTEM_PROMPT}
|
||||
real_user = {
|
||||
real_user = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"{prompt_text} \nThis is a math problem, you must use the python_interpreter or calculator tool call to solve it."
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
# Optionally insert one real demo from dynamic_pool
|
||||
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
|
||||
|
||||
# Optionally insert one real demo from dynamic_pool
|
||||
dyn = list(self.dynamic_pool[-1]) if self.dynamic_pool else []
|
||||
|
||||
messages = (
|
||||
[system_msg,
|
||||
fewshot_user, fewshot_assistant,
|
||||
|
|
@ -620,4 +627,4 @@ class InterleavedInlineEnv(BaseEnv):
|
|||
|
||||
# -------------------------------------------------------------------------- #
|
||||
if __name__ == "__main__":
|
||||
InterleavedInlineEnv.cli()
|
||||
InterleavedInlineEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue