mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-19 12:58:03 +00:00
383 lines
16 KiB
Python
383 lines
16 KiB
Python
"""Smart bot: domain-focus prestige climbing with harvest phase.
|
|
|
|
Strategy:
|
|
1. Pick 2 domains with most accessible tasks
|
|
2. Only accept tasks whose domains are ALL within the focused pair
|
|
3. Sort by prestige delta (climb fast)
|
|
4. When no matching tasks exist, rotate to a new pair sharing 1 domain
|
|
5. Once both focused domains reach prestige 3, switch to max reward
|
|
|
|
Uses bot_runner's candidate construction, but keeps its own run loop so it
|
|
can persist strategy state without modifying bot_runner.py.
|
|
|
|
Usage:
|
|
uv run python scripts/smart_bot.py --config hard --seed 1
|
|
uv run python scripts/smart_bot.py --config hard --seed 42 --seed 19 --seed 64
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from collections import Counter
|
|
from itertools import combinations
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from bot_runner import _build_candidates
|
|
|
|
|
|
HARVEST_PRESTIGE_THRESHOLD = 3.0
|
|
MAX_CONCURRENT_TASKS = 1
|
|
|
|
|
|
def _pick_best_pair(candidates, must_include=None):
|
|
"""Pick the 2-domain pair covering the most candidates.
|
|
If must_include is set, one domain must be in that set."""
|
|
all_domains = sorted(set(str(d) for c in candidates for d in c.domains))
|
|
if len(all_domains) < 2:
|
|
return set(all_domains)
|
|
pair_counts = Counter()
|
|
for c in candidates:
|
|
task_domains = set(str(d) for d in c.domains)
|
|
for pair in combinations(all_domains, 2):
|
|
if task_domains.issubset(set(pair)):
|
|
if must_include is None or any(d in must_include for d in pair):
|
|
pair_counts[pair] += 1
|
|
if pair_counts:
|
|
return set(pair_counts.most_common(1)[0][0])
|
|
return set(all_domains[:2])
|
|
|
|
|
|
def strategy_smart(candidates, context):
|
|
"""Domain-focus + prestige climb + harvest."""
|
|
if not candidates:
|
|
return None
|
|
|
|
# Pick initial focused domains
|
|
if "focused_domains" not in context:
|
|
context["focused_domains"] = _pick_best_pair(candidates)
|
|
context["harvest_mode"] = False
|
|
|
|
focused = context["focused_domains"]
|
|
|
|
# Filter to tasks in focused domains
|
|
matching = [c for c in candidates if all(str(d) in focused for d in c.domains)]
|
|
|
|
# Rotate if empty: pick new pair sharing 1 domain with current
|
|
if not matching:
|
|
new_focus = _pick_best_pair(candidates, must_include=focused)
|
|
if new_focus != focused:
|
|
context["focused_domains"] = new_focus
|
|
focused = new_focus
|
|
matching = [c for c in candidates if all(str(d) in focused for d in c.domains)]
|
|
|
|
# If still empty after rotation, take any task
|
|
if not matching:
|
|
matching = candidates
|
|
|
|
# Once both focused domains hit prestige 3, always harvest by reward.
|
|
# Before that, alternate: odd turns = prestige delta, even turns = reward.
|
|
# This balances prestige growth with cash flow.
|
|
if not context["harvest_mode"] and "prestige_map" in context:
|
|
p_map = context["prestige_map"]
|
|
if all(
|
|
any(float(v) >= HARVEST_PRESTIGE_THRESHOLD for k, v in p_map.items() if str(k) == d)
|
|
for d in focused
|
|
):
|
|
context["harvest_mode"] = True
|
|
|
|
if context["harvest_mode"]:
|
|
return max(matching, key=lambda c: c.reward_cents)
|
|
else:
|
|
turn = context.get("turn", 0)
|
|
if turn % 3 == 0:
|
|
# Every 3rd turn: pick highest reward (cash flow)
|
|
return max(matching, key=lambda c: c.reward_cents)
|
|
else:
|
|
# Other turns: pick highest prestige delta (climb)
|
|
return max(matching, key=lambda c: (c.prestige_delta, -c.completion_hours))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Smart bot: domain-focus + prestige climb + harvest")
|
|
parser.add_argument("--config", default="hard")
|
|
parser.add_argument("--seed", type=int, action="append", required=True)
|
|
args = parser.parse_args()
|
|
from decimal import Decimal
|
|
from datetime import datetime, timezone
|
|
from uuid import uuid4
|
|
|
|
from yc_bench.config import load_config
|
|
from yc_bench.core.business_time import add_business_hours
|
|
from yc_bench.core.engine import advance_time
|
|
from yc_bench.core.eta import recalculate_etas
|
|
from yc_bench.core.events import fetch_next_event, insert_event
|
|
from yc_bench.db.models.company import Company, CompanyPrestige
|
|
from yc_bench.db.models.employee import Employee
|
|
from yc_bench.db.models.event import EventType
|
|
from yc_bench.db.models.sim_state import SimState
|
|
from yc_bench.db.models.task import Task, TaskAssignment, TaskRequirement, TaskStatus
|
|
from yc_bench.db.session import build_engine, build_session_factory, init_db, session_scope
|
|
from yc_bench.services.generate_tasks import generate_replacement_task
|
|
from yc_bench.services.seed_world import SeedWorldRequest, seed_world_transactional
|
|
|
|
def _compute_deadline(accepted_at, max_domain_qty, world_cfg):
|
|
work_hours = world_cfg.workday_end_hour - world_cfg.workday_start_hour
|
|
biz_days = max(world_cfg.deadline_min_biz_days, int(max_domain_qty / world_cfg.deadline_qty_per_day))
|
|
return add_business_hours(accepted_at, Decimal(str(biz_days)) * Decimal(str(work_hours)))
|
|
|
|
results = []
|
|
for seed in args.seed:
|
|
config_name = args.config
|
|
cfg = load_config(config_name)
|
|
world_cfg = cfg.world
|
|
bot_slug = "smart_bot"
|
|
|
|
db_dir = Path("db")
|
|
db_dir.mkdir(exist_ok=True)
|
|
db_path = db_dir / f"{config_name}_{seed}_{bot_slug}.db"
|
|
if db_path.exists():
|
|
db_path.unlink()
|
|
|
|
db_url = f"sqlite:///{db_path}"
|
|
os.environ["DATABASE_URL"] = db_url
|
|
os.environ["YC_BENCH_EXPERIMENT"] = config_name
|
|
|
|
engine = build_engine(db_url)
|
|
init_db(engine)
|
|
factory = build_session_factory(engine)
|
|
|
|
with session_scope(factory) as db:
|
|
start_dt = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
|
horizon_end = start_dt.replace(year=start_dt.year + cfg.sim.horizon_years)
|
|
|
|
req = SeedWorldRequest(
|
|
run_seed=seed,
|
|
company_name="Smart Bot",
|
|
horizon_years=cfg.sim.horizon_years,
|
|
employee_count=world_cfg.num_employees,
|
|
market_task_count=world_cfg.num_market_tasks,
|
|
cfg=world_cfg,
|
|
start_date=start_dt,
|
|
)
|
|
result = seed_world_transactional(db, req)
|
|
company_id = result.company_id
|
|
|
|
insert_event(
|
|
db=db, company_id=company_id,
|
|
event_type=EventType.HORIZON_END,
|
|
scheduled_at=horizon_end,
|
|
payload={"reason": "horizon_end"},
|
|
dedupe_key="horizon_end",
|
|
)
|
|
db.add(SimState(
|
|
company_id=company_id, sim_time=start_dt,
|
|
run_seed=seed, horizon_end=horizon_end, replenish_counter=0,
|
|
))
|
|
db.flush()
|
|
|
|
tasks_completed = 0
|
|
tasks_failed = 0
|
|
turn = 0
|
|
strategy_context = {} # persistent across turns
|
|
|
|
while True:
|
|
turn += 1
|
|
|
|
with session_scope(factory) as db:
|
|
sim_state = db.query(SimState).first()
|
|
company = db.query(Company).filter(Company.id == company_id).one()
|
|
|
|
if company.funds_cents < 0:
|
|
break
|
|
if sim_state.sim_time >= sim_state.horizon_end:
|
|
break
|
|
|
|
active_count = db.query(Task).filter(
|
|
Task.company_id == company_id,
|
|
Task.status == TaskStatus.ACTIVE,
|
|
).count()
|
|
|
|
newly_accepted = []
|
|
while active_count + len(newly_accepted) < MAX_CONCURRENT_TASKS and len(newly_accepted) < 1:
|
|
employees = db.query(Employee).filter(Employee.company_id == company_id).all()
|
|
employee_tiers = [emp.tier for emp in employees]
|
|
employee_ids = [emp.id for emp in employees]
|
|
|
|
n_will_be_active = active_count + len(newly_accepted)
|
|
candidates, max_prestige = _build_candidates(
|
|
db, company_id, sim_state, world_cfg, employee_tiers,
|
|
n_active=n_will_be_active,
|
|
)
|
|
|
|
# Build prestige map for strategy
|
|
prestige_rows = db.query(CompanyPrestige).filter(
|
|
CompanyPrestige.company_id == company_id).all()
|
|
prestige_map = {p.domain: float(p.prestige_level) for p in prestige_rows}
|
|
|
|
strategy_context.update({
|
|
"seed": seed,
|
|
"turn": turn + len(newly_accepted),
|
|
"max_prestige": max_prestige,
|
|
"prestige_map": prestige_map,
|
|
})
|
|
|
|
# Add domain info to candidates (bot_runner doesn't include it)
|
|
for c in candidates:
|
|
reqs = db.query(TaskRequirement).filter(
|
|
TaskRequirement.task_id == c.task.id).all()
|
|
c.domains = [r.domain for r in reqs]
|
|
|
|
chosen = strategy_smart(candidates, strategy_context)
|
|
if chosen is None:
|
|
break
|
|
|
|
task = chosen.task
|
|
newly_accepted.append(task.id)
|
|
|
|
reqs = db.query(TaskRequirement).filter(
|
|
TaskRequirement.task_id == task.id).all()
|
|
|
|
# Apply trust work reduction
|
|
if task.client_id is not None:
|
|
from yc_bench.db.models.client import ClientTrust
|
|
ct = db.query(ClientTrust).filter(
|
|
ClientTrust.company_id == company_id,
|
|
ClientTrust.client_id == task.client_id,
|
|
).one_or_none()
|
|
trust_level = float(ct.trust_level) if ct else 0.0
|
|
work_reduction = world_cfg.trust_work_reduction_max * (trust_level / world_cfg.trust_max)
|
|
for r in reqs:
|
|
r.required_qty = int(float(r.required_qty) * (1 - work_reduction))
|
|
|
|
max_domain_qty = max(float(r.required_qty) for r in reqs)
|
|
|
|
task.status = TaskStatus.PLANNED
|
|
task.company_id = company_id
|
|
task.accepted_at = sim_state.sim_time
|
|
task.deadline = _compute_deadline(sim_state.sim_time, max_domain_qty, world_cfg)
|
|
|
|
# Generate replacement
|
|
counter = sim_state.replenish_counter
|
|
sim_state.replenish_counter = counter + 1
|
|
|
|
from yc_bench.db.models.client import Client as ClientModel
|
|
replaced_client_index = 0
|
|
if task.client_id is not None:
|
|
clients = db.query(ClientModel).order_by(ClientModel.name).all()
|
|
for i, c in enumerate(clients):
|
|
if c.id == task.client_id:
|
|
replaced_client_index = i
|
|
break
|
|
|
|
replacement_spec_domains = None
|
|
if task.client_id is not None:
|
|
orig_client = db.query(ClientModel).filter(ClientModel.id == task.client_id).one_or_none()
|
|
if orig_client:
|
|
replacement_spec_domains = orig_client.specialty_domains
|
|
|
|
replacement = generate_replacement_task(
|
|
run_seed=sim_state.run_seed,
|
|
replenish_counter=counter,
|
|
replaced_prestige=task.required_prestige,
|
|
replaced_client_index=replaced_client_index,
|
|
cfg=world_cfg,
|
|
specialty_domains=replacement_spec_domains,
|
|
)
|
|
|
|
clients = db.query(ClientModel).order_by(ClientModel.name).all()
|
|
replacement_client = clients[replacement.client_index % len(clients)] if clients else None
|
|
replacement_client_id = replacement_client.id if replacement_client else None
|
|
|
|
replacement_row = Task(
|
|
id=uuid4(),
|
|
company_id=None,
|
|
client_id=replacement_client_id,
|
|
status=TaskStatus.MARKET,
|
|
title=replacement.title,
|
|
required_prestige=replacement.required_prestige,
|
|
reward_funds_cents=replacement.reward_funds_cents,
|
|
reward_prestige_delta=replacement.reward_prestige_delta,
|
|
skill_boost_pct=replacement.skill_boost_pct,
|
|
accepted_at=None, deadline=None, completed_at=None,
|
|
success=None, progress_milestone_pct=0,
|
|
required_trust=replacement.required_trust,
|
|
)
|
|
db.add(replacement_row)
|
|
for domain, qty in replacement.requirements.items():
|
|
db.add(TaskRequirement(
|
|
task_id=replacement_row.id,
|
|
domain=domain,
|
|
required_qty=qty,
|
|
completed_qty=0,
|
|
))
|
|
|
|
for eid in employee_ids:
|
|
db.add(TaskAssignment(
|
|
task_id=task.id,
|
|
employee_id=eid,
|
|
assigned_at=sim_state.sim_time,
|
|
))
|
|
db.flush()
|
|
|
|
task.status = TaskStatus.ACTIVE
|
|
db.flush()
|
|
|
|
if newly_accepted:
|
|
recalculate_etas(db, company_id, sim_state.sim_time,
|
|
impacted_task_ids=set(newly_accepted),
|
|
milestones=world_cfg.task_progress_milestones)
|
|
|
|
total_active = active_count + len(newly_accepted)
|
|
if total_active == 0:
|
|
next_event = fetch_next_event(db, company_id, sim_state.horizon_end)
|
|
if next_event is None:
|
|
break
|
|
adv = advance_time(db, company_id, next_event.scheduled_at)
|
|
if adv.bankrupt or adv.horizon_reached:
|
|
break
|
|
continue
|
|
|
|
next_event = fetch_next_event(db, company_id, sim_state.horizon_end)
|
|
if next_event is None:
|
|
break
|
|
adv = advance_time(db, company_id, next_event.scheduled_at)
|
|
for we in adv.wake_events:
|
|
if we.get("type") == "task_completed":
|
|
if we.get("success"):
|
|
tasks_completed += 1
|
|
else:
|
|
tasks_failed += 1
|
|
if adv.bankrupt or adv.horizon_reached:
|
|
break
|
|
|
|
# Final state
|
|
with session_scope(factory) as db:
|
|
company = db.query(Company).filter(Company.id == company_id).one()
|
|
final_balance = company.funds_cents
|
|
bankrupt = final_balance < 0
|
|
prestige_rows = db.query(CompanyPrestige).filter(
|
|
CompanyPrestige.company_id == company_id).all()
|
|
max_p = max((float(p.prestige_level) for p in prestige_rows), default=1.0)
|
|
|
|
status = "BANKRUPT" if bankrupt else f"${final_balance / 100:,.0f}"
|
|
print(f" smart_bot | {config_name} seed={seed} ... {status} | {tasks_completed} OK, {tasks_failed} fail | prestige {max_p:.1f} | {turn} turns")
|
|
results.append({
|
|
"seed": seed, "final_balance": final_balance, "bankrupt": bankrupt,
|
|
"tasks_ok": tasks_completed, "tasks_fail": tasks_failed, "prestige": max_p,
|
|
})
|
|
|
|
print()
|
|
print(f"{'Bot':24s} {'Config':12s} {'Seed':>6s} {'Final Balance':>15s} {'OK':>4s} {'Fail':>5s} {'Prestige':>10s}")
|
|
print("-" * 80)
|
|
for r in results:
|
|
status = "BANKRUPT" if r["bankrupt"] else f"${r['final_balance'] / 100:,.0f}"
|
|
print(f"{'smart_bot':24s} {config_name:12s} {r['seed']:>6d} {status:>15s} {r['tasks_ok']:>4d} {r['tasks_fail']:>5d} {r['prestige']:>10.1f}")
|
|
bankruptcies = sum(1 for r in results if r["bankrupt"])
|
|
print(f"\nBankruptcies: {bankruptcies}/{len(results)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|