mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-19 12:58:03 +00:00
187 lines
8.9 KiB
Python
187 lines
8.9 KiB
Python
"""Plot a benchmark run: funds over time, prestige evolution, task outcomes."""
|
|
import os
|
|
import sys
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.gridspec as gridspec
|
|
import matplotlib.dates as mdates
|
|
import numpy as np
|
|
|
|
os.environ.setdefault("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/yc_bench")
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from src.bench.db.session import build_engine, build_session_factory, session_scope
|
|
from src.bench.db.models.ledger import LedgerEntry, LedgerCategory
|
|
from src.bench.db.models.task import Task, TaskRequirement, TaskStatus
|
|
from src.bench.db.models.company import CompanyPrestige
|
|
|
|
engine = build_engine()
|
|
factory = build_session_factory(engine)
|
|
|
|
DOMAIN_COLORS = {
|
|
"research": "#3498db",
|
|
"inference": "#9b59b6",
|
|
"data_environment": "#1abc9c",
|
|
"training": "#e67e22",
|
|
}
|
|
|
|
with session_scope(factory) as db:
|
|
# --- Ledger: reconstruct running balance ---
|
|
entries = (
|
|
db.query(LedgerEntry)
|
|
.order_by(LedgerEntry.occurred_at)
|
|
.all()
|
|
)
|
|
initial_funds = 25_000_000
|
|
times, balances, categories = [], [], []
|
|
running = initial_funds
|
|
for e in entries:
|
|
running += int(e.amount_cents)
|
|
times.append(e.occurred_at)
|
|
balances.append(running / 100)
|
|
categories.append(e.category)
|
|
|
|
# --- Tasks ---
|
|
tasks = (
|
|
db.query(Task)
|
|
.filter(Task.completed_at.isnot(None))
|
|
.order_by(Task.completed_at)
|
|
.all()
|
|
)
|
|
task_times, task_rewards, task_success, task_prestige = [], [], [], []
|
|
for t in tasks:
|
|
task_times.append(t.completed_at)
|
|
task_rewards.append(int(t.reward_funds_cents) / 100)
|
|
task_success.append(t.status == TaskStatus.COMPLETED_SUCCESS)
|
|
task_prestige.append(t.required_prestige)
|
|
|
|
# --- Prestige per domain (sampled from task completions) ---
|
|
# Build prestige history by replaying prestige deltas
|
|
from src.bench.db.models.task import TaskRequirement
|
|
from src.bench.db.models.company import Domain
|
|
prestige_history = {d.value: [(times[0] if times else None, 1.0)] for d in Domain}
|
|
|
|
completed = [t for t in tasks if t.completed_at]
|
|
completed.sort(key=lambda t: t.completed_at)
|
|
|
|
current_prestige = {d.value: 1.0 for d in Domain}
|
|
for t in completed:
|
|
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == t.id).all()
|
|
for req in reqs:
|
|
d = req.domain.value
|
|
if t.status == TaskStatus.COMPLETED_SUCCESS:
|
|
current_prestige[d] = min(10.0, current_prestige[d] + float(t.reward_prestige_delta))
|
|
else:
|
|
penalty = 1.4 * float(t.reward_prestige_delta)
|
|
current_prestige[d] = max(1.0, current_prestige[d] - penalty)
|
|
prestige_history[d].append((t.completed_at, current_prestige[d]))
|
|
|
|
# Final prestige from DB
|
|
final_prestige = {
|
|
row.domain.value: float(row.prestige_level)
|
|
for row in db.query(CompanyPrestige).all()
|
|
}
|
|
|
|
# ── Plot ────────────────────────────────────────────────────────────────
|
|
fig = plt.figure(figsize=(16, 10), facecolor="#0f1117")
|
|
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.35)
|
|
|
|
ax_funds = fig.add_subplot(gs[0, :]) # full width top
|
|
ax_prestige = fig.add_subplot(gs[1, 0])
|
|
ax_tasks = fig.add_subplot(gs[1, 1])
|
|
|
|
for ax in [ax_funds, ax_prestige, ax_tasks]:
|
|
ax.set_facecolor("#1a1d27")
|
|
ax.tick_params(colors="#aaaaaa", labelsize=9)
|
|
for spine in ax.spines.values():
|
|
spine.set_edgecolor("#333344")
|
|
|
|
# ── Funds over time ──────────────────────────────────────────────────────
|
|
payroll_times = [t for t, c in zip(times, categories) if c == LedgerCategory.MONTHLY_PAYROLL]
|
|
payroll_vals = [b for b, c in zip(balances, categories) if c == LedgerCategory.MONTHLY_PAYROLL]
|
|
reward_times = [t for t, c in zip(times, categories) if c == LedgerCategory.TASK_REWARD]
|
|
reward_vals = [b for b, c in zip(balances, categories) if c == LedgerCategory.TASK_REWARD]
|
|
|
|
ax_funds.plot(times, balances, color="#4fc3f7", linewidth=1.8, zorder=3, label="Balance")
|
|
ax_funds.fill_between(times, [b / max(balances) * min(balances) * 0.5 for b in balances],
|
|
balances, alpha=0.08, color="#4fc3f7", zorder=2)
|
|
ax_funds.scatter(reward_times, reward_vals, color="#2ecc71", s=30, zorder=5,
|
|
label="Task reward", marker="^")
|
|
ax_funds.scatter(payroll_times, payroll_vals, color="#e74c3c", s=20, zorder=5,
|
|
label="Payroll", marker="v", alpha=0.7)
|
|
|
|
ax_funds.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"${x/1000:.0f}K" if x < 1_000_000 else f"${x/1_000_000:.1f}M"))
|
|
ax_funds.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y"))
|
|
ax_funds.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
|
|
plt.setp(ax_funds.xaxis.get_majorticklabels(), rotation=30, ha="right")
|
|
ax_funds.set_title("Company Funds Over Time", color="white", fontsize=12, pad=8)
|
|
ax_funds.set_ylabel("Balance", color="#aaaaaa", fontsize=9)
|
|
ax_funds.legend(fontsize=8, facecolor="#1a1d27", edgecolor="#333344",
|
|
labelcolor="white", loc="upper left")
|
|
ax_funds.grid(axis="y", color="#333344", linewidth=0.5, linestyle="--")
|
|
|
|
# ── Prestige evolution ───────────────────────────────────────────────────
|
|
for domain, history in prestige_history.items():
|
|
hist_times = [h[0] for h in history if h[0] is not None]
|
|
hist_vals = [h[1] for h in history if h[0] is not None]
|
|
if len(hist_times) < 2:
|
|
continue
|
|
color = DOMAIN_COLORS.get(domain, "#aaaaaa")
|
|
ax_prestige.step(hist_times, hist_vals, where="post",
|
|
color=color, linewidth=1.6, label=domain)
|
|
|
|
ax_prestige.axhline(y=1.0, color="#555566", linewidth=0.8, linestyle=":")
|
|
ax_prestige.set_ylim(0.8, 10.5)
|
|
ax_prestige.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y"))
|
|
ax_prestige.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
|
|
plt.setp(ax_prestige.xaxis.get_majorticklabels(), rotation=30, ha="right")
|
|
ax_prestige.set_title("Prestige by Domain", color="white", fontsize=12, pad=8)
|
|
ax_prestige.set_ylabel("Prestige Level", color="#aaaaaa", fontsize=9)
|
|
ax_prestige.legend(fontsize=7.5, facecolor="#1a1d27", edgecolor="#333344",
|
|
labelcolor="white", ncol=2, loc="upper left")
|
|
ax_prestige.grid(axis="y", color="#333344", linewidth=0.5, linestyle="--")
|
|
|
|
# ── Task outcomes scatter ────────────────────────────────────────────────
|
|
if task_times:
|
|
colors = ["#2ecc71" if s else "#e74c3c" for s in task_success]
|
|
scatter = ax_tasks.scatter(
|
|
task_times, task_rewards,
|
|
c=colors, s=[40 + p * 12 for p in task_prestige],
|
|
alpha=0.85, zorder=4, edgecolors="none"
|
|
)
|
|
# Annotate prestige on each dot
|
|
for t, r, p, s in zip(task_times, task_rewards, task_prestige, task_success):
|
|
ax_tasks.annotate(f"p{p}", (t, r), fontsize=6, color="#cccccc",
|
|
xytext=(3, 3), textcoords="offset points")
|
|
|
|
ax_tasks.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"${x/1000:.0f}K" if x < 1_000_000 else f"${x/1_000_000:.1f}M"))
|
|
ax_tasks.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y"))
|
|
ax_tasks.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
|
|
plt.setp(ax_tasks.xaxis.get_majorticklabels(), rotation=30, ha="right")
|
|
ax_tasks.set_title("Task Outcomes (▲ success ● fail, size = prestige req)", color="white", fontsize=10, pad=8)
|
|
ax_tasks.set_ylabel("Reward Value", color="#aaaaaa", fontsize=9)
|
|
ax_tasks.grid(axis="y", color="#333344", linewidth=0.5, linestyle="--")
|
|
|
|
# Legend patches
|
|
from matplotlib.patches import Patch
|
|
ax_tasks.legend(handles=[
|
|
Patch(color="#2ecc71", label=f"Success ({sum(task_success)})"),
|
|
Patch(color="#e74c3c", label=f"Fail ({sum(not s for s in task_success)})"),
|
|
], fontsize=8, facecolor="#1a1d27", edgecolor="#333344", labelcolor="white")
|
|
|
|
# ── Summary annotation ───────────────────────────────────────────────────
|
|
final_bal = balances[-1] if balances else 0
|
|
fig.text(0.5, 0.97,
|
|
f"minimax-m2.5 | seed=42 | harder config | "
|
|
f"150 turns | Aug 2025 sim time | "
|
|
f"final balance ${final_bal/1_000_000:.2f}M | "
|
|
f"{sum(task_success)}/{len(task_success)} tasks succeeded",
|
|
ha="center", va="top", color="#aaaaaa", fontsize=9)
|
|
|
|
out = Path("plot_run_hard.png")
|
|
plt.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
|
print(f"Saved: {out}")
|