mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-19 12:58:03 +00:00
Merge pull request #11 from collinear-ai/feat/multi-episode
Add multi-episode setting with scratchpad carryover
This commit is contained in:
commit
d976b9cbb4
9 changed files with 490 additions and 82 deletions
BIN
plots/multi_episode_haiku.png
Normal file
BIN
plots/multi_episode_haiku.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 326 KiB |
|
|
@ -0,0 +1,43 @@
|
|||
{
|
||||
"session_id": "run-1-openrouter/anthropic/claude-haiku-4-5",
|
||||
"model": "openrouter/anthropic/claude-haiku-4-5",
|
||||
"seed": 1,
|
||||
"horizon_years": 1,
|
||||
"total_episodes": 3,
|
||||
"terminal": true,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"started_at": "2026-03-12T01:47:36.650022+00:00",
|
||||
"ended_at": "2026-03-12T01:57:25.802024+00:00",
|
||||
"turns_completed": 193,
|
||||
"total_cost_usd": 6.014085,
|
||||
"episodes": [
|
||||
{
|
||||
"episode": 1,
|
||||
"turns_completed": 29,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"cost_usd": 0.737144,
|
||||
"started_at": "2026-03-12T01:40:17.341900+00:00",
|
||||
"ended_at": "2026-03-12T01:43:04.213804+00:00"
|
||||
},
|
||||
{
|
||||
"episode": 2,
|
||||
"turns_completed": 43,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"cost_usd": 1.054148,
|
||||
"started_at": "2026-03-12T01:43:04.602496+00:00",
|
||||
"ended_at": "2026-03-12T01:47:36.518577+00:00"
|
||||
},
|
||||
{
|
||||
"episode": 3,
|
||||
"turns_completed": 121,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"cost_usd": 4.222793,
|
||||
"started_at": "2026-03-12T01:47:36.650022+00:00",
|
||||
"ended_at": "2026-03-12T01:57:25.802024+00:00"
|
||||
}
|
||||
]
|
||||
}
|
||||
229
scripts/plot_multi_episode.py
Normal file
229
scripts/plot_multi_episode.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Plot multi-episode benchmark: funds over time across episodes + scratchpad evolution."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
import matplotlib.ticker as mticker
|
||||
import matplotlib.gridspec as gridspec
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import textwrap
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
INITIAL_FUNDS_CENTS = 15_000_000
|
||||
|
||||
# ── Collinear brand palette ──────────────────────────────────────────────────
|
||||
NAVY = "#13234D"
|
||||
ORANGE = "#F26125"
|
||||
BLUE = "#4D65FF"
|
||||
BG_COLOR = "#FAFBFD"
|
||||
GRID_CLR = "#E8ECF2"
|
||||
TEXT_CLR = "#2A2F3D"
|
||||
MUTED = "#6B7694"
|
||||
CARD_BG = "#FFFFFF"
|
||||
|
||||
EP_COLORS = [BLUE, ORANGE, "#22C55E"]
|
||||
EP_LABELS = ["Episode 1", "Episode 2", "Episode 3"]
|
||||
|
||||
|
||||
def load_episode(db_path):
|
||||
"""Load funds curve, task stats, and scratchpad from an episode DB."""
|
||||
con = sqlite3.connect(str(db_path))
|
||||
rows = con.execute(
|
||||
"SELECT occurred_at, amount_cents, category FROM ledger_entries ORDER BY occurred_at ASC"
|
||||
).fetchall()
|
||||
|
||||
tasks = con.execute(
|
||||
"SELECT status, count(*) FROM tasks WHERE completed_at IS NOT NULL GROUP BY status"
|
||||
).fetchall()
|
||||
task_stats = dict(tasks)
|
||||
|
||||
scratchpad = con.execute("SELECT content FROM scratchpads LIMIT 1").fetchone()
|
||||
scratchpad_text = scratchpad[0] if scratchpad else ""
|
||||
|
||||
con.close()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
times, balances = [], []
|
||||
running = INITIAL_FUNDS_CENTS
|
||||
start = datetime.fromisoformat(rows[0][0]).replace(
|
||||
month=1, day=1, hour=9, minute=0, second=0, microsecond=0
|
||||
)
|
||||
times.append(start)
|
||||
balances.append(running / 100)
|
||||
for occurred_at, amount_cents, category in rows:
|
||||
running += int(amount_cents)
|
||||
t = datetime.fromisoformat(occurred_at)
|
||||
times.append(t)
|
||||
balances.append(running / 100)
|
||||
|
||||
return {
|
||||
"times": times,
|
||||
"balances": balances,
|
||||
"final_balance": balances[-1],
|
||||
"task_success": task_stats.get("completed_success", 0),
|
||||
"task_fail": task_stats.get("completed_fail", 0),
|
||||
"scratchpad": scratchpad_text,
|
||||
"duration_months": (times[-1] - times[0]).days / 30.0,
|
||||
"bankrupt": balances[-1] <= 0,
|
||||
}
|
||||
|
||||
|
||||
def make_plot(episodes, model_label, seed, config):
|
||||
fig = plt.figure(figsize=(20, 12), facecolor=BG_COLOR)
|
||||
gs = gridspec.GridSpec(2, 3, figure=fig, height_ratios=[2.2, 1],
|
||||
hspace=0.35, wspace=0.3,
|
||||
left=0.07, right=0.97, top=0.82, bottom=0.06)
|
||||
|
||||
# ── Header band ──────────────────────────────────────────────────────
|
||||
header_rect = plt.Rectangle((0, 0.88), 1, 0.12,
|
||||
transform=fig.transFigure, facecolor=NAVY,
|
||||
edgecolor="none", zorder=0)
|
||||
fig.patches.append(header_rect)
|
||||
accent_rect = plt.Rectangle((0, 0.875), 1, 0.006,
|
||||
transform=fig.transFigure, facecolor=ORANGE,
|
||||
edgecolor="none", zorder=1)
|
||||
fig.patches.append(accent_rect)
|
||||
|
||||
fig.text(0.5, 0.94,
|
||||
"YC-Bench | Multi-Episode Learning",
|
||||
ha="center", va="center",
|
||||
fontsize=32, fontweight="700", color="white",
|
||||
fontfamily="Helvetica Neue", zorder=2)
|
||||
fig.text(0.5, 0.895,
|
||||
f"{model_label} | {config} config | seed {seed} | {len(episodes)} episodes",
|
||||
ha="center", va="center",
|
||||
fontsize=16, fontweight="400", color="#AABBDD", zorder=2)
|
||||
|
||||
# ── Top row: funds over time (full width) ────────────────────────────
|
||||
ax_funds = fig.add_subplot(gs[0, :])
|
||||
ax_funds.set_facecolor(CARD_BG)
|
||||
for spine in ax_funds.spines.values():
|
||||
spine.set_edgecolor(GRID_CLR)
|
||||
spine.set_linewidth(1.2)
|
||||
|
||||
for i, ep in enumerate(episodes):
|
||||
color = EP_COLORS[i % len(EP_COLORS)]
|
||||
survived = f"{ep['duration_months']:.0f}mo"
|
||||
label = f"Ep {i+1}: {survived}, {ep['task_success']}W/{ep['task_fail']}L"
|
||||
|
||||
ax_funds.plot(ep["times"], ep["balances"],
|
||||
color=color, linewidth=2.8, alpha=0.9,
|
||||
label=label, zorder=3 + i)
|
||||
ax_funds.fill_between(ep["times"], 0, ep["balances"],
|
||||
color=color, alpha=0.06, zorder=1)
|
||||
|
||||
if ep["bankrupt"]:
|
||||
ax_funds.scatter([ep["times"][-1]], [max(ep["balances"][-1], 500)],
|
||||
color=color, marker="X", s=200,
|
||||
linewidths=2, edgecolors="white",
|
||||
alpha=0.9, zorder=5 + i)
|
||||
|
||||
ax_funds.axhline(0, color="#DC2626", linewidth=1.2, linestyle="--", alpha=0.5, zorder=2,
|
||||
label="Bankruptcy line")
|
||||
ax_funds.set_ylabel("Company Funds ($)", fontsize=14, color=TEXT_CLR, fontweight="500")
|
||||
ax_funds.yaxis.set_major_formatter(
|
||||
mticker.FuncFormatter(
|
||||
lambda x, _: f"${x/1e6:.1f}M" if x >= 1e6
|
||||
else f"${x/1e3:.0f}K" if x >= 1e3
|
||||
else f"${x:.0f}"
|
||||
)
|
||||
)
|
||||
ax_funds.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y"))
|
||||
ax_funds.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
|
||||
plt.setp(ax_funds.xaxis.get_majorticklabels(), rotation=30, ha="right")
|
||||
ax_funds.tick_params(colors=MUTED, labelsize=12)
|
||||
ax_funds.grid(axis="y", color=GRID_CLR, linewidth=0.7, alpha=0.8)
|
||||
ax_funds.grid(axis="x", color=GRID_CLR, linewidth=0.4, alpha=0.4)
|
||||
ax_funds.legend(fontsize=12, facecolor=CARD_BG, edgecolor=GRID_CLR,
|
||||
labelcolor=TEXT_CLR, loc="upper right",
|
||||
framealpha=0.95, borderpad=1)
|
||||
ax_funds.set_title("Funds Over Time — Each Episode Starts Fresh",
|
||||
fontsize=16, fontweight="600", color=TEXT_CLR, pad=12)
|
||||
|
||||
# ── Bottom row: 3 scratchpad panels ──────────────────────────────────
|
||||
for i, ep in enumerate(episodes):
|
||||
ax = fig.add_subplot(gs[1, i])
|
||||
ax.set_facecolor("#F8F9FC")
|
||||
for spine in ax.spines.values():
|
||||
spine.set_edgecolor(EP_COLORS[i % len(EP_COLORS)])
|
||||
spine.set_linewidth(2)
|
||||
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Title
|
||||
color = EP_COLORS[i % len(EP_COLORS)]
|
||||
ax.set_title(f"Episode {i+1} Scratchpad",
|
||||
fontsize=13, fontweight="600", color=color, pad=8)
|
||||
|
||||
# Scratchpad content (truncated)
|
||||
text = ep["scratchpad"].strip()
|
||||
if not text:
|
||||
text = "(empty)"
|
||||
# Take first ~8 lines, wrap to ~55 chars
|
||||
lines = text.split("\n")[:10]
|
||||
wrapped = []
|
||||
for line in lines:
|
||||
if len(line) > 60:
|
||||
wrapped.extend(textwrap.wrap(line, 58))
|
||||
else:
|
||||
wrapped.append(line)
|
||||
display = "\n".join(wrapped[:12])
|
||||
if len(wrapped) > 12 or len(lines) < text.count("\n") + 1:
|
||||
display += "\n..."
|
||||
|
||||
ax.text(0.05, 0.92, display,
|
||||
transform=ax.transAxes,
|
||||
fontsize=7.5, fontfamily="monospace",
|
||||
color=TEXT_CLR, verticalalignment="top",
|
||||
linespacing=1.4)
|
||||
|
||||
# Stats badge
|
||||
stats = f"{ep['task_success']}W / {ep['task_fail']}L | {ep['duration_months']:.0f} months"
|
||||
ax.text(0.5, 0.02, stats,
|
||||
transform=ax.transAxes, ha="center",
|
||||
fontsize=9, fontweight="600", color=MUTED)
|
||||
|
||||
# ── Footer ───────────────────────────────────────────────────────────
|
||||
fig.text(0.5, 0.01,
|
||||
"collinear.ai | Multi-Episode YC-Bench: Scratchpad carries over between bankruptcies",
|
||||
ha="center", va="bottom",
|
||||
fontsize=12, fontweight="400", color=MUTED, fontstyle="italic")
|
||||
|
||||
out = ROOT / "plots" / "multi_episode_haiku.png"
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(out, dpi=150, facecolor=BG_COLOR, pad_inches=0)
|
||||
print(f"Saved: {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
db_dir = ROOT / "db"
|
||||
slug = "openrouter_anthropic_claude-haiku-4-5"
|
||||
config = "hard"
|
||||
seed = 1
|
||||
|
||||
episodes = []
|
||||
for ep_num in [1, 2, 3]:
|
||||
db_path = db_dir / f"{config}_{seed}_{slug}.ep{ep_num}.db"
|
||||
if not db_path.exists():
|
||||
print(f"Skipping {db_path} (not found)")
|
||||
continue
|
||||
data = load_episode(db_path)
|
||||
if data:
|
||||
episodes.append(data)
|
||||
print(f"Episode {ep_num}: {data['task_success']}W/{data['task_fail']}L, "
|
||||
f"survived {data['duration_months']:.1f}mo, "
|
||||
f"final ${data['final_balance']:,.0f}")
|
||||
|
||||
if episodes:
|
||||
make_plot(episodes, "Claude Haiku 4.5", seed, config)
|
||||
else:
|
||||
print("No episode data found.")
|
||||
|
|
@ -118,6 +118,7 @@ def run_agent_loop(
|
|||
max_turns: int | None = None,
|
||||
on_turn_start=None,
|
||||
on_turn=None,
|
||||
episode: int = 1,
|
||||
) -> RunState:
|
||||
run_state.start()
|
||||
turns_since_resume = 0 # consecutive turns without sim resume
|
||||
|
|
@ -136,7 +137,7 @@ def run_agent_loop(
|
|||
if run_state.turn_count == 0:
|
||||
with db_factory() as db:
|
||||
snapshot = _snapshot_state(db, company_id)
|
||||
user_input = build_initial_user_prompt(**snapshot)
|
||||
user_input = build_initial_user_prompt(**snapshot, episode=episode)
|
||||
elif run_state.next_user_input is not None:
|
||||
user_input = run_state.next_user_input
|
||||
run_state.next_user_input = None
|
||||
|
|
|
|||
|
|
@ -173,6 +173,7 @@ def build_initial_user_prompt(
|
|||
employee_count: int,
|
||||
monthly_payroll_cents: int,
|
||||
bankrupt: bool,
|
||||
episode: int = 1,
|
||||
) -> str:
|
||||
"""Build the one-time initial user message at run start."""
|
||||
runway_months = (
|
||||
|
|
@ -188,7 +189,18 @@ def build_initial_user_prompt(
|
|||
)
|
||||
runway_str = f"~{runway_months} months" if runway_months is not None else "∞"
|
||||
|
||||
lines = [
|
||||
lines = []
|
||||
if episode > 1:
|
||||
lines.extend([
|
||||
f"## Episode {episode} — Restarting After Bankruptcy",
|
||||
"",
|
||||
f"You went bankrupt in episode {episode - 1}. The simulation has been reset,",
|
||||
"but your **scratchpad notes from the previous episode are preserved**.",
|
||||
"Read your scratchpad (`yc-bench scratchpad read`) to review your notes",
|
||||
"and learn from past mistakes before taking action.",
|
||||
"",
|
||||
])
|
||||
lines.extend([
|
||||
"## Simulation Start — Take Immediate Action",
|
||||
f"- current_time: {sim_time}",
|
||||
f"- horizon_end: {horizon_end}",
|
||||
|
|
@ -210,7 +222,7 @@ def build_initial_user_prompt(
|
|||
"7. `yc-bench sim resume` — advance time to collect the first task completion event",
|
||||
"",
|
||||
"Do not spend multiple turns just browsing. Accept and dispatch tasks immediately.",
|
||||
]
|
||||
])
|
||||
if bankrupt:
|
||||
lines.append("WARNING: company is already bankrupt at initialization.")
|
||||
return "\n".join(lines)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ class RunState:
|
|||
next_user_input: Optional[str] = None
|
||||
total_cost_usd: float = 0.0
|
||||
|
||||
# Multi-episode tracking
|
||||
current_episode: int = 1
|
||||
episode_results: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def start(self) -> None:
|
||||
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
|
@ -66,18 +70,14 @@ class RunState:
|
|||
return True
|
||||
return False
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
def finish_episode(self) -> Dict[str, Any]:
|
||||
"""Snapshot current episode state into episode_results and return it."""
|
||||
episode_data = {
|
||||
"episode": self.current_episode,
|
||||
"turns_completed": self.turn_count,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"total_cost_usd": round(self.total_cost_usd, 6),
|
||||
"cost_usd": round(self.total_cost_usd, 6),
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"transcript": [
|
||||
|
|
@ -91,11 +91,66 @@ class RunState:
|
|||
for t in self.transcript
|
||||
],
|
||||
}
|
||||
self.episode_results.append(episode_data)
|
||||
return episode_data
|
||||
|
||||
def reset_for_new_episode(self) -> None:
|
||||
"""Reset mutable state for a new episode, preserving episode_results."""
|
||||
self.current_episode += 1
|
||||
self.turn_count = 0
|
||||
self.terminal = False
|
||||
self.terminal_reason = None
|
||||
self.terminal_detail = None
|
||||
self.started_at = None
|
||||
self.ended_at = None
|
||||
self.transcript = []
|
||||
self.next_user_input = None
|
||||
self.total_cost_usd = 0.0
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
base = {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
"total_episodes": self.current_episode,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
}
|
||||
if self.episode_results:
|
||||
# Multi-episode: include all episode data
|
||||
total_turns = sum(ep["turns_completed"] for ep in self.episode_results)
|
||||
total_cost = sum(ep["cost_usd"] for ep in self.episode_results)
|
||||
base["turns_completed"] = total_turns
|
||||
base["total_cost_usd"] = round(total_cost, 6)
|
||||
base["episodes"] = self.episode_results
|
||||
else:
|
||||
# Single-episode (backward compat): flat transcript
|
||||
base["turns_completed"] = self.turn_count
|
||||
base["total_cost_usd"] = round(self.total_cost_usd, 6)
|
||||
base["transcript"] = [
|
||||
{
|
||||
"turn": t.turn,
|
||||
"timestamp": t.timestamp,
|
||||
"user_input": t.user_input,
|
||||
"agent_output": t.agent_output,
|
||||
"commands_executed": t.commands_executed,
|
||||
}
|
||||
for t in self.transcript
|
||||
]
|
||||
return base
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
"""Summary without transcript for logging."""
|
||||
rollout = self.full_rollout()
|
||||
rollout.pop("transcript", None)
|
||||
if "episodes" in rollout:
|
||||
for ep in rollout["episodes"]:
|
||||
ep.pop("transcript", None)
|
||||
return rollout
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -99,6 +99,10 @@ def run_command_cli(
|
|||
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
|
||||
),
|
||||
no_live: bool = typer.Option(False, "--no-live", help="Disable the live terminal dashboard"),
|
||||
max_episodes: int = typer.Option(
|
||||
1, "--max-episodes",
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
),
|
||||
):
|
||||
"""Run a full benchmark: migrate DB, seed world, run agent loop to completion."""
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
|
@ -114,6 +118,7 @@ def run_command_cli(
|
|||
start_date=start_date,
|
||||
config_name=config_name,
|
||||
no_live=no_live,
|
||||
max_episodes=max_episodes,
|
||||
)
|
||||
raise SystemExit(run_benchmark(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class RunArgs:
|
|||
start_date: str
|
||||
config_name: str = "default"
|
||||
no_live: bool = False
|
||||
max_episodes: int = 1
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -32,6 +33,10 @@ def build_parser():
|
|||
"--no-live", action="store_true", default=False,
|
||||
help="Disable the live terminal dashboard (show raw log output instead)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes", type=int, default=1,
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
)
|
||||
return parser
|
||||
|
||||
def parse_run_args(argv):
|
||||
|
|
@ -46,10 +51,13 @@ def parse_run_args(argv):
|
|||
start_date=ns.start_date,
|
||||
config_name=ns.config_name,
|
||||
no_live=ns.no_live,
|
||||
max_episodes=ns.max_episodes,
|
||||
)
|
||||
|
||||
def _validate(ns, parser):
|
||||
if ns.horizon_years is not None and ns.horizon_years <= 0:
|
||||
parser.error("--horizon-years must be int > 0")
|
||||
if ns.max_episodes < 1:
|
||||
parser.error("--max-episodes must be int >= 1")
|
||||
|
||||
__all__ = ["RunArgs", "build_parser", "parse_run_args"]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from ..agent.commands.executor import run_command
|
||||
from ..agent.loop import run_agent_loop
|
||||
from ..agent.run_state import RunState
|
||||
from ..agent.run_state import RunState, TerminalReason
|
||||
from ..agent.runtime.factory import build_runtime
|
||||
from ..agent.runtime.schemas import RuntimeSettings
|
||||
from ..db.session import build_engine, build_session_factory, session_scope, init_db
|
||||
|
|
@ -154,6 +154,37 @@ def _redirect_all_logging_to_file(log_file: Path) -> None:
|
|||
lg.addHandler(file_handler)
|
||||
|
||||
|
||||
def _build_db_url(args, episode: int, max_episodes: int) -> str:
|
||||
"""Build SQLite DATABASE_URL, adding episode suffix when multi-episode."""
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
base = f"{args.config_name}_{args.seed}_{slug}"
|
||||
if max_episodes > 1:
|
||||
return f"sqlite:///{db_dir}/{base}.ep{episode}.db"
|
||||
return f"sqlite:///{db_dir}/{base}.db"
|
||||
|
||||
|
||||
def _read_scratchpad(db_factory, company_id) -> str:
|
||||
"""Read scratchpad content from the current DB."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
return row.content if row else ""
|
||||
|
||||
|
||||
def _write_scratchpad(db_factory, company_id, content: str) -> None:
|
||||
"""Write scratchpad content into the current DB (upsert)."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
if row is None:
|
||||
db.add(Scratchpad(company_id=company_id, content=content))
|
||||
else:
|
||||
row.content = content
|
||||
db.flush()
|
||||
|
||||
|
||||
def run_benchmark(args):
|
||||
"""Run a full benchmark: migrate, seed, loop until terminal."""
|
||||
logging.basicConfig(
|
||||
|
|
@ -184,31 +215,11 @@ def run_benchmark(args):
|
|||
_redirect_all_logging_to_file(log_file)
|
||||
|
||||
logger.info(
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years,
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy max_episodes=%d",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years, args.max_episodes,
|
||||
)
|
||||
|
||||
# 1. Build engine and create all tables
|
||||
# If DATABASE_URL is not explicitly set, default to db/<config>_<seed>_<slug>.db
|
||||
if not os.environ.get("DATABASE_URL"):
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.config_name}_{args.seed}_{slug}.db"
|
||||
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory():
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
# 3. Init simulation using experiment world config
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 4. Build runtime settings from experiment AgentConfig
|
||||
# Build runtime settings (shared across episodes)
|
||||
agent_cfg = experiment_cfg.agent
|
||||
settings = RuntimeSettings(
|
||||
model=agent_cfg.model,
|
||||
|
|
@ -222,7 +233,7 @@ def run_benchmark(args):
|
|||
)
|
||||
runtime = build_runtime(settings, command_executor=run_command)
|
||||
|
||||
# 5. Build run state
|
||||
# Build run state (persists across episodes)
|
||||
session_id = f"run-{args.seed}-{args.model}"
|
||||
run_state = RunState(
|
||||
session_id=session_id,
|
||||
|
|
@ -231,66 +242,110 @@ def run_benchmark(args):
|
|||
horizon_years=horizon_years,
|
||||
)
|
||||
|
||||
# 6. Set up live dashboard (or not)
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 7. Run agent loop
|
||||
loop_cfg = experiment_cfg.loop
|
||||
try:
|
||||
max_episodes = args.max_episodes
|
||||
carried_scratchpad = ""
|
||||
|
||||
for episode in range(1, max_episodes + 1):
|
||||
logger.info("=== Episode %d / %d ===", episode, max_episodes)
|
||||
|
||||
# 1. Build engine for this episode's DB
|
||||
db_url = _build_db_url(args, episode, max_episodes)
|
||||
os.environ["DATABASE_URL"] = db_url
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory(_factory=factory):
|
||||
with session_scope(_factory) as session:
|
||||
yield session
|
||||
|
||||
# 2. Init simulation
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 3. Restore scratchpad from previous episode
|
||||
if episode > 1 and carried_scratchpad:
|
||||
_write_scratchpad(db_factory, company_id, carried_scratchpad)
|
||||
logger.info("Restored scratchpad from episode %d (%d chars).", episode - 1, len(carried_scratchpad))
|
||||
|
||||
# 4. Set up live dashboard
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 5. Run agent loop for this episode
|
||||
try:
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
episode=episode,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
dashboard.print_final_summary(final_state)
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
# 6. For multi-episode runs, snapshot this episode's data
|
||||
if max_episodes > 1:
|
||||
run_state.finish_episode()
|
||||
|
||||
# 8. Print final summary
|
||||
if dashboard is not None:
|
||||
dashboard.print_final_summary(final_state)
|
||||
logger.info("Episode %d finished: reason=%s", episode, run_state.terminal_reason)
|
||||
|
||||
# 9. Save full rollout (with transcript) and print summary
|
||||
# 7. If not bankruptcy, or last episode, stop
|
||||
if run_state.terminal_reason != TerminalReason.BANKRUPTCY or episode == max_episodes:
|
||||
break
|
||||
|
||||
# 8. Save scratchpad for next episode, then reset
|
||||
carried_scratchpad = _read_scratchpad(db_factory, company_id)
|
||||
logger.info("Carrying scratchpad to episode %d (%d chars).", episode + 1, len(carried_scratchpad))
|
||||
|
||||
# Clear runtime session (fresh conversation history)
|
||||
runtime.clear_session(session_id)
|
||||
run_state.reset_for_new_episode()
|
||||
engine.dispose()
|
||||
|
||||
# 9. Save full rollout and print summary
|
||||
rollout = final_state.full_rollout()
|
||||
rollout["time_series"] = extract_time_series(db_factory, company_id)
|
||||
summary = final_state.summary()
|
||||
logger.info("Run complete: %s", json.dumps(summary, indent=2))
|
||||
|
||||
# Write full rollout (includes transcript with commands)
|
||||
slug = args.model.replace("/", "_")
|
||||
results_dir = Path("results")
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
results_path = results_dir / f"yc_bench_result_{args.config_name}_{args.seed}_{slug}.json"
|
||||
results_path.write_text(json.dumps(rollout, indent=2))
|
||||
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
|
||||
logger.info("Full rollout written to %s", results_path)
|
||||
|
||||
return 0 if final_state.terminal_reason != "error" else 1
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue