mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-22 16:49:02 +00:00
543 lines
22 KiB
Python
543 lines
22 KiB
Python
"""Bot runner: plays YC-Bench using direct DB access with pluggable strategies.
|
|
|
|
Strategies:
|
|
greedy — pick highest reward among accessible tasks
|
|
random — pick randomly among accessible tasks (deterministic via RngStreams)
|
|
throughput — pick highest reward/hour among accessible tasks
|
|
prestige — phase 1: climb prestige fast, phase 2: throughput
|
|
|
|
The bot operates under the same constraints as the LLM agent:
|
|
- Same market visibility (browse limit, prestige/trust gating)
|
|
- Same economic rules (trust multiplier, work reduction, payroll, salary bumps)
|
|
- Runs multiple concurrent tasks (like the LLM agent does)
|
|
- Must have active tasks before time advances (same as LLM sim resume block)
|
|
|
|
Usage:
|
|
uv run python scripts/bot_runner.py # all bots, all configs, all seeds
|
|
uv run python scripts/bot_runner.py --bot greedy # just greedy
|
|
uv run python scripts/bot_runner.py --bot random --seed 1 --config medium
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
from typing import Callable, Optional
|
|
from uuid import uuid4
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from yc_bench.config import load_config
|
|
from yc_bench.core.business_time import add_business_hours
|
|
from yc_bench.core.engine import advance_time
|
|
from yc_bench.core.eta import recalculate_etas
|
|
from yc_bench.core.events import fetch_next_event, insert_event
|
|
from yc_bench.db.models.company import Company, CompanyPrestige
|
|
from yc_bench.db.models.employee import Employee, EmployeeSkillRate
|
|
from yc_bench.db.models.event import EventType
|
|
from yc_bench.db.models.sim_state import SimState
|
|
from yc_bench.db.models.task import Task, TaskAssignment, TaskRequirement, TaskStatus
|
|
from yc_bench.db.session import build_engine, build_session_factory, init_db, session_scope
|
|
from yc_bench.services.generate_tasks import generate_replacement_task
|
|
from yc_bench.services.rng import RngStreams
|
|
from yc_bench.services.seed_world import SeedWorldRequest, seed_world_transactional
|
|
|
|
CONFIGS = ["medium", "hard", "nightmare"]
|
|
SEEDS = [1, 2, 3]
|
|
|
|
# Baseline runs 1 task at a time — simple sequential greedy with no
|
|
# workload management. This is the "zero strategy" floor that any
|
|
# competent LLM agent should beat.
|
|
MAX_CONCURRENT_TASKS = 1
|
|
|
|
|
|
@dataclass
|
|
class CandidateTask:
|
|
task: object # ORM Task row
|
|
reward_cents: int
|
|
prestige_delta: float
|
|
completion_hours: Decimal
|
|
is_completable: bool
|
|
|
|
|
|
# Tier-average rates: E[uniform(0, max_rate)] = max_rate / 2.
|
|
# The LLM agent only sees tier + salary, not actual per-domain rates.
|
|
_TIER_AVG_RATE = {
|
|
"junior": Decimal("2.0"), # uniform(0, 4) => E=2.0
|
|
"mid": Decimal("3.5"), # uniform(0, 7) => E=3.5
|
|
"senior": Decimal("5.0"), # uniform(0, 10) => E=5.0
|
|
}
|
|
|
|
|
|
def estimate_completion_hours(task_reqs, employee_tiers, n_concurrent_tasks=1):
|
|
"""Estimate hours to complete task using tier-average rates (blind to actual skills).
|
|
|
|
employee_tiers is a list of tier strings like ["junior", "mid", "senior", ...].
|
|
Each employee is assumed to contribute their tier's average rate to every domain.
|
|
"""
|
|
total_rate = sum(_TIER_AVG_RATE[t] for t in employee_tiers)
|
|
effective_rate = total_rate / Decimal(n_concurrent_tasks)
|
|
|
|
if effective_rate <= 0:
|
|
return None
|
|
|
|
max_hours = Decimal("0")
|
|
for req in task_reqs:
|
|
qty = Decimal(str(req["required_qty"]))
|
|
hours = qty / effective_rate
|
|
if hours > max_hours:
|
|
max_hours = hours
|
|
return max_hours
|
|
|
|
|
|
def _compute_deadline(accepted_at, max_domain_qty, cfg):
|
|
work_hours = cfg.workday_end_hour - cfg.workday_start_hour
|
|
biz_days = max(cfg.deadline_min_biz_days, int(max_domain_qty / cfg.deadline_qty_per_day))
|
|
return add_business_hours(accepted_at, Decimal(str(biz_days)) * Decimal(str(work_hours)))
|
|
|
|
|
|
def _build_candidates(db, company_id, sim_state, world_cfg, employee_tiers, n_active=0):
|
|
"""Build CandidateTask list from the same limited market window the LLM sees.
|
|
|
|
Mirrors the LLM's constraints:
|
|
- Only sees `market_browse_default_limit` tasks (default 50), not the full market
|
|
- Respects prestige requirements (per-domain gating)
|
|
- Respects trust requirements (can't accept tasks above current trust level)
|
|
- Uses tier-average rates (blind to actual per-domain skills)
|
|
"""
|
|
from yc_bench.db.models.client import ClientTrust
|
|
|
|
prestige_rows = db.query(CompanyPrestige).filter(
|
|
CompanyPrestige.company_id == company_id
|
|
).all()
|
|
prestige_map = {p.domain: float(p.prestige_level) for p in prestige_rows}
|
|
max_prestige = max(prestige_map.values()) if prestige_map else 1.0
|
|
|
|
# Build trust map for trust requirement checks
|
|
trust_rows = db.query(ClientTrust).filter(
|
|
ClientTrust.company_id == company_id
|
|
).all()
|
|
trust_map = {str(ct.client_id): float(ct.trust_level) for ct in trust_rows}
|
|
|
|
# Browse full market — bot has direct DB access, no CLI browse limit.
|
|
# The LLM agent has its own browse limit via the CLI.
|
|
market_tasks = (
|
|
db.query(Task)
|
|
.filter(Task.status == TaskStatus.MARKET)
|
|
.order_by(Task.reward_funds_cents.desc())
|
|
.all()
|
|
)
|
|
|
|
candidates = []
|
|
for task in market_tasks:
|
|
reqs = db.query(TaskRequirement).filter(
|
|
TaskRequirement.task_id == task.id
|
|
).all()
|
|
|
|
# Per-domain prestige check: all required domains must meet threshold
|
|
meets_prestige = all(
|
|
prestige_map.get(r.domain, 1.0) >= task.required_prestige
|
|
for r in reqs
|
|
)
|
|
if not meets_prestige:
|
|
continue
|
|
|
|
# Trust requirement check (same validation as CLI task accept)
|
|
if task.required_trust > 0 and task.client_id is not None:
|
|
client_trust = trust_map.get(str(task.client_id), 0.0)
|
|
if client_trust < task.required_trust:
|
|
continue
|
|
|
|
task_reqs = [{"domain": r.domain, "required_qty": float(r.required_qty)} for r in reqs]
|
|
# Estimate hours accounting for concurrent task split
|
|
concurrent = max(1, n_active + 1)
|
|
completion_hours = estimate_completion_hours(task_reqs, employee_tiers, n_concurrent_tasks=concurrent)
|
|
|
|
candidates.append(CandidateTask(
|
|
task=task,
|
|
reward_cents=task.reward_funds_cents,
|
|
prestige_delta=float(task.reward_prestige_delta),
|
|
completion_hours=completion_hours if completion_hours is not None else Decimal("999999"),
|
|
is_completable=True, # Always accessible = always a candidate
|
|
))
|
|
|
|
return candidates, max_prestige
|
|
|
|
|
|
# ── Strategy functions ──────────────────────────────────────────────────────
|
|
|
|
StrategyFn = Callable # (completable: list[CandidateTask], context: dict) -> Optional[CandidateTask]
|
|
|
|
|
|
def strategy_greedy(candidates: list[CandidateTask], context: dict) -> Optional[CandidateTask]:
|
|
"""Pick the task with the highest reward."""
|
|
if not candidates:
|
|
return None
|
|
return max(candidates, key=lambda c: c.reward_cents)
|
|
|
|
|
|
def strategy_random(candidates: list[CandidateTask], context: dict) -> Optional[CandidateTask]:
|
|
"""Pick a random accessible task (deterministic via seeded RNG)."""
|
|
if not candidates:
|
|
return None
|
|
seed = context["seed"]
|
|
turn = context["turn"]
|
|
rng = RngStreams(seed).stream(f"bot_random_select:{turn}")
|
|
return rng.choice(candidates)
|
|
|
|
|
|
def strategy_throughput(candidates: list[CandidateTask], context: dict) -> Optional[CandidateTask]:
|
|
"""Pick the task with the highest reward per hour."""
|
|
if not candidates:
|
|
return None
|
|
return max(candidates, key=lambda c: Decimal(c.reward_cents) / c.completion_hours)
|
|
|
|
|
|
def strategy_prestige(candidates: list[CandidateTask], context: dict) -> Optional[CandidateTask]:
|
|
"""Phase 1 (prestige < 5): climb prestige fast. Phase 2: throughput."""
|
|
if not candidates:
|
|
return None
|
|
current_prestige = context["max_prestige"]
|
|
if current_prestige < 5:
|
|
prestige_tasks = [c for c in candidates if c.prestige_delta > 0]
|
|
if prestige_tasks:
|
|
return max(prestige_tasks, key=lambda c: Decimal(str(c.prestige_delta)) / c.completion_hours)
|
|
return max(candidates, key=lambda c: Decimal(c.reward_cents) / c.completion_hours)
|
|
|
|
|
|
STRATEGIES = {
|
|
"greedy": ("greedy_bot", strategy_greedy),
|
|
"random": ("random_bot", strategy_random),
|
|
"throughput": ("throughput_bot", strategy_throughput),
|
|
"prestige": ("prestige_bot", strategy_prestige),
|
|
}
|
|
|
|
|
|
# ── Shared simulation runner ───────────────────────────────────────────────
|
|
|
|
def run_bot(config_name: str, seed: int, bot_slug: str, strategy_fn: StrategyFn):
|
|
"""Run a bot strategy on one (config, seed) pair. Returns result dict."""
|
|
cfg = load_config(config_name)
|
|
world_cfg = cfg.world
|
|
|
|
db_dir = Path("db")
|
|
db_dir.mkdir(exist_ok=True)
|
|
db_path = db_dir / f"{config_name}_{seed}_{bot_slug}.db"
|
|
|
|
if db_path.exists():
|
|
db_path.unlink()
|
|
|
|
db_url = f"sqlite:///{db_path}"
|
|
os.environ["DATABASE_URL"] = db_url
|
|
os.environ["YC_BENCH_EXPERIMENT"] = config_name
|
|
|
|
engine = build_engine(db_url)
|
|
init_db(engine)
|
|
factory = build_session_factory(engine)
|
|
|
|
with session_scope(factory) as db:
|
|
start_dt = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
|
horizon_end = start_dt.replace(year=start_dt.year + cfg.sim.horizon_years)
|
|
|
|
req = SeedWorldRequest(
|
|
run_seed=seed,
|
|
company_name=bot_slug.replace("_", " ").title(),
|
|
horizon_years=cfg.sim.horizon_years,
|
|
employee_count=world_cfg.num_employees,
|
|
market_task_count=world_cfg.num_market_tasks,
|
|
cfg=world_cfg,
|
|
start_date=start_dt,
|
|
)
|
|
result = seed_world_transactional(db, req)
|
|
company_id = result.company_id
|
|
|
|
insert_event(
|
|
db=db,
|
|
company_id=company_id,
|
|
event_type=EventType.HORIZON_END,
|
|
scheduled_at=horizon_end,
|
|
payload={"reason": "horizon_end"},
|
|
dedupe_key="horizon_end",
|
|
)
|
|
|
|
sim_state = SimState(
|
|
company_id=company_id,
|
|
sim_time=start_dt,
|
|
run_seed=seed,
|
|
horizon_end=horizon_end,
|
|
replenish_counter=0,
|
|
)
|
|
db.add(sim_state)
|
|
db.flush()
|
|
|
|
tasks_completed = 0
|
|
tasks_failed = 0
|
|
turn = 0
|
|
|
|
while True:
|
|
turn += 1
|
|
|
|
with session_scope(factory) as db:
|
|
sim_state = db.query(SimState).first()
|
|
company = db.query(Company).filter(Company.id == company_id).one()
|
|
|
|
if company.funds_cents < 0:
|
|
break
|
|
if sim_state.sim_time >= sim_state.horizon_end:
|
|
break
|
|
|
|
active_count = db.query(Task).filter(
|
|
Task.company_id == company_id,
|
|
Task.status == TaskStatus.ACTIVE,
|
|
).count()
|
|
|
|
# Accept up to 1 new task per turn (same pace as LLM agent).
|
|
# The LLM spends multiple tool calls to browse/accept/assign/dispatch
|
|
# one task, so it effectively accepts ~1 per turn.
|
|
newly_accepted = []
|
|
while active_count + len(newly_accepted) < MAX_CONCURRENT_TASKS and len(newly_accepted) < 1:
|
|
employees = db.query(Employee).filter(Employee.company_id == company_id).all()
|
|
employee_tiers = [emp.tier for emp in employees]
|
|
employee_ids = [emp.id for emp in employees]
|
|
|
|
n_will_be_active = active_count + len(newly_accepted)
|
|
candidates, max_prestige = _build_candidates(
|
|
db, company_id, sim_state, world_cfg, employee_tiers,
|
|
n_active=n_will_be_active,
|
|
)
|
|
|
|
context = {
|
|
"seed": seed,
|
|
"turn": turn + len(newly_accepted), # vary context per pick
|
|
"max_prestige": max_prestige,
|
|
}
|
|
chosen = strategy_fn(candidates, context)
|
|
if chosen is None:
|
|
break
|
|
|
|
task = chosen.task
|
|
newly_accepted.append(task.id)
|
|
|
|
# Accept the task — same logic as CLI task accept
|
|
reqs = db.query(TaskRequirement).filter(
|
|
TaskRequirement.task_id == task.id
|
|
).all()
|
|
|
|
# Apply trust work reduction (no reward multiplier)
|
|
if task.client_id is not None:
|
|
from yc_bench.db.models.client import ClientTrust
|
|
ct = db.query(ClientTrust).filter(
|
|
ClientTrust.company_id == company_id,
|
|
ClientTrust.client_id == task.client_id,
|
|
).one_or_none()
|
|
trust_level = float(ct.trust_level) if ct else 0.0
|
|
work_reduction = world_cfg.trust_work_reduction_max * (trust_level / world_cfg.trust_max)
|
|
for r in reqs:
|
|
r.required_qty = int(float(r.required_qty) * (1 - work_reduction))
|
|
|
|
max_domain_qty = max(float(r.required_qty) for r in reqs)
|
|
|
|
task.status = TaskStatus.PLANNED
|
|
task.company_id = company_id
|
|
task.accepted_at = sim_state.sim_time
|
|
task.deadline = _compute_deadline(sim_state.sim_time, max_domain_qty, world_cfg)
|
|
|
|
# Generate replacement
|
|
counter = sim_state.replenish_counter
|
|
sim_state.replenish_counter = counter + 1
|
|
|
|
from yc_bench.db.models.client import Client as ClientModel
|
|
replaced_client_index = 0
|
|
if task.client_id is not None:
|
|
clients = db.query(ClientModel).order_by(ClientModel.name).all()
|
|
for i, c in enumerate(clients):
|
|
if c.id == task.client_id:
|
|
replaced_client_index = i
|
|
break
|
|
|
|
replacement_spec_domains = None
|
|
if task.client_id is not None:
|
|
orig_client = db.query(ClientModel).filter(ClientModel.id == task.client_id).one_or_none()
|
|
if orig_client:
|
|
replacement_spec_domains = orig_client.specialty_domains
|
|
|
|
replacement = generate_replacement_task(
|
|
run_seed=sim_state.run_seed,
|
|
replenish_counter=counter,
|
|
replaced_prestige=task.required_prestige,
|
|
replaced_client_index=replaced_client_index,
|
|
cfg=world_cfg,
|
|
specialty_domains=replacement_spec_domains,
|
|
)
|
|
|
|
clients = db.query(ClientModel).order_by(ClientModel.name).all()
|
|
replacement_client = clients[replacement.client_index % len(clients)] if clients else None
|
|
replacement_client_id = replacement_client.id if replacement_client else None
|
|
|
|
replacement_row = Task(
|
|
id=uuid4(),
|
|
company_id=None,
|
|
client_id=replacement_client_id,
|
|
status=TaskStatus.MARKET,
|
|
title=replacement.title,
|
|
required_prestige=replacement.required_prestige,
|
|
reward_funds_cents=replacement.reward_funds_cents,
|
|
reward_prestige_delta=replacement.reward_prestige_delta,
|
|
skill_boost_pct=replacement.skill_boost_pct,
|
|
accepted_at=None, deadline=None, completed_at=None,
|
|
success=None, progress_milestone_pct=0,
|
|
required_trust=replacement.required_trust,
|
|
)
|
|
db.add(replacement_row)
|
|
for domain, qty in replacement.requirements.items():
|
|
db.add(TaskRequirement(
|
|
task_id=replacement_row.id,
|
|
domain=domain,
|
|
required_qty=qty,
|
|
completed_qty=0,
|
|
))
|
|
|
|
# Assign ALL employees to this task
|
|
for eid in employee_ids:
|
|
db.add(TaskAssignment(
|
|
task_id=task.id,
|
|
employee_id=eid,
|
|
assigned_at=sim_state.sim_time,
|
|
))
|
|
db.flush()
|
|
|
|
task.status = TaskStatus.ACTIVE
|
|
db.flush()
|
|
|
|
# Recalculate ETAs for all newly accepted tasks
|
|
if newly_accepted:
|
|
recalculate_etas(db, company_id, sim_state.sim_time,
|
|
impacted_task_ids=set(newly_accepted),
|
|
milestones=world_cfg.task_progress_milestones)
|
|
|
|
# Now advance time (only if we have active tasks)
|
|
total_active = active_count + len(newly_accepted)
|
|
if total_active == 0:
|
|
# No accessible tasks at all — advance to next event to let
|
|
# prestige/trust change, then try again.
|
|
next_event = fetch_next_event(db, company_id, sim_state.horizon_end)
|
|
if next_event is None:
|
|
break
|
|
adv = advance_time(db, company_id, next_event.scheduled_at)
|
|
if adv.bankrupt or adv.horizon_reached:
|
|
break
|
|
continue
|
|
|
|
next_event = fetch_next_event(db, company_id, sim_state.horizon_end)
|
|
if next_event is None:
|
|
break
|
|
adv = advance_time(db, company_id, next_event.scheduled_at)
|
|
for we in adv.wake_events:
|
|
if we.get("type") == "task_completed":
|
|
if we.get("success"):
|
|
tasks_completed += 1
|
|
else:
|
|
tasks_failed += 1
|
|
if adv.bankrupt or adv.horizon_reached:
|
|
break
|
|
|
|
|
|
# Final state + extract time series for plotting
|
|
from yc_bench.runner.extract import extract_time_series
|
|
import json
|
|
|
|
with session_scope(factory) as db:
|
|
company = db.query(Company).filter(Company.id == company_id).one()
|
|
sim_state = db.query(SimState).first()
|
|
|
|
final_balance = company.funds_cents
|
|
bankrupt = final_balance < 0
|
|
|
|
prestige_rows = db.query(CompanyPrestige).filter(
|
|
CompanyPrestige.company_id == company_id
|
|
).all()
|
|
max_p = max((float(p.prestige_level) for p in prestige_rows), default=1.0)
|
|
|
|
time_series = extract_time_series(lambda: session_scope(factory), company_id)
|
|
|
|
# Write result JSON (same format as LLM runner for plot compatibility)
|
|
result_json = {
|
|
"session_id": f"bot-{seed}-{bot_slug}",
|
|
"model": bot_slug,
|
|
"seed": seed,
|
|
"horizon_years": cfg.sim.horizon_years,
|
|
"turns_completed": turn,
|
|
"terminal": True,
|
|
"terminal_reason": "bankrupt" if bankrupt else "horizon_end",
|
|
"terminal_detail": "bankrupt" if bankrupt else "horizon_end",
|
|
"total_cost_usd": 0,
|
|
"time_series": time_series,
|
|
}
|
|
results_dir = Path("results")
|
|
results_dir.mkdir(exist_ok=True)
|
|
result_path = results_dir / f"yc_bench_result_{config_name}_{seed}_{bot_slug}.json"
|
|
with open(result_path, "w") as f:
|
|
json.dump(result_json, f, indent=2)
|
|
|
|
return {
|
|
"config": config_name,
|
|
"seed": seed,
|
|
"bot": bot_slug,
|
|
"turns": turn,
|
|
"final_balance_cents": final_balance,
|
|
"bankrupt": bankrupt,
|
|
"tasks_completed": tasks_completed,
|
|
"tasks_failed": tasks_failed,
|
|
"max_prestige": max_p,
|
|
"result_path": str(result_path),
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Run YC-Bench bot strategies")
|
|
parser.add_argument("--bot", choices=list(STRATEGIES.keys()), default=None,
|
|
help="Run only this bot (default: all)")
|
|
parser.add_argument("--config", choices=CONFIGS, default=None,
|
|
help="Run only this config (default: all)")
|
|
parser.add_argument("--seed", type=int, default=None,
|
|
help="Run only this seed (default: all)")
|
|
args = parser.parse_args()
|
|
|
|
bots = [args.bot] if args.bot else list(STRATEGIES.keys())
|
|
configs = [args.config] if args.config else CONFIGS
|
|
seeds = [args.seed] if args.seed else SEEDS
|
|
|
|
results = []
|
|
total = len(bots) * len(configs) * len(seeds)
|
|
print(f"Running {total} bot simulations...\n")
|
|
|
|
for bot_name in bots:
|
|
slug, strategy_fn = STRATEGIES[bot_name]
|
|
for config_name in configs:
|
|
for seed in seeds:
|
|
print(f" {slug} | {config_name} seed={seed} ...", end=" ", flush=True)
|
|
r = run_bot(config_name, seed, slug, strategy_fn)
|
|
results.append(r)
|
|
|
|
if r["bankrupt"]:
|
|
tag = "BANKRUPT"
|
|
else:
|
|
tag = f"${r['final_balance_cents']/100:,.0f}"
|
|
print(f"{tag} | {r['tasks_completed']} OK, {r['tasks_failed']} fail | prestige {r['max_prestige']:.1f} | {r['turns']} turns")
|
|
|
|
print(f"\n{'Bot':<16} {'Config':<12} {'Seed':<5} {'Final Balance':>14} {'OK':>4} {'Fail':>5} {'Prestige':>9}")
|
|
print("-" * 70)
|
|
for r in results:
|
|
fb = "BANKRUPT" if r["bankrupt"] else f"${r['final_balance_cents']/100:,.0f}"
|
|
print(f"{r['bot']:<16} {r['config']:<12} {r['seed']:<5} {fb:>14} {r['tasks_completed']:>4} {r['tasks_failed']:>5} {r['max_prestige']:>8.1f}")
|
|
|
|
bankrupt_count = sum(1 for r in results if r["bankrupt"])
|
|
print(f"\nBankruptcies: {bankrupt_count}/{len(results)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|