"""Multi-model comparison plot: funds curves + cost vs budget. Run from the repo root: uv run python scripts/plot_multi_model.py [--seed 1] [--config hard] [--budget 30] """ import argparse import json import sqlite3 from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.dates as mdates from datetime import datetime ROOT = Path(__file__).parent.parent # --------------------------------------------------------------------------- # Default runs — edit or pass via CLI # --------------------------------------------------------------------------- DEFAULT_RUNS = [ { "label": "gemini-flash", "model_slug": "openrouter_google_gemini-3-flash-preview", "color": "#4fc3f7", }, { "label": "minimax-m2.5", "model_slug": "openrouter_minimax_minimax-m2.5", "color": "#f39c12", }, { "label": "kimi-k2.5", "model_slug": "openrouter_moonshotai_kimi-k2.5", "color": "#2ecc71", }, ] INITIAL_FUNDS_CENTS = 15_000_000 # $150K (default; presets may override) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=1) p.add_argument("--config", default="hard") p.add_argument("--budget", type=float, default=30.0) p.add_argument( "--out", default=None, help="Output PNG path (default: plots/funds_curves.png)" ) return p.parse_args() def load_funds_curve(db_path: Path): """Reconstruct running balance from ledger entries.""" con = sqlite3.connect(str(db_path)) rows = con.execute( "SELECT occurred_at, amount_cents FROM ledger_entries ORDER BY occurred_at ASC" ).fetchall() con.close() if not rows: return [], [] times, balances = [], [] running = INITIAL_FUNDS_CENTS # Prepend the sim start (day before first ledger event, pegged to Jan 1) start = datetime.fromisoformat(rows[0][0]).replace( month=1, day=1, hour=9, minute=0, second=0, microsecond=0 ) times.append(start) balances.append(INITIAL_FUNDS_CENTS / 100) for occurred_at, amount_cents in rows: running += int(amount_cents) times.append(datetime.fromisoformat(occurred_at)) balances.append(running / 100) return times, balances def load_meta(result_path: Path): with open(result_path) as f: d = json.load(f) return { "turns": d.get("turns_completed", 0), "terminal_reason": d.get("terminal_reason", "unknown"), "cost_usd": d.get("total_cost_usd", 0.0), "horizon_years": d.get("horizon_years", 3), } def load_run_data(runs, seed): run_data = [] for run in runs: slug = run["model_slug"] db_path = ROOT / "db" / f"{seed}_{slug}.db" result_path = ROOT / "results" / f"yc_bench_result_{seed}_{slug}.json" if not db_path.exists() or not result_path.exists(): print( f" Skipping {run['label']}: missing {db_path.name} or {result_path.name}" ) continue times, balances = load_funds_curve(db_path) meta = load_meta(result_path) run_data.append({**run, "times": times, "balances": balances, **meta}) print( f" Loaded {run['label']}: {meta['turns']} turns, {meta['terminal_reason']}, ${meta['cost_usd']:.4f}" ) return run_data # --------------------------------------------------------------------------- # Plot # --------------------------------------------------------------------------- def make_plot(run_data, seed, config_name, budget_usd, out_path: Path): total_cost = sum(r["cost_usd"] for r in run_data) budget_pct = (total_cost / budget_usd) * 100 if budget_usd else 0 fig, (ax_funds, ax_cost) = plt.subplots( 1, 2, figsize=(17, 6.5), facecolor="#0f1117", gridspec_kw={"width_ratios": [3, 1]}, ) for ax in [ax_funds, ax_cost]: ax.set_facecolor("#1a1d27") ax.tick_params(colors="#aaaaaa", labelsize=9) for spine in ax.spines.values(): spine.set_edgecolor("#333344") # Horizon annotation (approx end date) horizon_years = run_data[0]["horizon_years"] if run_data else 3 horizon_label = f"{horizon_years}-year horizon" fig.suptitle( f"YC-Bench · {len(run_data)}-Model Comparison · seed={seed} · {config_name} ({horizon_label}, 5 employees)\n" f"Total API spend: ${total_cost:.2f} / ${budget_usd:.0f} budget ({budget_pct:.1f}%)", color="white", fontsize=13, y=1.02, ) # ── Funds curves ───────────────────────────────────────────────────────── ax_funds.axhline( 0, color="#e74c3c", linewidth=0.9, linestyle="--", alpha=0.4, zorder=1 ) ax_funds.axhline( INITIAL_FUNDS_CENTS / 100, color="#555577", linewidth=0.7, linestyle=":", alpha=0.6, zorder=1, ) for r in run_data: if not r["times"]: continue reason = r["terminal_reason"] turns = r["turns"] reason_short = { "bankruptcy": "bankrupt", "horizon_end": "survived!", "error": "error", }.get(reason, reason) label = f"{r['label']} ({turns}t · {reason_short})" lw = 2.2 if reason == "horizon_end" else 1.8 alpha = 1.0 if reason == "horizon_end" else 0.85 ax_funds.plot( r["times"], r["balances"], color=r["color"], linewidth=lw, alpha=alpha, label=label, zorder=3, ) # Mark terminal point marker = "★" if reason == "horizon_end" else "x" msize = 100 if reason == "horizon_end" else 70 ax_funds.scatter( [r["times"][-1]], [r["balances"][-1]], color=r["color"], s=msize, zorder=5, marker="*" if reason == "horizon_end" else "x", linewidths=2, ) ax_funds.yaxis.set_major_formatter( plt.FuncFormatter( lambda x, _: ( f"${x/1000:.0f}K" if abs(x) < 1_000_000 else f"${x/1_000_000:.1f}M" ) ) ) ax_funds.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y")) ax_funds.xaxis.set_major_locator(mdates.MonthLocator(interval=3)) plt.setp(ax_funds.xaxis.get_majorticklabels(), rotation=30, ha="right") ax_funds.set_title("Company Funds Over Sim Time", color="white", fontsize=12, pad=8) ax_funds.set_ylabel("Balance (USD)", color="#aaaaaa", fontsize=9) ax_funds.legend( fontsize=9, facecolor="#1a1d27", edgecolor="#333344", labelcolor="white", loc="upper right", ) ax_funds.grid(axis="y", color="#333344", linewidth=0.5, linestyle="--") ax_funds.text( 0.005, 0.03, "← bankruptcy", transform=ax_funds.transAxes, color="#e74c3c", fontsize=7.5, alpha=0.6, ) # ── Cost vs budget bars ────────────────────────────────────────────────── labels = [r["label"] for r in run_data] costs = [r["cost_usd"] for r in run_data] colors = [r["color"] for r in run_data] y_pos = list(range(len(labels))) bars = ax_cost.barh(y_pos, costs, color=colors, alpha=0.85, height=0.55, zorder=3) # Total bar total_y = len(labels) + 0.6 ax_cost.barh( total_y, total_cost, color="#f1c40f", alpha=0.45, height=0.45, zorder=3 ) # Budget line ax_cost.axvline( budget_usd, color="#f1c40f", linewidth=1.6, linestyle="--", zorder=4, label=f"${budget_usd:.0f} budget", ) # Value labels for i, cost in enumerate(costs): ax_cost.text( cost + budget_usd * 0.01, i, f"${cost:.3f}", va="center", color="#dddddd", fontsize=8, ) ax_cost.text( total_cost + budget_usd * 0.01, total_y, f"${total_cost:.2f}", va="center", color="#f1c40f", fontsize=8.5, fontweight="bold", ) ax_cost.set_yticks(y_pos + [total_y]) ax_cost.set_yticklabels(labels + ["TOTAL"], color="#aaaaaa", fontsize=8) ax_cost.set_xlabel("API Cost (USD)", color="#aaaaaa", fontsize=9) ax_cost.set_title( f"Cost vs ${budget_usd:.0f} Budget", color="white", fontsize=12, pad=8 ) ax_cost.legend( fontsize=8, facecolor="#1a1d27", edgecolor="#333344", labelcolor="white" ) ax_cost.set_xlim(0, max(budget_usd * 1.15, max(costs) * 1.4 if costs else 1)) ax_cost.grid(axis="x", color="#333344", linewidth=0.5, linestyle="--") plt.tight_layout() out_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) print(f"Saved: {out_path}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": args = parse_args() out = Path(args.out) if args.out else ROOT / "plots" / "funds_curves.png" print(f"Loading runs: seed={args.seed}, config={args.config}") run_data = load_run_data(DEFAULT_RUNS, args.seed) if not run_data: print("No data found. Run benchmarks first.") else: make_plot(run_data, args.seed, args.config, args.budget, out)