mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-26 17:13:13 +00:00
Initial commit
This commit is contained in:
commit
3a1c562827
59 changed files with 8418 additions and 0 deletions
49
src/yc_bench/runner/args.py
Normal file
49
src/yc_bench/runner/args.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunArgs:
|
||||
model: str
|
||||
seed: int
|
||||
horizon_years: int | None # None = defer to sim.horizon_years in config
|
||||
company_name: str
|
||||
start_date: str
|
||||
config_name: str = "default"
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="yc-bench run",
|
||||
)
|
||||
parser.add_argument("--model", required=True)
|
||||
parser.add_argument("--seed", type=int, required=True)
|
||||
parser.add_argument("--horizon-years", type=int, default=None,
|
||||
help="Simulation horizon in years (default: read from config's sim.horizon_years)")
|
||||
parser.add_argument("--company-name", default="BenchCo")
|
||||
parser.add_argument("--start-date", default="2025-01-01", help="Simulation start date (YYYY-MM-DD)")
|
||||
parser.add_argument(
|
||||
"--config", dest="config_name", default="default",
|
||||
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
|
||||
)
|
||||
return parser
|
||||
|
||||
def parse_run_args(argv):
|
||||
parser = build_parser()
|
||||
ns = parser.parse_args(argv)
|
||||
_validate(ns, parser)
|
||||
return RunArgs(
|
||||
model=ns.model,
|
||||
seed=ns.seed,
|
||||
horizon_years=ns.horizon_years,
|
||||
company_name=ns.company_name,
|
||||
start_date=ns.start_date,
|
||||
config_name=ns.config_name,
|
||||
)
|
||||
|
||||
def _validate(ns, parser):
|
||||
if ns.horizon_years is not None and ns.horizon_years <= 0:
|
||||
parser.error("--horizon-years must be int > 0")
|
||||
|
||||
__all__ = ["RunArgs", "build_parser", "parse_run_args"]
|
||||
226
src/yc_bench/runner/main.py
Normal file
226
src/yc_bench/runner/main.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Benchmark entrypoint: provisions DB, seeds world, runs agent loop to completion."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from ..agent.commands.executor import run_command
|
||||
from ..agent.loop import run_agent_loop
|
||||
from ..agent.run_state import RunState
|
||||
from ..agent.runtime.factory import build_runtime
|
||||
from ..agent.runtime.schemas import RuntimeSettings
|
||||
from ..db.session import build_engine, build_session_factory, session_scope, init_db
|
||||
from .args import parse_run_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
"""Accept ISO (2025-01-01) or legacy MM/DD/YYYY format."""
|
||||
fmt = "%Y-%m-%d" if "-" in date_str else "%m/%d/%Y"
|
||||
dt = datetime.strptime(date_str, fmt)
|
||||
return dt.replace(hour=9, minute=0, second=0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _wipe_simulation(db) -> None:
|
||||
"""Delete all simulation rows so the DB can be reseeded cleanly."""
|
||||
from ..db.models.ledger import LedgerEntry
|
||||
from ..db.models.task import Task, TaskAssignment, TaskRequirement
|
||||
from ..db.models.employee import Employee, EmployeeSkillRate
|
||||
from ..db.models.company import Company, CompanyPrestige
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.event import SimEvent
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
|
||||
db.query(Scratchpad).delete(synchronize_session=False)
|
||||
db.query(LedgerEntry).delete(synchronize_session=False)
|
||||
db.query(TaskAssignment).delete(synchronize_session=False)
|
||||
db.query(TaskRequirement).delete(synchronize_session=False)
|
||||
db.query(Task).delete(synchronize_session=False)
|
||||
db.query(SimEvent).delete(synchronize_session=False)
|
||||
db.query(EmployeeSkillRate).delete(synchronize_session=False)
|
||||
db.query(Employee).delete(synchronize_session=False)
|
||||
db.query(CompanyPrestige).delete(synchronize_session=False)
|
||||
db.query(Company).delete(synchronize_session=False)
|
||||
db.query(SimState).delete(synchronize_session=False)
|
||||
db.flush()
|
||||
|
||||
|
||||
def _init_simulation(db_factory, args, experiment_cfg, horizon_years):
|
||||
"""Seed world directly (no CLI round-trip) using the experiment WorldConfig."""
|
||||
from ..db.models.event import EventType
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.company import Company
|
||||
from ..core.events import insert_event
|
||||
from ..services.seed_world import SeedWorldRequest, seed_world_transactional
|
||||
|
||||
with db_factory() as db:
|
||||
existing = db.query(SimState).first()
|
||||
if existing is not None:
|
||||
company = db.query(Company).filter(Company.id == existing.company_id).first()
|
||||
bankrupt = company is not None and company.funds_cents < 0
|
||||
horizon_reached = existing.sim_time >= existing.horizon_end
|
||||
if bankrupt or horizon_reached:
|
||||
logger.info(
|
||||
"Existing simulation is terminal (bankrupt=%s horizon_reached=%s) — reseeding.",
|
||||
bankrupt, horizon_reached,
|
||||
)
|
||||
_wipe_simulation(db)
|
||||
else:
|
||||
logger.info("Resuming non-terminal simulation (company_id=%s, sim_time=%s).",
|
||||
existing.company_id, existing.sim_time)
|
||||
return existing.company_id
|
||||
|
||||
start_dt = _parse_date(args.start_date)
|
||||
horizon_end = start_dt.replace(year=start_dt.year + horizon_years)
|
||||
world = experiment_cfg.world
|
||||
|
||||
req = SeedWorldRequest(
|
||||
run_seed=args.seed,
|
||||
company_name=args.company_name,
|
||||
horizon_years=horizon_years,
|
||||
employee_count=world.num_employees,
|
||||
market_task_count=world.num_market_tasks,
|
||||
start_date=start_dt,
|
||||
cfg=world,
|
||||
)
|
||||
logger.info(
|
||||
"Initializing simulation: seed=%d employees=%d tasks=%d horizon=%dy",
|
||||
args.seed, world.num_employees, world.num_market_tasks, horizon_years,
|
||||
)
|
||||
result = seed_world_transactional(db, req)
|
||||
|
||||
insert_event(
|
||||
db=db,
|
||||
company_id=result.company_id,
|
||||
event_type=EventType.HORIZON_END,
|
||||
scheduled_at=horizon_end,
|
||||
payload={"reason": "horizon_end"},
|
||||
dedupe_key="horizon_end",
|
||||
)
|
||||
db.add(SimState(
|
||||
company_id=result.company_id,
|
||||
sim_time=start_dt,
|
||||
run_seed=args.seed,
|
||||
horizon_end=horizon_end,
|
||||
replenish_counter=0,
|
||||
))
|
||||
db.flush()
|
||||
|
||||
logger.info("Simulation initialized: company_id=%s", result.company_id)
|
||||
return result.company_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_benchmark(args):
|
||||
"""Run a full benchmark: migrate, seed, loop until terminal."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(name)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Load experiment config — preset name or path to a .toml file
|
||||
from yc_bench.config import load_config
|
||||
experiment_cfg = load_config(args.config_name)
|
||||
# Propagate config name to CLI subprocess calls (e.g. task accept → generate_replacement_task)
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = args.config_name
|
||||
# CLI --model always overrides the experiment's agent model
|
||||
experiment_cfg = experiment_cfg.model_copy(
|
||||
update={"agent": experiment_cfg.agent.model_copy(update={"model": args.model})}
|
||||
)
|
||||
# --horizon-years CLI flag overrides config; fall back to sim.horizon_years from config
|
||||
horizon_years = args.horizon_years if args.horizon_years is not None else experiment_cfg.sim.horizon_years
|
||||
|
||||
logger.info(
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years,
|
||||
)
|
||||
|
||||
# 1. Build engine and create all tables
|
||||
# If DATABASE_URL is not explicitly set, default to db/<seed>_<slug>.db
|
||||
if not os.environ.get("DATABASE_URL"):
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.seed}_{slug}.db"
|
||||
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory():
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
# 3. Init simulation using experiment world config
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 4. Build runtime settings from experiment AgentConfig
|
||||
agent_cfg = experiment_cfg.agent
|
||||
settings = RuntimeSettings(
|
||||
model=agent_cfg.model,
|
||||
temperature=agent_cfg.temperature,
|
||||
top_p=agent_cfg.top_p,
|
||||
request_timeout_seconds=agent_cfg.request_timeout_seconds,
|
||||
retry_max_attempts=agent_cfg.retry_max_attempts,
|
||||
retry_backoff_seconds=agent_cfg.retry_backoff_seconds,
|
||||
history_keep_rounds=agent_cfg.history_keep_rounds,
|
||||
system_prompt=agent_cfg.system_prompt,
|
||||
)
|
||||
runtime = build_runtime(settings, command_executor=run_command)
|
||||
|
||||
# 5. Build run state
|
||||
session_id = f"run-{args.seed}-{args.model}"
|
||||
run_state = RunState(
|
||||
session_id=session_id,
|
||||
seed=args.seed,
|
||||
model=args.model,
|
||||
horizon_years=horizon_years,
|
||||
)
|
||||
|
||||
# 6. Run agent loop
|
||||
loop_cfg = experiment_cfg.loop
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
)
|
||||
|
||||
# 7. Save full rollout (with transcript) and print summary
|
||||
rollout = final_state.full_rollout()
|
||||
summary = final_state.summary()
|
||||
logger.info("Run complete: %s", json.dumps(summary, indent=2))
|
||||
|
||||
# Write full rollout (includes transcript with commands)
|
||||
slug = args.model.replace("/", "_")
|
||||
results_dir = Path("results")
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
results_path = results_dir / f"yc_bench_result_{args.seed}_{slug}.json"
|
||||
results_path.write_text(json.dumps(rollout, indent=2))
|
||||
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
|
||||
|
||||
return 0 if final_state.terminal_reason != "error" else 1
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
load_dotenv(find_dotenv(usecwd=True), override=False) # searches cwd upward for .env
|
||||
args = parse_run_args(argv)
|
||||
return run_benchmark(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
33
src/yc_bench/runner/session.py
Normal file
33
src/yc_bench/runner/session.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Session utilities for the benchmark runner."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.models.event import EventType
|
||||
from ..db.models.session import Session as SessionModel
|
||||
|
||||
|
||||
def open_session(db: Session, company_id, wake_reason: EventType) -> SessionModel:
|
||||
"""Create and persist a new session record."""
|
||||
session = SessionModel(
|
||||
id=uuid4(),
|
||||
company_id=company_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=None,
|
||||
wake_reason=wake_reason,
|
||||
)
|
||||
db.add(session)
|
||||
db.flush()
|
||||
return session
|
||||
|
||||
|
||||
def close_session(db: Session, session: SessionModel) -> None:
|
||||
"""Close an open session record."""
|
||||
session.ended_at = datetime.now(timezone.utc)
|
||||
db.flush()
|
||||
|
||||
|
||||
__all__ = ["open_session", "close_session"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue