fix seeding

This commit is contained in:
alckasoc 2026-03-20 18:43:19 -07:00
parent f76f5be652
commit b043b690c3
5 changed files with 233 additions and 15 deletions

View file

@ -1,7 +1,8 @@
"""Streamlit dashboard — live-monitor a YC-Bench run.
"""Streamlit dashboard — live-monitor YC-Bench runs.
Usage:
uv run streamlit run scripts/watch_dashboard.py -- db/medium_1_gemini_gemini-3-flash-preview.db
uv run streamlit run scripts/watch_dashboard.py # multi-model overview (auto-discovers db/)
uv run streamlit run scripts/watch_dashboard.py -- db/medium_1_model.db # single-run detail
Automatically overlays the greedy bot baseline if a matching *_greedy_bot.db exists
in the same directory (e.g. db/medium_1_greedy_bot.db).
@ -16,6 +17,11 @@ from datetime import datetime, timedelta
from decimal import Decimal
from pathlib import Path
import re
import sqlite3
from collections import defaultdict
import numpy as np
import streamlit as st
import pandas as pd
import plotly.graph_objects as go
@ -114,12 +120,19 @@ for a in args:
db_path = a
break
if db_path is None:
st.error("Pass a .db path: `uv run streamlit run scripts/watch_dashboard.py -- path/to.db`")
st.stop()
MULTI_MODEL_MODE = db_path is None
db_file = Path(db_path)
if not db_file.exists():
if MULTI_MODEL_MODE:
# Auto-discover all DBs — show multi-model overview
_all_dbs = sorted(Path("db").glob("*.db"))
if not _all_dbs:
st.error("No DB files found in db/. Run some experiments first.")
st.stop()
# Use first DB to bootstrap config detection
db_file = _all_dbs[0]
else:
db_file = Path(db_path)
if not db_file.exists():
st.error(f"DB not found: {db_file}")
st.stop()
@ -464,6 +477,170 @@ def _load_transcript(primary_db: Path) -> list[dict]:
return []
# ---------------------------------------------------------------------------
# Multi-model overview (when no DB arg given)
# ---------------------------------------------------------------------------
MODEL_COLORS = {
"gpt-5.4": "#4da6ff",
"gpt-5.4-nano": "#ff8c42",
"gpt-5.4-mini": "#ffd43b",
"gemini-3.1-pro-preview": "#00d4aa",
"gemini-3-flash-preview": "#b197fc",
"claude-sonnet-4-6": "#e599f7",
"greedy_bot": "#ff4b6e",
}
def _model_color(label: str) -> str:
for key, color in MODEL_COLORS.items():
if key in label:
return color
return "#8b8d93"
def _parse_db_stem(stem: str) -> tuple[str, int, str]:
"""Parse 'medium_1_openai_gpt-5.4' -> (config, seed, model_label)."""
parts = stem.split("_", 2)
if len(parts) < 3:
return stem, 0, stem
config = parts[0]
try:
seed = int(parts[1])
except ValueError:
return stem, 0, stem
raw = parts[2]
label = re.sub(r"^(openai|gemini|anthropic)_", "", raw)
return config, seed, label
def _read_db_summary(db_path: Path) -> dict | None:
"""Read key metrics from a DB via raw sqlite3 (no ORM)."""
conn = sqlite3.connect(str(db_path))
try:
row = conn.execute("SELECT sim_time, horizon_end FROM sim_state LIMIT 1").fetchone()
co = conn.execute("SELECT funds_cents FROM companies LIMIT 1").fetchone()
if not row or not co:
return None
ok = conn.execute("SELECT COUNT(*) FROM tasks WHERE status='completed_success'").fetchone()[0]
fail = conn.execute("SELECT COUNT(*) FROM tasks WHERE status='completed_fail'").fetchone()[0]
ledger = conn.execute(
"SELECT occurred_at, amount_cents FROM ledger_entries ORDER BY occurred_at"
).fetchall()
running = 20_000_000 # default starting funds
funds_by_day = {}
for occ, amt in ledger:
running += amt
funds_by_day[occ[:10]] = running
sim_time, horizon_end = row
funds = co[0]
return {
"sim_time": sim_time[:10],
"funds": funds / 100,
"ok": ok, "fail": fail,
"done": sim_time >= horizon_end or funds < 0,
"bankrupt": funds < 0,
"funds_by_day": {k: v / 100 for k, v in funds_by_day.items()},
}
except Exception:
return None
finally:
conn.close()
def _render_multi_model():
"""Render the multi-model overview dashboard."""
all_dbs = sorted(Path("db").glob("*.db"))
if not all_dbs:
st.warning("No DB files in db/")
return
model_runs = defaultdict(list)
for p in all_dbs:
config, seed, label = _parse_db_stem(p.stem)
data = _read_db_summary(p)
if data:
model_runs[label].append((seed, p, data))
# --- Per-seed funds curves ---
st.markdown('<div class="section-header">Funds Over Time (per seed)</div>', unsafe_allow_html=True)
fig1 = go.Figure()
for label in sorted(model_runs.keys()):
color = _model_color(label)
for seed, _, data in sorted(model_runs[label], key=lambda x: x[0]):
fbd = data.get("funds_by_day", {})
if not fbd:
continue
days = sorted(fbd.keys())
vals = [fbd[d] for d in days]
fig1.add_trace(go.Scatter(
x=days, y=vals, mode="lines",
name=f"{label} (s{seed})",
line=dict(color=color, width=1.5, dash="dot" if seed > 1 else "solid"),
opacity=0.7,
))
fig1.add_hline(y=200_000, line_dash="dash", line_color="#555",
annotation_text="Starting $200K")
fig1.update_layout(**_chart_layout(yaxis_title="Funds ($)", height=500))
fig1.update_yaxes(tickprefix="$", tickformat=",")
st.plotly_chart(fig1, use_container_width=True, config={"displayModeBar": False})
# --- Averaged funds curves ---
st.markdown('<div class="section-header">Funds Over Time (averaged across seeds)</div>', unsafe_allow_html=True)
fig2 = go.Figure()
for label in sorted(model_runs.keys()):
color = _model_color(label)
runs = model_runs[label]
all_days = set()
series = []
for seed, _, data in runs:
fbd = data.get("funds_by_day", {})
if fbd:
all_days.update(fbd.keys())
series.append(fbd)
if not all_days or not series:
continue
common_days = sorted(all_days)
aligned = []
for s in series:
s_days = sorted(s.keys())
if not s_days:
continue
vals, last_val, si = [], 200_000, 0
for d in common_days:
while si < len(s_days) and s_days[si] <= d:
last_val = s[s_days[si]]
si += 1
vals.append(last_val)
aligned.append(vals)
if not aligned:
continue
arr = np.array(aligned)
mean = arr.mean(axis=0)
fig2.add_trace(go.Scatter(
x=common_days, y=mean, mode="lines",
name=f"{label} (n={len(aligned)})",
line=dict(color=color, width=3),
))
if len(aligned) > 1:
lo, hi = arr.min(axis=0), arr.max(axis=0)
# Convert hex color to rgba for fill
_r, _g, _b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
fig2.add_trace(go.Scatter(
x=list(common_days) + list(common_days)[::-1],
y=list(hi) + list(lo[::-1]),
fill="toself", fillcolor=f"rgba({_r},{_g},{_b},0.1)",
line=dict(color="rgba(0,0,0,0)"),
showlegend=False, hoverinfo="skip",
))
fig2.add_hline(y=200_000, line_dash="dash", line_color="#555",
annotation_text="Starting $200K")
fig2.update_layout(**_chart_layout(yaxis_title="Funds ($) — averaged", height=500))
fig2.update_yaxes(tickprefix="$", tickformat=",")
st.plotly_chart(fig2, use_container_width=True, config={"displayModeBar": False})
# ---------------------------------------------------------------------------
# Header + metrics (always visible)
# ---------------------------------------------------------------------------
@ -475,6 +652,34 @@ st.markdown("""
font-size: 0.75rem; font-weight: 600;">LIVE</span>
</div>
""", unsafe_allow_html=True)
# ---------------------------------------------------------------------------
# Multi-model overview (when no DB arg given)
# ---------------------------------------------------------------------------
if MULTI_MODEL_MODE:
_render_multi_model()
st.markdown("---")
# Let user pick a run for detail view below
all_dbs = sorted(Path("db").glob("*.db"))
db_options = {p.stem: p for p in all_dbs}
if db_options:
selected = st.selectbox("Select run for detail view", list(db_options.keys()), index=0)
db_file = db_options[selected]
# Re-initialize factory and peers for selected DB
factory = get_factory(str(db_file))
peer_dbs = _find_all_peer_dbs(db_file)
peer_factories = [(label, get_factory(str(p))) for label, p in peer_dbs]
baseline_db = None
baseline_factory = None
for label, p in peer_dbs:
if "greedy" in p.stem.lower():
baseline_db = p
baseline_factory = get_factory(str(p))
break
else:
st.warning("No DB files found.")
st.stop()
baseline_label = f' &nbsp;|&nbsp; baseline: <span style="color:#ff4b6e">{baseline_db.name}</span>' if baseline_db else ""
st.markdown(f'<div class="db-path">{db_file}{baseline_label}</div>', unsafe_allow_html=True)

View file

@ -114,7 +114,7 @@ def task_accept(
if is_rat:
intensity = abs(client_row.loyalty)
inflation = _cfg.scope_creep_max * intensity
inflation = max(2.0, inflation)
inflation = max(3.0, inflation)
for r in reqs:
inflated = float(r.required_qty) * (1 + inflation)
r.required_qty = int(min(25000, max(200, inflated)))
@ -125,9 +125,14 @@ def task_accept(
task.accepted_at = accepted_at
task.deadline = deadline
# Generate replacement task (inherits same client for stable market distribution)
counter = sim_state.replenish_counter
sim_state.replenish_counter = counter + 1
# Generate replacement task — keyed on market_slot so every model
# sees the same replacement for the same task, regardless of accept order.
slot = task.market_slot if task.market_slot is not None else 0
# Generation = how many times this slot has been replaced before
generation = db.query(Task).filter(
Task.market_slot == slot,
Task.company_id.isnot(None), # accepted tasks
).count()
# Find the client index for the accepted task
replaced_client_index = 0
@ -147,7 +152,7 @@ def task_accept(
replacement = generate_replacement_task(
run_seed=sim_state.run_seed,
replenish_counter=counter,
replenish_counter=slot * 1000 + generation, # deterministic per slot+generation
replaced_prestige=task.required_prestige,
replaced_client_index=replaced_client_index,
cfg=_get_world_cfg(),
@ -175,6 +180,7 @@ def task_accept(
success=None,
progress_milestone_pct=0,
required_trust=replacement.required_trust,
market_slot=slot,
)
db.add(replacement_row)

View file

@ -99,6 +99,10 @@ class Task(Base):
BigInteger,
nullable=True,
)
market_slot = mapped_column(
Integer,
nullable=True,
)
class TaskRequirement(Base):
__tablename__ = "task_requirements"

View file

@ -72,6 +72,8 @@ def generate_clients(*, run_seed: int, count: int, cfg: WorldConfig) -> list[Gen
if i < n_rats:
# RAT: loyalty in [-1.0, -0.3]
loyalty = round(rng.uniform(-1.0, -0.3), 3)
# RATs offer competitive rewards (top 30% range) to attract greedy agents
mult = max(mult, cfg.client_reward_mult_high * 0.75)
else:
# Non-RAT: loyalty in [-0.3, 1.0]
loyalty = round(rng.triangular(-0.3, 1.0, cfg.loyalty_mode), 3)

View file

@ -110,7 +110,7 @@ def _seed_market_tasks(db, company, req, clients):
generated = generate_tasks(run_seed=req.run_seed, count=req.market_task_count, cfg=req.cfg,
client_specialties=client_specialties,
client_reward_mults=client_reward_mults)
for task in generated:
for slot_idx, task in enumerate(generated):
client = clients[task.client_index % len(clients)] if clients else None
task_row = Task(
id=uuid4(),
@ -128,6 +128,7 @@ def _seed_market_tasks(db, company, req, clients):
success=None,
progress_milestone_pct=0,
required_trust=task.required_trust,
market_slot=slot_idx,
)
db.add(task_row)