diff --git a/plots/multi_episode_haiku.png b/plots/multi_episode_haiku.png new file mode 100644 index 0000000..f806772 Binary files /dev/null and b/plots/multi_episode_haiku.png differ diff --git a/results/yc_bench_result_hard_1_openrouter_anthropic_claude-haiku-4-5.json b/results/yc_bench_result_hard_1_openrouter_anthropic_claude-haiku-4-5.json new file mode 100644 index 0000000..7a3ba3a --- /dev/null +++ b/results/yc_bench_result_hard_1_openrouter_anthropic_claude-haiku-4-5.json @@ -0,0 +1,43 @@ +{ + "session_id": "run-1-openrouter/anthropic/claude-haiku-4-5", + "model": "openrouter/anthropic/claude-haiku-4-5", + "seed": 1, + "horizon_years": 1, + "total_episodes": 3, + "terminal": true, + "terminal_reason": "bankruptcy", + "terminal_detail": "bankruptcy", + "started_at": "2026-03-12T01:47:36.650022+00:00", + "ended_at": "2026-03-12T01:57:25.802024+00:00", + "turns_completed": 193, + "total_cost_usd": 6.014085, + "episodes": [ + { + "episode": 1, + "turns_completed": 29, + "terminal_reason": "bankruptcy", + "terminal_detail": "bankruptcy", + "cost_usd": 0.737144, + "started_at": "2026-03-12T01:40:17.341900+00:00", + "ended_at": "2026-03-12T01:43:04.213804+00:00" + }, + { + "episode": 2, + "turns_completed": 43, + "terminal_reason": "bankruptcy", + "terminal_detail": "bankruptcy", + "cost_usd": 1.054148, + "started_at": "2026-03-12T01:43:04.602496+00:00", + "ended_at": "2026-03-12T01:47:36.518577+00:00" + }, + { + "episode": 3, + "turns_completed": 121, + "terminal_reason": "bankruptcy", + "terminal_detail": "bankruptcy", + "cost_usd": 4.222793, + "started_at": "2026-03-12T01:47:36.650022+00:00", + "ended_at": "2026-03-12T01:57:25.802024+00:00" + } + ] +} \ No newline at end of file diff --git a/scripts/plot_multi_episode.py b/scripts/plot_multi_episode.py new file mode 100644 index 0000000..daf1458 --- /dev/null +++ b/scripts/plot_multi_episode.py @@ -0,0 +1,229 @@ +"""Plot multi-episode benchmark: funds over time across episodes + scratchpad evolution.""" +import sqlite3 +from pathlib import Path +from datetime import datetime, timedelta + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import matplotlib.ticker as mticker +import matplotlib.gridspec as gridspec +from matplotlib.patches import FancyBboxPatch +import textwrap + +ROOT = Path(__file__).parent.parent +INITIAL_FUNDS_CENTS = 15_000_000 + +# ── Collinear brand palette ────────────────────────────────────────────────── +NAVY = "#13234D" +ORANGE = "#F26125" +BLUE = "#4D65FF" +BG_COLOR = "#FAFBFD" +GRID_CLR = "#E8ECF2" +TEXT_CLR = "#2A2F3D" +MUTED = "#6B7694" +CARD_BG = "#FFFFFF" + +EP_COLORS = [BLUE, ORANGE, "#22C55E"] +EP_LABELS = ["Episode 1", "Episode 2", "Episode 3"] + + +def load_episode(db_path): + """Load funds curve, task stats, and scratchpad from an episode DB.""" + con = sqlite3.connect(str(db_path)) + rows = con.execute( + "SELECT occurred_at, amount_cents, category FROM ledger_entries ORDER BY occurred_at ASC" + ).fetchall() + + tasks = con.execute( + "SELECT status, count(*) FROM tasks WHERE completed_at IS NOT NULL GROUP BY status" + ).fetchall() + task_stats = dict(tasks) + + scratchpad = con.execute("SELECT content FROM scratchpads LIMIT 1").fetchone() + scratchpad_text = scratchpad[0] if scratchpad else "" + + con.close() + + if not rows: + return None + + times, balances = [], [] + running = INITIAL_FUNDS_CENTS + start = datetime.fromisoformat(rows[0][0]).replace( + month=1, day=1, hour=9, minute=0, second=0, microsecond=0 + ) + times.append(start) + balances.append(running / 100) + for occurred_at, amount_cents, category in rows: + running += int(amount_cents) + t = datetime.fromisoformat(occurred_at) + times.append(t) + balances.append(running / 100) + + return { + "times": times, + "balances": balances, + "final_balance": balances[-1], + "task_success": task_stats.get("completed_success", 0), + "task_fail": task_stats.get("completed_fail", 0), + "scratchpad": scratchpad_text, + "duration_months": (times[-1] - times[0]).days / 30.0, + "bankrupt": balances[-1] <= 0, + } + + +def make_plot(episodes, model_label, seed, config): + fig = plt.figure(figsize=(20, 12), facecolor=BG_COLOR) + gs = gridspec.GridSpec(2, 3, figure=fig, height_ratios=[2.2, 1], + hspace=0.35, wspace=0.3, + left=0.07, right=0.97, top=0.82, bottom=0.06) + + # ── Header band ────────────────────────────────────────────────────── + header_rect = plt.Rectangle((0, 0.88), 1, 0.12, + transform=fig.transFigure, facecolor=NAVY, + edgecolor="none", zorder=0) + fig.patches.append(header_rect) + accent_rect = plt.Rectangle((0, 0.875), 1, 0.006, + transform=fig.transFigure, facecolor=ORANGE, + edgecolor="none", zorder=1) + fig.patches.append(accent_rect) + + fig.text(0.5, 0.94, + "YC-Bench | Multi-Episode Learning", + ha="center", va="center", + fontsize=32, fontweight="700", color="white", + fontfamily="Helvetica Neue", zorder=2) + fig.text(0.5, 0.895, + f"{model_label} | {config} config | seed {seed} | {len(episodes)} episodes", + ha="center", va="center", + fontsize=16, fontweight="400", color="#AABBDD", zorder=2) + + # ── Top row: funds over time (full width) ──────────────────────────── + ax_funds = fig.add_subplot(gs[0, :]) + ax_funds.set_facecolor(CARD_BG) + for spine in ax_funds.spines.values(): + spine.set_edgecolor(GRID_CLR) + spine.set_linewidth(1.2) + + for i, ep in enumerate(episodes): + color = EP_COLORS[i % len(EP_COLORS)] + survived = f"{ep['duration_months']:.0f}mo" + label = f"Ep {i+1}: {survived}, {ep['task_success']}W/{ep['task_fail']}L" + + ax_funds.plot(ep["times"], ep["balances"], + color=color, linewidth=2.8, alpha=0.9, + label=label, zorder=3 + i) + ax_funds.fill_between(ep["times"], 0, ep["balances"], + color=color, alpha=0.06, zorder=1) + + if ep["bankrupt"]: + ax_funds.scatter([ep["times"][-1]], [max(ep["balances"][-1], 500)], + color=color, marker="X", s=200, + linewidths=2, edgecolors="white", + alpha=0.9, zorder=5 + i) + + ax_funds.axhline(0, color="#DC2626", linewidth=1.2, linestyle="--", alpha=0.5, zorder=2, + label="Bankruptcy line") + ax_funds.set_ylabel("Company Funds ($)", fontsize=14, color=TEXT_CLR, fontweight="500") + ax_funds.yaxis.set_major_formatter( + mticker.FuncFormatter( + lambda x, _: f"${x/1e6:.1f}M" if x >= 1e6 + else f"${x/1e3:.0f}K" if x >= 1e3 + else f"${x:.0f}" + ) + ) + ax_funds.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y")) + ax_funds.xaxis.set_major_locator(mdates.MonthLocator(interval=1)) + plt.setp(ax_funds.xaxis.get_majorticklabels(), rotation=30, ha="right") + ax_funds.tick_params(colors=MUTED, labelsize=12) + ax_funds.grid(axis="y", color=GRID_CLR, linewidth=0.7, alpha=0.8) + ax_funds.grid(axis="x", color=GRID_CLR, linewidth=0.4, alpha=0.4) + ax_funds.legend(fontsize=12, facecolor=CARD_BG, edgecolor=GRID_CLR, + labelcolor=TEXT_CLR, loc="upper right", + framealpha=0.95, borderpad=1) + ax_funds.set_title("Funds Over Time — Each Episode Starts Fresh", + fontsize=16, fontweight="600", color=TEXT_CLR, pad=12) + + # ── Bottom row: 3 scratchpad panels ────────────────────────────────── + for i, ep in enumerate(episodes): + ax = fig.add_subplot(gs[1, i]) + ax.set_facecolor("#F8F9FC") + for spine in ax.spines.values(): + spine.set_edgecolor(EP_COLORS[i % len(EP_COLORS)]) + spine.set_linewidth(2) + + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xticks([]) + ax.set_yticks([]) + + # Title + color = EP_COLORS[i % len(EP_COLORS)] + ax.set_title(f"Episode {i+1} Scratchpad", + fontsize=13, fontweight="600", color=color, pad=8) + + # Scratchpad content (truncated) + text = ep["scratchpad"].strip() + if not text: + text = "(empty)" + # Take first ~8 lines, wrap to ~55 chars + lines = text.split("\n")[:10] + wrapped = [] + for line in lines: + if len(line) > 60: + wrapped.extend(textwrap.wrap(line, 58)) + else: + wrapped.append(line) + display = "\n".join(wrapped[:12]) + if len(wrapped) > 12 or len(lines) < text.count("\n") + 1: + display += "\n..." + + ax.text(0.05, 0.92, display, + transform=ax.transAxes, + fontsize=7.5, fontfamily="monospace", + color=TEXT_CLR, verticalalignment="top", + linespacing=1.4) + + # Stats badge + stats = f"{ep['task_success']}W / {ep['task_fail']}L | {ep['duration_months']:.0f} months" + ax.text(0.5, 0.02, stats, + transform=ax.transAxes, ha="center", + fontsize=9, fontweight="600", color=MUTED) + + # ── Footer ─────────────────────────────────────────────────────────── + fig.text(0.5, 0.01, + "collinear.ai | Multi-Episode YC-Bench: Scratchpad carries over between bankruptcies", + ha="center", va="bottom", + fontsize=12, fontweight="400", color=MUTED, fontstyle="italic") + + out = ROOT / "plots" / "multi_episode_haiku.png" + out.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out, dpi=150, facecolor=BG_COLOR, pad_inches=0) + print(f"Saved: {out}") + + +if __name__ == "__main__": + db_dir = ROOT / "db" + slug = "openrouter_anthropic_claude-haiku-4-5" + config = "hard" + seed = 1 + + episodes = [] + for ep_num in [1, 2, 3]: + db_path = db_dir / f"{config}_{seed}_{slug}.ep{ep_num}.db" + if not db_path.exists(): + print(f"Skipping {db_path} (not found)") + continue + data = load_episode(db_path) + if data: + episodes.append(data) + print(f"Episode {ep_num}: {data['task_success']}W/{data['task_fail']}L, " + f"survived {data['duration_months']:.1f}mo, " + f"final ${data['final_balance']:,.0f}") + + if episodes: + make_plot(episodes, "Claude Haiku 4.5", seed, config) + else: + print("No episode data found.") diff --git a/src/yc_bench/agent/loop.py b/src/yc_bench/agent/loop.py index 7c192e9..fee8f98 100644 --- a/src/yc_bench/agent/loop.py +++ b/src/yc_bench/agent/loop.py @@ -118,6 +118,7 @@ def run_agent_loop( max_turns: int | None = None, on_turn_start=None, on_turn=None, + episode: int = 1, ) -> RunState: run_state.start() turns_since_resume = 0 # consecutive turns without sim resume @@ -136,7 +137,7 @@ def run_agent_loop( if run_state.turn_count == 0: with db_factory() as db: snapshot = _snapshot_state(db, company_id) - user_input = build_initial_user_prompt(**snapshot) + user_input = build_initial_user_prompt(**snapshot, episode=episode) elif run_state.next_user_input is not None: user_input = run_state.next_user_input run_state.next_user_input = None diff --git a/src/yc_bench/agent/prompt.py b/src/yc_bench/agent/prompt.py index c2ae796..b8da261 100644 --- a/src/yc_bench/agent/prompt.py +++ b/src/yc_bench/agent/prompt.py @@ -173,6 +173,7 @@ def build_initial_user_prompt( employee_count: int, monthly_payroll_cents: int, bankrupt: bool, + episode: int = 1, ) -> str: """Build the one-time initial user message at run start.""" runway_months = ( @@ -188,7 +189,18 @@ def build_initial_user_prompt( ) runway_str = f"~{runway_months} months" if runway_months is not None else "∞" - lines = [ + lines = [] + if episode > 1: + lines.extend([ + f"## Episode {episode} — Restarting After Bankruptcy", + "", + f"You went bankrupt in episode {episode - 1}. The simulation has been reset,", + "but your **scratchpad notes from the previous episode are preserved**.", + "Read your scratchpad (`yc-bench scratchpad read`) to review your notes", + "and learn from past mistakes before taking action.", + "", + ]) + lines.extend([ "## Simulation Start — Take Immediate Action", f"- current_time: {sim_time}", f"- horizon_end: {horizon_end}", @@ -210,7 +222,7 @@ def build_initial_user_prompt( "7. `yc-bench sim resume` — advance time to collect the first task completion event", "", "Do not spend multiple turns just browsing. Accept and dispatch tasks immediately.", - ] + ]) if bankrupt: lines.append("WARNING: company is already bankrupt at initialization.") return "\n".join(lines) diff --git a/src/yc_bench/agent/run_state.py b/src/yc_bench/agent/run_state.py index e4245e1..1ccea78 100644 --- a/src/yc_bench/agent/run_state.py +++ b/src/yc_bench/agent/run_state.py @@ -41,6 +41,10 @@ class RunState: next_user_input: Optional[str] = None total_cost_usd: float = 0.0 + # Multi-episode tracking + current_episode: int = 1 + episode_results: List[Dict[str, Any]] = field(default_factory=list) + def start(self) -> None: self.started_at = datetime.now(timezone.utc).isoformat() @@ -66,18 +70,14 @@ class RunState: return True return False - def full_rollout(self) -> Dict[str, Any]: - """Full results including transcript for saving to disk.""" - return { - "session_id": self.session_id, - "model": self.model, - "seed": self.seed, - "horizon_years": self.horizon_years, + def finish_episode(self) -> Dict[str, Any]: + """Snapshot current episode state into episode_results and return it.""" + episode_data = { + "episode": self.current_episode, "turns_completed": self.turn_count, - "terminal": self.terminal, "terminal_reason": self.terminal_reason.value if self.terminal_reason else None, "terminal_detail": self.terminal_detail, - "total_cost_usd": round(self.total_cost_usd, 6), + "cost_usd": round(self.total_cost_usd, 6), "started_at": self.started_at, "ended_at": self.ended_at, "transcript": [ @@ -91,11 +91,66 @@ class RunState: for t in self.transcript ], } + self.episode_results.append(episode_data) + return episode_data + + def reset_for_new_episode(self) -> None: + """Reset mutable state for a new episode, preserving episode_results.""" + self.current_episode += 1 + self.turn_count = 0 + self.terminal = False + self.terminal_reason = None + self.terminal_detail = None + self.started_at = None + self.ended_at = None + self.transcript = [] + self.next_user_input = None + self.total_cost_usd = 0.0 + + def full_rollout(self) -> Dict[str, Any]: + """Full results including transcript for saving to disk.""" + base = { + "session_id": self.session_id, + "model": self.model, + "seed": self.seed, + "horizon_years": self.horizon_years, + "total_episodes": self.current_episode, + "terminal": self.terminal, + "terminal_reason": self.terminal_reason.value if self.terminal_reason else None, + "terminal_detail": self.terminal_detail, + "started_at": self.started_at, + "ended_at": self.ended_at, + } + if self.episode_results: + # Multi-episode: include all episode data + total_turns = sum(ep["turns_completed"] for ep in self.episode_results) + total_cost = sum(ep["cost_usd"] for ep in self.episode_results) + base["turns_completed"] = total_turns + base["total_cost_usd"] = round(total_cost, 6) + base["episodes"] = self.episode_results + else: + # Single-episode (backward compat): flat transcript + base["turns_completed"] = self.turn_count + base["total_cost_usd"] = round(self.total_cost_usd, 6) + base["transcript"] = [ + { + "turn": t.turn, + "timestamp": t.timestamp, + "user_input": t.user_input, + "agent_output": t.agent_output, + "commands_executed": t.commands_executed, + } + for t in self.transcript + ] + return base def summary(self) -> Dict[str, Any]: """Summary without transcript for logging.""" rollout = self.full_rollout() rollout.pop("transcript", None) + if "episodes" in rollout: + for ep in rollout["episodes"]: + ep.pop("transcript", None) return rollout diff --git a/src/yc_bench/cli/__init__.py b/src/yc_bench/cli/__init__.py index 4af50ec..f33e1b0 100644 --- a/src/yc_bench/cli/__init__.py +++ b/src/yc_bench/cli/__init__.py @@ -99,6 +99,10 @@ def run_command_cli( help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file", ), no_live: bool = typer.Option(False, "--no-live", help="Disable the live terminal dashboard"), + max_episodes: int = typer.Option( + 1, "--max-episodes", + help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1", + ), ): """Run a full benchmark: migrate DB, seed world, run agent loop to completion.""" from dotenv import find_dotenv, load_dotenv @@ -114,6 +118,7 @@ def run_command_cli( start_date=start_date, config_name=config_name, no_live=no_live, + max_episodes=max_episodes, ) raise SystemExit(run_benchmark(args)) diff --git a/src/yc_bench/runner/args.py b/src/yc_bench/runner/args.py index d8bd705..83c5719 100644 --- a/src/yc_bench/runner/args.py +++ b/src/yc_bench/runner/args.py @@ -13,6 +13,7 @@ class RunArgs: start_date: str config_name: str = "default" no_live: bool = False + max_episodes: int = 1 def build_parser(): parser = argparse.ArgumentParser( @@ -32,6 +33,10 @@ def build_parser(): "--no-live", action="store_true", default=False, help="Disable the live terminal dashboard (show raw log output instead)", ) + parser.add_argument( + "--max-episodes", type=int, default=1, + help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1", + ) return parser def parse_run_args(argv): @@ -46,10 +51,13 @@ def parse_run_args(argv): start_date=ns.start_date, config_name=ns.config_name, no_live=ns.no_live, + max_episodes=ns.max_episodes, ) def _validate(ns, parser): if ns.horizon_years is not None and ns.horizon_years <= 0: parser.error("--horizon-years must be int > 0") + if ns.max_episodes < 1: + parser.error("--max-episodes must be int >= 1") __all__ = ["RunArgs", "build_parser", "parse_run_args"] diff --git a/src/yc_bench/runner/main.py b/src/yc_bench/runner/main.py index 8efecc9..e4f659a 100644 --- a/src/yc_bench/runner/main.py +++ b/src/yc_bench/runner/main.py @@ -11,7 +11,7 @@ from pathlib import Path from ..agent.commands.executor import run_command from ..agent.loop import run_agent_loop -from ..agent.run_state import RunState +from ..agent.run_state import RunState, TerminalReason from ..agent.runtime.factory import build_runtime from ..agent.runtime.schemas import RuntimeSettings from ..db.session import build_engine, build_session_factory, session_scope, init_db @@ -154,6 +154,37 @@ def _redirect_all_logging_to_file(log_file: Path) -> None: lg.addHandler(file_handler) +def _build_db_url(args, episode: int, max_episodes: int) -> str: + """Build SQLite DATABASE_URL, adding episode suffix when multi-episode.""" + slug = args.model.replace("/", "_") + db_dir = Path("db") + db_dir.mkdir(exist_ok=True) + base = f"{args.config_name}_{args.seed}_{slug}" + if max_episodes > 1: + return f"sqlite:///{db_dir}/{base}.ep{episode}.db" + return f"sqlite:///{db_dir}/{base}.db" + + +def _read_scratchpad(db_factory, company_id) -> str: + """Read scratchpad content from the current DB.""" + from ..db.models.scratchpad import Scratchpad + with db_factory() as db: + row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first() + return row.content if row else "" + + +def _write_scratchpad(db_factory, company_id, content: str) -> None: + """Write scratchpad content into the current DB (upsert).""" + from ..db.models.scratchpad import Scratchpad + with db_factory() as db: + row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first() + if row is None: + db.add(Scratchpad(company_id=company_id, content=content)) + else: + row.content = content + db.flush() + + def run_benchmark(args): """Run a full benchmark: migrate, seed, loop until terminal.""" logging.basicConfig( @@ -184,31 +215,11 @@ def run_benchmark(args): _redirect_all_logging_to_file(log_file) logger.info( - "YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy", - experiment_cfg.name, args.model, args.seed, horizon_years, + "YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy max_episodes=%d", + experiment_cfg.name, args.model, args.seed, horizon_years, args.max_episodes, ) - # 1. Build engine and create all tables - # If DATABASE_URL is not explicitly set, default to db/__.db - if not os.environ.get("DATABASE_URL"): - slug = args.model.replace("/", "_") - db_dir = Path("db") - db_dir.mkdir(exist_ok=True) - os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.config_name}_{args.seed}_{slug}.db" - - engine = build_engine() - init_db(engine) - factory = build_session_factory(engine) - - @contextmanager - def db_factory(): - with session_scope(factory) as session: - yield session - - # 3. Init simulation using experiment world config - company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years) - - # 4. Build runtime settings from experiment AgentConfig + # Build runtime settings (shared across episodes) agent_cfg = experiment_cfg.agent settings = RuntimeSettings( model=agent_cfg.model, @@ -222,7 +233,7 @@ def run_benchmark(args): ) runtime = build_runtime(settings, command_executor=run_command) - # 5. Build run state + # Build run state (persists across episodes) session_id = f"run-{args.seed}-{args.model}" run_state = RunState( session_id=session_id, @@ -231,66 +242,110 @@ def run_benchmark(args): horizon_years=horizon_years, ) - # 6. Set up live dashboard (or not) - dashboard = None - on_turn_start = None - on_turn = None - - if use_live: - from .dashboard import BenchmarkDashboard - - dashboard = BenchmarkDashboard( - model=args.model, - seed=args.seed, - config_name=args.config_name, - db_factory=db_factory, - company_id=company_id, - ) - - def on_turn_start(turn_num): - dashboard.mark_turn_start(turn_num) - - def on_turn(snapshot, rs, commands): - dashboard.update(snapshot, rs, commands) - - # 7. Run agent loop loop_cfg = experiment_cfg.loop - try: + max_episodes = args.max_episodes + carried_scratchpad = "" + + for episode in range(1, max_episodes + 1): + logger.info("=== Episode %d / %d ===", episode, max_episodes) + + # 1. Build engine for this episode's DB + db_url = _build_db_url(args, episode, max_episodes) + os.environ["DATABASE_URL"] = db_url + engine = build_engine() + init_db(engine) + factory = build_session_factory(engine) + + @contextmanager + def db_factory(_factory=factory): + with session_scope(_factory) as session: + yield session + + # 2. Init simulation + company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years) + + # 3. Restore scratchpad from previous episode + if episode > 1 and carried_scratchpad: + _write_scratchpad(db_factory, company_id, carried_scratchpad) + logger.info("Restored scratchpad from episode %d (%d chars).", episode - 1, len(carried_scratchpad)) + + # 4. Set up live dashboard + dashboard = None + on_turn_start = None + on_turn = None + + if use_live: + from .dashboard import BenchmarkDashboard + + dashboard = BenchmarkDashboard( + model=args.model, + seed=args.seed, + config_name=args.config_name, + db_factory=db_factory, + company_id=company_id, + ) + + def on_turn_start(turn_num): + dashboard.mark_turn_start(turn_num) + + def on_turn(snapshot, rs, commands): + dashboard.update(snapshot, rs, commands) + + # 5. Run agent loop for this episode + try: + if dashboard is not None: + dashboard.start() + + final_state = run_agent_loop( + runtime=runtime, + db_factory=db_factory, + company_id=company_id, + run_state=run_state, + command_executor=run_command, + auto_advance_after_turns=loop_cfg.auto_advance_after_turns, + max_turns=loop_cfg.max_turns, + on_turn_start=on_turn_start, + on_turn=on_turn, + episode=episode, + ) + finally: + if dashboard is not None: + dashboard.stop() + if dashboard is not None: - dashboard.start() + dashboard.print_final_summary(final_state) - final_state = run_agent_loop( - runtime=runtime, - db_factory=db_factory, - company_id=company_id, - run_state=run_state, - command_executor=run_command, - auto_advance_after_turns=loop_cfg.auto_advance_after_turns, - max_turns=loop_cfg.max_turns, - on_turn_start=on_turn_start, - on_turn=on_turn, - ) - finally: - if dashboard is not None: - dashboard.stop() + # 6. For multi-episode runs, snapshot this episode's data + if max_episodes > 1: + run_state.finish_episode() - # 8. Print final summary - if dashboard is not None: - dashboard.print_final_summary(final_state) + logger.info("Episode %d finished: reason=%s", episode, run_state.terminal_reason) - # 9. Save full rollout (with transcript) and print summary + # 7. If not bankruptcy, or last episode, stop + if run_state.terminal_reason != TerminalReason.BANKRUPTCY or episode == max_episodes: + break + + # 8. Save scratchpad for next episode, then reset + carried_scratchpad = _read_scratchpad(db_factory, company_id) + logger.info("Carrying scratchpad to episode %d (%d chars).", episode + 1, len(carried_scratchpad)) + + # Clear runtime session (fresh conversation history) + runtime.clear_session(session_id) + run_state.reset_for_new_episode() + engine.dispose() + + # 9. Save full rollout and print summary rollout = final_state.full_rollout() rollout["time_series"] = extract_time_series(db_factory, company_id) summary = final_state.summary() logger.info("Run complete: %s", json.dumps(summary, indent=2)) - # Write full rollout (includes transcript with commands) slug = args.model.replace("/", "_") results_dir = Path("results") results_dir.mkdir(exist_ok=True) results_path = results_dir / f"yc_bench_result_{args.config_name}_{args.seed}_{slug}.json" results_path.write_text(json.dumps(rollout, indent=2)) - logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", []))) + logger.info("Full rollout written to %s", results_path) return 0 if final_state.terminal_reason != "error" else 1