mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-19 12:58:03 +00:00
Add multi-episode setting with scratchpad carryover between bankruptcies
When an agent goes bankrupt, the simulation can now restart for another episode while preserving the scratchpad from the previous attempt. This lets us measure whether LLMs can learn from failure via persistent notes. Each episode gets its own SQLite DB (*.ep1.db, *.ep2.db, ...) so plotting scripts and post-hoc analysis work unchanged. The rollout JSON aggregates per-episode transcripts, turns, and costs. Key changes: - --max-episodes CLI flag (default 1, fully backward compatible) - Per-episode DB files when max_episodes > 1 - Scratchpad read from old DB, written into fresh DB between episodes - RunState tracks episode results with finish_episode/reset_for_new_episode - Agent prompt tells it about the episode number and to read its scratchpad - Plotting script for multi-episode fund curves + scratchpad evolution Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ecd3d9e415
commit
01535c2042
9 changed files with 490 additions and 82 deletions
BIN
plots/multi_episode_haiku.png
Normal file
BIN
plots/multi_episode_haiku.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 326 KiB |
|
|
@ -0,0 +1,43 @@
|
|||
{
|
||||
"session_id": "run-1-openrouter/anthropic/claude-haiku-4-5",
|
||||
"model": "openrouter/anthropic/claude-haiku-4-5",
|
||||
"seed": 1,
|
||||
"horizon_years": 1,
|
||||
"total_episodes": 3,
|
||||
"terminal": true,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"started_at": "2026-03-12T01:47:36.650022+00:00",
|
||||
"ended_at": "2026-03-12T01:57:25.802024+00:00",
|
||||
"turns_completed": 193,
|
||||
"total_cost_usd": 6.014085,
|
||||
"episodes": [
|
||||
{
|
||||
"episode": 1,
|
||||
"turns_completed": 29,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"cost_usd": 0.737144,
|
||||
"started_at": "2026-03-12T01:40:17.341900+00:00",
|
||||
"ended_at": "2026-03-12T01:43:04.213804+00:00"
|
||||
},
|
||||
{
|
||||
"episode": 2,
|
||||
"turns_completed": 43,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"cost_usd": 1.054148,
|
||||
"started_at": "2026-03-12T01:43:04.602496+00:00",
|
||||
"ended_at": "2026-03-12T01:47:36.518577+00:00"
|
||||
},
|
||||
{
|
||||
"episode": 3,
|
||||
"turns_completed": 121,
|
||||
"terminal_reason": "bankruptcy",
|
||||
"terminal_detail": "bankruptcy",
|
||||
"cost_usd": 4.222793,
|
||||
"started_at": "2026-03-12T01:47:36.650022+00:00",
|
||||
"ended_at": "2026-03-12T01:57:25.802024+00:00"
|
||||
}
|
||||
]
|
||||
}
|
||||
229
scripts/plot_multi_episode.py
Normal file
229
scripts/plot_multi_episode.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Plot multi-episode benchmark: funds over time across episodes + scratchpad evolution."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
import matplotlib.ticker as mticker
|
||||
import matplotlib.gridspec as gridspec
|
||||
from matplotlib.patches import FancyBboxPatch
|
||||
import textwrap
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
INITIAL_FUNDS_CENTS = 15_000_000
|
||||
|
||||
# ── Collinear brand palette ──────────────────────────────────────────────────
|
||||
NAVY = "#13234D"
|
||||
ORANGE = "#F26125"
|
||||
BLUE = "#4D65FF"
|
||||
BG_COLOR = "#FAFBFD"
|
||||
GRID_CLR = "#E8ECF2"
|
||||
TEXT_CLR = "#2A2F3D"
|
||||
MUTED = "#6B7694"
|
||||
CARD_BG = "#FFFFFF"
|
||||
|
||||
EP_COLORS = [BLUE, ORANGE, "#22C55E"]
|
||||
EP_LABELS = ["Episode 1", "Episode 2", "Episode 3"]
|
||||
|
||||
|
||||
def load_episode(db_path):
|
||||
"""Load funds curve, task stats, and scratchpad from an episode DB."""
|
||||
con = sqlite3.connect(str(db_path))
|
||||
rows = con.execute(
|
||||
"SELECT occurred_at, amount_cents, category FROM ledger_entries ORDER BY occurred_at ASC"
|
||||
).fetchall()
|
||||
|
||||
tasks = con.execute(
|
||||
"SELECT status, count(*) FROM tasks WHERE completed_at IS NOT NULL GROUP BY status"
|
||||
).fetchall()
|
||||
task_stats = dict(tasks)
|
||||
|
||||
scratchpad = con.execute("SELECT content FROM scratchpads LIMIT 1").fetchone()
|
||||
scratchpad_text = scratchpad[0] if scratchpad else ""
|
||||
|
||||
con.close()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
times, balances = [], []
|
||||
running = INITIAL_FUNDS_CENTS
|
||||
start = datetime.fromisoformat(rows[0][0]).replace(
|
||||
month=1, day=1, hour=9, minute=0, second=0, microsecond=0
|
||||
)
|
||||
times.append(start)
|
||||
balances.append(running / 100)
|
||||
for occurred_at, amount_cents, category in rows:
|
||||
running += int(amount_cents)
|
||||
t = datetime.fromisoformat(occurred_at)
|
||||
times.append(t)
|
||||
balances.append(running / 100)
|
||||
|
||||
return {
|
||||
"times": times,
|
||||
"balances": balances,
|
||||
"final_balance": balances[-1],
|
||||
"task_success": task_stats.get("completed_success", 0),
|
||||
"task_fail": task_stats.get("completed_fail", 0),
|
||||
"scratchpad": scratchpad_text,
|
||||
"duration_months": (times[-1] - times[0]).days / 30.0,
|
||||
"bankrupt": balances[-1] <= 0,
|
||||
}
|
||||
|
||||
|
||||
def make_plot(episodes, model_label, seed, config):
|
||||
fig = plt.figure(figsize=(20, 12), facecolor=BG_COLOR)
|
||||
gs = gridspec.GridSpec(2, 3, figure=fig, height_ratios=[2.2, 1],
|
||||
hspace=0.35, wspace=0.3,
|
||||
left=0.07, right=0.97, top=0.82, bottom=0.06)
|
||||
|
||||
# ── Header band ──────────────────────────────────────────────────────
|
||||
header_rect = plt.Rectangle((0, 0.88), 1, 0.12,
|
||||
transform=fig.transFigure, facecolor=NAVY,
|
||||
edgecolor="none", zorder=0)
|
||||
fig.patches.append(header_rect)
|
||||
accent_rect = plt.Rectangle((0, 0.875), 1, 0.006,
|
||||
transform=fig.transFigure, facecolor=ORANGE,
|
||||
edgecolor="none", zorder=1)
|
||||
fig.patches.append(accent_rect)
|
||||
|
||||
fig.text(0.5, 0.94,
|
||||
"YC-Bench | Multi-Episode Learning",
|
||||
ha="center", va="center",
|
||||
fontsize=32, fontweight="700", color="white",
|
||||
fontfamily="Helvetica Neue", zorder=2)
|
||||
fig.text(0.5, 0.895,
|
||||
f"{model_label} | {config} config | seed {seed} | {len(episodes)} episodes",
|
||||
ha="center", va="center",
|
||||
fontsize=16, fontweight="400", color="#AABBDD", zorder=2)
|
||||
|
||||
# ── Top row: funds over time (full width) ────────────────────────────
|
||||
ax_funds = fig.add_subplot(gs[0, :])
|
||||
ax_funds.set_facecolor(CARD_BG)
|
||||
for spine in ax_funds.spines.values():
|
||||
spine.set_edgecolor(GRID_CLR)
|
||||
spine.set_linewidth(1.2)
|
||||
|
||||
for i, ep in enumerate(episodes):
|
||||
color = EP_COLORS[i % len(EP_COLORS)]
|
||||
survived = f"{ep['duration_months']:.0f}mo"
|
||||
label = f"Ep {i+1}: {survived}, {ep['task_success']}W/{ep['task_fail']}L"
|
||||
|
||||
ax_funds.plot(ep["times"], ep["balances"],
|
||||
color=color, linewidth=2.8, alpha=0.9,
|
||||
label=label, zorder=3 + i)
|
||||
ax_funds.fill_between(ep["times"], 0, ep["balances"],
|
||||
color=color, alpha=0.06, zorder=1)
|
||||
|
||||
if ep["bankrupt"]:
|
||||
ax_funds.scatter([ep["times"][-1]], [max(ep["balances"][-1], 500)],
|
||||
color=color, marker="X", s=200,
|
||||
linewidths=2, edgecolors="white",
|
||||
alpha=0.9, zorder=5 + i)
|
||||
|
||||
ax_funds.axhline(0, color="#DC2626", linewidth=1.2, linestyle="--", alpha=0.5, zorder=2,
|
||||
label="Bankruptcy line")
|
||||
ax_funds.set_ylabel("Company Funds ($)", fontsize=14, color=TEXT_CLR, fontweight="500")
|
||||
ax_funds.yaxis.set_major_formatter(
|
||||
mticker.FuncFormatter(
|
||||
lambda x, _: f"${x/1e6:.1f}M" if x >= 1e6
|
||||
else f"${x/1e3:.0f}K" if x >= 1e3
|
||||
else f"${x:.0f}"
|
||||
)
|
||||
)
|
||||
ax_funds.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y"))
|
||||
ax_funds.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
|
||||
plt.setp(ax_funds.xaxis.get_majorticklabels(), rotation=30, ha="right")
|
||||
ax_funds.tick_params(colors=MUTED, labelsize=12)
|
||||
ax_funds.grid(axis="y", color=GRID_CLR, linewidth=0.7, alpha=0.8)
|
||||
ax_funds.grid(axis="x", color=GRID_CLR, linewidth=0.4, alpha=0.4)
|
||||
ax_funds.legend(fontsize=12, facecolor=CARD_BG, edgecolor=GRID_CLR,
|
||||
labelcolor=TEXT_CLR, loc="upper right",
|
||||
framealpha=0.95, borderpad=1)
|
||||
ax_funds.set_title("Funds Over Time — Each Episode Starts Fresh",
|
||||
fontsize=16, fontweight="600", color=TEXT_CLR, pad=12)
|
||||
|
||||
# ── Bottom row: 3 scratchpad panels ──────────────────────────────────
|
||||
for i, ep in enumerate(episodes):
|
||||
ax = fig.add_subplot(gs[1, i])
|
||||
ax.set_facecolor("#F8F9FC")
|
||||
for spine in ax.spines.values():
|
||||
spine.set_edgecolor(EP_COLORS[i % len(EP_COLORS)])
|
||||
spine.set_linewidth(2)
|
||||
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# Title
|
||||
color = EP_COLORS[i % len(EP_COLORS)]
|
||||
ax.set_title(f"Episode {i+1} Scratchpad",
|
||||
fontsize=13, fontweight="600", color=color, pad=8)
|
||||
|
||||
# Scratchpad content (truncated)
|
||||
text = ep["scratchpad"].strip()
|
||||
if not text:
|
||||
text = "(empty)"
|
||||
# Take first ~8 lines, wrap to ~55 chars
|
||||
lines = text.split("\n")[:10]
|
||||
wrapped = []
|
||||
for line in lines:
|
||||
if len(line) > 60:
|
||||
wrapped.extend(textwrap.wrap(line, 58))
|
||||
else:
|
||||
wrapped.append(line)
|
||||
display = "\n".join(wrapped[:12])
|
||||
if len(wrapped) > 12 or len(lines) < text.count("\n") + 1:
|
||||
display += "\n..."
|
||||
|
||||
ax.text(0.05, 0.92, display,
|
||||
transform=ax.transAxes,
|
||||
fontsize=7.5, fontfamily="monospace",
|
||||
color=TEXT_CLR, verticalalignment="top",
|
||||
linespacing=1.4)
|
||||
|
||||
# Stats badge
|
||||
stats = f"{ep['task_success']}W / {ep['task_fail']}L | {ep['duration_months']:.0f} months"
|
||||
ax.text(0.5, 0.02, stats,
|
||||
transform=ax.transAxes, ha="center",
|
||||
fontsize=9, fontweight="600", color=MUTED)
|
||||
|
||||
# ── Footer ───────────────────────────────────────────────────────────
|
||||
fig.text(0.5, 0.01,
|
||||
"collinear.ai | Multi-Episode YC-Bench: Scratchpad carries over between bankruptcies",
|
||||
ha="center", va="bottom",
|
||||
fontsize=12, fontweight="400", color=MUTED, fontstyle="italic")
|
||||
|
||||
out = ROOT / "plots" / "multi_episode_haiku.png"
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(out, dpi=150, facecolor=BG_COLOR, pad_inches=0)
|
||||
print(f"Saved: {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
db_dir = ROOT / "db"
|
||||
slug = "openrouter_anthropic_claude-haiku-4-5"
|
||||
config = "hard"
|
||||
seed = 1
|
||||
|
||||
episodes = []
|
||||
for ep_num in [1, 2, 3]:
|
||||
db_path = db_dir / f"{config}_{seed}_{slug}.ep{ep_num}.db"
|
||||
if not db_path.exists():
|
||||
print(f"Skipping {db_path} (not found)")
|
||||
continue
|
||||
data = load_episode(db_path)
|
||||
if data:
|
||||
episodes.append(data)
|
||||
print(f"Episode {ep_num}: {data['task_success']}W/{data['task_fail']}L, "
|
||||
f"survived {data['duration_months']:.1f}mo, "
|
||||
f"final ${data['final_balance']:,.0f}")
|
||||
|
||||
if episodes:
|
||||
make_plot(episodes, "Claude Haiku 4.5", seed, config)
|
||||
else:
|
||||
print("No episode data found.")
|
||||
|
|
@ -118,6 +118,7 @@ def run_agent_loop(
|
|||
max_turns: int | None = None,
|
||||
on_turn_start=None,
|
||||
on_turn=None,
|
||||
episode: int = 1,
|
||||
) -> RunState:
|
||||
run_state.start()
|
||||
turns_since_resume = 0 # consecutive turns without sim resume
|
||||
|
|
@ -136,7 +137,7 @@ def run_agent_loop(
|
|||
if run_state.turn_count == 0:
|
||||
with db_factory() as db:
|
||||
snapshot = _snapshot_state(db, company_id)
|
||||
user_input = build_initial_user_prompt(**snapshot)
|
||||
user_input = build_initial_user_prompt(**snapshot, episode=episode)
|
||||
elif run_state.next_user_input is not None:
|
||||
user_input = run_state.next_user_input
|
||||
run_state.next_user_input = None
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ def build_initial_user_prompt(
|
|||
employee_count: int,
|
||||
monthly_payroll_cents: int,
|
||||
bankrupt: bool,
|
||||
episode: int = 1,
|
||||
) -> str:
|
||||
"""Build the one-time initial user message at run start."""
|
||||
runway_months = (
|
||||
|
|
@ -156,7 +157,18 @@ def build_initial_user_prompt(
|
|||
)
|
||||
runway_str = f"~{runway_months} months" if runway_months is not None else "∞"
|
||||
|
||||
lines = [
|
||||
lines = []
|
||||
if episode > 1:
|
||||
lines.extend([
|
||||
f"## Episode {episode} — Restarting After Bankruptcy",
|
||||
"",
|
||||
f"You went bankrupt in episode {episode - 1}. The simulation has been reset,",
|
||||
"but your **scratchpad notes from the previous episode are preserved**.",
|
||||
"Read your scratchpad (`yc-bench scratchpad read`) to review your notes",
|
||||
"and learn from past mistakes before taking action.",
|
||||
"",
|
||||
])
|
||||
lines.extend([
|
||||
"## Simulation Start — Take Immediate Action",
|
||||
f"- current_time: {sim_time}",
|
||||
f"- horizon_end: {horizon_end}",
|
||||
|
|
@ -177,7 +189,7 @@ def build_initial_user_prompt(
|
|||
"6. `yc-bench sim resume` — advance time to collect the first task completion event",
|
||||
"",
|
||||
"Do not spend multiple turns just browsing. Accept and dispatch tasks immediately.",
|
||||
]
|
||||
])
|
||||
if bankrupt:
|
||||
lines.append("WARNING: company is already bankrupt at initialization.")
|
||||
return "\n".join(lines)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ class RunState:
|
|||
next_user_input: Optional[str] = None
|
||||
total_cost_usd: float = 0.0
|
||||
|
||||
# Multi-episode tracking
|
||||
current_episode: int = 1
|
||||
episode_results: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def start(self) -> None:
|
||||
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
|
@ -66,18 +70,14 @@ class RunState:
|
|||
return True
|
||||
return False
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
def finish_episode(self) -> Dict[str, Any]:
|
||||
"""Snapshot current episode state into episode_results and return it."""
|
||||
episode_data = {
|
||||
"episode": self.current_episode,
|
||||
"turns_completed": self.turn_count,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"total_cost_usd": round(self.total_cost_usd, 6),
|
||||
"cost_usd": round(self.total_cost_usd, 6),
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"transcript": [
|
||||
|
|
@ -91,11 +91,66 @@ class RunState:
|
|||
for t in self.transcript
|
||||
],
|
||||
}
|
||||
self.episode_results.append(episode_data)
|
||||
return episode_data
|
||||
|
||||
def reset_for_new_episode(self) -> None:
|
||||
"""Reset mutable state for a new episode, preserving episode_results."""
|
||||
self.current_episode += 1
|
||||
self.turn_count = 0
|
||||
self.terminal = False
|
||||
self.terminal_reason = None
|
||||
self.terminal_detail = None
|
||||
self.started_at = None
|
||||
self.ended_at = None
|
||||
self.transcript = []
|
||||
self.next_user_input = None
|
||||
self.total_cost_usd = 0.0
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
base = {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
"total_episodes": self.current_episode,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
}
|
||||
if self.episode_results:
|
||||
# Multi-episode: include all episode data
|
||||
total_turns = sum(ep["turns_completed"] for ep in self.episode_results)
|
||||
total_cost = sum(ep["cost_usd"] for ep in self.episode_results)
|
||||
base["turns_completed"] = total_turns
|
||||
base["total_cost_usd"] = round(total_cost, 6)
|
||||
base["episodes"] = self.episode_results
|
||||
else:
|
||||
# Single-episode (backward compat): flat transcript
|
||||
base["turns_completed"] = self.turn_count
|
||||
base["total_cost_usd"] = round(self.total_cost_usd, 6)
|
||||
base["transcript"] = [
|
||||
{
|
||||
"turn": t.turn,
|
||||
"timestamp": t.timestamp,
|
||||
"user_input": t.user_input,
|
||||
"agent_output": t.agent_output,
|
||||
"commands_executed": t.commands_executed,
|
||||
}
|
||||
for t in self.transcript
|
||||
]
|
||||
return base
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
"""Summary without transcript for logging."""
|
||||
rollout = self.full_rollout()
|
||||
rollout.pop("transcript", None)
|
||||
if "episodes" in rollout:
|
||||
for ep in rollout["episodes"]:
|
||||
ep.pop("transcript", None)
|
||||
return rollout
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,10 @@ def run_command_cli(
|
|||
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
|
||||
),
|
||||
no_live: bool = typer.Option(False, "--no-live", help="Disable the live terminal dashboard"),
|
||||
max_episodes: int = typer.Option(
|
||||
1, "--max-episodes",
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
),
|
||||
):
|
||||
"""Run a full benchmark: migrate DB, seed world, run agent loop to completion."""
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
|
@ -112,6 +116,7 @@ def run_command_cli(
|
|||
start_date=start_date,
|
||||
config_name=config_name,
|
||||
no_live=no_live,
|
||||
max_episodes=max_episodes,
|
||||
)
|
||||
raise SystemExit(run_benchmark(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class RunArgs:
|
|||
start_date: str
|
||||
config_name: str = "default"
|
||||
no_live: bool = False
|
||||
max_episodes: int = 1
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -32,6 +33,10 @@ def build_parser():
|
|||
"--no-live", action="store_true", default=False,
|
||||
help="Disable the live terminal dashboard (show raw log output instead)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes", type=int, default=1,
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
)
|
||||
return parser
|
||||
|
||||
def parse_run_args(argv):
|
||||
|
|
@ -46,10 +51,13 @@ def parse_run_args(argv):
|
|||
start_date=ns.start_date,
|
||||
config_name=ns.config_name,
|
||||
no_live=ns.no_live,
|
||||
max_episodes=ns.max_episodes,
|
||||
)
|
||||
|
||||
def _validate(ns, parser):
|
||||
if ns.horizon_years is not None and ns.horizon_years <= 0:
|
||||
parser.error("--horizon-years must be int > 0")
|
||||
if ns.max_episodes < 1:
|
||||
parser.error("--max-episodes must be int >= 1")
|
||||
|
||||
__all__ = ["RunArgs", "build_parser", "parse_run_args"]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from ..agent.commands.executor import run_command
|
||||
from ..agent.loop import run_agent_loop
|
||||
from ..agent.run_state import RunState
|
||||
from ..agent.run_state import RunState, TerminalReason
|
||||
from ..agent.runtime.factory import build_runtime
|
||||
from ..agent.runtime.schemas import RuntimeSettings
|
||||
from ..db.session import build_engine, build_session_factory, session_scope, init_db
|
||||
|
|
@ -150,6 +150,37 @@ def _redirect_all_logging_to_file(log_file: Path) -> None:
|
|||
lg.addHandler(file_handler)
|
||||
|
||||
|
||||
def _build_db_url(args, episode: int, max_episodes: int) -> str:
|
||||
"""Build SQLite DATABASE_URL, adding episode suffix when multi-episode."""
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
base = f"{args.config_name}_{args.seed}_{slug}"
|
||||
if max_episodes > 1:
|
||||
return f"sqlite:///{db_dir}/{base}.ep{episode}.db"
|
||||
return f"sqlite:///{db_dir}/{base}.db"
|
||||
|
||||
|
||||
def _read_scratchpad(db_factory, company_id) -> str:
|
||||
"""Read scratchpad content from the current DB."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
return row.content if row else ""
|
||||
|
||||
|
||||
def _write_scratchpad(db_factory, company_id, content: str) -> None:
|
||||
"""Write scratchpad content into the current DB (upsert)."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
if row is None:
|
||||
db.add(Scratchpad(company_id=company_id, content=content))
|
||||
else:
|
||||
row.content = content
|
||||
db.flush()
|
||||
|
||||
|
||||
def run_benchmark(args):
|
||||
"""Run a full benchmark: migrate, seed, loop until terminal."""
|
||||
logging.basicConfig(
|
||||
|
|
@ -180,31 +211,11 @@ def run_benchmark(args):
|
|||
_redirect_all_logging_to_file(log_file)
|
||||
|
||||
logger.info(
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years,
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy max_episodes=%d",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years, args.max_episodes,
|
||||
)
|
||||
|
||||
# 1. Build engine and create all tables
|
||||
# If DATABASE_URL is not explicitly set, default to db/<config>_<seed>_<slug>.db
|
||||
if not os.environ.get("DATABASE_URL"):
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.config_name}_{args.seed}_{slug}.db"
|
||||
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory():
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
# 3. Init simulation using experiment world config
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 4. Build runtime settings from experiment AgentConfig
|
||||
# Build runtime settings (shared across episodes)
|
||||
agent_cfg = experiment_cfg.agent
|
||||
settings = RuntimeSettings(
|
||||
model=agent_cfg.model,
|
||||
|
|
@ -218,7 +229,7 @@ def run_benchmark(args):
|
|||
)
|
||||
runtime = build_runtime(settings, command_executor=run_command)
|
||||
|
||||
# 5. Build run state
|
||||
# Build run state (persists across episodes)
|
||||
session_id = f"run-{args.seed}-{args.model}"
|
||||
run_state = RunState(
|
||||
session_id=session_id,
|
||||
|
|
@ -227,65 +238,109 @@ def run_benchmark(args):
|
|||
horizon_years=horizon_years,
|
||||
)
|
||||
|
||||
# 6. Set up live dashboard (or not)
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 7. Run agent loop
|
||||
loop_cfg = experiment_cfg.loop
|
||||
try:
|
||||
max_episodes = args.max_episodes
|
||||
carried_scratchpad = ""
|
||||
|
||||
for episode in range(1, max_episodes + 1):
|
||||
logger.info("=== Episode %d / %d ===", episode, max_episodes)
|
||||
|
||||
# 1. Build engine for this episode's DB
|
||||
db_url = _build_db_url(args, episode, max_episodes)
|
||||
os.environ["DATABASE_URL"] = db_url
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory(_factory=factory):
|
||||
with session_scope(_factory) as session:
|
||||
yield session
|
||||
|
||||
# 2. Init simulation
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 3. Restore scratchpad from previous episode
|
||||
if episode > 1 and carried_scratchpad:
|
||||
_write_scratchpad(db_factory, company_id, carried_scratchpad)
|
||||
logger.info("Restored scratchpad from episode %d (%d chars).", episode - 1, len(carried_scratchpad))
|
||||
|
||||
# 4. Set up live dashboard
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 5. Run agent loop for this episode
|
||||
try:
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
episode=episode,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
dashboard.print_final_summary(final_state)
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
# 6. For multi-episode runs, snapshot this episode's data
|
||||
if max_episodes > 1:
|
||||
run_state.finish_episode()
|
||||
|
||||
# 8. Print final summary
|
||||
if dashboard is not None:
|
||||
dashboard.print_final_summary(final_state)
|
||||
logger.info("Episode %d finished: reason=%s", episode, run_state.terminal_reason)
|
||||
|
||||
# 9. Save full rollout (with transcript) and print summary
|
||||
# 7. If not bankruptcy, or last episode, stop
|
||||
if run_state.terminal_reason != TerminalReason.BANKRUPTCY or episode == max_episodes:
|
||||
break
|
||||
|
||||
# 8. Save scratchpad for next episode, then reset
|
||||
carried_scratchpad = _read_scratchpad(db_factory, company_id)
|
||||
logger.info("Carrying scratchpad to episode %d (%d chars).", episode + 1, len(carried_scratchpad))
|
||||
|
||||
# Clear runtime session (fresh conversation history)
|
||||
runtime.clear_session(session_id)
|
||||
run_state.reset_for_new_episode()
|
||||
engine.dispose()
|
||||
|
||||
# 9. Save full rollout and print summary
|
||||
rollout = final_state.full_rollout()
|
||||
summary = final_state.summary()
|
||||
logger.info("Run complete: %s", json.dumps(summary, indent=2))
|
||||
|
||||
# Write full rollout (includes transcript with commands)
|
||||
slug = args.model.replace("/", "_")
|
||||
results_dir = Path("results")
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
results_path = results_dir / f"yc_bench_result_{args.config_name}_{args.seed}_{slug}.json"
|
||||
results_path.write_text(json.dumps(rollout, indent=2))
|
||||
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
|
||||
logger.info("Full rollout written to %s", results_path)
|
||||
|
||||
return 0 if final_state.terminal_reason != "error" else 1
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue