Initial commit

This commit is contained in:
adit jain 2026-02-25 02:16:35 -08:00
commit 3a1c562827
59 changed files with 8418 additions and 0 deletions

3
src/yc_bench/__init__.py Normal file
View file

@ -0,0 +1,3 @@
"""YC-Bench: long-horizon deterministic benchmark for LLM agents."""
__version__ = "0.1.0"

3
src/yc_bench/__main__.py Normal file
View file

@ -0,0 +1,3 @@
from yc_bench.cli import app_main
app_main()

View file

@ -0,0 +1,86 @@
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from .policy import parse_bench_command
def _resolve_yc_bench() -> str:
"""Find the yc-bench script in the same directory as the running Python."""
venv_bin = Path(sys.executable).parent
candidate = venv_bin / "yc-bench"
if candidate.exists():
return str(candidate)
return "yc-bench" # fallback to PATH lookup
def run_command(
command,
*,
cwd=None,
timeout_seconds=60.0,
command_exists = None,
):
ok, err, argv = parse_bench_command(command)
if not ok:
return {
"ok": False,
"exit_code": 2,
"stdout": "",
"stderr": err or "invalid command",
"sim_time": None,
"command": command if isinstance(command, str) else str(command),
}
if command_exists is not None and not command_exists(argv):
return {
"ok": False,
"exit_code": 127,
"stdout": "",
"stderr": f"command not found: {' '.join(argv)}",
"sim_time": None,
"command": command,
}
# Resolve yc-bench to the venv-local script
if argv[0] == "yc-bench":
argv[0] = _resolve_yc_bench()
try:
proc = subprocess.run(
argv,
shell=False,
cwd=cwd,
text=True,
capture_output=True,
timeout=timeout_seconds,
)
return {
"ok": proc.returncode == 0,
"exit_code": proc.returncode,
"stdout": proc.stdout,
"stderr": proc.stderr,
"sim_time": None,
"command": command
}
except subprocess.TimeoutExpired as exc:
return {
"ok": False,
"exit_code": 124,
"stdout": exc.stdout or "",
"stderr": f"command timed out after {timeout_seconds} seconds",
"sim_time": None,
"command": command,
}
except Exception as exc:
return {
"ok": False,
"exit_code": 1,
"stdout": "",
"stderr": str(exc),
"sim_time": None,
"command": command,
}
__all__ = ["run_command"]

View file

@ -0,0 +1,22 @@
from __future__ import annotations
import shlex
def parse_bench_command(command):
if not isinstance(command, str) or not command.strip():
return False, "command must be a non-empty string", None
try:
argv = shlex.split(command)
except ValueError as exc:
return False, f"invalid command syntax: {command}", None
if not argv:
return False, "command must be a non-empty string", None
if argv[0] != "yc-bench":
return False, "only top-level `yc-bench` commands are allowed", None
return True, None, argv
__all__ = ["parse_bench_command"]

213
src/yc_bench/agent/loop.py Normal file
View file

@ -0,0 +1,213 @@
from __future__ import annotations
import json
import logging
from sqlalchemy import func
from sqlalchemy.orm import Session
from ..db.models.company import Company
from ..db.models.employee import Employee
from ..db.models.sim_state import SimState
from ..db.models.task import Task, TaskStatus
from .prompt import build_initial_user_prompt, build_turn_context
from .run_state import RunState, TerminalReason
from .runtime.base import AgentRuntime
from .runtime.schemas import RuntimeTurnRequest
logger = logging.getLogger(__name__)
def _snapshot_state(db: Session, company_id):
"""Read current simulation state from DB for context building."""
sim_state = db.query(SimState).filter(SimState.company_id == company_id).one()
company = db.query(Company).filter(Company.id == company_id).one()
active_count = db.query(func.count(Task.id)).filter(
Task.company_id == company_id,
Task.status == TaskStatus.ACTIVE,
).scalar() or 0
planned_count = db.query(func.count(Task.id)).filter(
Task.company_id == company_id,
Task.status == TaskStatus.PLANNED,
).scalar() or 0
employee_count = db.query(func.count(Employee.id)).filter(
Employee.company_id == company_id,
).scalar() or 0
monthly_payroll = db.query(func.sum(Employee.salary_cents)).filter(
Employee.company_id == company_id,
).scalar() or 0
return {
"sim_time": sim_state.sim_time.isoformat(),
"horizon_end": sim_state.horizon_end.isoformat(),
"funds_cents": company.funds_cents,
"active_tasks": active_count,
"planned_tasks": planned_count,
"employee_count": employee_count,
"monthly_payroll_cents": int(monthly_payroll),
"bankrupt": company.funds_cents < 0,
}
def _extract_commands(raw_result) -> list[str]:
"""Extract CLI commands the agent executed from the raw_result dict.
raw_result is {"tool_calls": [{"command": "yc-bench ...", "result": "..."}, ...]}.
"""
commands = []
try:
for tc in (raw_result or {}).get("tool_calls", []):
cmd = tc.get("command", "")
result = tc.get("result", "")
if cmd:
commands.append(f"{cmd} -> {result[:500]}" if result else cmd)
except Exception as exc:
logger.debug("Could not extract commands from raw_result: %s", exc)
return commands
def _auto_resume(command_executor) -> tuple[dict | None, str | None]:
"""Force-advance simulation time via sim resume. Returns (payload, error_msg)."""
try:
raw = command_executor("yc-bench sim resume")
stdout = raw.get("stdout", "")
if isinstance(stdout, str) and stdout.strip():
payload = json.loads(stdout)
if isinstance(payload, dict):
return payload, None
return None, raw.get("stderr", "sim resume returned no payload")
except Exception as exc:
return None, str(exc)
def _build_resume_handoff_user_input(payload: dict) -> str:
"""Build next-turn user message from sim resume payload."""
wake_events = payload.get("wake_events") or []
lines = [
"Simulation advanced via `yc-bench sim resume`.",
f"- old_sim_time: {payload.get('old_sim_time')}",
f"- new_sim_time: {payload.get('new_sim_time')}",
f"- checkpoint_event_type: {payload.get('checkpoint_event_type')}",
f"- events_processed: {payload.get('events_processed')}",
f"- payrolls_applied: {payload.get('payrolls_applied')}",
f"- balance_delta: {payload.get('balance_delta')}",
f"- bankrupt: {payload.get('bankrupt')}",
f"- horizon_reached: {payload.get('horizon_reached')}",
f"- terminal_reason: {payload.get('terminal_reason')}",
f"- wake_events_count: {len(wake_events)}",
]
for ev in wake_events:
lines.append(f"- wake_event: {ev}")
lines.append(
"You are now at the new checkpoint. Query live state via yc-bench commands and decide next actions."
)
return "\n".join(lines)
def run_agent_loop(
runtime: AgentRuntime,
db_factory,
company_id,
run_state: RunState,
command_executor=None,
auto_advance_after_turns: int = 10,
max_turns: int | None = None,
) -> RunState:
run_state.start()
turns_since_resume = 0 # consecutive turns without sim resume
logger.info(
"Starting agent loop: model=%s seed=%d auto_advance_after=%d turns max_turns=%s",
run_state.model, run_state.seed, auto_advance_after_turns, max_turns or "unlimited",
)
while not run_state.terminal:
if max_turns is not None and run_state.turn_count >= max_turns:
logger.info("Reached max_turns=%d, stopping.", max_turns)
run_state.mark_terminal(TerminalReason.ERROR, f"max_turns={max_turns} reached")
break
turn_num = run_state.turn_count + 1
if run_state.turn_count == 0:
with db_factory() as db:
snapshot = _snapshot_state(db, company_id)
user_input = build_initial_user_prompt(**snapshot)
elif run_state.next_user_input is not None:
user_input = run_state.next_user_input
run_state.next_user_input = None
else:
with db_factory() as db:
snapshot = _snapshot_state(db, company_id)
user_input = build_turn_context(
turn_number=turn_num,
**snapshot,
)
try:
result = runtime.run_turn(
RuntimeTurnRequest(
session_id=run_state.session_id,
user_input=user_input,
)
)
agent_output = result.final_output
except Exception as exc:
logger.error("Runtime error on turn %d: %s", turn_num, exc)
run_state.mark_terminal(TerminalReason.ERROR, str(exc))
break
commands_executed = _extract_commands(result.raw_result)
resume_payload = result.resume_payload
if result.checkpoint_advanced and resume_payload is not None:
logger.info("Turn %d: agent called sim resume.", turn_num)
turns_since_resume = 0
else:
turns_since_resume += 1
if command_executor is not None and turns_since_resume >= auto_advance_after_turns:
logger.info(
"Turn %d: %d consecutive turns without sim resume; auto-advancing.",
turn_num, turns_since_resume,
)
resume_payload, err = _auto_resume(command_executor)
if err:
logger.warning("Auto-resume failed on turn %d: %s", turn_num, err)
else:
turns_since_resume = 0
if resume_payload is not None:
run_state.next_user_input = _build_resume_handoff_user_input(resume_payload)
reason = resume_payload.get("terminal_reason")
if reason == "bankruptcy":
run_state.mark_terminal(TerminalReason.BANKRUPTCY, reason)
elif reason == "horizon_end":
run_state.mark_terminal(TerminalReason.HORIZON_END, reason)
if run_state.terminal:
logger.info("Terminal after turn %d: %s", turn_num, reason)
run_state.record_turn(
user_input=user_input,
agent_output=agent_output,
commands_executed=commands_executed,
turn_cost_usd=getattr(result, "turn_cost_usd", 0.0),
)
logger.info(
"Turn %d complete. Agent output length: %d, commands: %d",
turn_num, len(agent_output), len(commands_executed),
)
logger.info(
"Agent loop finished: turns=%d terminal=%s reason=%s",
run_state.turn_count,
run_state.terminal,
run_state.terminal_reason,
)
return run_state
__all__ = ["run_agent_loop"]

View file

@ -0,0 +1,185 @@
"""System prompt and user-message builders for the YC-Bench agent."""
from __future__ import annotations
SYSTEM_PROMPT = """\
You are the autonomous CEO of an AI startup in a deterministic business simulation. \
Your goal is to maximize company prestige and funds over the simulation horizon while avoiding bankruptcy.
## How It Works
- All actions are performed via the `run_command` tool, which executes `yc-bench` CLI commands.
- All commands return JSON. Parse the output to make decisions.
- Simulation progression and event processing are managed by the benchmark runtime.
- Business hours are weekdays 09:00-18:00. Nights, weekends, and Feb 29 are skipped.
- Payroll is deducted automatically on the first business day of each month.
- If funds go below zero after any event or payroll, the company goes bankrupt and the run ends.
## Available Commands
### Observe
- `yc-bench company status` funds, prestige, employee count, payroll, bankruptcy risk
- `yc-bench employee list` list all employees with IDs, salaries, skill rates, and current assignments
- `yc-bench market browse [--domain X] [--required-prestige-lte N] [--reward-min-cents N] [--limit N] [--offset N]` browse available tasks (default limit 50; the response includes a `total` field if total > 50, paginate with --offset to see more)
- `yc-bench task list [--status X]` list your tasks (planned, active, completed, cancelled)
- `yc-bench task inspect --task-id <UUID>` detailed task info (requirements, assignments, progress)
- `yc-bench finance ledger [--from MM/DD/YYYY] [--to MM/DD/YYYY] [--category X]` financial history
- `yc-bench report monthly [--from-month YYYY-MM] [--to-month YYYY-MM]` monthly P&L
- `yc-bench scratchpad read` read your persistent notes
### Memory (scratchpad)
- `yc-bench scratchpad write --content "text"` overwrite scratchpad with new notes
- `yc-bench scratchpad append --content "text"` append a line to existing notes
- `yc-bench scratchpad clear` erase all notes
- Use the scratchpad to store key decisions, task deadlines, employee assignments, and strategy notes. Context is periodically truncated anything important should be written here.
### Act
- `yc-bench task accept --task-id <UUID>` accept a market task (sets deadline, generates replacement)
- `yc-bench task assign --task-id <UUID> --employee-id <UUID>` assign employee to task
- `yc-bench task dispatch --task-id <UUID>` start work on a planned task (must have assignments)
- `yc-bench task cancel --task-id <UUID> --reason "text"` cancel a task (prestige penalty: 1.2x reward delta)
- `yc-bench sim resume` advance simulation to the next checkpoint event and return wake events
## Strategy Guidelines
1. **Check company status first** to understand your financial position and runway.
2. **Browse the market** for tasks you can accept (check prestige requirements).
3. **Accept tasks** that match your capabilities and offer good reward-to-risk ratio.
4. **Assign employees strategically** employees split throughput across active tasks. Focus employees on fewer tasks for faster completion.
5. **Dispatch tasks** once assigned, then continue monitoring progress/events via status and reports.
6. **Monitor deadlines** completing after deadline causes failure (0.8x prestige penalty). Cancel hopeless tasks early (1.2x penalty, but stops bleeding time).
7. **Watch payroll** monthly salaries are deducted automatically. Don't let runway drop to zero.
8. **Use status checks** to track critical milestones and risks.
9. **Successful tasks** award funds + prestige + employee skill boosts. Build momentum.
## Key Rules
- Task completion at or before deadline = success (reward funds + prestige + skill boost)
- Task completion after deadline = failure (0.8x prestige penalty, no reward)
- Task cancellation = 1.2x prestige penalty per domain
- Employee throughput = base_rate / number_of_active_tasks_assigned
- Time advances only when you run `yc-bench sim resume`
- Prestige is clamped [1, 10]. Funds are in cents.
"""
def build_turn_context(
turn_number: int,
sim_time: str,
horizon_end: str,
funds_cents: int,
active_tasks: int,
planned_tasks: int,
employee_count: int,
monthly_payroll_cents: int,
bankrupt: bool,
last_wake_events: list | None = None,
) -> str:
"""Build per-turn context message injected as user input."""
runway_months = (
round(funds_cents / monthly_payroll_cents, 1)
if monthly_payroll_cents > 0
else None
)
runway_str = f"~{runway_months} months" if runway_months is not None else "∞ (no payroll)"
parts = [
f"## Turn {turn_number} — Simulation State",
f"- **Current time**: {sim_time}",
f"- **Horizon end**: {horizon_end}",
f"- **Funds**: ${funds_cents / 100:,.2f} ({funds_cents} cents)",
f"- **Monthly payroll**: ${monthly_payroll_cents / 100:,.2f}",
f"- **Runway**: {runway_str}",
f"- **Employees**: {employee_count}",
f"- **Active tasks**: {active_tasks}",
f"- **Planned tasks**: {planned_tasks}",
]
if bankrupt:
parts.append("\n**WARNING: Company is bankrupt. Run will terminate.**")
if last_wake_events:
parts.append("\n### Events since last turn:")
for ev in last_wake_events:
ev_type = ev.get("type", "unknown")
if ev_type == "task_completed":
success = ev.get("success", False)
tid = ev.get("task_id", "?")
parts.append(f"- Task {tid}: {'SUCCESS' if success else 'FAILED'}")
elif ev_type == "task_half":
parts.append(f"- Task {ev.get('task_id', '?')}: 50% progress reached")
elif ev_type == "horizon_end":
parts.append("- **Horizon end reached. Simulation complete.**")
elif ev_type == "bankruptcy":
parts.append("- **BANKRUPTCY. Simulation terminated.**")
else:
parts.append(f"- Event: {ev_type}")
if active_tasks == 0 and planned_tasks == 0:
parts.append(
"\n**ACTION REQUIRED**: No tasks are running. "
"Accept a task, assign employees to it, dispatch it, then call `yc-bench sim resume`. "
"Do this now — every turn without active tasks burns runway."
)
elif planned_tasks > 0 and active_tasks == 0:
parts.append(
"\n**ACTION REQUIRED**: You have planned tasks but none are dispatched. "
"Assign employees and dispatch now, then call `yc-bench sim resume`."
)
else:
parts.append("\nDecide your next actions. Use `run_command` to execute CLI commands.")
return "\n".join(parts)
def build_initial_user_prompt(
sim_time: str,
horizon_end: str,
funds_cents: int,
active_tasks: int,
planned_tasks: int,
employee_count: int,
monthly_payroll_cents: int,
bankrupt: bool,
) -> str:
"""Build the one-time initial user message at run start."""
runway_months = (
round(funds_cents / monthly_payroll_cents, 1)
if monthly_payroll_cents > 0
else float("inf")
)
runway_months = (
round(funds_cents / monthly_payroll_cents, 1)
if monthly_payroll_cents > 0
else None
)
runway_str = f"~{runway_months} months" if runway_months is not None else ""
lines = [
"## Simulation Start — Take Immediate Action",
f"- current_time: {sim_time}",
f"- horizon_end: {horizon_end}",
f"- funds: ${funds_cents / 100:,.2f}",
f"- monthly_payroll: ${monthly_payroll_cents / 100:,.2f}",
f"- runway: {runway_str}",
f"- employees: {employee_count}",
f"- active_tasks: {active_tasks}",
f"- planned_tasks: {planned_tasks}",
"",
"**Your immediate priority**: generate revenue before payroll drains your runway.",
"You MUST complete these steps now (multiple commands per turn are fine):",
"1. `yc-bench market browse --required-prestige-lte 1` — find tasks you can accept",
"2. `yc-bench task accept --task-id <UUID>` — accept 2-3 suitable tasks",
"3. `yc-bench employee list` — get employee IDs",
"4. `yc-bench task assign --task-id <UUID> --employee-id <UUID>` — assign employees",
"5. `yc-bench task dispatch --task-id <UUID>` — start work on each assigned task",
"6. `yc-bench sim resume` — advance time to collect the first task completion event",
"",
"Do not spend multiple turns just browsing. Accept and dispatch tasks immediately.",
]
if bankrupt:
lines.append("WARNING: company is already bankrupt at initialization.")
return "\n".join(lines)
__all__ = ["SYSTEM_PROMPT", "build_turn_context", "build_initial_user_prompt"]

View file

@ -0,0 +1,102 @@
"""Run state: tracks the progress and terminal status of a benchmark run."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional
class TerminalReason(str, Enum):
BANKRUPTCY = "bankruptcy"
HORIZON_END = "horizon_end"
ERROR = "error"
@dataclass
class TranscriptEntry:
turn: int
timestamp: str
user_input: str
agent_output: str
commands_executed: List[str] = field(default_factory=list)
@dataclass
class RunState:
"""Mutable state for a single benchmark run."""
session_id: str
seed: int
model: str
horizon_years: int
turn_count: int = 0
terminal: bool = False
terminal_reason: Optional[TerminalReason] = None
terminal_detail: Optional[str] = None
started_at: Optional[str] = None
ended_at: Optional[str] = None
transcript: List[TranscriptEntry] = field(default_factory=list)
next_user_input: Optional[str] = None
total_cost_usd: float = 0.0
def start(self) -> None:
self.started_at = datetime.now(timezone.utc).isoformat()
def record_turn(self, user_input: str, agent_output: str, commands_executed: List[str] | None = None, turn_cost_usd: float = 0.0) -> None:
self.turn_count += 1
self.total_cost_usd += turn_cost_usd
self.transcript.append(TranscriptEntry(
turn=self.turn_count,
timestamp=datetime.now(timezone.utc).isoformat(),
user_input=user_input,
agent_output=agent_output,
commands_executed=commands_executed or [],
))
def mark_terminal(self, reason: TerminalReason, detail: str = "") -> None:
self.terminal = True
self.terminal_reason = reason
self.terminal_detail = detail
self.ended_at = datetime.now(timezone.utc).isoformat()
def should_stop(self) -> bool:
if self.terminal:
return True
return False
def full_rollout(self) -> Dict[str, Any]:
"""Full results including transcript for saving to disk."""
return {
"session_id": self.session_id,
"model": self.model,
"seed": self.seed,
"horizon_years": self.horizon_years,
"turns_completed": self.turn_count,
"terminal": self.terminal,
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
"terminal_detail": self.terminal_detail,
"total_cost_usd": round(self.total_cost_usd, 6),
"started_at": self.started_at,
"ended_at": self.ended_at,
"transcript": [
{
"turn": t.turn,
"timestamp": t.timestamp,
"user_input": t.user_input,
"agent_output": t.agent_output,
"commands_executed": t.commands_executed,
}
for t in self.transcript
],
}
def summary(self) -> Dict[str, Any]:
"""Summary without transcript for logging."""
rollout = self.full_rollout()
rollout.pop("transcript", None)
return rollout
__all__ = ["TerminalReason", "TranscriptEntry", "RunState"]

View file

@ -0,0 +1,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
class AgentRuntime(ABC):
@abstractmethod
def run_turn(self, request):
raise NotImplementedError
@abstractmethod
def clear_session(self, session_id):
raise NotImplementedError
__all__ = ["AgentRuntime"]

View file

@ -0,0 +1,13 @@
from __future__ import annotations
from .base import AgentRuntime
from .litellm_runtime import LiteLLMRuntime
from .schemas import RuntimeSettings
def build_runtime(settings, command_executor):
return LiteLLMRuntime(
settings=settings,
command_executor=command_executor,
)
__all__ = ["build_runtime"]

View file

@ -0,0 +1,294 @@
from __future__ import annotations
import json
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
from dataclasses import dataclass, field
import litellm
from .base import AgentRuntime
from .schemas import RuntimeTurnResult
from ..prompt import SYSTEM_PROMPT
from ..tools.run_command_schema import normalize_result
logger = logging.getLogger(__name__)
litellm.suppress_debug_info = True
# Tool schema passed to the LLM on every call
_RUN_COMMAND_TOOL = {
"type": "function",
"function": {
"name": "run_command",
"description": (
"Execute one benchmark CLI command inside the sandbox "
"and return structured execution output."
),
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The full yc-bench CLI command to execute.",
}
},
"required": ["command"],
},
},
}
@dataclass
class _Session:
messages: list = field(default_factory=list)
class LiteLLMRuntime(AgentRuntime):
def __init__(self, settings, command_executor):
self._settings = settings
self._command_executor = command_executor
self._sessions: dict[str, _Session] = {}
self._request_timeout_seconds = settings.request_timeout_seconds
self._retry_max_attempts = settings.retry_max_attempts
self._retry_backoff_seconds = settings.retry_backoff_seconds
if self._request_timeout_seconds <= 0:
raise ValueError("request_timeout_seconds must be > 0")
if self._retry_max_attempts <= 0:
raise ValueError("retry_max_attempts must be > 0")
if self._retry_backoff_seconds <= 0:
raise ValueError("retry_backoff_seconds must be > 0")
# API key: prefer OPENAI_API_KEY, fall back to OPENROUTER_API_KEY.
# For openrouter/ prefixed models LiteLLM also reads OPENROUTER_API_KEY
# automatically, so either approach works.
self._api_key = (
os.environ.get("OPENAI_API_KEY")
or os.environ.get("OPENROUTER_API_KEY")
or None
)
# Base URL: only needed for raw OpenAI-compatible endpoints.
# openrouter/ model prefix is handled natively by LiteLLM without this.
self._api_base = os.environ.get("OPENAI_BASE_URL") or None
self._history_keep_rounds = settings.history_keep_rounds
logger.info(
"LiteLLMRuntime configured: model=%s api_base=%s history_keep_rounds=%d",
self._settings.model,
self._api_base or "default",
self._history_keep_rounds,
)
# ------------------------------------------------------------------
# AgentRuntime interface
# ------------------------------------------------------------------
def run_turn(self, request):
session = self._get_or_create_session(request.session_id)
# Proactively drop old rounds before appending new input.
self._proactive_truncate(session)
session.messages.append({"role": "user", "content": request.user_input})
result = None
last_err = None
for attempt in range(1, self._retry_max_attempts + 1):
try:
result = self._run_with_timeout(session)
break
except Exception as e:
last_err = e
if self._is_context_length_error(e):
logger.warning(
"Context-length error on attempt %d despite proactive truncation "
"(history_keep_rounds=%d). Consider reducing YC_BENCH_HISTORY_KEEP_ROUNDS.",
attempt,
self._history_keep_rounds,
)
continue
logger.warning("Turn attempt %d failed: %s", attempt, e)
if attempt >= self._retry_max_attempts:
raise RuntimeError(
f"Failed to run turn after {self._retry_max_attempts} attempts"
) from last_err
time.sleep(self._retry_backoff_seconds * (2 ** (attempt - 1)))
if result is None:
raise RuntimeError("run_turn failed without result") from last_err
final_output, tool_calls_made, resume_payload, turn_cost = result
return RuntimeTurnResult(
final_output=final_output,
raw_result={"tool_calls": tool_calls_made},
checkpoint_advanced=resume_payload is not None,
resume_payload=resume_payload,
turn_cost_usd=turn_cost,
)
def clear_session(self, session_id):
self._sessions.pop(session_id, None)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _run_with_timeout(self, session):
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(self._do_turn, session)
try:
return future.result(timeout=self._request_timeout_seconds)
except FuturesTimeoutError as exc:
future.cancel()
raise TimeoutError(
f"LiteLLM call timed out after {self._request_timeout_seconds}s"
) from exc
def _do_turn(self, session):
"""One LLM call + tool execution. Returns (final_output, tool_calls_made, resume_payload, cost_usd)."""
system_prompt = self._settings.system_prompt or SYSTEM_PROMPT
messages = [{"role": "system", "content": system_prompt}] + session.messages
kwargs = dict(
model=self._settings.model,
messages=messages,
tools=[_RUN_COMMAND_TOOL],
tool_choice="auto",
temperature=self._settings.temperature,
top_p=self._settings.top_p,
timeout=self._request_timeout_seconds,
)
if self._api_base:
kwargs["api_base"] = self._api_base
if self._api_key:
kwargs["api_key"] = self._api_key
response = litellm.completion(**kwargs)
# Log token usage and cost per call
turn_cost = 0.0
usage = getattr(response, "usage", None)
if usage:
cost = getattr(response, "_hidden_params", {}).get("response_cost") or 0
turn_cost = float(cost)
logger.info(
"LLM call: prompt_tokens=%s completion_tokens=%s cost=$%.6f",
getattr(usage, "prompt_tokens", "?"),
getattr(usage, "completion_tokens", "?"),
turn_cost,
)
message = response.choices[0].message
tool_calls = getattr(message, "tool_calls", None) or []
tool_calls_made = []
resume_payload = None
if tool_calls:
# Persist assistant message with tool calls
session.messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
command = args.get("command", "")
except Exception:
command = ""
raw = self._command_executor(command)
normalized = normalize_result(raw)
tool_result_str = json.dumps(normalized.__dict__)
tool_calls_made.append({"command": command, "result": tool_result_str})
# Extract resume payload when the agent advances simulation time
if command.startswith("yc-bench sim resume"):
try:
stdout = normalized.__dict__.get("stdout", "")
if isinstance(stdout, str) and stdout.strip():
payload = json.loads(stdout)
if isinstance(payload, dict):
resume_payload = payload
except Exception:
pass
session.messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": tool_result_str,
})
cmds = [tc["command"] for tc in tool_calls_made]
final_output = f"Executed {len(tool_calls)} tool call(s): {', '.join(cmds)}"
else:
content = message.content or ""
session.messages.append({"role": "assistant", "content": content})
final_output = content
return final_output, tool_calls_made, resume_payload, turn_cost
def _get_or_create_session(self, session_id: str) -> _Session:
if session_id not in self._sessions:
self._sessions[session_id] = _Session()
return self._sessions[session_id]
def _is_context_length_error(self, err: Exception) -> bool:
text = str(err).lower()
patterns = (
"context length",
"maximum context",
"max context",
"too many tokens",
"prompt is too long",
"token limit",
"context window",
)
return any(p in text for p in patterns)
def _round_start_indices(self, messages: list) -> list[int]:
"""Return indices of user messages — each marks the start of a round."""
return [i for i, m in enumerate(messages) if m.get("role") == "user"]
def _proactive_truncate(self, session: _Session) -> None:
"""Drop oldest rounds before each turn so at most history_keep_rounds remain."""
messages = session.messages
user_indices = self._round_start_indices(messages)
if len(user_indices) <= self._history_keep_rounds:
return
cutoff = user_indices[-self._history_keep_rounds]
marker = {
"role": "user",
"content": (
f"[Earlier turns removed. Only the last {self._history_keep_rounds} "
"turns are retained in this context window.]"
),
}
session.messages = [marker] + messages[cutoff:]
logger.info(
"Proactive truncation: kept last %d rounds (%d%d messages).",
self._history_keep_rounds,
len(messages),
len(session.messages),
)
__all__ = ["LiteLLMRuntime"]

View file

@ -0,0 +1,40 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class RuntimeSettings:
model: str
temperature: float
top_p: float
request_timeout_seconds: float = 300.0
retry_max_attempts: int = 3
retry_backoff_seconds: float = 1.0
history_keep_rounds: int = 20
# Optional system prompt override; None = use default from agent/prompt.py
system_prompt: str | None = None
@dataclass(frozen=True)
class RuntimeTurnRequest:
session_id: str
user_input: str
@dataclass(frozen=True)
class RuntimeTurnResult:
final_output: str
raw_result: Any
checkpoint_advanced: bool = False
resume_payload: dict | None = None
turn_cost_usd: float = 0.0
@dataclass(frozen=True)
class CommandResult:
ok: bool
exit_code: int
stdout: str
stderr: str
sim_time: str | None
command: str
__all__ = ["RuntimeSettings", "RuntimeTurnRequest", "RuntimeTurnResult", "CommandResult"]

View file

@ -0,0 +1,22 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class RunCommandResult:
ok: bool
exit_code: int
stdout: str
stderr: str
sim_time: str | None
command: str
def normalize_result(payload):
return RunCommandResult(
ok=bool(payload.get("ok", False)),
exit_code=int(payload.get("exit_code", 1)),
stdout=str(payload.get("stdout", "")),
stderr=str(payload.get("stderr", "")),
sim_time=payload.get("sim_time"),
command=str(payload.get("command", "")),
)

View file

@ -0,0 +1,110 @@
from __future__ import annotations
import json
import sys
from contextlib import contextmanager
from decimal import Decimal
from uuid import UUID
import typer
from ..db.session import build_engine, build_session_factory, session_scope
app = typer.Typer(name="yc-bench", add_completion=False)
# ---------------------------------------------------------------------------
# Helpers shared across command modules
# ---------------------------------------------------------------------------
@contextmanager
def get_db():
"""Yield a transactional SQLAlchemy session, commit on success."""
engine = build_engine()
factory = build_session_factory(engine)
with session_scope(factory) as session:
yield session
class _JSONEncoder(json.JSONEncoder):
"""Handle UUID, Decimal, datetime serialisation."""
def default(self, o):
if isinstance(o, UUID):
return str(o)
if isinstance(o, Decimal):
return float(o)
from datetime import datetime, date
if isinstance(o, datetime):
return o.isoformat()
if isinstance(o, date):
return o.isoformat()
return super().default(o)
def json_output(data: dict | list) -> None:
"""Print JSON to stdout (captured by run_command executor)."""
typer.echo(json.dumps(data, cls=_JSONEncoder, indent=2))
def error_output(message: str, code: int = 1) -> None:
"""Print JSON error and exit."""
typer.echo(json.dumps({"error": message}), err=False)
raise typer.Exit(code=code)
# ---------------------------------------------------------------------------
# Register sub-command groups
# ---------------------------------------------------------------------------
from .sim_commands import sim_app # noqa: E402
from .company_commands import company_app # noqa: E402
from .market_commands import market_app # noqa: E402
from .task_commands import task_app # noqa: E402
from .finance_commands import finance_app # noqa: E402
from .report_commands import report_app # noqa: E402
from .employee_commands import employee_app # noqa: E402
from .scratchpad_commands import scratchpad_app # noqa: E402
app.add_typer(sim_app, name="sim")
app.add_typer(company_app, name="company")
app.add_typer(employee_app, name="employee")
app.add_typer(market_app, name="market")
app.add_typer(task_app, name="task")
app.add_typer(finance_app, name="finance")
app.add_typer(report_app, name="report")
app.add_typer(scratchpad_app, name="scratchpad")
@app.command("run")
def run_command_cli(
model: str = typer.Option(..., help="LiteLLM model string (e.g. openrouter/z-ai/glm-5)"),
seed: int = typer.Option(..., help="Random seed for deterministic world generation"),
horizon_years: int = typer.Option(3, help="Simulation horizon in years"),
company_name: str = typer.Option("BenchCo", help="Name of the simulated company"),
start_date: str = typer.Option("2025-01-01", help="Simulation start date (YYYY-MM-DD)"),
config_name: str = typer.Option(
"default", "--config",
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
),
):
"""Run a full benchmark: migrate DB, seed world, run agent loop to completion."""
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv(usecwd=True), override=False)
from ..runner.main import run_benchmark
from ..runner.args import RunArgs
args = RunArgs(
model=model,
seed=seed,
horizon_years=horizon_years,
company_name=company_name,
start_date=start_date,
config_name=config_name,
)
raise SystemExit(run_benchmark(args))
def app_main():
"""Entry point for `yc-bench` console_script."""
app()

View file

@ -0,0 +1,102 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from decimal import Decimal
import typer
from sqlalchemy import func
from ..core.business_time import first_business_of_month
from ..db.models.company import Company, CompanyPrestige
from ..db.models.employee import Employee
from ..db.models.sim_state import SimState
from ..db.models.task import Task, TaskStatus
from . import get_db, json_output, error_output
company_app = typer.Typer(help="Company status commands.")
def _next_payroll_date(sim_time: datetime) -> datetime:
"""Compute next first-business-day-of-month at 09:00 from sim_time."""
if sim_time.month == 12:
next_month = sim_time.replace(year=sim_time.year + 1, month=1, day=1)
else:
next_month = sim_time.replace(month=sim_time.month + 1, day=1)
return first_business_of_month(next_month)
@company_app.command("status")
def company_status():
"""Show company status: funds, prestige, tasks, payroll, risk."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
company = db.query(Company).filter(Company.id == sim_state.company_id).one_or_none()
if company is None:
error_output("Company not found.")
# Prestige by domain
prestige_rows = db.query(CompanyPrestige).filter(
CompanyPrestige.company_id == company.id
).all()
prestige_map = {row.domain.value: float(row.prestige_level) for row in prestige_rows}
# Task counts
active_count = db.query(func.count(Task.id)).filter(
Task.company_id == company.id,
Task.status == TaskStatus.ACTIVE,
).scalar() or 0
planned_count = db.query(func.count(Task.id)).filter(
Task.company_id == company.id,
Task.status == TaskStatus.PLANNED,
).scalar() or 0
completed_count = db.query(func.count(Task.id)).filter(
Task.company_id == company.id,
Task.status.in_([TaskStatus.COMPLETED_SUCCESS, TaskStatus.COMPLETED_FAIL]),
).scalar() or 0
cancelled_count = db.query(func.count(Task.id)).filter(
Task.company_id == company.id,
Task.status == TaskStatus.CANCELLED,
).scalar() or 0
# Employee count
employee_count = db.query(func.count(Employee.id)).filter(
Employee.company_id == company.id
).scalar() or 0
# Monthly payroll estimate
total_salary = db.query(func.sum(Employee.salary_cents)).filter(
Employee.company_id == company.id
).scalar() or 0
next_payroll = _next_payroll_date(sim_state.sim_time)
# Bankruptcy risk
months_runway = round(float(company.funds_cents) / float(total_salary), 2) if total_salary > 0 else None
json_output({
"company_id": str(company.id),
"company_name": company.name,
"funds_cents": company.funds_cents,
"prestige": prestige_map,
"sim_time": sim_state.sim_time.isoformat(),
"horizon_end": sim_state.horizon_end.isoformat(),
"tasks": {
"active": active_count,
"planned": planned_count,
"completed": completed_count,
"cancelled": cancelled_count,
},
"employees": employee_count,
"next_payroll": next_payroll.isoformat(),
"monthly_payroll_cents": total_salary,
"risk": {
"months_runway": months_runway,
"bankrupt": company.funds_cents < 0,
},
})

View file

@ -0,0 +1,62 @@
from __future__ import annotations
import typer
from sqlalchemy import func
from ..db.models.employee import Employee, EmployeeSkillRate
from ..db.models.task import Task, TaskAssignment, TaskStatus
from ..db.models.sim_state import SimState
from . import get_db, json_output, error_output
employee_app = typer.Typer(help="Employee management commands.")
@employee_app.command("list")
def employee_list():
"""List all employees with their skills and current assignments."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
employees = db.query(Employee).filter(
Employee.company_id == sim_state.company_id
).all()
results = []
for emp in employees:
# Skills
skills = db.query(EmployeeSkillRate).filter(
EmployeeSkillRate.employee_id == emp.id
).all()
skill_map = {
s.domain.value: float(s.rate_domain_per_hour)
for s in skills
}
# Current active assignments
active_assignments = (
db.query(TaskAssignment.task_id)
.join(Task, Task.id == TaskAssignment.task_id)
.filter(
TaskAssignment.employee_id == emp.id,
Task.status == TaskStatus.ACTIVE,
)
.all()
)
active_task_ids = [str(a.task_id) for a in active_assignments]
results.append({
"employee_id": str(emp.id),
"name": emp.name,
"salary_cents": emp.salary_cents,
"work_hours_per_day": float(emp.work_hours_per_day),
"skills": skill_map,
"active_task_count": len(active_task_ids),
"active_task_ids": active_task_ids,
})
json_output({
"count": len(results),
"employees": results,
})

View file

@ -0,0 +1,75 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Optional
import typer
from sqlalchemy import and_
from ..db.models.ledger import LedgerCategory, LedgerEntry
from ..db.models.sim_state import SimState
from . import get_db, json_output, error_output
finance_app = typer.Typer(help="Finance and ledger commands.")
def _parse_date(date_str: str) -> datetime:
try:
dt = datetime.strptime(date_str, "%m/%d/%Y")
return dt.replace(tzinfo=timezone.utc)
except ValueError:
raise typer.BadParameter(f"Invalid date format: {date_str}. Use MM/DD/YYYY.")
@finance_app.command("ledger")
def finance_ledger(
from_date: Optional[str] = typer.Option(None, "--from", help="Start date MM/DD/YYYY"),
to_date: Optional[str] = typer.Option(None, "--to", help="End date MM/DD/YYYY"),
category: Optional[str] = typer.Option(None, "--category", help="Filter by ledger category"),
):
"""View ledger entries with optional date and category filters."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
query = db.query(LedgerEntry).filter(
LedgerEntry.company_id == sim_state.company_id
)
if from_date is not None:
from_dt = _parse_date(from_date)
query = query.filter(LedgerEntry.occurred_at >= from_dt)
if to_date is not None:
to_dt = _parse_date(to_date)
query = query.filter(LedgerEntry.occurred_at <= to_dt)
if category is not None:
try:
cat = LedgerCategory(category)
except ValueError:
error_output(f"Invalid category: {category}. Valid: {[c.value for c in LedgerCategory]}")
query = query.filter(LedgerEntry.category == cat)
entries = query.order_by(LedgerEntry.occurred_at.asc()).all()
total_amount = sum(e.amount_cents for e in entries)
results = [
{
"id": str(e.id),
"occurred_at": e.occurred_at.isoformat(),
"category": e.category.value,
"amount_cents": e.amount_cents,
"ref_type": e.ref_type,
"ref_id": str(e.ref_id) if e.ref_id else None,
}
for e in entries
]
json_output({
"count": len(results),
"total_amount_cents": total_amount,
"entries": results,
})

View file

@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Optional
import typer
from ..db.models.company import Domain
from ..db.models.task import Task, TaskRequirement, TaskStatus
from ..config import get_world_config
from . import get_db, json_output, error_output
market_app = typer.Typer(help="Market browsing commands.")
@market_app.command("browse")
def market_browse(
domain: Optional[Domain] = typer.Option(None, "--domain", help="Filter by requirement domain"),
required_prestige_lte: Optional[int] = typer.Option(None, "--required-prestige-lte", help="Max required prestige"),
reward_min_cents: Optional[int] = typer.Option(None, "--reward-min-cents", help="Min reward in cents"),
limit: Optional[int] = typer.Option(None, "--limit", help="Max results (default from experiment config)"),
offset: int = typer.Option(0, "--offset", help="Offset for pagination"),
):
"""Browse available tasks on the market."""
if limit is None:
limit = get_world_config().market_browse_default_limit
with get_db() as db:
query = db.query(Task).filter(Task.status == TaskStatus.MARKET)
if required_prestige_lte is not None:
query = query.filter(Task.required_prestige <= required_prestige_lte)
if reward_min_cents is not None:
query = query.filter(Task.reward_funds_cents >= reward_min_cents)
if domain is not None:
# Filter tasks that have a requirement in the given domain
query = query.filter(
Task.id.in_(
db.query(TaskRequirement.task_id).filter(
TaskRequirement.domain == domain
)
)
)
total = query.count()
tasks = query.order_by(Task.reward_funds_cents.desc()).offset(offset).limit(limit).all()
results = []
for task in tasks:
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task.id).all()
requirements = [
{
"domain": r.domain.value,
"required_qty": float(r.required_qty),
}
for r in reqs
]
results.append({
"task_id": str(task.id),
"title": task.title,
"description": task.description,
"required_prestige": task.required_prestige,
"reward_funds_cents": task.reward_funds_cents,
"reward_prestige_delta": float(task.reward_prestige_delta),
"skill_boost_pct": float(task.skill_boost_pct),
"requirements": requirements,
})
json_output({
"total": total,
"offset": offset,
"limit": limit,
"tasks": results,
})

View file

@ -0,0 +1,64 @@
from __future__ import annotations
from datetime import date, datetime
from typing import Optional
import typer
from sqlalchemy import and_
from ..db.models.session import MonthlyMetric
from ..db.models.sim_state import SimState
from . import get_db, json_output, error_output
report_app = typer.Typer(help="Reporting commands.")
def _parse_month(month_str: str) -> date:
"""Parse YYYY-MM into a date (first of month)."""
try:
dt = datetime.strptime(month_str, "%Y-%m")
return dt.date().replace(day=1)
except ValueError:
raise typer.BadParameter(f"Invalid month format: {month_str}. Use YYYY-MM.")
@report_app.command("monthly")
def report_monthly(
from_month: Optional[str] = typer.Option(None, "--from-month", help="Start month YYYY-MM"),
to_month: Optional[str] = typer.Option(None, "--to-month", help="End month YYYY-MM"),
):
"""View monthly metrics (revenue, cost, return, ending funds)."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
query = db.query(MonthlyMetric).filter(
MonthlyMetric.company_id == sim_state.company_id
)
if from_month is not None:
from_dt = _parse_month(from_month)
query = query.filter(MonthlyMetric.month_start >= from_dt)
if to_month is not None:
to_dt = _parse_month(to_month)
query = query.filter(MonthlyMetric.month_start <= to_dt)
metrics = query.order_by(MonthlyMetric.month_start.asc()).all()
results = [
{
"month_start": m.month_start.isoformat(),
"revenue_cents": m.revenue_cents,
"cost_cents": m.cost_cents,
"return_cents": m.return_cents,
"ending_funds_cents": m.ending_funds_cents,
}
for m in metrics
]
json_output({
"count": len(results),
"months": results,
})

View file

@ -0,0 +1,72 @@
from __future__ import annotations
import typer
from ..db.models.scratchpad import Scratchpad
from ..db.models.sim_state import SimState
from . import get_db, json_output, error_output
scratchpad_app = typer.Typer(help="Agent scratchpad for persistent notes.")
def _get_or_create(db, company_id) -> Scratchpad:
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).one_or_none()
if row is None:
row = Scratchpad(company_id=company_id, content="")
db.add(row)
db.flush()
return row
@scratchpad_app.command("read")
def scratchpad_read():
"""Read the current scratchpad content."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
row = _get_or_create(db, sim_state.company_id)
json_output({"content": row.content})
@scratchpad_app.command("write")
def scratchpad_write(
content: str = typer.Option(..., help="Text to write (replaces existing content)."),
):
"""Overwrite the scratchpad with new content."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
row = _get_or_create(db, sim_state.company_id)
row.content = content
json_output({"ok": True, "content": row.content})
@scratchpad_app.command("append")
def scratchpad_append(
content: str = typer.Option(..., help="Text to append to existing content."),
):
"""Append text to the scratchpad (adds a newline separator)."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
row = _get_or_create(db, sim_state.company_id)
if row.content:
row.content = row.content + "\n" + content
else:
row.content = content
json_output({"ok": True, "content": row.content})
@scratchpad_app.command("clear")
def scratchpad_clear():
"""Clear the scratchpad."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
row = _get_or_create(db, sim_state.company_id)
row.content = ""
json_output({"ok": True})

View file

@ -0,0 +1,151 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Optional
import typer
from ..core.engine import advance_time
from ..core.events import fetch_next_event, insert_event
from ..db.models.company import Company
from ..db.models.event import EventType
from ..db.models.sim_state import SimState
from ..config import get_world_config
from ..services.seed_world import SeedWorldRequest, seed_world_transactional
from . import get_db, json_output, error_output
sim_app = typer.Typer(help="Simulation initialization commands.")
def _parse_date(date_str: str) -> datetime:
"""Parse MM/DD/YYYY into a timezone-aware datetime at 09:00 UTC."""
try:
dt = datetime.strptime(date_str, "%m/%d/%Y")
return dt.replace(hour=9, minute=0, second=0, tzinfo=timezone.utc)
except ValueError:
raise typer.BadParameter(f"Invalid date format: {date_str}. Use MM/DD/YYYY.")
@sim_app.command("init")
def sim_init(
seed: int = typer.Option(..., help="RNG seed for deterministic generation"),
start_date: str = typer.Option(..., "--start-date", help="Start date MM/DD/YYYY"),
horizon_years: int = typer.Option(3, "--horizon-years", help="Simulation horizon in years"),
company_name: str = typer.Option(..., "--company-name", help="Company name"),
employee_count: Optional[int] = typer.Option(None, "--employee-count", help="Number of employees (default from experiment config)"),
market_task_count: Optional[int] = typer.Option(None, "--market-task-count", help="Number of market tasks (default from experiment config)"),
):
"""Initialize a new simulation: seed world, create company, schedule horizon."""
_wc = get_world_config()
if employee_count is None:
employee_count = _wc.num_employees
if market_task_count is None:
market_task_count = _wc.num_market_tasks
start_dt = _parse_date(start_date)
horizon_end = start_dt.replace(year=start_dt.year + horizon_years)
with get_db() as db:
# Check if a simulation already exists
existing = db.query(SimState).first()
if existing is not None:
error_output("A simulation already exists. Only one simulation per database is supported.")
req = SeedWorldRequest(
run_seed=seed,
company_name=company_name,
horizon_years=horizon_years,
employee_count=employee_count,
market_task_count=market_task_count,
start_date=start_dt,
)
result = seed_world_transactional(db, req)
# Schedule horizon_end event (deterministic id via insert_event helper)
insert_event(
db=db,
company_id=result.company_id,
event_type=EventType.HORIZON_END,
scheduled_at=horizon_end,
payload={"reason": "horizon_end"},
dedupe_key="horizon_end",
)
# Create sim_state row
sim_state = SimState(
company_id=result.company_id,
sim_time=start_dt,
run_seed=seed,
horizon_end=horizon_end,
replenish_counter=0,
)
db.add(sim_state)
db.flush()
json_output({
"simulation_id": str(result.company_id),
"company_id": str(result.company_id),
"sim_time": start_dt.isoformat(),
"horizon_end": horizon_end.isoformat(),
"company_name": company_name,
"seed": seed,
})
@sim_app.command("resume")
def sim_resume():
"""Advance simulation to the next queued event checkpoint and return wake results."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
company = db.query(Company).filter(Company.id == sim_state.company_id).one()
next_event = fetch_next_event(
db=db,
company_id=sim_state.company_id,
up_to=sim_state.horizon_end,
)
if next_event is None:
terminal_reason = None
bankrupt = company.funds_cents < 0
horizon_reached = sim_state.sim_time >= sim_state.horizon_end
if bankrupt:
terminal_reason = "bankruptcy"
elif horizon_reached:
terminal_reason = "horizon_end"
json_output({
"ok": True,
"message": "no_pending_events",
"old_sim_time": sim_state.sim_time.isoformat(),
"new_sim_time": sim_state.sim_time.isoformat(),
"events_processed": 0,
"payrolls_applied": 0,
"balance_delta": 0,
"wake_events": [],
"bankrupt": bankrupt,
"horizon_reached": horizon_reached,
"terminal_reason": terminal_reason,
})
return
checkpoint_type = next_event.event_type.value
result = advance_time(
db=db,
company_id=sim_state.company_id,
target_time=next_event.scheduled_at,
)
terminal_reason = None
if result.bankrupt:
terminal_reason = "bankruptcy"
elif result.horizon_reached:
terminal_reason = "horizon_end"
payload = result.__dict__.copy()
payload["ok"] = True
payload["checkpoint_event_type"] = checkpoint_type
payload["terminal_reason"] = terminal_reason
json_output(payload)

View file

@ -0,0 +1,460 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from typing import Optional
from uuid import UUID, uuid4
import typer
from sqlalchemy import func
from ..core.business_time import add_business_hours
from ..core.eta import recalculate_etas
from ..db.models.company import Company, CompanyPrestige
from ..db.models.employee import Employee
from ..db.models.event import SimEvent
from ..db.models.sim_state import SimState
from ..db.models.task import Task, TaskAssignment, TaskRequirement, TaskStatus
from ..services.generate_tasks import generate_replacement_task
from . import get_db, json_output, error_output
def _get_world_cfg():
"""Load WorldConfig from the active experiment (YC_BENCH_EXPERIMENT env var)."""
from yc_bench.config import get_world_config
return get_world_config()
task_app = typer.Typer(help="Task management commands.")
def _compute_deadline(accepted_at: datetime, total_required_qty: float, cfg=None) -> datetime:
"""Deadline: cfg.deadline_qty_per_day units per business day, minimum cfg.deadline_min_biz_days days."""
if cfg is None:
cfg = _get_world_cfg()
work_hours = cfg.workday_end_hour - cfg.workday_start_hour
biz_days = max(cfg.deadline_min_biz_days, int(total_required_qty / cfg.deadline_qty_per_day))
return add_business_hours(accepted_at, Decimal(str(biz_days)) * Decimal(str(work_hours)))
@task_app.command("accept")
def task_accept(
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to accept"),
):
"""Accept a market task: transition to planned, assign to company, generate replacement."""
try:
tid = UUID(task_id)
except ValueError:
error_output(f"Invalid UUID: {task_id}")
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found. Run `yc-bench sim init` first.")
task = db.query(Task).filter(Task.id == tid).one_or_none()
if task is None:
error_output(f"Task {task_id} not found.")
if task.status != TaskStatus.MARKET:
error_output(f"Task {task_id} is not in market status (current: {task.status.value}).")
# Validate prestige requirement
company_id = sim_state.company_id
prestige_rows = db.query(CompanyPrestige).filter(
CompanyPrestige.company_id == company_id
).all()
max_prestige = max((float(p.prestige_level) for p in prestige_rows), default=1.0)
if task.required_prestige > max_prestige:
error_output(
f"Company max prestige ({max_prestige}) does not meet task requirement ({task.required_prestige})."
)
# Compute deadline
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == tid).all()
total_qty = sum(float(r.required_qty) for r in reqs)
accepted_at = sim_state.sim_time
deadline = _compute_deadline(accepted_at, total_qty)
# Transition task
task.status = TaskStatus.PLANNED
task.company_id = company_id
task.accepted_at = accepted_at
task.deadline = deadline
# Generate replacement task
counter = sim_state.replenish_counter
sim_state.replenish_counter = counter + 1
replacement = generate_replacement_task(
run_seed=sim_state.run_seed,
replenish_counter=counter,
cfg=_get_world_cfg(),
)
replacement_row = Task(
id=uuid4(),
company_id=None,
status=TaskStatus.MARKET,
title=replacement.title,
description=replacement.description,
required_prestige=replacement.required_prestige,
reward_funds_cents=replacement.reward_funds_cents,
reward_prestige_delta=replacement.reward_prestige_delta,
skill_boost_pct=replacement.skill_boost_pct,
accepted_at=None,
deadline=None,
completed_at=None,
success=None,
halfway_event_emitted=False,
)
db.add(replacement_row)
for domain, qty in replacement.requirements.items():
db.add(TaskRequirement(
task_id=replacement_row.id,
domain=domain,
required_qty=qty,
completed_qty=0,
))
db.flush()
json_output({
"task_id": str(task.id),
"status": task.status.value,
"accepted_at": accepted_at.isoformat(),
"deadline": deadline.isoformat(),
"replacement_task_id": str(replacement_row.id),
})
@task_app.command("assign")
def task_assign(
task_id: str = typer.Option(..., "--task-id", help="UUID of the task"),
employee_id: str = typer.Option(..., "--employee-id", help="UUID of the employee"),
):
"""Assign an employee to a task."""
try:
tid = UUID(task_id)
eid = UUID(employee_id)
except ValueError:
error_output("Invalid UUID provided.")
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
task = db.query(Task).filter(Task.id == tid).one_or_none()
if task is None:
error_output(f"Task {task_id} not found.")
if task.status not in (TaskStatus.PLANNED, TaskStatus.ACTIVE):
error_output(f"Task {task_id} must be planned or active to assign (current: {task.status.value}).")
if task.company_id != sim_state.company_id:
error_output(f"Task {task_id} does not belong to your company.")
employee = db.query(Employee).filter(Employee.id == eid).one_or_none()
if employee is None:
error_output(f"Employee {employee_id} not found.")
if employee.company_id != sim_state.company_id:
error_output(f"Employee {employee_id} does not belong to your company.")
# Check if already assigned
existing = db.query(TaskAssignment).filter(
TaskAssignment.task_id == tid,
TaskAssignment.employee_id == eid,
).one_or_none()
if existing is not None:
error_output(f"Employee {employee_id} is already assigned to task {task_id}.")
assignment = TaskAssignment(
task_id=tid,
employee_id=eid,
assigned_at=sim_state.sim_time,
)
db.add(assignment)
db.flush()
# Recalculate ETAs for all active tasks sharing this employee
if task.status == TaskStatus.ACTIVE:
emp_assignments = db.query(TaskAssignment).filter(
TaskAssignment.employee_id == eid
).all()
impacted = set()
for ea in emp_assignments:
t = db.query(Task).filter(Task.id == ea.task_id).one_or_none()
if t and t.status == TaskStatus.ACTIVE:
impacted.add(t.id)
if impacted:
recalculate_etas(db, sim_state.company_id, sim_state.sim_time, impacted, half_threshold=_get_world_cfg().task_half_threshold)
# Return current assignment list
assignments = db.query(TaskAssignment).filter(TaskAssignment.task_id == tid).all()
assignment_list = [
{
"employee_id": str(a.employee_id),
"assigned_at": a.assigned_at.isoformat(),
}
for a in assignments
]
json_output({
"task_id": str(task.id),
"status": task.status.value,
"assignments": assignment_list,
})
@task_app.command("dispatch")
def task_dispatch(
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to dispatch"),
):
"""Dispatch a planned task to active status."""
try:
tid = UUID(task_id)
except ValueError:
error_output(f"Invalid UUID: {task_id}")
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
task = db.query(Task).filter(Task.id == tid).one_or_none()
if task is None:
error_output(f"Task {task_id} not found.")
if task.status != TaskStatus.PLANNED:
error_output(f"Task {task_id} must be planned to dispatch (current: {task.status.value}).")
if task.company_id != sim_state.company_id:
error_output(f"Task {task_id} does not belong to your company.")
# Validate at least one assignment
assignment_count = db.query(func.count(TaskAssignment.employee_id)).filter(
TaskAssignment.task_id == tid
).scalar() or 0
if assignment_count == 0:
error_output(f"Task {task_id} has no assignments. Assign employees before dispatching.")
# Transition to active
task.status = TaskStatus.ACTIVE
db.flush()
# Recalculate ETAs for this task and other active tasks that share assigned employees.
impacted = {tid}
assigned = db.query(TaskAssignment).filter(TaskAssignment.task_id == tid).all()
for a in assigned:
peer_assignments = db.query(TaskAssignment).filter(
TaskAssignment.employee_id == a.employee_id
).all()
for pa in peer_assignments:
if pa.task_id == tid:
continue
peer_task = db.query(Task).filter(Task.id == pa.task_id).one_or_none()
if peer_task and peer_task.status == TaskStatus.ACTIVE:
impacted.add(peer_task.id)
recalculate_etas(db, sim_state.company_id, sim_state.sim_time, impacted, half_threshold=_get_world_cfg().task_half_threshold)
json_output({
"task_id": str(task.id),
"status": task.status.value,
"assignment_count": assignment_count,
})
@task_app.command("list")
def task_list(
status: Optional[str] = typer.Option(None, "--status", help="Filter by task status"),
):
"""List tasks owned by the company."""
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
query = db.query(Task).filter(Task.company_id == sim_state.company_id)
if status is not None:
try:
ts = TaskStatus(status)
except ValueError:
error_output(f"Invalid status: {status}. Valid: {[s.value for s in TaskStatus]}")
query = query.filter(Task.status == ts)
tasks = query.order_by(Task.accepted_at.desc().nulls_last()).all()
results = []
for task in tasks:
# Compute progress %
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task.id).all()
total_required = sum(float(r.required_qty) for r in reqs)
total_completed = sum(float(r.completed_qty) for r in reqs)
progress_pct = (total_completed / total_required * 100) if total_required > 0 else 0.0
# Deadline risk
at_risk = False
if task.deadline and task.status == TaskStatus.ACTIVE:
if sim_state.sim_time > task.deadline:
at_risk = True
results.append({
"task_id": str(task.id),
"title": task.title,
"status": task.status.value,
"progress_pct": round(progress_pct, 2),
"deadline": task.deadline.isoformat() if task.deadline else None,
"at_risk": at_risk,
})
json_output({
"count": len(results),
"tasks": results,
})
@task_app.command("inspect")
def task_inspect(
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to inspect"),
):
"""Inspect detailed task information."""
try:
tid = UUID(task_id)
except ValueError:
error_output(f"Invalid UUID: {task_id}")
with get_db() as db:
task = db.query(Task).filter(Task.id == tid).one_or_none()
if task is None:
error_output(f"Task {task_id} not found.")
# Requirements
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == tid).all()
requirements = []
for r in reqs:
requirements.append({
"domain": r.domain.value,
"required_qty": float(r.required_qty),
"completed_qty": float(r.completed_qty),
"remaining_qty": float(r.required_qty - r.completed_qty),
})
# Assignments with employee info
assignments_raw = db.query(TaskAssignment).filter(TaskAssignment.task_id == tid).all()
assignments = []
for a in assignments_raw:
emp = db.query(Employee).filter(Employee.id == a.employee_id).one_or_none()
assignments.append({
"employee_id": str(a.employee_id),
"employee_name": emp.name if emp else "unknown",
"assigned_at": a.assigned_at.isoformat(),
})
total_required = sum(float(r.required_qty) for r in reqs)
total_completed = sum(float(r.completed_qty) for r in reqs)
progress_pct = (total_completed / total_required * 100) if total_required > 0 else 0.0
json_output({
"task_id": str(task.id),
"title": task.title,
"description": task.description,
"status": task.status.value,
"required_prestige": task.required_prestige,
"reward_funds_cents": task.reward_funds_cents,
"reward_prestige_delta": float(task.reward_prestige_delta),
"skill_boost_pct": float(task.skill_boost_pct),
"accepted_at": task.accepted_at.isoformat() if task.accepted_at else None,
"deadline": task.deadline.isoformat() if task.deadline else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"success": task.success,
"progress_pct": round(progress_pct, 2),
"requirements": requirements,
"assignments": assignments,
})
@task_app.command("cancel")
def task_cancel(
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to cancel"),
reason: str = typer.Option(..., "--reason", help="Reason for cancellation"),
):
"""Cancel a task and apply prestige penalty."""
try:
tid = UUID(task_id)
except ValueError:
error_output(f"Invalid UUID: {task_id}")
with get_db() as db:
sim_state = db.query(SimState).first()
if sim_state is None:
error_output("No simulation found.")
task = db.query(Task).filter(Task.id == tid).one_or_none()
if task is None:
error_output(f"Task {task_id} not found.")
if task.status not in (TaskStatus.PLANNED, TaskStatus.ACTIVE):
error_output(f"Task {task_id} cannot be cancelled (current: {task.status.value}).")
if task.company_id != sim_state.company_id:
error_output(f"Task {task_id} does not belong to your company.")
# Apply prestige penalty: penalty_cancel_multiplier * reward_prestige_delta across task's required domains
_cfg = _get_world_cfg()
cancel_penalty = Decimal(str(_cfg.penalty_cancel_multiplier)) * task.reward_prestige_delta
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == tid).all()
penalties_applied = {}
for req in reqs:
prestige = db.query(CompanyPrestige).filter(
CompanyPrestige.company_id == sim_state.company_id,
CompanyPrestige.domain == req.domain,
).one_or_none()
if prestige is not None:
old_val = prestige.prestige_level
new_val = max(Decimal(str(_cfg.prestige_min)), prestige.prestige_level - cancel_penalty)
prestige.prestige_level = new_val
penalties_applied[req.domain.value] = {
"old": float(old_val),
"new": float(new_val),
"delta": float(old_val - new_val),
}
# Set status to cancelled
task.status = TaskStatus.CANCELLED
# Drop pending events for this task
pending_events = db.query(SimEvent).filter(
SimEvent.company_id == sim_state.company_id,
SimEvent.consumed == False,
SimEvent.payload["task_id"].astext == str(tid),
).all()
for ev in pending_events:
ev.consumed = True
# Recalculate ETAs for tasks sharing freed employees
cancelled_assignments = db.query(TaskAssignment).filter(
TaskAssignment.task_id == tid
).all()
freed_emp_ids = {a.employee_id for a in cancelled_assignments}
impacted = set()
for emp_id in freed_emp_ids:
emp_assignments = db.query(TaskAssignment).filter(
TaskAssignment.employee_id == emp_id
).all()
for ea in emp_assignments:
if ea.task_id != tid:
t = db.query(Task).filter(Task.id == ea.task_id).one_or_none()
if t and t.status == TaskStatus.ACTIVE:
impacted.add(t.id)
if impacted:
recalculate_etas(db, sim_state.company_id, sim_state.sim_time, impacted, half_threshold=_get_world_cfg().task_half_threshold)
# Bankruptcy check
company = db.query(Company).filter(Company.id == sim_state.company_id).one()
bankrupt = company.funds_cents < 0
db.flush()
json_output({
"task_id": str(task.id),
"status": task.status.value,
"reason": reason,
"cancel_penalty_per_domain": float(cancel_penalty),
"prestige_changes": penalties_applied,
"bankrupt": bankrupt,
})

44
src/yc_bench/config.py Normal file
View file

@ -0,0 +1,44 @@
from __future__ import annotations
from dataclasses import dataclass
import os
from .agent.runtime.schemas import RuntimeSettings
@dataclass(frozen=True)
class RuntimeConfig:
model: str
temperature: float
top_p: float
history_keep_rounds: int = 20
def load_runtime_config():
model = os.getenv("YC_BENCH_MODEL", "gpt-5.2")
temperature = float(os.getenv("YC_BENCH_TEMPERATURE", "0"))
top_p = float(os.getenv("YC_BENCH_TOP_P", "1"))
history_keep_rounds = int(os.getenv("YC_BENCH_HISTORY_KEEP_ROUNDS", "20"))
if temperature < 0:
raise ValueError("YC_BENCH_TEMPERATURE must be >= 0")
if top_p <= 0 or top_p > 1:
raise ValueError("YC_BENCH_TOP_P must be in (0, 1]")
if history_keep_rounds < 1:
raise ValueError("YC_BENCH_HISTORY_KEEP_ROUNDS must be >= 1")
return RuntimeConfig(
model=model,
temperature=temperature,
top_p=top_p,
history_keep_rounds=history_keep_rounds,
)
def to_runtime_settings(cfg):
return RuntimeSettings(
model=cfg.model,
temperature=cfg.temperature,
top_p=cfg.top_p,
history_keep_rounds=cfg.history_keep_rounds,
)
__all__ = ["RuntimeConfig", "load_runtime_config", "to_runtime_settings"]

View file

@ -0,0 +1,27 @@
import os
from .loader import load_config
from .schema import ExperimentConfig, AgentConfig, LoopConfig, SimConfig, WorldConfig, SalaryTierConfig
def get_world_config() -> WorldConfig:
"""Load WorldConfig from the active experiment (YC_BENCH_EXPERIMENT env var, default: 'default').
Falls back to default WorldConfig if config loading fails (e.g. outside a benchmark run).
"""
try:
return load_config(os.environ.get("YC_BENCH_EXPERIMENT", "default")).world
except Exception:
return WorldConfig()
__all__ = [
"load_config",
"get_world_config",
"ExperimentConfig",
"AgentConfig",
"LoopConfig",
"SimConfig",
"WorldConfig",
"SalaryTierConfig",
]

View file

@ -0,0 +1,124 @@
"""Load ExperimentConfig from a built-in preset name or a user TOML file path.
Usage::
from yc_bench.config import load_config
cfg = load_config("default") # built-in preset
cfg = load_config("fast_test") # built-in preset
cfg = load_config("./my_run.toml") # user file, absolute or relative path
cfg = load_config("/abs/path/run.toml") # user file, absolute path
User TOML files can inherit from a preset with ``extends``::
# my_run.toml
extends = "default"
[agent]
model = "openrouter/anthropic/claude-3.5-sonnet"
[world]
num_employees = 15
Environment variable overrides (applied last, highest priority)::
YC_BENCH_MODEL override agent.model
YC_BENCH_TEMPERATURE override agent.temperature
YC_BENCH_TOP_P override agent.top_p
YC_BENCH_HISTORY_KEEP_ROUNDS override agent.history_keep_rounds
YC_BENCH_AUTO_ADVANCE_TURNS override loop.auto_advance_after_turns
"""
from __future__ import annotations
import os
import tomllib
from importlib.resources import files
from pathlib import Path
from .schema import ExperimentConfig
def load_config(path_or_name: str = "default") -> ExperimentConfig:
"""Load a named preset or a TOML file path, apply env-var overrides."""
raw = _read_raw(path_or_name)
# Handle `extends = "preset_name"` inheritance
if "extends" in raw:
base_name = raw.pop("extends")
base_raw = _read_preset(base_name)
raw = _deep_merge(base_raw, raw)
cfg = ExperimentConfig.model_validate(raw)
return _apply_env_overrides(cfg)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _read_raw(path_or_name: str) -> dict:
p = Path(path_or_name)
# Treat as a file path if it has a .toml suffix or contains a path separator
if p.suffix == ".toml" or os.sep in path_or_name or "/" in path_or_name:
if not p.exists():
raise FileNotFoundError(f"Config file not found: {p.resolve()}")
return _read_file(p)
return _read_preset(path_or_name)
def _read_preset(name: str) -> dict:
try:
resource = files("yc_bench.config.presets") / f"{name}.toml"
with resource.open("rb") as f:
return tomllib.load(f)
except (FileNotFoundError, TypeError):
available = [
p.name.replace(".toml", "")
for p in Path(__file__).parent.joinpath("presets").glob("*.toml")
]
raise ValueError(
f"Unknown preset '{name}'. Available presets: {available}. "
"Pass a file path ending in .toml for a custom config."
)
def _read_file(path: Path) -> dict:
with open(path, "rb") as f:
return tomllib.load(f)
def _deep_merge(base: dict, override: dict) -> dict:
"""Recursively merge override on top of base. Override wins on conflicts."""
result = dict(base)
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
def _apply_env_overrides(cfg: ExperimentConfig) -> ExperimentConfig:
"""YC_BENCH_* env vars always win, regardless of config file."""
agent_updates: dict = {}
loop_updates: dict = {}
if v := os.environ.get("YC_BENCH_MODEL"):
agent_updates["model"] = v
if v := os.environ.get("YC_BENCH_TEMPERATURE"):
agent_updates["temperature"] = float(v)
if v := os.environ.get("YC_BENCH_TOP_P"):
agent_updates["top_p"] = float(v)
if v := os.environ.get("YC_BENCH_HISTORY_KEEP_ROUNDS"):
agent_updates["history_keep_rounds"] = int(v)
if v := os.environ.get("YC_BENCH_AUTO_ADVANCE_TURNS"):
loop_updates["auto_advance_after_turns"] = int(v)
if agent_updates:
cfg = cfg.model_copy(update={"agent": cfg.agent.model_copy(update=agent_updates)})
if loop_updates:
cfg = cfg.model_copy(update={"loop": cfg.loop.model_copy(update=loop_updates)})
return cfg
__all__ = ["load_config"]

View file

View file

@ -0,0 +1,76 @@
# challenge — 3-year benchmark designed for differentiated model behavior
#
# Design goals:
# - Models that focus employees on tasks complete on time, build prestige,
# and can reach the 3-year horizon.
# - Models that spread employees or forget to dispatch go bankrupt.
# - Wide outcome variance across models: some die early, some survive.
#
# Economics (seed=1, 5 employees):
# Monthly payroll ≈ $32K. Starting runway ≈ 7.8 months.
# Mode task: 2 domains × 700 units = 1400 total.
# Deadline: max(7, 1400/200) = 7 business days (minimum).
# All 5 focused on one task, split ~2+3: domain1 3×5.1×9=137.7 units/day →
# 700/137.7 = 5.1 days; domain2 2×5.1×9=91.8 → 700/91.8 = 7.6 days > 7 FAIL.
# → Need ≥3 employees per domain: put 3 on domain1 and 3 on domain2 (count ×2)
# or accept single-domain tasks where all 5 focus.
# When done right: complete 3 tasks/month × avg $30K = $90K > $32K payroll ✓
# Late game: prestige-5 tasks pay $30K × 3.2× = $96K each → strong growth.
extends = "default"
name = "challenge"
description = "3-year benchmark calibrated for differentiated model behavior. ~7.8 months starting runway. Strategy and focus determine survival."
# ---------------------------------------------------------------------------
# Simulation
# ---------------------------------------------------------------------------
[sim]
horizon_years = 3
# ---------------------------------------------------------------------------
# Agent loop
# ---------------------------------------------------------------------------
[loop]
# Force a sim advance if the agent goes 5 consecutive turns without one.
auto_advance_after_turns = 5
# 500 turns covers 3 years for an efficient agent:
# advancing every 3 turns → 167 advances × ~14 avg sim days = 2338 days = 6.4 yrs.
# forced advances only → 100 advances × ~14 days = 1400 days = 3.8 yrs.
max_turns = 500
# ---------------------------------------------------------------------------
# World — 5 employees keeps payroll manageable (~$32K/month).
# 200 market tasks gives a deep enough pool to find tasks in any prestige tier.
# ---------------------------------------------------------------------------
[world]
num_employees = 5
num_market_tasks = 200
deadline_qty_per_day = 200.0
# ---------------------------------------------------------------------------
# Task size: mode=700 (smaller than hardened default 1400) so tasks complete
# in ~5-10 sim days when employees are focused. This creates frequent reward
# events visible in the funds curve, and makes payroll crises recoverable.
# ---------------------------------------------------------------------------
[world.dist.required_qty]
type = "triangular"
low = 200
high = 2000
mode = 700
# ---------------------------------------------------------------------------
# Prestige: mode=3 ensures a healthy supply of prestige-1 and prestige-2 tasks
# at the start. Agents that specialise unlock 48 tasks paying 25× more.
# ---------------------------------------------------------------------------
[world.dist.required_prestige]
type = "triangular"
low = 1
high = 8
mode = 3

View file

@ -0,0 +1,180 @@
# YC-Bench default preset — hardened 3-year benchmark
#
# This is the canonical benchmark configuration. Every parameter is listed
# explicitly so this file serves as complete documentation of what is tunable.
# Copy this file and override only what you need.
# experiment identity (top-level fields)
name = "default"
description = "Hardened 3-year benchmark. Most tasks require prestige 3-5, 2-domain work, tight deadlines, costly cancellations."
# ---------------------------------------------------------------------------
# LLM agent
# ---------------------------------------------------------------------------
[agent]
model = "openrouter/z-ai/glm-5"
temperature = 0.0
top_p = 1.0
request_timeout_seconds = 300.0
retry_max_attempts = 3
retry_backoff_seconds = 1.0
# Conversation rounds retained in context before each API call.
# Older rounds are dropped proactively. Scratchpad survives truncation.
history_keep_rounds = 20
# Optional system prompt override. Omit to use the built-in prompt.
# system_prompt = "..."
# ---------------------------------------------------------------------------
# Agent loop
# ---------------------------------------------------------------------------
[loop]
# Consecutive turns without `sim resume` before the loop forces a time-advance.
auto_advance_after_turns = 10
# Hard cap on total turns. Comment out or set to a large number for unlimited.
# max_turns = 500
# ---------------------------------------------------------------------------
# Simulation
# ---------------------------------------------------------------------------
[sim]
start_date = "2025-01-01" # ISO 8601
horizon_years = 3
company_name = "BenchCo"
# ---------------------------------------------------------------------------
# World generation — scalar parameters
# ---------------------------------------------------------------------------
[world]
num_employees = 10
initial_funds_cents = 25_000_000 # $250,000
initial_prestige_level = 1.0
work_hours_per_day = 9.0
num_market_tasks = 500
market_browse_default_limit = 50
# Salary bump per completed task — each assigned employee gets this raise.
# Compounds over time, accelerating payroll pressure.
salary_bump_pct = 0.01
# --- Prestige mechanics ---
prestige_max = 10.0
prestige_min = 1.0
# Cancelling (2x) is worse than failing (1.4x) at the margin — accepting
# is a real commitment, not a free option.
penalty_fail_multiplier = 1.4 # hardened: was 0.8
penalty_cancel_multiplier = 2.0 # hardened: was 1.2
# Extra reward fraction per prestige level above 1.
# At 0.55: a prestige-8 task pays ~4.85x more than a prestige-1 task.
reward_prestige_scale = 0.55 # hardened: was 0.3
# --- Deadline ---
# 320 units/day is ~60% tighter than the base 200 — need real employee-hours.
deadline_qty_per_day = 320.0 # hardened: was 200.0
deadline_min_biz_days = 7
# --- Progress milestone ---
task_half_threshold = 0.5
# --- Business hours ---
workday_start_hour = 9
workday_end_hour = 18
# ---------------------------------------------------------------------------
# World distributions
#
# Each [world.dist.<name>] section specifies the distribution used for one
# random quantity. Change `type` to switch families; tune the parameters to
# reshape the distribution.
#
# Available types: triangular, beta, normal, uniform, constant
# ---------------------------------------------------------------------------
# Prestige level required to accept a task (cast to int after sampling).
# mode=4 means most tasks need prestige 3-5 — the early market is thin,
# forcing agents to choose which domains to specialise in.
[world.dist.required_prestige]
type = "triangular"
low = 1
high = 10
mode = 4 # hardened: base default is mode=1
# Base reward paid on task completion, in cents (scaled further by prestige).
# Higher-prestige tasks automatically pay more via reward_prestige_scale.
[world.dist.reward_funds_cents]
type = "triangular"
low = 500_000 # $5,000
high = 10_000_000 # $100,000
mode = 3_000_000 # $30,000
# Number of domains each task requires work in (cast to int after sampling).
# mode=2: most tasks need 2 domains — single-specialist dominance gone.
[world.dist.domain_count]
type = "triangular"
low = 1
high = 3
mode = 2 # hardened: base default is mode=1
# Work units required per domain per task (cast to int after sampling).
# No trivially-small tasks: every task requires sustained employee-hours.
[world.dist.required_qty]
type = "triangular"
low = 500 # hardened: base default is 200
high = 3000
mode = 1400 # hardened: base default is 800
# Prestige delta awarded per required domain on task success.
# Left-skewed beta: most completions give small gains; occasional large jumps.
[world.dist.reward_prestige_delta]
type = "beta"
alpha = 1.2
beta = 2.8
scale = 2.0
low = 0.0
high = 2.0
# Skill rate boost applied to each assigned employee on task success.
# Expressed as a fraction of current rate: 0.12 → average 12% boost.
[world.dist.skill_boost]
type = "normal"
mean = 0.12
stdev = 0.06
low = 0.01
high = 0.40
# ---------------------------------------------------------------------------
# Salary tiers — all three share values must sum to exactly 1.0
# ---------------------------------------------------------------------------
[world.salary_junior]
name = "junior"
share = 0.50
min_cents = 200_000 # $2,000/month
max_cents = 400_000 # $4,000/month
rate_min = 1.0 # units/hour
rate_max = 6.5
[world.salary_mid]
name = "mid"
share = 0.35
min_cents = 600_000 # $6,000/month
max_cents = 800_000 # $8,000/month
rate_min = 3.5
rate_max = 8.5
[world.salary_senior]
name = "senior"
share = 0.15
min_cents = 1_000_000 # $10,000/month
max_cents = 1_500_000 # $15,000/month
rate_min = 5.5
rate_max = 10.0

View file

@ -0,0 +1,28 @@
# fast_test — quick 1-year run for iteration and debugging
#
# Changes from default:
# - 1-year horizon instead of 3
# - 5 employees, 100 market tasks (faster seeding)
# - auto-advance after 5 idle turns instead of 10
# - 50-turn hard cap (prevents runaway debug runs)
extends = "default"
name = "fast_test"
description = "1-year horizon with minimal world for fast iteration."
[sim]
horizon_years = 1
[loop]
auto_advance_after_turns = 5
max_turns = 50
[world]
num_employees = 5
num_market_tasks = 100
# With 5 employees (~220 units/day team throughput), the hardened default of
# 320 is physically impossible for most tasks. 200 restores parity so a
# focused agent can actually win tasks on time.
deadline_qty_per_day = 200

View file

@ -0,0 +1,28 @@
# high_reward — inflated reward pool to test prestige ceiling behaviour
#
# Changes from default:
# - Task reward funds 2x higher (mode: $60k, max: $200k)
# - Higher prestige delta rewards (up to 3.0 instead of 2.0)
# - Larger market pool (1000 tasks)
extends = "default"
name = "high_reward"
description = "2x reward scaling to test prestige ceiling and late-game behaviour."
[world]
num_market_tasks = 1000
[world.dist.reward_funds_cents]
type = "triangular"
low = 1_000_000 # $10,000
high = 20_000_000 # $200,000
mode = 6_000_000 # $60,000
[world.dist.reward_prestige_delta]
type = "beta"
alpha = 1.2
beta = 2.8
scale = 2.0
low = 0.0
high = 3.0 # raised ceiling: up to 3.0 instead of 2.0

View file

@ -0,0 +1,104 @@
"""Distribution specs and sampling.
Every random draw in world generation is described by a DistSpec. Callers
pass the spec from WorldConfig changing distribution families or parameters
only requires a TOML edit, no code change.
Supported families
------------------
triangular low, high, mode random.triangular
beta alpha, beta, scale, low, high scale × Beta(α,β), clamped
normal mean, stdev, low, high gauss, clamped
uniform low, high random.uniform
constant value always returns value (useful for ablations)
"""
from __future__ import annotations
import random
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Distribution spec models (one per family)
# ---------------------------------------------------------------------------
class TriangularDist(BaseModel):
type: Literal["triangular"] = "triangular"
low: float
high: float
mode: float
class BetaDist(BaseModel):
type: Literal["beta"] = "beta"
alpha: float
beta: float
scale: float = 1.0
low: float = 0.0
high: float = 1.0
class NormalDist(BaseModel):
type: Literal["normal"] = "normal"
mean: float
stdev: float
low: float
high: float
class UniformDist(BaseModel):
type: Literal["uniform"] = "uniform"
low: float
high: float
class ConstantDist(BaseModel):
type: Literal["constant"] = "constant"
value: float
# Discriminated union — Pydantic picks the right model from the "type" field
DistSpec = Annotated[
Union[TriangularDist, BetaDist, NormalDist, UniformDist, ConstantDist],
Field(discriminator="type"),
]
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
def sample_from_spec(rng: random.Random, spec: DistSpec) -> float:
"""Draw one sample from the given distribution spec."""
if isinstance(spec, TriangularDist):
val = rng.triangular(spec.low, spec.high, spec.mode)
return max(spec.low, min(spec.high, val))
if isinstance(spec, BetaDist):
val = spec.scale * rng.betavariate(spec.alpha, spec.beta)
return round(max(spec.low, min(spec.high, val)), 4)
if isinstance(spec, NormalDist):
val = rng.gauss(spec.mean, spec.stdev)
return round(max(spec.low, min(spec.high, val)), 4)
if isinstance(spec, UniformDist):
return rng.uniform(spec.low, spec.high)
if isinstance(spec, ConstantDist):
return spec.value
raise TypeError(f"Unknown DistSpec type: {type(spec)}")
__all__ = [
"BetaDist",
"ConstantDist",
"DistSpec",
"NormalDist",
"TriangularDist",
"UniformDist",
"sample_from_spec",
]

View file

@ -0,0 +1,194 @@
"""Pydantic models for all experiment configuration.
Every tunable parameter lives here. TOML files are validated against these
models Pydantic catches typos and type errors at load time.
"""
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator
from .sampling import BetaDist, ConstantDist, NormalDist, TriangularDist, UniformDist, DistSpec # noqa: F401
# ---------------------------------------------------------------------------
# Salary tier
# ---------------------------------------------------------------------------
class SalaryTierConfig(BaseModel):
name: str
share: float # fraction of employees in this tier (all tiers must sum to 1.0)
min_cents: int # minimum monthly salary in cents
max_cents: int # maximum monthly salary in cents
rate_min: float # minimum skill rate (units/hour)
rate_max: float # maximum skill rate (units/hour)
# ---------------------------------------------------------------------------
# World distributions
#
# Each field names a random quantity in world generation and specifies which
# distribution family + parameters to use. Changing `type` switches families;
# changing parameters tunes the shape. See config/sampling.py for all families.
# ---------------------------------------------------------------------------
class WorldDists(BaseModel):
# Prestige level required to accept a task (result cast to int).
# Any DistSpec family works — e.g. constant for ablations, uniform for flat sampling.
required_prestige: DistSpec = Field(
default_factory=lambda: TriangularDist(low=1, high=10, mode=1)
)
# Base reward paid on task completion, in cents (result cast to int).
reward_funds_cents: DistSpec = Field(
default_factory=lambda: TriangularDist(low=500_000, high=10_000_000, mode=3_000_000)
)
# Number of domains required per task (result cast to int).
domain_count: DistSpec = Field(
default_factory=lambda: TriangularDist(low=1, high=3, mode=1)
)
# Work units required per domain (result cast to int).
required_qty: DistSpec = Field(
default_factory=lambda: TriangularDist(low=200, high=3000, mode=800)
)
# Prestige delta awarded on task success.
reward_prestige_delta: DistSpec = Field(
default_factory=lambda: BetaDist(alpha=1.2, beta=2.8, scale=2.0, low=0.0, high=2.0)
)
# Skill rate boost fraction applied to each assigned employee on task success.
skill_boost: DistSpec = Field(
default_factory=lambda: NormalDist(mean=0.12, stdev=0.06, low=0.01, high=0.40)
)
# ---------------------------------------------------------------------------
# Agent / LLM
# ---------------------------------------------------------------------------
class AgentConfig(BaseModel):
model: str = "openrouter/z-ai/glm-5"
temperature: float = 0.0
top_p: float = 1.0
request_timeout_seconds: float = 300.0
retry_max_attempts: int = 3
retry_backoff_seconds: float = 1.0
# Conversation rounds kept in context before each API call; older rounds dropped.
history_keep_rounds: int = 20
# Optional system prompt override. None = use default from agent/prompt.py
system_prompt: str | None = None
# ---------------------------------------------------------------------------
# Agent loop
# ---------------------------------------------------------------------------
class LoopConfig(BaseModel):
# Consecutive turns without `sim resume` before the loop forces a time-advance.
auto_advance_after_turns: int = 10
# Hard cap on total turns. null = unlimited.
max_turns: int | None = None
# ---------------------------------------------------------------------------
# Simulation
# ---------------------------------------------------------------------------
class SimConfig(BaseModel):
start_date: str = "2025-01-01" # ISO 8601 (YYYY-MM-DD)
horizon_years: int = 3
company_name: str = "BenchCo"
# ---------------------------------------------------------------------------
# World generation
# ---------------------------------------------------------------------------
class WorldConfig(BaseModel):
# --- Workforce ---
num_employees: int = 10
initial_funds_cents: int = 25_000_000 # $250,000
initial_prestige_level: float = 1.0
work_hours_per_day: float = 9.0
# --- Market ---
num_market_tasks: int = 500
market_browse_default_limit: int = 50
# --- Salary bump on task completion ---
salary_bump_pct: float = 0.01 # 1% raise per assigned employee per completed task
# --- Prestige mechanics ---
prestige_max: float = 10.0
prestige_min: float = 1.0
penalty_fail_multiplier: float = 0.8
penalty_cancel_multiplier: float = 1.2
# Extra reward fraction per prestige level above 1.
# At 0.55: prestige-8 tasks pay ~4.85x more than prestige-1.
reward_prestige_scale: float = 0.3
# --- Deadline computation ---
deadline_qty_per_day: float = 200.0 # work units assumed completable per business day
deadline_min_biz_days: int = 7
# --- Progress milestone ---
task_half_threshold: float = 0.5
# --- Business hours ---
workday_start_hour: int = 9
workday_end_hour: int = 18
# --- Distributions (shape of random draws during world generation) ---
dist: WorldDists = Field(default_factory=WorldDists)
# --- Salary tiers ---
salary_junior: SalaryTierConfig = Field(
default_factory=lambda: SalaryTierConfig(
name="junior", share=0.50,
min_cents=200_000, max_cents=400_000,
rate_min=1.0, rate_max=6.5,
)
)
salary_mid: SalaryTierConfig = Field(
default_factory=lambda: SalaryTierConfig(
name="mid", share=0.35,
min_cents=600_000, max_cents=800_000,
rate_min=3.5, rate_max=8.5,
)
)
salary_senior: SalaryTierConfig = Field(
default_factory=lambda: SalaryTierConfig(
name="senior", share=0.15,
min_cents=1_000_000, max_cents=1_500_000,
rate_min=5.5, rate_max=10.0,
)
)
@model_validator(mode="after")
def _salary_shares_sum_to_one(self) -> WorldConfig:
total = self.salary_junior.share + self.salary_mid.share + self.salary_senior.share
if abs(total - 1.0) > 1e-6:
raise ValueError(f"salary tier shares must sum to 1.0, got {total:.6f}")
return self
# ---------------------------------------------------------------------------
# Top-level experiment
# ---------------------------------------------------------------------------
class ExperimentConfig(BaseModel):
name: str = "default"
description: str = ""
agent: AgentConfig = Field(default_factory=AgentConfig)
loop: LoopConfig = Field(default_factory=LoopConfig)
sim: SimConfig = Field(default_factory=SimConfig)
world: WorldConfig = Field(default_factory=WorldConfig)
__all__ = [
"AgentConfig",
"DistSpec",
"ExperimentConfig",
"LoopConfig",
"SalaryTierConfig",
"SimConfig",
"WorldConfig",
"WorldDists",
]

View file

View file

@ -0,0 +1,159 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta, time
from decimal import Decimal
WORKDAY_START = time(9, 0)
WORKDAY_END = time(18, 0)
WORK_HOURS_PER_DAY = Decimal("9")
@dataclass(frozen=True)
class BusinessCalendar:
workday_start: time = WORKDAY_START
workday_end: time = WORKDAY_END
DEFAULT_CALENDAR = BusinessCalendar()
def is_weekday(ts):
return ts.weekday() < 5
def is_business_time(ts, cal=DEFAULT_CALENDAR):
if not is_weekday(ts):
return False
t = ts.timetz().replace(tzinfo=None) if ts.tzinfo else ts.time()
return cal.workday_start <= t < cal.workday_end
def _day_start(ts, cal):
return ts.replace(
hour=cal.workday_start.hour,
minute=cal.workday_start.minute,
second=0,
microsecond=0,
)
def _day_end(ts, cal):
return ts.replace(
hour=cal.workday_end.hour,
minute=cal.workday_end.minute,
second=0,
microsecond=0,
)
def _next_weekday_start(ts, cal):
cur = _day_start(ts, cal)
while not is_weekday(cur):
cur += timedelta(days=1)
cur = _day_start(cur, cal)
return cur
def next_business_time(ts, cal):
if is_business_time(ts, cal):
return ts
if not is_weekday(ts):
return _next_weekday_start(ts, cal)
day_start = _day_start(ts, cal)
day_end = _day_end(ts, cal)
if ts < day_start:
return day_start
if ts >= day_end:
return _next_weekday_start(ts + timedelta(days=1), cal)
raise ValueError(f"No valid business time found after {ts}")
def add_business_hours(ts, hours, cal=DEFAULT_CALENDAR):
hours = Decimal(str(hours))
if hours < 0:
raise ValueError(f"Cannot add negative business hours: {hours}")
if hours == 0:
return next_business_time(ts, cal)
cur = next_business_time(ts, cal)
remaining = hours
while remaining > 0:
day_end = _day_end(cur, cal)
available = Decimal(str((day_end - cur).total_seconds())) / Decimal("3600")
if remaining <= available:
return cur + timedelta(seconds=float(remaining * Decimal("3600")))
remaining -= available
cur = next_business_time(day_end, cal)
return cur
def _business_interval_same_day(start, end, cal):
if end <= start:
return Decimal("0")
if not is_weekday(start):
return Decimal("0")
day_start = _day_start(start, cal)
day_end = _day_end(end, cal)
lo = max(start, day_start)
hi = min(end, day_end)
if hi <= lo:
return Decimal("0")
return Decimal(str((hi - lo).total_seconds())) / Decimal("3600")
def business_hours_between(t0, t1, cal=DEFAULT_CALENDAR):
if t1 <= t0:
return Decimal("0")
cur = t0
total = Decimal("0")
while cur < t1:
next_midnight = (cur + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
seg_end = min(next_midnight, t1)
total += _business_interval_same_day(cur, seg_end, cal)
cur = seg_end
return total
def business_seconds_between(t0, t1, cal=DEFAULT_CALENDAR):
return int((business_hours_between(t0, t1, cal) * Decimal("3600")).to_integral_value())
def first_business_of_month(dt, cal=DEFAULT_CALENDAR):
first = dt.replace(day=1, hour=cal.workday_start.hour, minute=cal.workday_start.minute, second=0, microsecond=0)
while not is_weekday(first):
first += timedelta(days=1)
first = first.replace(hour=cal.workday_start.hour, minute=cal.workday_start.minute, second=0, microsecond=0)
return first
def iter_monthly_payroll_boundaries(start, end, cal=DEFAULT_CALENDAR):
if end <= start:
return []
cursor = start.replace(day=1, hour=cal.workday_start.hour, minute=cal.workday_start.minute, second=0, microsecond=0)
out = []
while cursor < end:
boundary = first_business_of_month(cursor, cal)
if start < boundary <= end:
out.append(boundary)
if cursor.month == 12:
cursor = cursor.replace(year=cursor.year + 1, month=1, day=1)
else:
cursor = cursor.replace(month=cursor.month + 1, day=1)
return out
__all__ = [
"BusinessCalendar",
"DEFAULT_CALENDAR",
"is_weekday",
"is_business_time",
"next_business_time",
"add_business_hours",
"business_hours_between",
"business_seconds_between",
"first_business_of_month",
"iter_monthly_payroll_boundaries",
]

210
src/yc_bench/core/engine.py Normal file
View file

@ -0,0 +1,210 @@
"""Simulation engine: advance time with deterministic event processing.
Main loop:
1. Find next action: earliest of (next_event, next_payroll_boundary, target_time)
2. Flush progress from current_time to action_time
3. If payroll: deduct salaries, write ledger entries, bankruptcy check
4. If event: dispatch to handler, consume, bankruptcy check
5. Loop until target or terminal condition
Payroll-event tie-breaking: payroll first at same timestamp (start-of-day obligation).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from typing import Dict, List, Optional
from uuid import UUID
from sqlalchemy.orm import Session
from ..db.models.company import Company
from ..db.models.employee import Employee
from ..db.models.event import EventType, SimEvent
from ..db.models.ledger import LedgerCategory, LedgerEntry
from ..db.models.sim_state import SimState
from .business_time import iter_monthly_payroll_boundaries
from .eta import recalculate_etas
from .events import consume_event, fetch_next_event, insert_event
from .handlers.bankruptcy import handle_bankruptcy
from .handlers.horizon_end import handle_horizon_end
from .handlers.task_complete import handle_task_complete
from .handlers.task_half import handle_task_half
from .progress import flush_progress
@dataclass
class AdvanceResult:
old_sim_time: str
new_sim_time: str
events_processed: int = 0
payrolls_applied: int = 0
balance_delta: int = 0
bankrupt: bool = False
horizon_reached: bool = False
wake_events: List[Dict] = field(default_factory=list)
def apply_payroll(db: Session, company_id: UUID, time: datetime) -> bool:
"""Deduct monthly salaries for all employees. Returns True if bankrupt after payroll."""
company = db.query(Company).filter(Company.id == company_id).one()
employees = db.query(Employee).filter(Employee.company_id == company_id).all()
total_payroll = 0
for emp in employees:
salary = int(emp.salary_cents)
total_payroll += salary
db.add(LedgerEntry(
company_id=company_id,
occurred_at=time,
category=LedgerCategory.MONTHLY_PAYROLL,
amount_cents=-salary,
ref_type="employee",
ref_id=emp.id,
))
company.funds_cents -= total_payroll
db.flush()
return company.funds_cents < 0
def dispatch_event(db: Session, event: SimEvent, sim_time: datetime, company_id: UUID) -> Dict:
"""Route event to appropriate handler. Returns result dict."""
if event.event_type == EventType.TASK_HALF_PROGRESS:
result = handle_task_half(db, event)
return {"type": "task_half", "task_id": str(result.task_id), "handled": result.handled}
elif event.event_type == EventType.TASK_COMPLETED:
result = handle_task_complete(db, event, sim_time)
# Recalculate ETAs — freed employees change topology
from ..config import get_world_config
recalculate_etas(db, company_id, sim_time, half_threshold=get_world_config().task_half_threshold)
return {
"type": "task_completed",
"task_id": str(result.task_id),
"success": result.success,
"funds_delta": result.funds_delta,
"bankrupt": result.bankrupt,
}
elif event.event_type == EventType.HORIZON_END:
result = handle_horizon_end(db, event)
return {"type": "horizon_end", "reached": result.reached}
elif event.event_type == EventType.BANKRUPTCY:
result = handle_bankruptcy(db, event)
return {"type": "bankruptcy", "bankrupt": result.bankrupt}
return {"type": "unknown", "event_type": event.event_type.value}
def advance_time(
db: Session,
company_id: UUID,
target_time: datetime,
) -> AdvanceResult:
"""Advance simulation from current sim_time to target_time, processing all events and payroll."""
sim_state = db.query(SimState).filter(SimState.company_id == company_id).one()
current_time = sim_state.sim_time
old_time = current_time
company = db.query(Company).filter(Company.id == company_id).one()
starting_funds = company.funds_cents
result = AdvanceResult(
old_sim_time=old_time.isoformat(),
new_sim_time=target_time.isoformat(),
)
# Pre-compute payroll boundaries
payroll_times = iter_monthly_payroll_boundaries(current_time, target_time)
payroll_idx = 0
while True:
# Find next payroll
next_payroll = None
if payroll_idx < len(payroll_times):
next_payroll = payroll_times[payroll_idx]
# Find next event
next_event = fetch_next_event(db, company_id, target_time)
# Determine next action time
candidates = []
if next_payroll is not None and next_payroll <= target_time:
candidates.append(("payroll", next_payroll))
if next_event is not None:
candidates.append(("event", next_event.scheduled_at))
candidates.append(("target", target_time))
# Sort: earliest time first; at same time, payroll before event before target
action_priority = {"payroll": 0, "event": 1, "target": 2}
candidates.sort(key=lambda c: (c[1], action_priority[c[0]]))
action_type, action_time = candidates[0]
# Flush progress from current_time to action_time
if action_time > current_time:
flush_progress(db, company_id, current_time, action_time)
current_time = action_time
if action_type == "target":
# Nothing due before/equal target; stop advancement.
break
if action_type == "payroll":
bankrupt = apply_payroll(db, company_id, current_time)
result.payrolls_applied += 1
payroll_idx += 1
if bankrupt:
# Insert bankruptcy event at this time
insert_event(
db, company_id,
EventType.BANKRUPTCY,
current_time,
{"reason": "funds_negative_after_payroll"},
dedupe_key=f"bankruptcy:{current_time.isoformat()}",
)
result.bankrupt = True
break
elif action_type == "event":
event_result = dispatch_event(db, next_event, current_time, company_id)
consume_event(db, next_event)
result.events_processed += 1
result.wake_events.append(event_result)
# Check terminal conditions
if next_event.event_type == EventType.HORIZON_END:
result.horizon_reached = True
break
if next_event.event_type == EventType.BANKRUPTCY:
result.bankrupt = True
break
if event_result.get("bankrupt", False):
result.bankrupt = True
break
# Continue loop: more due actions can still exist at the same timestamp.
# Update sim_time
sim_state.sim_time = current_time
db.flush()
# Compute balance delta
company = db.query(Company).filter(Company.id == company_id).one()
result.balance_delta = company.funds_cents - starting_funds
result.new_sim_time = current_time.isoformat()
return result
__all__ = [
"AdvanceResult",
"advance_time",
"apply_payroll",
"dispatch_event",
]

263
src/yc_bench/core/eta.py Normal file
View file

@ -0,0 +1,263 @@
"""ETA solver: compute task completion and halfway times, manage projection events.
Projection events (task_completed, task_half_progress) are inserted into sim_events
and recalculated whenever the topology changes (assign, dispatch, cancel, complete).
"""
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import Dict, List, Optional, Set
from uuid import UUID
from sqlalchemy.orm import Session
from ..db.models.company import Domain
from ..db.models.event import EventType, SimEvent
from ..db.models.task import Task, TaskRequirement, TaskStatus
from .business_time import add_business_hours
from .events import insert_event
from .progress import EffectiveRate, compute_effective_rates
def solve_task_completion_time(
db: Session,
task_id: UUID,
now: datetime,
rates: List[EffectiveRate],
) -> Optional[datetime]:
"""Solve for the business-time at which a task will complete.
Completion = all domains reach 100%. Time is max(remaining[d] / rate[d]) across domains.
Returns None if any domain has remaining > 0 and rate == 0 (impossible to complete).
"""
reqs = db.query(TaskRequirement).filter(
TaskRequirement.task_id == task_id
).all()
if not reqs:
return None
# Build rate map for this task
rate_map: Dict[Domain, Decimal] = {}
for r in rates:
if r.task_id == task_id:
rate_map[r.domain] = r.rate_per_hour
max_hours = Decimal("0")
for req in reqs:
remaining = req.required_qty - req.completed_qty
if remaining <= 0:
continue
rate = rate_map.get(req.domain, Decimal("0"))
if rate <= 0:
return None # Can't complete this domain
hours = remaining / rate
if hours > max_hours:
max_hours = hours
if max_hours <= 0:
# Already complete
return now
return add_business_hours(now, max_hours)
def solve_task_halfway_time(
db: Session,
task_id: UUID,
now: datetime,
rates: List[EffectiveRate],
half_threshold: float = 0.5,
) -> Optional[datetime]:
"""Solve for the business-time at which weighted progress ratio >= 0.5.
Weighted ratio is:
sum_d(completed_d) / sum_d(required_d)
Each domain progresses linearly until capped at required_d.
Returns None if reaching 50% is impossible.
"""
reqs = db.query(TaskRequirement).filter(
TaskRequirement.task_id == task_id
).all()
if not reqs:
return None
rate_map: Dict[Domain, Decimal] = {}
for r in rates:
if r.task_id == task_id:
rate_map[r.domain] = r.rate_per_hour
total_required = sum((req.required_qty for req in reqs), Decimal("0"))
if total_required <= 0:
return now
target_completed = Decimal(str(half_threshold)) * total_required
# Check current weighted progress
current_completed = Decimal("0")
for req in reqs:
current_completed += min(req.required_qty, req.completed_qty)
if current_completed >= target_completed:
return now
# Build list of domain progression pieces.
# cap_hours = hours until domain reaches required_qty
domains = []
for req in reqs:
remaining = req.required_qty - req.completed_qty
rate = rate_map.get(req.domain, Decimal("0"))
if remaining > 0 and rate <= 0:
# This domain can never progress further
cap_hours = None
elif remaining <= 0:
cap_hours = Decimal("0")
else:
cap_hours = remaining / rate
domains.append({
"completed": req.completed_qty,
"required": req.required_qty,
"rate": rate,
"cap_hours": cap_hours,
})
# Piecewise-linear solver over cap breakpoints.
# In each segment, d(current_completed)/dh = sum(rate_d) for uncapped domains.
breakpoints = sorted(set(
d["cap_hours"] for d in domains
if d["cap_hours"] is not None and d["cap_hours"] > 0
))
h = Decimal("0")
completed_sum = current_completed
for bp in breakpoints:
# Slope of completed_sum during [h, bp]
slope = Decimal("0")
for d in domains:
if d["cap_hours"] is not None and d["cap_hours"] > h:
slope += d["rate"]
if slope <= 0:
# No further progress in this segment.
if completed_sum >= target_completed:
return add_business_hours(now, h)
h = bp
# Re-evaluate completed amount at breakpoint.
completed_sum = Decimal("0")
for d in domains:
progress = min(d["required"], d["completed"] + d["rate"] * h)
completed_sum += progress
continue
needed = target_completed - completed_sum
delta_h = needed / slope
if h + delta_h <= bp:
return add_business_hours(now, h + delta_h)
completed_sum += slope * (bp - h)
h = bp
# After all breakpoints, check remaining slope
slope = Decimal("0")
for d in domains:
if d["cap_hours"] is not None and d["cap_hours"] > h:
slope += d["rate"]
if slope > 0:
needed = target_completed - completed_sum
if needed <= 0:
return add_business_hours(now, h)
delta_h = needed / slope
return add_business_hours(now, h + delta_h)
if completed_sum >= target_completed:
return add_business_hours(now, h)
return None # Cannot reach 50%
def recalculate_etas(
db: Session,
company_id: UUID,
now: datetime,
impacted_task_ids: Optional[Set[UUID]] = None,
half_threshold: float = 0.5,
) -> None:
"""Recalculate projection events for active tasks.
1. Delete stale projection events for impacted tasks (or all if None).
2. Compute effective rates.
3. For each active task, solve completion and halfway times.
4. Insert new projection events.
"""
# Determine which tasks to recalculate
if impacted_task_ids is None:
active_tasks = db.query(Task).filter(
Task.company_id == company_id,
Task.status == TaskStatus.ACTIVE,
).all()
task_ids = {t.id for t in active_tasks}
else:
task_ids = impacted_task_ids
if not task_ids:
return
# Delete stale unconsumed projection events for these tasks
for tid in task_ids:
stale = db.query(SimEvent).filter(
SimEvent.company_id == company_id,
SimEvent.consumed == False,
SimEvent.event_type.in_([EventType.TASK_COMPLETED, EventType.TASK_HALF_PROGRESS]),
SimEvent.dedupe_key.like(f"task:{tid}:%"),
).all()
for ev in stale:
db.delete(ev)
db.flush()
# Compute rates for all active tasks (topology-wide, since employee sharing matters)
rates = compute_effective_rates(db, company_id)
for tid in task_ids:
task = db.query(Task).filter(Task.id == tid).one_or_none()
if task is None or task.status != TaskStatus.ACTIVE:
continue
# Completion ETA
completion_time = solve_task_completion_time(db, tid, now, rates)
if completion_time is not None:
insert_event(
db,
company_id=company_id,
event_type=EventType.TASK_COMPLETED,
scheduled_at=completion_time,
payload={"task_id": str(tid)},
dedupe_key=f"task:{tid}:completed",
)
# Halfway ETA (only if not already emitted)
if not task.halfway_event_emitted:
halfway_time = solve_task_halfway_time(db, tid, now, rates, half_threshold=half_threshold)
if halfway_time is not None:
insert_event(
db,
company_id=company_id,
event_type=EventType.TASK_HALF_PROGRESS,
scheduled_at=halfway_time,
payload={"task_id": str(tid)},
dedupe_key=f"task:{tid}:half",
)
db.flush()
__all__ = [
"solve_task_completion_time",
"solve_task_halfway_time",
"recalculate_etas",
]

128
src/yc_bench/core/events.py Normal file
View file

@ -0,0 +1,128 @@
"""Event infrastructure: fetch, consume, and insert simulation events.
Events are processed in deterministic order: (scheduled_at, priority, id).
Priority by event_type: task_completed=0, bankruptcy=1, task_half=2, horizon_end=3.
"""
from __future__ import annotations
from datetime import datetime
import json
from typing import Any, Dict, Optional
from uuid import UUID
from uuid import uuid5, NAMESPACE_URL
from sqlalchemy import case
from sqlalchemy.orm import Session
from ..db.models.event import EventType, SimEvent
# Priority ordering — lower number = higher priority
EVENT_PRIORITY: Dict[EventType, int] = {
EventType.TASK_COMPLETED: 0,
EventType.BANKRUPTCY: 1,
EventType.TASK_HALF_PROGRESS: 2,
EventType.HORIZON_END: 3,
}
def _deterministic_event_id(
company_id: UUID,
event_type: EventType,
scheduled_at: datetime,
dedupe_key: Optional[str],
payload: Dict[str, Any],
) -> UUID:
"""Generate deterministic event UUID to stabilize same-seed replay ordering."""
payload_key = json.dumps(payload, sort_keys=True, separators=(",", ":"))
base = "|".join(
[
str(company_id),
event_type.value,
scheduled_at.isoformat(),
dedupe_key or "",
payload_key,
]
)
return uuid5(NAMESPACE_URL, base)
def fetch_next_event(
db: Session,
company_id: UUID,
up_to: datetime,
) -> Optional[SimEvent]:
"""Fetch the next unconsumed event scheduled at or before up_to.
Deterministic order: (scheduled_at ASC, priority ASC, id ASC).
"""
priority_expr = case(
{et: p for et, p in EVENT_PRIORITY.items()},
value=SimEvent.event_type,
else_=99,
)
event = (
db.query(SimEvent)
.filter(
SimEvent.company_id == company_id,
SimEvent.consumed == False,
SimEvent.scheduled_at <= up_to,
)
.order_by(
SimEvent.scheduled_at.asc(),
priority_expr.asc(),
SimEvent.id.asc(),
)
.first()
)
return event
def consume_event(db: Session, event: SimEvent) -> None:
"""Mark an event as consumed."""
event.consumed = True
db.flush()
def insert_event(
db: Session,
company_id: UUID,
event_type: EventType,
scheduled_at: datetime,
payload: Dict[str, Any],
dedupe_key: Optional[str] = None,
) -> SimEvent:
"""Insert a new event, with optional idempotent deduplication.
If dedupe_key is provided and an unconsumed event with the same key exists,
the existing event is returned unchanged.
"""
if dedupe_key is not None:
existing = db.query(SimEvent).filter(
SimEvent.company_id == company_id,
SimEvent.dedupe_key == dedupe_key,
SimEvent.consumed == False,
).first()
if existing is not None:
return existing
event = SimEvent(
id=_deterministic_event_id(company_id, event_type, scheduled_at, dedupe_key, payload),
company_id=company_id,
event_type=event_type,
scheduled_at=scheduled_at,
payload=payload,
dedupe_key=dedupe_key,
)
db.add(event)
db.flush()
return event
__all__ = [
"EVENT_PRIORITY",
"fetch_next_event",
"consume_event",
"insert_event",
]

View file

View file

@ -0,0 +1,18 @@
"""Handler for bankruptcy events."""
from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy.orm import Session
from ...db.models.event import SimEvent
@dataclass
class BankruptcyResult:
bankrupt: bool = True
def handle_bankruptcy(db: Session, event: SimEvent) -> BankruptcyResult:
"""Signal that the company has gone bankrupt."""
return BankruptcyResult(bankrupt=True)

View file

@ -0,0 +1,18 @@
"""Handler for horizon_end events."""
from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy.orm import Session
from ...db.models.event import SimEvent
@dataclass
class HorizonEndResult:
reached: bool = True
def handle_horizon_end(db: Session, event: SimEvent) -> HorizonEndResult:
"""Signal that the simulation horizon has been reached."""
return HorizonEndResult(reached=True)

View file

@ -0,0 +1,141 @@
"""Handler for task_completed events.
On completion:
- If completion_time <= deadline: success add reward funds, add prestige, skill-boost employees.
- If completion_time > deadline: fail set completed_fail, apply 0.8 * delta prestige penalty.
After either outcome, recalculate ETAs (freed employees change topology).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from decimal import Decimal
from typing import Dict
from uuid import UUID
from sqlalchemy.orm import Session
from ...db.models.company import Company, CompanyPrestige, Domain
from ...db.models.employee import Employee, EmployeeSkillRate
from ...config import get_world_config
from ...db.models.event import SimEvent
from ...db.models.ledger import LedgerCategory, LedgerEntry
from ...db.models.task import Task, TaskAssignment, TaskRequirement, TaskStatus
@dataclass
class TaskCompleteResult:
task_id: UUID
success: bool
funds_delta: int = 0
prestige_changes: Dict[str, float] = field(default_factory=dict)
bankrupt: bool = False
def handle_task_complete(db: Session, event: SimEvent, sim_time) -> TaskCompleteResult:
"""Process task completion: finalize progress, determine success/fail, apply rewards/penalties."""
task_id = UUID(event.payload["task_id"])
task = db.query(Task).filter(Task.id == task_id).one()
company_id = task.company_id
# Finalize all domain progress to 100%
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task_id).all()
for req in reqs:
req.completed_qty = req.required_qty
db.flush()
task.completed_at = sim_time
success = sim_time <= task.deadline
wc = get_world_config()
prestige_changes: Dict[str, float] = {}
funds_delta = 0
if success:
task.status = TaskStatus.COMPLETED_SUCCESS
task.success = True
# Add reward funds
company = db.query(Company).filter(Company.id == company_id).one()
company.funds_cents += task.reward_funds_cents
funds_delta = task.reward_funds_cents
# Ledger entry
db.add(LedgerEntry(
company_id=company_id,
occurred_at=sim_time,
category=LedgerCategory.TASK_REWARD,
amount_cents=task.reward_funds_cents,
ref_type="task",
ref_id=task_id,
))
# Add prestige to each domain
for req in reqs:
prestige = db.query(CompanyPrestige).filter(
CompanyPrestige.company_id == company_id,
CompanyPrestige.domain == req.domain,
).one_or_none()
if prestige is not None:
old = float(prestige.prestige_level)
prestige.prestige_level = min(
Decimal(str(wc.prestige_max)),
prestige.prestige_level + task.reward_prestige_delta,
)
prestige_changes[req.domain.value] = float(prestige.prestige_level) - old
# Skill boost assigned employees
assignments = db.query(TaskAssignment).filter(
TaskAssignment.task_id == task_id
).all()
if task.skill_boost_pct > 0:
task_domains = {req.domain for req in reqs}
for a in assignments:
for domain in task_domains:
skill = db.query(EmployeeSkillRate).filter(
EmployeeSkillRate.employee_id == a.employee_id,
EmployeeSkillRate.domain == domain,
).one_or_none()
if skill is not None:
boost = skill.rate_domain_per_hour * task.skill_boost_pct
skill.rate_domain_per_hour += boost
# Salary bump: small raise for each employee who contributed to this task
if wc.salary_bump_pct > 0:
for a in assignments:
employee = db.query(Employee).filter(Employee.id == a.employee_id).one_or_none()
if employee is not None:
bump = int(employee.salary_cents * wc.salary_bump_pct)
employee.salary_cents += bump
else:
task.status = TaskStatus.COMPLETED_FAIL
task.success = False
# Apply penalty_fail_multiplier * reward_prestige_delta penalty
penalty = Decimal(str(wc.penalty_fail_multiplier)) * task.reward_prestige_delta
for req in reqs:
prestige = db.query(CompanyPrestige).filter(
CompanyPrestige.company_id == company_id,
CompanyPrestige.domain == req.domain,
).one_or_none()
if prestige is not None:
old = float(prestige.prestige_level)
prestige.prestige_level = max(
Decimal(str(wc.prestige_min)),
prestige.prestige_level - penalty,
)
prestige_changes[req.domain.value] = float(prestige.prestige_level) - old
db.flush()
# Check bankruptcy
company = db.query(Company).filter(Company.id == company_id).one()
bankrupt = company.funds_cents < 0
return TaskCompleteResult(
task_id=task_id,
success=success,
funds_delta=funds_delta,
prestige_changes=prestige_changes,
bankrupt=bankrupt,
)

View file

@ -0,0 +1,30 @@
"""Handler for task_half_progress events."""
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
from sqlalchemy.orm import Session
from ...db.models.event import SimEvent
from ...db.models.task import Task
@dataclass
class TaskHalfResult:
task_id: UUID
handled: bool
def handle_task_half(db: Session, event: SimEvent) -> TaskHalfResult:
"""Mark the task's halfway_event_emitted flag as True."""
task_id = UUID(event.payload["task_id"])
task = db.query(Task).filter(Task.id == task_id).one_or_none()
if task is None:
return TaskHalfResult(task_id=task_id, handled=False)
task.halfway_event_emitted = True
db.flush()
return TaskHalfResult(task_id=task_id, handled=True)

View file

@ -0,0 +1,308 @@
from __future__ import annotations
from dataclasses import dataclass
from decimal import Decimal
from uuid import UUID
from .business_time import business_hours_between
from ..db.models.company import Domain
from ..db.models.employee import EmployeeSkillRate
from ..db.models.task import TaskRequirement, Task, TaskAssignment, TaskStatus
@dataclass(frozen=True)
class RequirementState:
domain: str
required_qty: Decimal
completed_qty: Decimal
@dataclass(frozen=True)
class TaskProgressState:
task_id: str
status: str
requirements: tuple[RequirementState, ...]
@dataclass(frozen=True)
class AssignmentState:
task_id: str
employee_id: str
@dataclass(frozen=True)
class EmployeeRateState:
employee_id: str
domain: str
rate_domain_per_hour: Decimal
@dataclass(frozen=True)
class ProgressDelta:
task_id: str
domain: str
delta_qty: Decimal
before_qty: Decimal
after_qty: Decimal
@dataclass(frozen=True)
class TaskProgressSummary:
task_id: str
ratio_before: Decimal
ratio_after: Decimal
completed: bool
@dataclass(frozen=True)
class EffectiveRate:
task_id: UUID
domain: Domain
rate_per_hour: Decimal
def _active_assignment_count(assignments):
counts = {}
for a in assignments:
counts[a.employee_id] = counts.get(a.employee_id, 0) + 1
return counts
def _rates_by_employee_domain(rates):
m = {}
for r in rates:
m[(r.employee_id, r.domain)] = r.rate_domain_per_hour
return m
def _effective_rate_for_task_domain(*, task_id, domain, assignments, assignment_counts, base_rates):
total = Decimal("0")
for a in assignments:
if a.task_id != task_id:
continue
k = assignment_counts.get(a.employee_id, 0)
if k <= 0:
continue
base = base_rates.get((a.employee_id, domain), Decimal("0"))
total += base / Decimal(k)
return total
def _weighted_ratio_from_rows(rows, *, task_id_label):
total_completed = Decimal("0")
total_required = Decimal("0")
for req in rows:
if req.required_qty <= 0:
raise ValueError(
f"Task {task_id_label} requirement {req.domain} has quantity {req.required_qty}"
)
completed = req.completed_qty
if completed < 0:
raise ValueError(
f"Task {task_id_label} requirement {req.domain} has completed quantity {req.completed_qty} which is less than 0"
)
if completed > req.required_qty:
completed = req.required_qty
total_completed += completed
total_required += req.required_qty
if total_required == 0:
return Decimal("0")
return total_completed / total_required
def task_progress_ratio(task):
if not task.requirements:
raise ValueError(f"Task {task.task_id} has no requirements")
return _weighted_ratio_from_rows(task.requirements, task_id_label=task.task_id)
def apply_progress_window(*, tasks, assignments, rates, t0, t1):
hours = Decimal(str(business_hours_between(t0, t1)))
if hours <= 0:
unchanged = list(tasks)
summaries = [
TaskProgressSummary(
task_id=t.task_id,
ratio_before=task_progress_ratio(t),
ratio_after=task_progress_ratio(t),
completed=all(r.completed_qty >= r.required_qty for r in t.requirements),
)
for t in unchanged
]
return unchanged, [], summaries
assignment_list = list(assignments)
assignment_counts = _active_assignment_count(assignment_list)
base_rates = _rates_by_employee_domain(rates)
updated_tasks = []
deltas = []
summaries = []
for task in tasks:
ratio_before = task_progress_ratio(task)
if task.status not in {"planned", "active"}:
updated_tasks.append(task)
summaries.append(
TaskProgressSummary(
task_id=task.task_id,
ratio_before=ratio_before,
ratio_after=ratio_before,
completed=all(r.completed_qty >= r.required_qty for r in task.requirements),
)
)
continue
next_requirements = []
for req in task.requirements:
before = Decimal(req.completed_qty)
required = Decimal(req.required_qty)
eff_rate = _effective_rate_for_task_domain(
task_id=task.task_id,
domain=req.domain,
assignments=assignment_list,
assignment_counts=assignment_counts,
base_rates=base_rates,
)
delta = eff_rate * hours
after = before + delta
# Progress cap is independent from deadline/failure logic.
if after > required:
after = required
# Guardrail in case of inconsistent negative upstream data.
if after < 0:
after = Decimal("0")
next_requirements.append(RequirementState(
domain=req.domain,
required_qty=required,
completed_qty=after,
))
deltas.append(ProgressDelta(
task_id=task.task_id,
domain=req.domain,
delta_qty=after - before,
before_qty=before,
after_qty=after,
))
next_task = TaskProgressState(
task_id=task.task_id,
status=task.status,
requirements=tuple(next_requirements),
)
ratio_after = task_progress_ratio(next_task)
completed = all(r.completed_qty >= r.required_qty for r in next_requirements)
updated_tasks.append(next_task)
summaries.append(
TaskProgressSummary(
task_id=task.task_id,
ratio_before=ratio_before,
ratio_after=ratio_after,
completed=completed,
)
)
return updated_tasks, deltas, summaries
def compute_task_progress_ratio(db, task_id):
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task_id).all()
if not reqs:
return Decimal("0")
return _weighted_ratio_from_rows(reqs, task_id_label=task_id)
def compute_effective_rates(db, company_id):
active_tasks = db.query(Task).filter(Task.company_id == company_id, Task.status == TaskStatus.ACTIVE).all()
if not active_tasks:
return []
task_ids = [t.id for t in active_tasks]
requirements = db.query(TaskRequirement).filter(TaskRequirement.task_id.in_(task_ids)).all()
assignments = db.query(TaskAssignment).filter(TaskAssignment.task_id.in_(task_ids)).all()
if not assignments:
out = []
for req in requirements:
out.append(EffectiveRate(
task_id=req.task_id,
domain=req.domain,
rate_per_hour=Decimal("0"),
))
return out
assignment_counts = {}
assignments_by_task = {}
for a in assignments:
assignments_by_task.setdefault(a.task_id, []).append(a)
assignment_counts[a.employee_id] = assignment_counts.get(a.employee_id, 0) + 1
employee_ids = list(assignment_counts.keys())
skill_rows = db.query(EmployeeSkillRate).filter(EmployeeSkillRate.employee_id.in_(employee_ids)).all()
base_rates = {}
for s in skill_rows:
base_rates[(s.employee_id, s.domain)] = Decimal(s.rate_domain_per_hour)
out = []
for req in requirements:
total = Decimal("0")
for a in assignments_by_task.get(req.task_id, []):
k = assignment_counts.get(a.employee_id, 0)
if k <= 0:
continue
base = base_rates.get((a.employee_id, req.domain), Decimal("0"))
total += base / Decimal(k)
out.append(EffectiveRate(
task_id=req.task_id,
domain=req.domain,
rate_per_hour=total,
))
return out
def flush_progress(db, company_id, t0, t1):
active_tasks = db.query(Task).filter(Task.company_id == company_id, Task.status == TaskStatus.ACTIVE).all()
if not active_tasks:
return
task_ids = [t.id for t in active_tasks]
req_rows = db.query(TaskRequirement).filter(TaskRequirement.task_id.in_(task_ids)).all()
asg_rows = db.query(TaskAssignment).filter(TaskAssignment.task_id.in_(task_ids)).all()
emp_ids = list({a.employee_id for a in asg_rows})
rate_rows = db.query(EmployeeSkillRate).filter(EmployeeSkillRate.employee_id.in_(emp_ids)).all()
reqs_by_task = {}
req_index = {}
for r in req_rows:
req_index[(r.task_id, r.domain)] = r
reqs_by_task.setdefault(r.task_id, []).append(
RequirementState(
domain=r.domain,
required_qty=Decimal(r.required_qty),
completed_qty=Decimal(r.completed_qty),
)
)
task_states = [
TaskProgressState(
task_id=t.id,
status=t.status.value if hasattr(t.status, "value") else str(t.status),
requirements=tuple(reqs_by_task.get(t.id, [])),
)
for t in active_tasks
]
assignments_states = [
AssignmentState(
task_id=a.task_id,
employee_id=a.employee_id,
)
for a in asg_rows
]
employee_rate_states = [
EmployeeRateState(
employee_id=s.employee_id,
domain=s.domain,
rate_domain_per_hour=Decimal(s.rate_domain_per_hour),
)
for s in rate_rows
]
updated_tasks, _, _ = apply_progress_window(tasks=task_states, assignments=assignments_states, rates=employee_rate_states, t0=t0, t1=t1)
for ut in updated_tasks:
for req in ut.requirements:
row = req_index[(ut.task_id, req.domain)]
if row is not None:
row.completed_qty = req.completed_qty
db.flush()

View file

@ -0,0 +1,49 @@
from __future__ import annotations
from dataclasses import dataclass
import argparse
from pathlib import Path
@dataclass(frozen=True)
class RunArgs:
model: str
seed: int
horizon_years: int | None # None = defer to sim.horizon_years in config
company_name: str
start_date: str
config_name: str = "default"
def build_parser():
parser = argparse.ArgumentParser(
prog="yc-bench run",
)
parser.add_argument("--model", required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--horizon-years", type=int, default=None,
help="Simulation horizon in years (default: read from config's sim.horizon_years)")
parser.add_argument("--company-name", default="BenchCo")
parser.add_argument("--start-date", default="2025-01-01", help="Simulation start date (YYYY-MM-DD)")
parser.add_argument(
"--config", dest="config_name", default="default",
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
)
return parser
def parse_run_args(argv):
parser = build_parser()
ns = parser.parse_args(argv)
_validate(ns, parser)
return RunArgs(
model=ns.model,
seed=ns.seed,
horizon_years=ns.horizon_years,
company_name=ns.company_name,
start_date=ns.start_date,
config_name=ns.config_name,
)
def _validate(ns, parser):
if ns.horizon_years is not None and ns.horizon_years <= 0:
parser.error("--horizon-years must be int > 0")
__all__ = ["RunArgs", "build_parser", "parse_run_args"]

226
src/yc_bench/runner/main.py Normal file
View file

@ -0,0 +1,226 @@
"""Benchmark entrypoint: provisions DB, seeds world, runs agent loop to completion."""
from __future__ import annotations
import json
import logging
import os
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from ..agent.commands.executor import run_command
from ..agent.loop import run_agent_loop
from ..agent.run_state import RunState
from ..agent.runtime.factory import build_runtime
from ..agent.runtime.schemas import RuntimeSettings
from ..db.session import build_engine, build_session_factory, session_scope, init_db
from .args import parse_run_args
logger = logging.getLogger(__name__)
def _parse_date(date_str: str) -> datetime:
"""Accept ISO (2025-01-01) or legacy MM/DD/YYYY format."""
fmt = "%Y-%m-%d" if "-" in date_str else "%m/%d/%Y"
dt = datetime.strptime(date_str, fmt)
return dt.replace(hour=9, minute=0, second=0, tzinfo=timezone.utc)
def _wipe_simulation(db) -> None:
"""Delete all simulation rows so the DB can be reseeded cleanly."""
from ..db.models.ledger import LedgerEntry
from ..db.models.task import Task, TaskAssignment, TaskRequirement
from ..db.models.employee import Employee, EmployeeSkillRate
from ..db.models.company import Company, CompanyPrestige
from ..db.models.sim_state import SimState
from ..db.models.event import SimEvent
from ..db.models.scratchpad import Scratchpad
db.query(Scratchpad).delete(synchronize_session=False)
db.query(LedgerEntry).delete(synchronize_session=False)
db.query(TaskAssignment).delete(synchronize_session=False)
db.query(TaskRequirement).delete(synchronize_session=False)
db.query(Task).delete(synchronize_session=False)
db.query(SimEvent).delete(synchronize_session=False)
db.query(EmployeeSkillRate).delete(synchronize_session=False)
db.query(Employee).delete(synchronize_session=False)
db.query(CompanyPrestige).delete(synchronize_session=False)
db.query(Company).delete(synchronize_session=False)
db.query(SimState).delete(synchronize_session=False)
db.flush()
def _init_simulation(db_factory, args, experiment_cfg, horizon_years):
"""Seed world directly (no CLI round-trip) using the experiment WorldConfig."""
from ..db.models.event import EventType
from ..db.models.sim_state import SimState
from ..db.models.company import Company
from ..core.events import insert_event
from ..services.seed_world import SeedWorldRequest, seed_world_transactional
with db_factory() as db:
existing = db.query(SimState).first()
if existing is not None:
company = db.query(Company).filter(Company.id == existing.company_id).first()
bankrupt = company is not None and company.funds_cents < 0
horizon_reached = existing.sim_time >= existing.horizon_end
if bankrupt or horizon_reached:
logger.info(
"Existing simulation is terminal (bankrupt=%s horizon_reached=%s) — reseeding.",
bankrupt, horizon_reached,
)
_wipe_simulation(db)
else:
logger.info("Resuming non-terminal simulation (company_id=%s, sim_time=%s).",
existing.company_id, existing.sim_time)
return existing.company_id
start_dt = _parse_date(args.start_date)
horizon_end = start_dt.replace(year=start_dt.year + horizon_years)
world = experiment_cfg.world
req = SeedWorldRequest(
run_seed=args.seed,
company_name=args.company_name,
horizon_years=horizon_years,
employee_count=world.num_employees,
market_task_count=world.num_market_tasks,
start_date=start_dt,
cfg=world,
)
logger.info(
"Initializing simulation: seed=%d employees=%d tasks=%d horizon=%dy",
args.seed, world.num_employees, world.num_market_tasks, horizon_years,
)
result = seed_world_transactional(db, req)
insert_event(
db=db,
company_id=result.company_id,
event_type=EventType.HORIZON_END,
scheduled_at=horizon_end,
payload={"reason": "horizon_end"},
dedupe_key="horizon_end",
)
db.add(SimState(
company_id=result.company_id,
sim_time=start_dt,
run_seed=args.seed,
horizon_end=horizon_end,
replenish_counter=0,
))
db.flush()
logger.info("Simulation initialized: company_id=%s", result.company_id)
return result.company_id
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def run_benchmark(args):
"""Run a full benchmark: migrate, seed, loop until terminal."""
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Load experiment config — preset name or path to a .toml file
from yc_bench.config import load_config
experiment_cfg = load_config(args.config_name)
# Propagate config name to CLI subprocess calls (e.g. task accept → generate_replacement_task)
os.environ["YC_BENCH_EXPERIMENT"] = args.config_name
# CLI --model always overrides the experiment's agent model
experiment_cfg = experiment_cfg.model_copy(
update={"agent": experiment_cfg.agent.model_copy(update={"model": args.model})}
)
# --horizon-years CLI flag overrides config; fall back to sim.horizon_years from config
horizon_years = args.horizon_years if args.horizon_years is not None else experiment_cfg.sim.horizon_years
logger.info(
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
experiment_cfg.name, args.model, args.seed, horizon_years,
)
# 1. Build engine and create all tables
# If DATABASE_URL is not explicitly set, default to db/<seed>_<slug>.db
if not os.environ.get("DATABASE_URL"):
slug = args.model.replace("/", "_")
db_dir = Path("db")
db_dir.mkdir(exist_ok=True)
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.seed}_{slug}.db"
engine = build_engine()
init_db(engine)
factory = build_session_factory(engine)
@contextmanager
def db_factory():
with session_scope(factory) as session:
yield session
# 3. Init simulation using experiment world config
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
# 4. Build runtime settings from experiment AgentConfig
agent_cfg = experiment_cfg.agent
settings = RuntimeSettings(
model=agent_cfg.model,
temperature=agent_cfg.temperature,
top_p=agent_cfg.top_p,
request_timeout_seconds=agent_cfg.request_timeout_seconds,
retry_max_attempts=agent_cfg.retry_max_attempts,
retry_backoff_seconds=agent_cfg.retry_backoff_seconds,
history_keep_rounds=agent_cfg.history_keep_rounds,
system_prompt=agent_cfg.system_prompt,
)
runtime = build_runtime(settings, command_executor=run_command)
# 5. Build run state
session_id = f"run-{args.seed}-{args.model}"
run_state = RunState(
session_id=session_id,
seed=args.seed,
model=args.model,
horizon_years=horizon_years,
)
# 6. Run agent loop
loop_cfg = experiment_cfg.loop
final_state = run_agent_loop(
runtime=runtime,
db_factory=db_factory,
company_id=company_id,
run_state=run_state,
command_executor=run_command,
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
max_turns=loop_cfg.max_turns,
)
# 7. Save full rollout (with transcript) and print summary
rollout = final_state.full_rollout()
summary = final_state.summary()
logger.info("Run complete: %s", json.dumps(summary, indent=2))
# Write full rollout (includes transcript with commands)
slug = args.model.replace("/", "_")
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
results_path = results_dir / f"yc_bench_result_{args.seed}_{slug}.json"
results_path.write_text(json.dumps(rollout, indent=2))
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
return 0 if final_state.terminal_reason != "error" else 1
def main(argv=None):
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv(usecwd=True), override=False) # searches cwd upward for .env
args = parse_run_args(argv)
return run_benchmark(args)
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -0,0 +1,33 @@
"""Session utilities for the benchmark runner."""
from __future__ import annotations
from datetime import datetime, timezone
from uuid import uuid4
from sqlalchemy.orm import Session
from ..db.models.event import EventType
from ..db.models.session import Session as SessionModel
def open_session(db: Session, company_id, wake_reason: EventType) -> SessionModel:
"""Create and persist a new session record."""
session = SessionModel(
id=uuid4(),
company_id=company_id,
started_at=datetime.now(timezone.utc),
ended_at=None,
wake_reason=wake_reason,
)
db.add(session)
db.flush()
return session
def close_session(db: Session, session: SessionModel) -> None:
"""Close an open session record."""
session.ended_at = datetime.now(timezone.utc)
db.flush()
__all__ = ["open_session", "close_session"]

View file

@ -0,0 +1,107 @@
from __future__ import annotations
from dataclasses import dataclass
from ..config.schema import WorldConfig
from ..db.models.company import Domain
from .rng import RngStreams, sample_right_skew_triangular_int
_ALL_DOMAINS = list(Domain)
@dataclass(frozen=True)
class GeneratedEmployee:
name: str
work_hours_per_day: float
salary_cents: int
tier: str
rates_by_domain: dict[str, float]
def _salary_tiers(cfg):
return (cfg.salary_junior, cfg.salary_mid, cfg.salary_senior)
def _pick_tier_name(rng, cfg):
x = rng.random()
acc = 0.0
for tier in _salary_tiers(cfg):
acc += tier.share
if acc >= x:
return tier.name
return _salary_tiers(cfg)[-1].name
def _tier_by_name(cfg, tier_name):
for tier in _salary_tiers(cfg):
if tier.name == tier_name:
return tier
raise ValueError(f"Tier {tier_name} not found")
def _sample_salary_cents(rng, cfg, tier_name):
tier = _tier_by_name(cfg, tier_name)
return sample_right_skew_triangular_int(rng, tier.min_cents, tier.max_cents)
def _sample_rates_by_domain(rng, cfg, tier_name):
tier = _tier_by_name(cfg, tier_name)
lo, hi = tier.rate_min, tier.rate_max
return {domain: round(rng.uniform(lo, hi), 4) for domain in _ALL_DOMAINS}
def generate_employees(*, run_seed, count, cfg=None):
if cfg is None:
cfg = WorldConfig()
if count <= 0:
return []
employees = []
streams = RngStreams(run_seed)
for idx in range(1, count + 1):
rng = streams.stream(f"employee_{idx}")
tier_name = _pick_tier_name(rng, cfg)
employees.append(
GeneratedEmployee(
name=f"Emp_{idx}",
work_hours_per_day=cfg.work_hours_per_day,
salary_cents=_sample_salary_cents(rng, cfg, tier_name),
tier=tier_name,
rates_by_domain=_sample_rates_by_domain(rng, cfg, tier_name),
)
)
return employees
def build_employee_rows(*, run_seed, company_id, count, cfg=None):
generated = generate_employees(run_seed=run_seed, count=count, cfg=cfg)
employee_rows = []
skill_rows = []
for emp in generated:
employee_rows.append(
{
"company_id": company_id,
"name": emp.name,
"work_hours_per_day": emp.work_hours_per_day,
"salary_cents": emp.salary_cents,
}
)
for domain, rate in emp.rates_by_domain.items():
skill_rows.append(
{
"_employee_name": emp.name,
"domain": domain,
"rate_domain_per_hour": rate,
}
)
return employee_rows, skill_rows
__all__ = [
"build_employee_rows",
"GeneratedEmployee",
"generate_employees",
]

View file

@ -0,0 +1,163 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from ..config.sampling import sample_from_spec
from ..config.schema import WorldConfig
from ..db.models.company import Domain
from .rng import RngStreams, sample_without_replacement
from .task_catalog import pick_task_text
@dataclass(frozen=True)
class GeneratedTask:
title: str
description: str
required_prestige: int
reward_funds_cents: int
reward_prestige_delta: float
skill_boost_pct: float
status: str
company_id: Any | None
accepted_at: datetime | None
deadline: datetime | None
completed_at: datetime | None
success: bool | None
halfway_event_emitted: bool
requirements: dict[str, int]
# First 10 market tasks are given explicit prestige values to guarantee a
# climbable ladder from the start (avoids runs where all early tasks need
# prestige 4+ before any are completable).
_STRATIFIED_PRESTIGE = [1, 1, 1, 1, 2, 2, 2, 3, 3, 4]
_ALL_DOMAINS = list(Domain)
def _sample_required_prestige(rng, cfg, index=None):
if index is not None and index < len(_STRATIFIED_PRESTIGE):
return _STRATIFIED_PRESTIGE[index]
return int(sample_from_spec(rng, cfg.dist.required_prestige))
def _sample_reward_funds_cents(rng, cfg, prestige=1):
base = int(sample_from_spec(rng, cfg.dist.reward_funds_cents))
# Scale reward by prestige: higher-prestige tasks pay proportionally more
return int(base * (1 + cfg.reward_prestige_scale * (prestige - 1)))
def _sample_reward_prestige_delta(rng, cfg):
return sample_from_spec(rng, cfg.dist.reward_prestige_delta)
def _sample_skill_boost_pct(rng, cfg):
return sample_from_spec(rng, cfg.dist.skill_boost)
def _sample_domain_count(rng, cfg):
return int(sample_from_spec(rng, cfg.dist.domain_count))
def _sample_required_qty(rng, cfg):
return int(sample_from_spec(rng, cfg.dist.required_qty))
def _sample_requirements(rng, cfg):
k = _sample_domain_count(rng, cfg)
picked_domains = sample_without_replacement(rng, _ALL_DOMAINS, k)
return {domain: _sample_required_qty(rng, cfg) for domain in picked_domains}
def _pick_title_desc(rng, primary_domain, serial):
title, description = pick_task_text(rng, primary_domain)
domain_str = primary_domain.value if hasattr(primary_domain, "value") else str(primary_domain)
title = f"{title} [{domain_str.upper()}-{serial}]"
return title, description
def _make_task(rng, cfg, prestige, serial, requirements):
title, description = _pick_title_desc(rng, next(iter(requirements)), serial)
return GeneratedTask(
title=title,
description=description,
required_prestige=prestige,
reward_funds_cents=_sample_reward_funds_cents(rng, cfg, prestige=prestige),
reward_prestige_delta=_sample_reward_prestige_delta(rng, cfg),
skill_boost_pct=_sample_skill_boost_pct(rng, cfg),
status="market",
company_id=None,
accepted_at=None,
deadline=None,
completed_at=None,
success=None,
halfway_event_emitted=False,
requirements=requirements,
)
def generate_tasks(*, run_seed, count, cfg=None):
if cfg is None:
cfg = WorldConfig()
if count <= 0:
return []
streams = RngStreams(run_seed)
out = []
for idx in range(1, count + 1):
rng = streams.stream(f"task_{idx}")
requirements = _sample_requirements(rng, cfg)
prestige = _sample_required_prestige(rng, cfg, index=idx - 1)
out.append(_make_task(rng, cfg, prestige, serial=idx, requirements=requirements))
return out
def build_task_rows(*, run_seed, count, cfg=None):
generated = generate_tasks(run_seed=run_seed, count=count, cfg=cfg)
task_rows = []
requirement_rows = []
for task in generated:
task_rows.append({
"title": task.title,
"description": task.description,
"required_prestige": task.required_prestige,
"reward_funds_cents": task.reward_funds_cents,
"reward_prestige_delta": task.reward_prestige_delta,
"skill_boost_pct": task.skill_boost_pct,
"status": task.status,
"company_id": task.company_id,
"accepted_at": task.accepted_at,
"deadline": task.deadline,
"completed_at": task.completed_at,
"success": task.success,
"halfway_event_emitted": task.halfway_event_emitted,
})
for domain, qty in task.requirements.items():
requirement_rows.append({
"_task_title": task.title,
"domain": domain,
"required_qty": qty,
"completed_qty": 0,
})
return task_rows, requirement_rows
def generate_replacement_task(*, run_seed, replenish_counter, cfg=None):
if cfg is None:
cfg = WorldConfig()
streams = RngStreams(run_seed)
rng = streams.stream(f"replenish_{replenish_counter}")
requirements = _sample_requirements(rng, cfg)
prestige = _sample_required_prestige(rng, cfg)
return _make_task(rng, cfg, prestige, serial=replenish_counter, requirements=requirements)
__all__ = [
"build_task_rows",
"generate_replacement_task",
"generate_tasks",
"GeneratedTask",
]

View file

@ -0,0 +1,42 @@
from __future__ import annotations
import hashlib
import random
from dataclasses import dataclass
def _stable_seed(run_seed, stream_key):
raw = f"{run_seed}:{stream_key}".encode("utf-8")
digest = hashlib.sha256(raw).digest()
return int.from_bytes(digest[:8], byteorder="big", signed=False)
@dataclass(frozen=True)
class RngStreams:
run_seed: int
def stream(self, stream_key):
return random.Random(_stable_seed(self.run_seed, stream_key))
def clamp_numeric(value, low, high):
return max(low, min(high, value))
def sample_triangular_int(rng, low, high, mode):
return int(clamp_numeric((round(rng.triangular(low, high, mode))), low, high))
def sample_right_skew_triangular_int(rng, low, high):
return sample_triangular_int(rng, low, high, high)
def sample_normal_clamped_float(rng, mean, stdev, low, high):
val = clamp_numeric(rng.gauss(mean, stdev), low, high)
return round(val, 4)
def sample_beta_scaled(rng, alpha: float, beta: float, scale: float) -> float:
"""Sample from Beta(alpha, beta) multiplied by scale. Used for prestige reward deltas."""
return round(scale * rng.betavariate(alpha, beta), 4)
def sample_left_skew_0_2(rng):
"""Backward-compat alias using default Beta(1.2, 2.8) * 2.0 params."""
return sample_beta_scaled(rng, 1.2, 2.8, 2.0)
def sample_without_replacement(rng, population, k):
return rng.sample(list(population), k)

View file

@ -0,0 +1,143 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from uuid import uuid4
from sqlalchemy.orm import Session
from ..config.schema import WorldConfig
from ..db.models.company import Company, CompanyPrestige, Domain
from ..db.models.employee import Employee, EmployeeSkillRate
from ..db.models.task import Task, TaskRequirement, TaskStatus
from .generate_employees import generate_employees
from .generate_tasks import generate_tasks
_ALL_DOMAINS = list(Domain)
@dataclass(frozen=True)
class SeedWorldRequest:
run_seed: int
company_name: str
horizon_years: int
employee_count: int
market_task_count: int
start_date: datetime | None = None
cfg: WorldConfig = field(default_factory=WorldConfig)
@dataclass(frozen=True)
class SeedWorldResult:
company_id: str
seeded_at: datetime
def _seed_company(db, req):
company = Company(
id=uuid4(),
name=req.company_name,
funds_cents=req.cfg.initial_funds_cents,
)
db.add(company)
db.flush()
return company
def _seed_company_prestige(db, company, cfg):
for domain in _ALL_DOMAINS:
db.add(
CompanyPrestige(
company_id=company.id,
domain=domain,
prestige_level=cfg.initial_prestige_level,
)
)
def _seed_employees(db, company, req):
generated = generate_employees(run_seed=req.run_seed, count=req.employee_count, cfg=req.cfg)
for emp in generated:
employee = Employee(
id=uuid4(),
company_id=company.id,
name=emp.name,
work_hours_per_day=emp.work_hours_per_day,
salary_cents=emp.salary_cents,
)
db.add(employee)
for domain, rate in emp.rates_by_domain.items():
db.add(
EmployeeSkillRate(
employee_id=employee.id,
domain=domain,
rate_domain_per_hour=rate,
)
)
def _seed_market_tasks(db, company, req):
generated = generate_tasks(run_seed=req.run_seed, count=req.market_task_count, cfg=req.cfg)
for task in generated:
task_row = Task(
id=uuid4(),
company_id=None,
status=TaskStatus.MARKET,
title=task.title,
description=task.description,
required_prestige=task.required_prestige,
reward_funds_cents=task.reward_funds_cents,
reward_prestige_delta=task.reward_prestige_delta,
skill_boost_pct=task.skill_boost_pct,
accepted_at=None,
deadline=None,
completed_at=None,
success=None,
halfway_event_emitted=False,
)
db.add(task_row)
for domain, qty in task.requirements.items():
db.add(
TaskRequirement(
task_id=task_row.id,
domain=domain,
required_qty=qty,
completed_qty=0,
)
)
def seed_world(db, req):
if req.employee_count <= 0:
raise ValueError("employee_count must be positive")
if req.market_task_count <= 0:
raise ValueError("market_task_count must be positive")
seeded_at = req.start_date
company = _seed_company(db, req)
_seed_company_prestige(db, company, req.cfg)
_seed_employees(db, company, req)
_seed_market_tasks(db, company, req)
return SeedWorldResult(
company_id=company.id,
seeded_at=seeded_at,
)
def seed_world_transactional(db, req):
result = seed_world(db, req)
db.flush()
return result
__all__ = [
"SeedWorldRequest",
"SeedWorldResult",
"seed_world",
"seed_world_transactional",
]

View file

@ -0,0 +1,365 @@
"""Realistic AI-startup task titles and descriptions, keyed by domain.
Each domain has a pool of (title, description) tuples. The generator picks
from these deterministically using the seeded RNG, cycling if the pool is
exhausted.
"""
from __future__ import annotations
from ..db.models.company import Domain
TASK_POOL: dict[Domain, list[tuple[str, str]]] = {
Domain.SYSTEM: [
(
"Set Up GPU-Aware K8s Cluster with Auto-Scaling",
"Deploy a Kubernetes cluster with NVIDIA GPU operator, node auto-scaling based on inference queue depth, and spot instance fallback for training workloads.",
),
(
"Build CI/CD Pipeline for ML Model Registry",
"Create a CI pipeline that runs training validation, pushes versioned model artifacts to a registry, and auto-deploys to a staging inference endpoint.",
),
(
"Implement Blue-Green Deployment for LLM Serving",
"Set up zero-downtime model swaps for a vLLM serving cluster with automated rollback triggered by latency and error-rate thresholds.",
),
(
"Deploy Observability Stack for AI Workloads",
"Stand up Grafana, Prometheus, and OpenTelemetry with custom dashboards tracking GPU utilization, token throughput, time-to-first-token, and per-request cost.",
),
(
"Terraform Multi-Region Inference Infrastructure",
"Write IaC modules to provision inference endpoints across 3+ regions with global load balancing, failover routing, and centralized logging.",
),
(
"Container Image Optimization for ML Serving",
"Reduce Docker image sizes for PyTorch/CUDA serving containers from 15 GB to under 4 GB using multi-stage builds and distroless bases to cut cold-start times.",
),
(
"Implement Secret Rotation and API Key Management",
"Build an automated secret rotation system for API keys, database credentials, and model provider tokens across staging and production environments.",
),
(
"Set Up Cost Monitoring and GPU Budget Alerts",
"Integrate cloud billing APIs with a dashboard showing per-team GPU spend, cost-per-inference breakdowns, and automated alerts when daily spend exceeds thresholds.",
),
(
"Build Canary Release Pipeline for Embedding Models",
"Implement a canary deployment system that gradually shifts traffic to new embedding model versions, comparing retrieval quality metrics in real time.",
),
(
"Migrate Inference Workloads to Serverless GPU",
"Evaluate and migrate bursty inference workloads to serverless GPU providers, benchmarking cold-start latency against always-on instances.",
),
(
"Implement Disaster Recovery for Training Checkpoints",
"Design a cross-region checkpoint backup system with automated integrity verification, ensuring training runs can resume within 15 minutes of any single-region failure.",
),
(
"Build Internal Developer Platform for ML Engineers",
"Create a self-service portal where ML engineers can request GPU instances, spin up Jupyter environments, and launch training jobs without touching infrastructure.",
),
],
Domain.RESEARCH: [
(
"Design Benchmark for Legal Document QA",
"Create a benchmark suite of 2,000+ annotated legal questions across contract law and compliance, with human-expert baselines and an automated evaluation harness.",
),
(
"Investigate MoE Routing for Multilingual Models",
"Research and prototype alternative Mixture-of-Experts routing strategies that improve expert utilization for low-resource languages without degrading high-resource performance.",
),
(
"Reproduce and Extend Speculative Decoding Results",
"Replicate speculative decoding paper results on Llama-3 class models, then test novel draft model architectures that improve acceptance rates on code generation.",
),
(
"Develop RAG Hallucination Detection Framework",
"Build a systematic evaluation pipeline measuring faithfulness, relevance, and attribution accuracy for retrieval-augmented generation systems.",
),
(
"Prototype LoRA Merging for Multi-Tenant Serving",
"Research methods for dynamically composing multiple LoRA adapters at inference time, measuring quality degradation versus serving separate fine-tuned models.",
),
(
"Benchmark Long-Context Retrieval Across 128K Models",
"Systematically evaluate needle-in-a-haystack and multi-hop reasoning performance across frontier models at various context lengths with reproducible results.",
),
(
"Investigate Synthetic Data Quality for Code Generation",
"Develop automated quality scoring methods for synthetically generated code training data, correlating filter thresholds with downstream model performance.",
),
(
"Research KV-Cache Compression Techniques",
"Prototype and benchmark KV-cache eviction and quantization strategies for long-running conversational agents under fixed memory budgets.",
),
(
"Build Ablation Study Framework for Prompt Engineering",
"Create an experimentation harness for testing prompt variations across multiple models and tasks with statistical significance testing and cost tracking.",
),
(
"Explore Constitutional AI for Domain-Specific Safety",
"Adapt constitutional AI methods to create a self-improving safety filter for a healthcare chatbot, defining domain-specific principles and measuring accuracy.",
),
(
"Develop Novel Chunking Strategies for Technical RAG",
"Research and benchmark alternative document chunking methods—semantic, AST-aware, sliding window—specifically for API documentation and code repositories.",
),
(
"Prototype Test-Time Compute Scaling for Math Reasoning",
"Implement best-of-N sampling, tree search, and self-verification approaches for math reasoning, measuring the compute-accuracy Pareto frontier.",
),
],
Domain.DATA: [
(
"Build Web Scraping Pipeline for Industry News Corpus",
"Design a pipeline that crawls 50+ AI/tech news sources daily, deduplicates articles, extracts structured metadata, and loads clean text into a vector store.",
),
(
"Create Annotation Platform for Dialogue Quality",
"Build an annotation workflow where human raters score LLM conversation logs on helpfulness, accuracy, and safety, with inter-rater agreement tracking.",
),
(
"Implement PII Detection and Redaction Pipeline",
"Deploy a pipeline to detect and redact personally identifiable information from training data, with audit logging and configurable redaction strategies.",
),
(
"Curate Instruction-Tuning Dataset from Internal Docs",
"Extract, clean, and convert 10,000+ pages of internal documentation into high-quality instruction-response pairs suitable for fine-tuning.",
),
(
"Build Data Quality Monitoring for Feature Store",
"Implement data validation checks on streaming feature pipelines, alerting on schema drift, null-rate spikes, and distribution shifts before they affect models.",
),
(
"Design ETL Pipeline for Multi-Modal Training Data",
"Build a DAG pipeline that ingests images, PDFs, and structured data, applies OCR and layout detection, and produces unified records for vision-language training.",
),
(
"Implement Deduplication for Large Text Corpora",
"Deploy MinHash LSH-based near-deduplication at scale for 100M+ documents with configurable similarity thresholds and a review UI for borderline cases.",
),
(
"Build Synthetic Data Pipeline for Rare Edge Cases",
"Create a system that uses frontier LLMs to generate realistic synthetic examples for underrepresented categories in a classification dataset.",
),
(
"Create Data Versioning and Lineage Tracking System",
"Set up data versioning integrated with the ML training pipeline so every model checkpoint can be traced back to the exact dataset snapshot used.",
),
(
"Build Customer Feedback Loop into Training Pipeline",
"Implement a system where end-user thumbs-up/down signals are routed, reviewed, and selectively incorporated into fine-tuning datasets with human approval.",
),
(
"Migrate Legacy Warehouse to ML-Ready Lakehouse",
"Transform and migrate 5 years of product analytics data from a legacy SQL warehouse into a Parquet-based lakehouse optimized for feature engineering.",
),
],
Domain.FRONTEND: [
(
"Build Interactive LLM Playground with Streaming",
"Create a web app where users test multiple LLM providers side-by-side with streaming output, adjustable parameters, and conversation history persistence.",
),
(
"Design Admin Dashboard for AI Agent Monitoring",
"Build a dashboard showing real-time agent execution traces, tool call sequences, token usage graphs, and cost breakdowns with drill-down filtering.",
),
(
"Create Document Chat Interface for RAG Product",
"Implement a drag-and-drop document upload UI with a conversational interface showing source citations, confidence indicators, and reference highlighting.",
),
(
"Build Annotation Review and Approval Interface",
"Design a UI for data team leads to review annotator work, resolve disagreements, view agreement stats, and approve batches for training inclusion.",
),
(
"Implement Prompt Management Studio",
"Build a collaborative app where teams version, test, and A/B deploy prompt templates with visual diffs, rollback, and per-version performance analytics.",
),
(
"Create Customer-Facing AI Usage Analytics Dashboard",
"Build an embeddable dashboard showing API call volumes, latency percentiles, token consumption, and cost trends for enterprise customers.",
),
(
"Build Visual Pipeline Editor for No-Code AI Workflows",
"Create a node-based drag-and-drop editor where non-technical users chain data sources, LLM calls, and output actions into automated AI workflows.",
),
(
"Design Chat Widget for Website Embedding",
"Build a lightweight, brandable chat widget under 50 KB that customers embed on their sites, with streaming responses and escalation-to-human capability.",
),
(
"Build Model Comparison Results Viewer",
"Create a web interface displaying benchmark results across models in interactive tables and charts with filtering by task type and model size.",
),
(
"Implement Real-Time Collaboration for AI Writing Tool",
"Add multiplayer editing to an AI writing tool using CRDTs, with per-user cursors, AI suggestion tracking, and version history.",
),
(
"Create Enterprise RAG Onboarding Wizard",
"Build a step-by-step setup wizard guiding enterprise customers through connecting data sources, configuring chunking, testing retrieval, and deploying their endpoint.",
),
],
Domain.BACKEND: [
(
"Build Multi-Tenant LLM Gateway with Rate Limiting",
"Implement an API gateway that proxies requests to multiple LLM providers, enforces per-tenant rate limits, tracks usage, and handles automatic failover.",
),
(
"Implement OAuth2 + SAML SSO for Enterprise Platform",
"Add enterprise authentication supporting SAML 2.0, OIDC, and SCIM provisioning for customers integrating with their identity provider.",
),
(
"Design Webhook System for Async AI Job Completion",
"Build a reliable webhook delivery system with exponential backoff, signature verification, dead letter queue, and a webhook management API.",
),
(
"Create Unified Embedding API with Caching Layer",
"Build a microservice abstracting over multiple embedding providers with a Redis-backed cache, batch processing, and automatic model version migration.",
),
(
"Build Conversation Memory Service for Multi-Session Agents",
"Implement a service that stores, summarizes, and retrieves conversation history across sessions using structured storage and semantic vector search.",
),
(
"Implement Usage-Based Billing with Stripe Integration",
"Build a metering system that tracks token consumption per customer, aggregates monthly invoices, and syncs with Stripe for automated usage-based charging.",
),
(
"Create Plugin Marketplace Backend",
"Design the API and data model for a marketplace where third-party developers register, version, and distribute plugins for the AI platform.",
),
(
"Build RAG Ingestion Service with Chunking and Indexing",
"Implement an async document processing service that accepts PDFs, DOCX, and HTML, chunks them, generates embeddings, and upserts into a vector store.",
),
(
"Implement Audit Logging and Compliance API",
"Build a tamper-evident audit log system recording all AI interactions and admin actions, with an API for compliance queries and SOC 2 / HIPAA exports.",
),
(
"Design Multi-Model Routing and Fallback Service",
"Create a smart routing layer directing requests to the optimal model based on task complexity, latency requirements, and cost, with provider failover.",
),
(
"Build File Processing Service for Vision-Language Models",
"Implement an async service that accepts images and documents, runs them through vision-language models for extraction, and returns structured JSON output.",
),
(
"Implement Streaming API with Server-Sent Events",
"Build an SSE-based streaming endpoint for LLM responses with connection resumption, partial response caching, and graceful degradation.",
),
],
Domain.TRAINING: [
(
"Fine-Tune Llama-3 8B for Domain-Specific Support",
"Run supervised fine-tuning on 50K curated customer support conversations using QLoRA, targeting 15% accuracy improvement over the base model.",
),
(
"Implement RLHF Pipeline for Code Generation Model",
"Build an end-to-end RLHF pipeline with a reward model trained on human preference data and PPO training loop evaluated against HumanEval.",
),
(
"Distill GPT-4 Class Model into Efficient 3B Model",
"Use knowledge distillation with synthetic data to create a compact model retaining 90%+ teacher performance on targeted tasks at 10x lower inference cost.",
),
(
"Train Custom Embedding Model for Vertical Search",
"Fine-tune a sentence-transformers model on domain-specific query-document pairs with contrastive learning, hard negative mining, and retrieval benchmarks.",
),
(
"Build Hyperparameter Search for Fine-Tuning Jobs",
"Implement an Optuna-based HPO system searching over learning rate, LoRA rank, batch size, and data mixing ratios with early stopping.",
),
(
"Run Continued Pre-Training on Proprietary Corpus",
"Execute continued pre-training of a 7B base model on 10B tokens of domain-specific text with careful learning rate scheduling to avoid catastrophic forgetting.",
),
(
"Train Reward Model from Preference Annotations",
"Collect and process 20K pairwise preference annotations, train a Bradley-Terry reward model, and validate calibration against held-out human judgments.",
),
(
"Build Multi-GPU Training Infra with DeepSpeed",
"Set up distributed training using DeepSpeed ZeRO Stage 3 across an 8-node GPU cluster with checkpoint sharding and fault-tolerant resumption.",
),
(
"Implement DPO Fine-Tuning Pipeline",
"Build a Direct Preference Optimization pipeline as a simpler RLHF alternative, comparing quality and training stability on the same preference dataset.",
),
(
"Train Vision-Language Adapter for Document Understanding",
"Fine-tune a LoRA adapter on a VLM for extracting structured data from invoices, receipts, and forms with 95%+ field-level accuracy.",
),
(
"Build Eval-Driven Training Loop with Auto Checkpointing",
"Implement a training harness that runs benchmarks every N steps, auto-saves the best checkpoint, detects instability, and alerts on loss spikes.",
),
(
"Fine-Tune Whisper for Industry-Specific Transcription",
"Adapt Whisper-large for medical dictation using 500 hours of labeled audio, targeting 30% WER reduction on domain-specific terminology.",
),
],
Domain.HARDWARE: [
(
"Optimize LLM Inference Latency with TensorRT-LLM",
"Convert a 70B model to TensorRT-LLM with INT8/FP8 quantization, continuous batching, and paged attention, targeting sub-200ms time-to-first-token.",
),
(
"Deploy On-Device ML Model for Mobile Classification",
"Convert a PyTorch vision model to Core ML and TFLite, optimize with quantization-aware training, and benchmark on iPhone and Pixel hardware.",
),
(
"Build GPU Cluster Scheduling with Fair-Share Queuing",
"Implement a scheduler for a shared GPU cluster enforcing per-team quotas, priority queuing, preemption policies, and utilization-based chargeback.",
),
(
"Implement Quantization Pipeline (GPTQ/AWQ/GGUF)",
"Build an automated pipeline that takes any model, produces GPTQ, AWQ, and GGUF quantized variants, runs quality regression, and publishes passing models.",
),
(
"Deploy Edge Inference for Real-Time Video Analytics",
"Set up an NVIDIA Jetson-based inference node running YOLO and a lightweight LLM for on-premises real-time camera analysis with local data processing.",
),
(
"Optimize vLLM Serving for Production Workload",
"Profile and tune vLLM parameters—max batch size, KV cache, swap space, tensor parallelism—for target throughput at P99 latency SLA.",
),
(
"Build Multi-GPU Inference with Tensor Parallelism",
"Configure and benchmark a 70B+ model serving across 4-8 GPUs with tensor and pipeline parallelism, optimizing throughput versus latency tradeoffs.",
),
(
"Implement Dynamic Batching for Inference Requests",
"Build a request batching layer that groups incoming requests by sequence length and priority, maximizing GPU utilization within per-request latency SLAs.",
),
(
"Design Hybrid CPU/GPU Inference Architecture",
"Architect a system routing lightweight requests to CPU inference and complex requests to GPU instances, reducing overall compute cost by 40%.",
),
(
"Set Up Triton Inference Server for Multi-Model Serving",
"Deploy NVIDIA Triton to serve embedding, reranking, and generation models on shared GPU infrastructure with dynamic batching and concurrency control.",
),
(
"Build GPU Health Monitoring and Failover System",
"Implement a daemon detecting GPU memory errors, thermal throttling, and NVLink degradation, automatically draining affected nodes and redistributing workloads.",
),
(
"Benchmark Specialized AI Accelerators vs H100",
"Evaluate Groq, Cerebras, and custom ASICs against H100 GPUs, producing a cost-per-token and latency comparison with a migration recommendation.",
),
(
"Implement Speculative Decoding in Production Stack",
"Integrate speculative decoding with a small draft model into the existing serving infrastructure, measuring real-world throughput improvement.",
),
],
}
def pick_task_text(rng, domain: Domain) -> tuple[str, str]:
"""Deterministically pick a (title, description) for *domain* using *rng*."""
pool = TASK_POOL[domain]
idx = rng.randint(0, len(pool) - 1)
return pool[idx]