mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-19 12:58:03 +00:00
fix seeding
This commit is contained in:
parent
f76f5be652
commit
b043b690c3
5 changed files with 233 additions and 15 deletions
|
|
@ -1,7 +1,8 @@
|
|||
"""Streamlit dashboard — live-monitor a YC-Bench run.
|
||||
"""Streamlit dashboard — live-monitor YC-Bench runs.
|
||||
|
||||
Usage:
|
||||
uv run streamlit run scripts/watch_dashboard.py -- db/medium_1_gemini_gemini-3-flash-preview.db
|
||||
uv run streamlit run scripts/watch_dashboard.py # multi-model overview (auto-discovers db/)
|
||||
uv run streamlit run scripts/watch_dashboard.py -- db/medium_1_model.db # single-run detail
|
||||
|
||||
Automatically overlays the greedy bot baseline if a matching *_greedy_bot.db exists
|
||||
in the same directory (e.g. db/medium_1_greedy_bot.db).
|
||||
|
|
@ -16,6 +17,11 @@ from datetime import datetime, timedelta
|
|||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
|
|
@ -114,14 +120,21 @@ for a in args:
|
|||
db_path = a
|
||||
break
|
||||
|
||||
if db_path is None:
|
||||
st.error("Pass a .db path: `uv run streamlit run scripts/watch_dashboard.py -- path/to.db`")
|
||||
st.stop()
|
||||
MULTI_MODEL_MODE = db_path is None
|
||||
|
||||
db_file = Path(db_path)
|
||||
if not db_file.exists():
|
||||
st.error(f"DB not found: {db_file}")
|
||||
st.stop()
|
||||
if MULTI_MODEL_MODE:
|
||||
# Auto-discover all DBs — show multi-model overview
|
||||
_all_dbs = sorted(Path("db").glob("*.db"))
|
||||
if not _all_dbs:
|
||||
st.error("No DB files found in db/. Run some experiments first.")
|
||||
st.stop()
|
||||
# Use first DB to bootstrap config detection
|
||||
db_file = _all_dbs[0]
|
||||
else:
|
||||
db_file = Path(db_path)
|
||||
if not db_file.exists():
|
||||
st.error(f"DB not found: {db_file}")
|
||||
st.stop()
|
||||
|
||||
# Auto-detect config name from DB filename (e.g. "medium_1_model.db" -> "medium")
|
||||
_db_stem_parts = db_file.stem.split("_")
|
||||
|
|
@ -464,6 +477,170 @@ def _load_transcript(primary_db: Path) -> list[dict]:
|
|||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-model overview (when no DB arg given)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MODEL_COLORS = {
|
||||
"gpt-5.4": "#4da6ff",
|
||||
"gpt-5.4-nano": "#ff8c42",
|
||||
"gpt-5.4-mini": "#ffd43b",
|
||||
"gemini-3.1-pro-preview": "#00d4aa",
|
||||
"gemini-3-flash-preview": "#b197fc",
|
||||
"claude-sonnet-4-6": "#e599f7",
|
||||
"greedy_bot": "#ff4b6e",
|
||||
}
|
||||
|
||||
|
||||
def _model_color(label: str) -> str:
|
||||
for key, color in MODEL_COLORS.items():
|
||||
if key in label:
|
||||
return color
|
||||
return "#8b8d93"
|
||||
|
||||
|
||||
def _parse_db_stem(stem: str) -> tuple[str, int, str]:
|
||||
"""Parse 'medium_1_openai_gpt-5.4' -> (config, seed, model_label)."""
|
||||
parts = stem.split("_", 2)
|
||||
if len(parts) < 3:
|
||||
return stem, 0, stem
|
||||
config = parts[0]
|
||||
try:
|
||||
seed = int(parts[1])
|
||||
except ValueError:
|
||||
return stem, 0, stem
|
||||
raw = parts[2]
|
||||
label = re.sub(r"^(openai|gemini|anthropic)_", "", raw)
|
||||
return config, seed, label
|
||||
|
||||
|
||||
def _read_db_summary(db_path: Path) -> dict | None:
|
||||
"""Read key metrics from a DB via raw sqlite3 (no ORM)."""
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
row = conn.execute("SELECT sim_time, horizon_end FROM sim_state LIMIT 1").fetchone()
|
||||
co = conn.execute("SELECT funds_cents FROM companies LIMIT 1").fetchone()
|
||||
if not row or not co:
|
||||
return None
|
||||
ok = conn.execute("SELECT COUNT(*) FROM tasks WHERE status='completed_success'").fetchone()[0]
|
||||
fail = conn.execute("SELECT COUNT(*) FROM tasks WHERE status='completed_fail'").fetchone()[0]
|
||||
ledger = conn.execute(
|
||||
"SELECT occurred_at, amount_cents FROM ledger_entries ORDER BY occurred_at"
|
||||
).fetchall()
|
||||
running = 20_000_000 # default starting funds
|
||||
funds_by_day = {}
|
||||
for occ, amt in ledger:
|
||||
running += amt
|
||||
funds_by_day[occ[:10]] = running
|
||||
sim_time, horizon_end = row
|
||||
funds = co[0]
|
||||
return {
|
||||
"sim_time": sim_time[:10],
|
||||
"funds": funds / 100,
|
||||
"ok": ok, "fail": fail,
|
||||
"done": sim_time >= horizon_end or funds < 0,
|
||||
"bankrupt": funds < 0,
|
||||
"funds_by_day": {k: v / 100 for k, v in funds_by_day.items()},
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _render_multi_model():
|
||||
"""Render the multi-model overview dashboard."""
|
||||
all_dbs = sorted(Path("db").glob("*.db"))
|
||||
if not all_dbs:
|
||||
st.warning("No DB files in db/")
|
||||
return
|
||||
|
||||
model_runs = defaultdict(list)
|
||||
for p in all_dbs:
|
||||
config, seed, label = _parse_db_stem(p.stem)
|
||||
data = _read_db_summary(p)
|
||||
if data:
|
||||
model_runs[label].append((seed, p, data))
|
||||
|
||||
# --- Per-seed funds curves ---
|
||||
st.markdown('<div class="section-header">Funds Over Time (per seed)</div>', unsafe_allow_html=True)
|
||||
fig1 = go.Figure()
|
||||
for label in sorted(model_runs.keys()):
|
||||
color = _model_color(label)
|
||||
for seed, _, data in sorted(model_runs[label], key=lambda x: x[0]):
|
||||
fbd = data.get("funds_by_day", {})
|
||||
if not fbd:
|
||||
continue
|
||||
days = sorted(fbd.keys())
|
||||
vals = [fbd[d] for d in days]
|
||||
fig1.add_trace(go.Scatter(
|
||||
x=days, y=vals, mode="lines",
|
||||
name=f"{label} (s{seed})",
|
||||
line=dict(color=color, width=1.5, dash="dot" if seed > 1 else "solid"),
|
||||
opacity=0.7,
|
||||
))
|
||||
fig1.add_hline(y=200_000, line_dash="dash", line_color="#555",
|
||||
annotation_text="Starting $200K")
|
||||
fig1.update_layout(**_chart_layout(yaxis_title="Funds ($)", height=500))
|
||||
fig1.update_yaxes(tickprefix="$", tickformat=",")
|
||||
st.plotly_chart(fig1, use_container_width=True, config={"displayModeBar": False})
|
||||
|
||||
# --- Averaged funds curves ---
|
||||
st.markdown('<div class="section-header">Funds Over Time (averaged across seeds)</div>', unsafe_allow_html=True)
|
||||
fig2 = go.Figure()
|
||||
for label in sorted(model_runs.keys()):
|
||||
color = _model_color(label)
|
||||
runs = model_runs[label]
|
||||
all_days = set()
|
||||
series = []
|
||||
for seed, _, data in runs:
|
||||
fbd = data.get("funds_by_day", {})
|
||||
if fbd:
|
||||
all_days.update(fbd.keys())
|
||||
series.append(fbd)
|
||||
if not all_days or not series:
|
||||
continue
|
||||
common_days = sorted(all_days)
|
||||
aligned = []
|
||||
for s in series:
|
||||
s_days = sorted(s.keys())
|
||||
if not s_days:
|
||||
continue
|
||||
vals, last_val, si = [], 200_000, 0
|
||||
for d in common_days:
|
||||
while si < len(s_days) and s_days[si] <= d:
|
||||
last_val = s[s_days[si]]
|
||||
si += 1
|
||||
vals.append(last_val)
|
||||
aligned.append(vals)
|
||||
if not aligned:
|
||||
continue
|
||||
arr = np.array(aligned)
|
||||
mean = arr.mean(axis=0)
|
||||
fig2.add_trace(go.Scatter(
|
||||
x=common_days, y=mean, mode="lines",
|
||||
name=f"{label} (n={len(aligned)})",
|
||||
line=dict(color=color, width=3),
|
||||
))
|
||||
if len(aligned) > 1:
|
||||
lo, hi = arr.min(axis=0), arr.max(axis=0)
|
||||
# Convert hex color to rgba for fill
|
||||
_r, _g, _b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
|
||||
fig2.add_trace(go.Scatter(
|
||||
x=list(common_days) + list(common_days)[::-1],
|
||||
y=list(hi) + list(lo[::-1]),
|
||||
fill="toself", fillcolor=f"rgba({_r},{_g},{_b},0.1)",
|
||||
line=dict(color="rgba(0,0,0,0)"),
|
||||
showlegend=False, hoverinfo="skip",
|
||||
))
|
||||
fig2.add_hline(y=200_000, line_dash="dash", line_color="#555",
|
||||
annotation_text="Starting $200K")
|
||||
fig2.update_layout(**_chart_layout(yaxis_title="Funds ($) — averaged", height=500))
|
||||
fig2.update_yaxes(tickprefix="$", tickformat=",")
|
||||
st.plotly_chart(fig2, use_container_width=True, config={"displayModeBar": False})
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header + metrics (always visible)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -475,6 +652,34 @@ st.markdown("""
|
|||
font-size: 0.75rem; font-weight: 600;">LIVE</span>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-model overview (when no DB arg given)
|
||||
# ---------------------------------------------------------------------------
|
||||
if MULTI_MODEL_MODE:
|
||||
_render_multi_model()
|
||||
st.markdown("---")
|
||||
|
||||
# Let user pick a run for detail view below
|
||||
all_dbs = sorted(Path("db").glob("*.db"))
|
||||
db_options = {p.stem: p for p in all_dbs}
|
||||
if db_options:
|
||||
selected = st.selectbox("Select run for detail view", list(db_options.keys()), index=0)
|
||||
db_file = db_options[selected]
|
||||
# Re-initialize factory and peers for selected DB
|
||||
factory = get_factory(str(db_file))
|
||||
peer_dbs = _find_all_peer_dbs(db_file)
|
||||
peer_factories = [(label, get_factory(str(p))) for label, p in peer_dbs]
|
||||
baseline_db = None
|
||||
baseline_factory = None
|
||||
for label, p in peer_dbs:
|
||||
if "greedy" in p.stem.lower():
|
||||
baseline_db = p
|
||||
baseline_factory = get_factory(str(p))
|
||||
break
|
||||
else:
|
||||
st.warning("No DB files found.")
|
||||
st.stop()
|
||||
|
||||
baseline_label = f' | baseline: <span style="color:#ff4b6e">{baseline_db.name}</span>' if baseline_db else ""
|
||||
st.markdown(f'<div class="db-path">{db_file}{baseline_label}</div>', unsafe_allow_html=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def task_accept(
|
|||
if is_rat:
|
||||
intensity = abs(client_row.loyalty)
|
||||
inflation = _cfg.scope_creep_max * intensity
|
||||
inflation = max(2.0, inflation)
|
||||
inflation = max(3.0, inflation)
|
||||
for r in reqs:
|
||||
inflated = float(r.required_qty) * (1 + inflation)
|
||||
r.required_qty = int(min(25000, max(200, inflated)))
|
||||
|
|
@ -125,9 +125,14 @@ def task_accept(
|
|||
task.accepted_at = accepted_at
|
||||
task.deadline = deadline
|
||||
|
||||
# Generate replacement task (inherits same client for stable market distribution)
|
||||
counter = sim_state.replenish_counter
|
||||
sim_state.replenish_counter = counter + 1
|
||||
# Generate replacement task — keyed on market_slot so every model
|
||||
# sees the same replacement for the same task, regardless of accept order.
|
||||
slot = task.market_slot if task.market_slot is not None else 0
|
||||
# Generation = how many times this slot has been replaced before
|
||||
generation = db.query(Task).filter(
|
||||
Task.market_slot == slot,
|
||||
Task.company_id.isnot(None), # accepted tasks
|
||||
).count()
|
||||
|
||||
# Find the client index for the accepted task
|
||||
replaced_client_index = 0
|
||||
|
|
@ -147,7 +152,7 @@ def task_accept(
|
|||
|
||||
replacement = generate_replacement_task(
|
||||
run_seed=sim_state.run_seed,
|
||||
replenish_counter=counter,
|
||||
replenish_counter=slot * 1000 + generation, # deterministic per slot+generation
|
||||
replaced_prestige=task.required_prestige,
|
||||
replaced_client_index=replaced_client_index,
|
||||
cfg=_get_world_cfg(),
|
||||
|
|
@ -175,6 +180,7 @@ def task_accept(
|
|||
success=None,
|
||||
progress_milestone_pct=0,
|
||||
required_trust=replacement.required_trust,
|
||||
market_slot=slot,
|
||||
)
|
||||
db.add(replacement_row)
|
||||
|
||||
|
|
|
|||
|
|
@ -99,6 +99,10 @@ class Task(Base):
|
|||
BigInteger,
|
||||
nullable=True,
|
||||
)
|
||||
market_slot = mapped_column(
|
||||
Integer,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
class TaskRequirement(Base):
|
||||
__tablename__ = "task_requirements"
|
||||
|
|
|
|||
|
|
@ -72,6 +72,8 @@ def generate_clients(*, run_seed: int, count: int, cfg: WorldConfig) -> list[Gen
|
|||
if i < n_rats:
|
||||
# RAT: loyalty in [-1.0, -0.3]
|
||||
loyalty = round(rng.uniform(-1.0, -0.3), 3)
|
||||
# RATs offer competitive rewards (top 30% range) to attract greedy agents
|
||||
mult = max(mult, cfg.client_reward_mult_high * 0.75)
|
||||
else:
|
||||
# Non-RAT: loyalty in [-0.3, 1.0]
|
||||
loyalty = round(rng.triangular(-0.3, 1.0, cfg.loyalty_mode), 3)
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ def _seed_market_tasks(db, company, req, clients):
|
|||
generated = generate_tasks(run_seed=req.run_seed, count=req.market_task_count, cfg=req.cfg,
|
||||
client_specialties=client_specialties,
|
||||
client_reward_mults=client_reward_mults)
|
||||
for task in generated:
|
||||
for slot_idx, task in enumerate(generated):
|
||||
client = clients[task.client_index % len(clients)] if clients else None
|
||||
task_row = Task(
|
||||
id=uuid4(),
|
||||
|
|
@ -128,6 +128,7 @@ def _seed_market_tasks(db, company, req, clients):
|
|||
success=None,
|
||||
progress_milestone_pct=0,
|
||||
required_trust=task.required_trust,
|
||||
market_slot=slot_idx,
|
||||
)
|
||||
db.add(task_row)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue