mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-19 12:58:03 +00:00
- Hide exact reward_multiplier from agent; show tier (Standard/Premium/Enterprise) and specialty domains instead - Add client domain specialization with 70% bias on task generation toward client specialties - Remove qty_scale by multiplier (leaked info and doubly punished high-mult clients) - Rewrite agent prompt to describe tiers/specialties without exact formulas - Fix critical loop.py bug: provide full state context after sim resume (prevents idle multi-month skips) - Add Streamlit dashboard, watch scripts, and updated plotting/extraction Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""Plot benchmark results from JSON result files.
|
|
|
|
Usage:
|
|
uv run python scripts/plot_results.py results/yc_bench_result_hard_1_gemini*.json
|
|
uv run python scripts/plot_results.py results/*.json --out plots/all_runs.png
|
|
uv run python scripts/plot_results.py results/run.json --plot prestige
|
|
"""
|
|
import argparse
|
|
import json
|
|
import re
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.dates as mdates
|
|
|
|
# Regex for turn 1 format (initial prompt)
|
|
RE_INIT_TIME = re.compile(r"current_time:\s*(\S+)")
|
|
RE_INIT_FUNDS = re.compile(r"funds:\s*\$([\d,.]+)")
|
|
|
|
# Regex for subsequent turns
|
|
RE_TURN_TIME = re.compile(r"\*\*Current time\*\*:\s*(\S+)")
|
|
RE_TURN_FUNDS = re.compile(r"\*\*Funds\*\*:\s*\$[\d,.]+\s*\((\d+)\s*cents\)")
|
|
|
|
# Regex for sim resume turns
|
|
RE_NEW_SIM_TIME = re.compile(r"new_sim_time:\s*(\S+)")
|
|
RE_BALANCE_DELTA = re.compile(r"balance_delta:\s*(-?\d+)")
|
|
|
|
LINE_COLORS = ["#4fc3f7", "#2ecc71", "#e67e22", "#e74c3c", "#9b59b6", "#1abc9c", "#f1c40f", "#e91e63"]
|
|
DOMAIN_COLORS = {"research": "#4fc3f7", "inference": "#2ecc71", "data_environment": "#e67e22", "training": "#9b59b6"}
|
|
|
|
BG_COLOR = "#0f1117"
|
|
FACE_COLOR = "#1a1d27"
|
|
GRID_COLOR = "#333344"
|
|
TEXT_COLOR = "#aaaaaa"
|
|
|
|
|
|
def _style_ax(ax):
|
|
"""Apply common dark-theme styling to an axes."""
|
|
ax.set_facecolor(FACE_COLOR)
|
|
ax.tick_params(colors=TEXT_COLOR, labelsize=9)
|
|
for spine in ax.spines.values():
|
|
spine.set_edgecolor(GRID_COLOR)
|
|
|
|
|
|
def _format_time_axis(ax):
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter("%b '%y"))
|
|
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
|
|
plt.setp(ax.xaxis.get_majorticklabels(), rotation=30, ha="right")
|
|
|
|
|
|
def _funds_formatter():
|
|
return plt.FuncFormatter(lambda x, _: f"${x/1000:.0f}K" if abs(x) < 1_000_000 else f"${x/1_000_000:.1f}M")
|
|
|
|
|
|
def _smooth_funds(times, funds, window_days=3):
|
|
"""Resample funds to daily frequency and apply rolling average to smooth payroll staircases."""
|
|
if len(times) < 3:
|
|
return times, funds
|
|
|
|
from datetime import timedelta
|
|
|
|
# Create daily time series via forward-fill
|
|
start, end = times[0], times[-1]
|
|
n_days = (end - start).days
|
|
if n_days < 2:
|
|
return times, funds
|
|
|
|
daily_times = [start + timedelta(days=d) for d in range(n_days + 1)]
|
|
daily_funds = []
|
|
src_idx = 0
|
|
for dt in daily_times:
|
|
while src_idx < len(times) - 1 and times[src_idx + 1] <= dt:
|
|
src_idx += 1
|
|
daily_funds.append(funds[src_idx])
|
|
|
|
# Rolling average
|
|
window = min(window_days, len(daily_funds))
|
|
if window < 2:
|
|
return daily_times, daily_funds
|
|
|
|
smoothed = []
|
|
for i in range(len(daily_funds)):
|
|
lo = max(0, i - window // 2)
|
|
hi = min(len(daily_funds), i + window // 2 + 1)
|
|
smoothed.append(sum(daily_funds[lo:hi]) / (hi - lo))
|
|
|
|
return daily_times, smoothed
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def parse_funds_curve(result):
|
|
"""Extract (times, funds_dollars) — prefer time_series.funds, fall back to regex."""
|
|
ts = result.get("time_series", {}).get("funds")
|
|
if ts:
|
|
times = [datetime.fromisoformat(p["time"]) for p in ts]
|
|
funds = [p["funds_cents"] / 100 for p in ts]
|
|
return times, funds
|
|
|
|
return _parse_funds_curve_regex(result)
|
|
|
|
|
|
def _parse_funds_curve_regex(result):
|
|
"""Legacy: extract funds curve via regex from transcript."""
|
|
times = []
|
|
funds = []
|
|
running_cents = None
|
|
|
|
for entry in result.get("transcript", []):
|
|
text = entry.get("user_input", "")
|
|
|
|
t_match = RE_TURN_TIME.search(text) or RE_INIT_TIME.search(text)
|
|
f_match = RE_TURN_FUNDS.search(text)
|
|
|
|
if f_match and t_match:
|
|
running_cents = int(f_match.group(1))
|
|
times.append(datetime.fromisoformat(t_match.group(1)))
|
|
funds.append(running_cents / 100)
|
|
continue
|
|
|
|
if not f_match:
|
|
f_init = RE_INIT_FUNDS.search(text)
|
|
if f_init and t_match:
|
|
dollar_str = f_init.group(1).replace(",", "")
|
|
running_cents = int(float(dollar_str) * 100)
|
|
times.append(datetime.fromisoformat(t_match.group(1)))
|
|
funds.append(running_cents / 100)
|
|
continue
|
|
|
|
sim_match = RE_NEW_SIM_TIME.search(text)
|
|
delta_match = RE_BALANCE_DELTA.search(text)
|
|
if sim_match and running_cents is not None:
|
|
new_time = datetime.fromisoformat(sim_match.group(1))
|
|
delta = int(delta_match.group(1)) if delta_match else 0
|
|
running_cents += delta
|
|
times.append(new_time)
|
|
funds.append(running_cents / 100)
|
|
|
|
terminal = result.get("terminal_reason", "")
|
|
if "bankrupt" in str(terminal) and funds and funds[-1] > 0:
|
|
times.append(times[-1])
|
|
funds.append(0)
|
|
|
|
return times, funds
|
|
|
|
|
|
def parse_prestige_curves(result):
|
|
"""Extract per-domain prestige curves from time_series.prestige.
|
|
|
|
Returns dict[domain] -> (times, levels).
|
|
"""
|
|
ts = result.get("time_series", {}).get("prestige", [])
|
|
if not ts:
|
|
return {}
|
|
|
|
by_domain = {}
|
|
for p in ts:
|
|
domain = p["domain"]
|
|
by_domain.setdefault(domain, ([], []))
|
|
by_domain[domain][0].append(datetime.fromisoformat(p["time"]))
|
|
by_domain[domain][1].append(p["level"])
|
|
|
|
return by_domain
|
|
|
|
|
|
def parse_trust_curves(result):
|
|
"""Extract per-client trust curves from time_series.client_trust.
|
|
|
|
Returns dict[client_name] -> (times, levels).
|
|
"""
|
|
ts = result.get("time_series", {}).get("client_trust", [])
|
|
if not ts:
|
|
return {}
|
|
|
|
by_client = {}
|
|
for p in ts:
|
|
name = p["client_name"]
|
|
by_client.setdefault(name, ([], []))
|
|
by_client[name][0].append(datetime.fromisoformat(p["time"]))
|
|
by_client[name][1].append(p["trust_level"])
|
|
|
|
return by_client
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Plot functions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def make_label(result, override=None):
|
|
if override:
|
|
return override
|
|
model = result.get("model", "unknown")
|
|
short = model.split("/")[-1]
|
|
seed = result.get("seed", "?")
|
|
return f"{short} (seed {seed})"
|
|
|
|
|
|
def plot_funds(ax, results_data, labels=None, smooth=True):
|
|
"""Plot net worth over time on the given axes."""
|
|
ax.axhline(0, color="#e74c3c", linewidth=0.9, linestyle="--", alpha=0.4)
|
|
|
|
for i, (fpath, result) in enumerate(results_data):
|
|
times, funds = parse_funds_curve(result)
|
|
if not times:
|
|
print(f"No funds data in {fpath}")
|
|
continue
|
|
|
|
if smooth:
|
|
times, funds = _smooth_funds(times, funds)
|
|
|
|
color = LINE_COLORS[i % len(LINE_COLORS)]
|
|
label = make_label(result, labels[i] if labels and i < len(labels) else None)
|
|
ax.plot(times, funds, color=color, linewidth=2, alpha=0.95, label=label)
|
|
|
|
terminal = result.get("terminal_reason", "")
|
|
marker = "x" if "bankrupt" in str(terminal) else "*"
|
|
ax.scatter([times[-1]], [funds[-1]], color=color, s=80, marker=marker, linewidths=2, zorder=5)
|
|
|
|
if terminal:
|
|
ax.annotate(
|
|
terminal,
|
|
(times[-1], funds[-1]),
|
|
textcoords="offset points",
|
|
xytext=(8, -5),
|
|
fontsize=7,
|
|
color=color,
|
|
alpha=0.8,
|
|
)
|
|
|
|
ax.yaxis.set_major_formatter(_funds_formatter())
|
|
_format_time_axis(ax)
|
|
ax.set_title("Net Worth Over Time", color="white", fontsize=12, pad=10)
|
|
ax.set_ylabel("Funds (USD)", color=TEXT_COLOR, fontsize=9)
|
|
ax.grid(axis="y", color=GRID_COLOR, linewidth=0.5, linestyle="--")
|
|
ax.text(0.005, 0.03, "← bankruptcy", transform=ax.transAxes,
|
|
color="#e74c3c", fontsize=7.5, alpha=0.6)
|
|
|
|
if len(results_data) > 1:
|
|
ax.legend(fontsize=8, facecolor=FACE_COLOR, edgecolor=GRID_COLOR,
|
|
labelcolor="white", loc="best")
|
|
|
|
|
|
def plot_prestige(ax, results_data, labels=None):
|
|
"""Plot prestige per domain over time."""
|
|
for i, (fpath, result) in enumerate(results_data):
|
|
curves = parse_prestige_curves(result)
|
|
if not curves:
|
|
print(f"No prestige data in {fpath}")
|
|
continue
|
|
|
|
label_prefix = make_label(result, labels[i] if labels and i < len(labels) else None) + " " if len(results_data) > 1 else ""
|
|
for domain, (times, levels) in sorted(curves.items()):
|
|
color = DOMAIN_COLORS.get(domain, LINE_COLORS[i % len(LINE_COLORS)])
|
|
ax.plot(times, levels, color=color, linewidth=2, alpha=0.9,
|
|
label=f"{label_prefix}{domain}", marker="o", markersize=4)
|
|
|
|
_format_time_axis(ax)
|
|
ax.set_title("Prestige Over Time", color="white", fontsize=12, pad=10)
|
|
ax.set_ylabel("Prestige Level", color=TEXT_COLOR, fontsize=9)
|
|
ax.grid(axis="y", color=GRID_COLOR, linewidth=0.5, linestyle="--")
|
|
ax.legend(fontsize=8, facecolor=FACE_COLOR, edgecolor=GRID_COLOR,
|
|
labelcolor="white", loc="best")
|
|
|
|
|
|
def plot_trust(ax, results_data, labels=None):
|
|
"""Plot client trust over time."""
|
|
for i, (fpath, result) in enumerate(results_data):
|
|
curves = parse_trust_curves(result)
|
|
if not curves:
|
|
print(f"No client trust data in {fpath}")
|
|
continue
|
|
|
|
label_prefix = make_label(result, labels[i] if labels and i < len(labels) else None) + " " if len(results_data) > 1 else ""
|
|
for j, (client_name, (times, levels)) in enumerate(sorted(curves.items())):
|
|
color = LINE_COLORS[j % len(LINE_COLORS)]
|
|
ax.plot(times, levels, color=color, linewidth=2, alpha=0.9,
|
|
label=f"{label_prefix}{client_name}", marker="o", markersize=3)
|
|
|
|
ax.set_ylim(-0.1, 5.1)
|
|
_format_time_axis(ax)
|
|
ax.set_title("Client Trust Over Time", color="white", fontsize=12, pad=10)
|
|
ax.set_ylabel("Trust Level", color=TEXT_COLOR, fontsize=9)
|
|
ax.grid(axis="y", color=GRID_COLOR, linewidth=0.5, linestyle="--")
|
|
ax.legend(fontsize=7, facecolor=FACE_COLOR, edgecolor=GRID_COLOR,
|
|
labelcolor="white", loc="best", ncol=2)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="Plot benchmark results from JSON files")
|
|
p.add_argument("files", nargs="+", help="JSON result file paths")
|
|
p.add_argument("--out", default=None, help="Output PNG path (default: auto-generated in plots/)")
|
|
p.add_argument("--plot", default="funds", choices=["funds", "prestige", "trust"],
|
|
help="Plot mode (default: funds)")
|
|
p.add_argument("--labels", nargs="+", default=None,
|
|
help="Custom legend labels (one per file, in order)")
|
|
p.add_argument("--smooth", action="store_true", default=False,
|
|
help="Enable 3-day rolling average smoothing on funds plot")
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Load all result files
|
|
results_data = []
|
|
for fpath in args.files:
|
|
with open(fpath) as f:
|
|
results_data.append((fpath, json.load(f)))
|
|
|
|
# Prestige with multiple runs: use side-by-side subplots
|
|
if args.plot == "prestige" and len(results_data) > 1:
|
|
fig, axes = plt.subplots(1, len(results_data), figsize=(6 * len(results_data), 5),
|
|
facecolor=BG_COLOR, sharey=True)
|
|
if len(results_data) == 1:
|
|
axes = [axes]
|
|
for idx, (ax_i, (fpath, result)) in enumerate(zip(axes, results_data)):
|
|
_style_ax(ax_i)
|
|
lbl = args.labels[idx] if args.labels and idx < len(args.labels) else make_label(result)
|
|
plot_prestige(ax_i, [(fpath, result)], labels=[lbl])
|
|
ax_i.set_title(f"Prestige — {lbl}", color="white", fontsize=11, pad=10)
|
|
else:
|
|
fig, ax = plt.subplots(figsize=(12, 5), facecolor=BG_COLOR)
|
|
_style_ax(ax)
|
|
|
|
if args.plot == "funds":
|
|
plot_funds(ax, results_data, labels=args.labels, smooth=args.smooth)
|
|
elif args.plot == "prestige":
|
|
plot_prestige(ax, results_data, labels=args.labels)
|
|
elif args.plot == "trust":
|
|
plot_trust(ax, results_data, labels=args.labels)
|
|
|
|
suffix = f"_{args.plot}" if args.plot != "funds" else ""
|
|
|
|
plt.tight_layout()
|
|
|
|
if args.out:
|
|
out = Path(args.out)
|
|
elif len(args.files) == 1:
|
|
out = Path("plots") / f"{Path(args.files[0]).stem}{suffix}.png"
|
|
else:
|
|
out = Path("plots") / f"results_comparison{suffix}.png"
|
|
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
plt.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
|
print(f"Saved: {out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|