mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-05-02 17:45:54 +00:00
Initial commit
This commit is contained in:
commit
3a1c562827
59 changed files with 8418 additions and 0 deletions
3
src/yc_bench/__init__.py
Normal file
3
src/yc_bench/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""YC-Bench: long-horizon deterministic benchmark for LLM agents."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
3
src/yc_bench/__main__.py
Normal file
3
src/yc_bench/__main__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from yc_bench.cli import app_main
|
||||
|
||||
app_main()
|
||||
86
src/yc_bench/agent/commands/executor.py
Normal file
86
src/yc_bench/agent/commands/executor.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from .policy import parse_bench_command
|
||||
|
||||
def _resolve_yc_bench() -> str:
|
||||
"""Find the yc-bench script in the same directory as the running Python."""
|
||||
venv_bin = Path(sys.executable).parent
|
||||
candidate = venv_bin / "yc-bench"
|
||||
if candidate.exists():
|
||||
return str(candidate)
|
||||
return "yc-bench" # fallback to PATH lookup
|
||||
|
||||
def run_command(
|
||||
command,
|
||||
*,
|
||||
cwd=None,
|
||||
timeout_seconds=60.0,
|
||||
command_exists = None,
|
||||
):
|
||||
ok, err, argv = parse_bench_command(command)
|
||||
if not ok:
|
||||
return {
|
||||
"ok": False,
|
||||
"exit_code": 2,
|
||||
"stdout": "",
|
||||
"stderr": err or "invalid command",
|
||||
"sim_time": None,
|
||||
"command": command if isinstance(command, str) else str(command),
|
||||
}
|
||||
|
||||
if command_exists is not None and not command_exists(argv):
|
||||
return {
|
||||
"ok": False,
|
||||
"exit_code": 127,
|
||||
"stdout": "",
|
||||
"stderr": f"command not found: {' '.join(argv)}",
|
||||
"sim_time": None,
|
||||
"command": command,
|
||||
}
|
||||
|
||||
# Resolve yc-bench to the venv-local script
|
||||
if argv[0] == "yc-bench":
|
||||
argv[0] = _resolve_yc_bench()
|
||||
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
argv,
|
||||
shell=False,
|
||||
cwd=cwd,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
return {
|
||||
"ok": proc.returncode == 0,
|
||||
"exit_code": proc.returncode,
|
||||
"stdout": proc.stdout,
|
||||
"stderr": proc.stderr,
|
||||
"sim_time": None,
|
||||
"command": command
|
||||
}
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
return {
|
||||
"ok": False,
|
||||
"exit_code": 124,
|
||||
"stdout": exc.stdout or "",
|
||||
"stderr": f"command timed out after {timeout_seconds} seconds",
|
||||
"sim_time": None,
|
||||
"command": command,
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"ok": False,
|
||||
"exit_code": 1,
|
||||
"stdout": "",
|
||||
"stderr": str(exc),
|
||||
"sim_time": None,
|
||||
"command": command,
|
||||
}
|
||||
|
||||
__all__ = ["run_command"]
|
||||
22
src/yc_bench/agent/commands/policy.py
Normal file
22
src/yc_bench/agent/commands/policy.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import shlex
|
||||
|
||||
def parse_bench_command(command):
|
||||
if not isinstance(command, str) or not command.strip():
|
||||
return False, "command must be a non-empty string", None
|
||||
|
||||
try:
|
||||
argv = shlex.split(command)
|
||||
except ValueError as exc:
|
||||
return False, f"invalid command syntax: {command}", None
|
||||
|
||||
if not argv:
|
||||
return False, "command must be a non-empty string", None
|
||||
|
||||
if argv[0] != "yc-bench":
|
||||
return False, "only top-level `yc-bench` commands are allowed", None
|
||||
|
||||
return True, None, argv
|
||||
|
||||
__all__ = ["parse_bench_command"]
|
||||
213
src/yc_bench/agent/loop.py
Normal file
213
src/yc_bench/agent/loop.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.models.company import Company
|
||||
from ..db.models.employee import Employee
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.task import Task, TaskStatus
|
||||
from .prompt import build_initial_user_prompt, build_turn_context
|
||||
from .run_state import RunState, TerminalReason
|
||||
from .runtime.base import AgentRuntime
|
||||
from .runtime.schemas import RuntimeTurnRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _snapshot_state(db: Session, company_id):
|
||||
"""Read current simulation state from DB for context building."""
|
||||
sim_state = db.query(SimState).filter(SimState.company_id == company_id).one()
|
||||
company = db.query(Company).filter(Company.id == company_id).one()
|
||||
|
||||
active_count = db.query(func.count(Task.id)).filter(
|
||||
Task.company_id == company_id,
|
||||
Task.status == TaskStatus.ACTIVE,
|
||||
).scalar() or 0
|
||||
|
||||
planned_count = db.query(func.count(Task.id)).filter(
|
||||
Task.company_id == company_id,
|
||||
Task.status == TaskStatus.PLANNED,
|
||||
).scalar() or 0
|
||||
|
||||
employee_count = db.query(func.count(Employee.id)).filter(
|
||||
Employee.company_id == company_id,
|
||||
).scalar() or 0
|
||||
|
||||
monthly_payroll = db.query(func.sum(Employee.salary_cents)).filter(
|
||||
Employee.company_id == company_id,
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
"sim_time": sim_state.sim_time.isoformat(),
|
||||
"horizon_end": sim_state.horizon_end.isoformat(),
|
||||
"funds_cents": company.funds_cents,
|
||||
"active_tasks": active_count,
|
||||
"planned_tasks": planned_count,
|
||||
"employee_count": employee_count,
|
||||
"monthly_payroll_cents": int(monthly_payroll),
|
||||
"bankrupt": company.funds_cents < 0,
|
||||
}
|
||||
|
||||
|
||||
def _extract_commands(raw_result) -> list[str]:
|
||||
"""Extract CLI commands the agent executed from the raw_result dict.
|
||||
|
||||
raw_result is {"tool_calls": [{"command": "yc-bench ...", "result": "..."}, ...]}.
|
||||
"""
|
||||
commands = []
|
||||
try:
|
||||
for tc in (raw_result or {}).get("tool_calls", []):
|
||||
cmd = tc.get("command", "")
|
||||
result = tc.get("result", "")
|
||||
if cmd:
|
||||
commands.append(f"{cmd} -> {result[:500]}" if result else cmd)
|
||||
except Exception as exc:
|
||||
logger.debug("Could not extract commands from raw_result: %s", exc)
|
||||
return commands
|
||||
|
||||
|
||||
def _auto_resume(command_executor) -> tuple[dict | None, str | None]:
|
||||
"""Force-advance simulation time via sim resume. Returns (payload, error_msg)."""
|
||||
try:
|
||||
raw = command_executor("yc-bench sim resume")
|
||||
stdout = raw.get("stdout", "")
|
||||
if isinstance(stdout, str) and stdout.strip():
|
||||
payload = json.loads(stdout)
|
||||
if isinstance(payload, dict):
|
||||
return payload, None
|
||||
return None, raw.get("stderr", "sim resume returned no payload")
|
||||
except Exception as exc:
|
||||
return None, str(exc)
|
||||
|
||||
|
||||
def _build_resume_handoff_user_input(payload: dict) -> str:
|
||||
"""Build next-turn user message from sim resume payload."""
|
||||
wake_events = payload.get("wake_events") or []
|
||||
lines = [
|
||||
"Simulation advanced via `yc-bench sim resume`.",
|
||||
f"- old_sim_time: {payload.get('old_sim_time')}",
|
||||
f"- new_sim_time: {payload.get('new_sim_time')}",
|
||||
f"- checkpoint_event_type: {payload.get('checkpoint_event_type')}",
|
||||
f"- events_processed: {payload.get('events_processed')}",
|
||||
f"- payrolls_applied: {payload.get('payrolls_applied')}",
|
||||
f"- balance_delta: {payload.get('balance_delta')}",
|
||||
f"- bankrupt: {payload.get('bankrupt')}",
|
||||
f"- horizon_reached: {payload.get('horizon_reached')}",
|
||||
f"- terminal_reason: {payload.get('terminal_reason')}",
|
||||
f"- wake_events_count: {len(wake_events)}",
|
||||
]
|
||||
for ev in wake_events:
|
||||
lines.append(f"- wake_event: {ev}")
|
||||
lines.append(
|
||||
"You are now at the new checkpoint. Query live state via yc-bench commands and decide next actions."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def run_agent_loop(
|
||||
runtime: AgentRuntime,
|
||||
db_factory,
|
||||
company_id,
|
||||
run_state: RunState,
|
||||
command_executor=None,
|
||||
auto_advance_after_turns: int = 10,
|
||||
max_turns: int | None = None,
|
||||
) -> RunState:
|
||||
run_state.start()
|
||||
turns_since_resume = 0 # consecutive turns without sim resume
|
||||
|
||||
logger.info(
|
||||
"Starting agent loop: model=%s seed=%d auto_advance_after=%d turns max_turns=%s",
|
||||
run_state.model, run_state.seed, auto_advance_after_turns, max_turns or "unlimited",
|
||||
)
|
||||
|
||||
while not run_state.terminal:
|
||||
if max_turns is not None and run_state.turn_count >= max_turns:
|
||||
logger.info("Reached max_turns=%d, stopping.", max_turns)
|
||||
run_state.mark_terminal(TerminalReason.ERROR, f"max_turns={max_turns} reached")
|
||||
break
|
||||
turn_num = run_state.turn_count + 1
|
||||
if run_state.turn_count == 0:
|
||||
with db_factory() as db:
|
||||
snapshot = _snapshot_state(db, company_id)
|
||||
user_input = build_initial_user_prompt(**snapshot)
|
||||
elif run_state.next_user_input is not None:
|
||||
user_input = run_state.next_user_input
|
||||
run_state.next_user_input = None
|
||||
else:
|
||||
with db_factory() as db:
|
||||
snapshot = _snapshot_state(db, company_id)
|
||||
user_input = build_turn_context(
|
||||
turn_number=turn_num,
|
||||
**snapshot,
|
||||
)
|
||||
|
||||
try:
|
||||
result = runtime.run_turn(
|
||||
RuntimeTurnRequest(
|
||||
session_id=run_state.session_id,
|
||||
user_input=user_input,
|
||||
)
|
||||
)
|
||||
agent_output = result.final_output
|
||||
except Exception as exc:
|
||||
logger.error("Runtime error on turn %d: %s", turn_num, exc)
|
||||
run_state.mark_terminal(TerminalReason.ERROR, str(exc))
|
||||
break
|
||||
|
||||
commands_executed = _extract_commands(result.raw_result)
|
||||
|
||||
resume_payload = result.resume_payload
|
||||
if result.checkpoint_advanced and resume_payload is not None:
|
||||
logger.info("Turn %d: agent called sim resume.", turn_num)
|
||||
turns_since_resume = 0
|
||||
else:
|
||||
turns_since_resume += 1
|
||||
if command_executor is not None and turns_since_resume >= auto_advance_after_turns:
|
||||
logger.info(
|
||||
"Turn %d: %d consecutive turns without sim resume; auto-advancing.",
|
||||
turn_num, turns_since_resume,
|
||||
)
|
||||
resume_payload, err = _auto_resume(command_executor)
|
||||
if err:
|
||||
logger.warning("Auto-resume failed on turn %d: %s", turn_num, err)
|
||||
else:
|
||||
turns_since_resume = 0
|
||||
|
||||
if resume_payload is not None:
|
||||
run_state.next_user_input = _build_resume_handoff_user_input(resume_payload)
|
||||
reason = resume_payload.get("terminal_reason")
|
||||
if reason == "bankruptcy":
|
||||
run_state.mark_terminal(TerminalReason.BANKRUPTCY, reason)
|
||||
elif reason == "horizon_end":
|
||||
run_state.mark_terminal(TerminalReason.HORIZON_END, reason)
|
||||
if run_state.terminal:
|
||||
logger.info("Terminal after turn %d: %s", turn_num, reason)
|
||||
|
||||
run_state.record_turn(
|
||||
user_input=user_input,
|
||||
agent_output=agent_output,
|
||||
commands_executed=commands_executed,
|
||||
turn_cost_usd=getattr(result, "turn_cost_usd", 0.0),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Turn %d complete. Agent output length: %d, commands: %d",
|
||||
turn_num, len(agent_output), len(commands_executed),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Agent loop finished: turns=%d terminal=%s reason=%s",
|
||||
run_state.turn_count,
|
||||
run_state.terminal,
|
||||
run_state.terminal_reason,
|
||||
)
|
||||
|
||||
return run_state
|
||||
|
||||
|
||||
__all__ = ["run_agent_loop"]
|
||||
185
src/yc_bench/agent/prompt.py
Normal file
185
src/yc_bench/agent/prompt.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""System prompt and user-message builders for the YC-Bench agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an AI startup in a deterministic business simulation. \
|
||||
Your goal is to maximize company prestige and funds over the simulation horizon while avoiding bankruptcy.
|
||||
|
||||
## How It Works
|
||||
|
||||
- All actions are performed via the `run_command` tool, which executes `yc-bench` CLI commands.
|
||||
- All commands return JSON. Parse the output to make decisions.
|
||||
- Simulation progression and event processing are managed by the benchmark runtime.
|
||||
- Business hours are weekdays 09:00-18:00. Nights, weekends, and Feb 29 are skipped.
|
||||
- Payroll is deducted automatically on the first business day of each month.
|
||||
- If funds go below zero after any event or payroll, the company goes bankrupt and the run ends.
|
||||
|
||||
## Available Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` — funds, prestige, employee count, payroll, bankruptcy risk
|
||||
- `yc-bench employee list` — list all employees with IDs, salaries, skill rates, and current assignments
|
||||
- `yc-bench market browse [--domain X] [--required-prestige-lte N] [--reward-min-cents N] [--limit N] [--offset N]` — browse available tasks (default limit 50; the response includes a `total` field — if total > 50, paginate with --offset to see more)
|
||||
- `yc-bench task list [--status X]` — list your tasks (planned, active, completed, cancelled)
|
||||
- `yc-bench task inspect --task-id <UUID>` — detailed task info (requirements, assignments, progress)
|
||||
- `yc-bench finance ledger [--from MM/DD/YYYY] [--to MM/DD/YYYY] [--category X]` — financial history
|
||||
- `yc-bench report monthly [--from-month YYYY-MM] [--to-month YYYY-MM]` — monthly P&L
|
||||
- `yc-bench scratchpad read` — read your persistent notes
|
||||
|
||||
### Memory (scratchpad)
|
||||
- `yc-bench scratchpad write --content "text"` — overwrite scratchpad with new notes
|
||||
- `yc-bench scratchpad append --content "text"` — append a line to existing notes
|
||||
- `yc-bench scratchpad clear` — erase all notes
|
||||
- Use the scratchpad to store key decisions, task deadlines, employee assignments, and strategy notes. Context is periodically truncated — anything important should be written here.
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id <UUID>` — accept a market task (sets deadline, generates replacement)
|
||||
- `yc-bench task assign --task-id <UUID> --employee-id <UUID>` — assign employee to task
|
||||
- `yc-bench task dispatch --task-id <UUID>` — start work on a planned task (must have assignments)
|
||||
- `yc-bench task cancel --task-id <UUID> --reason "text"` — cancel a task (prestige penalty: 1.2x reward delta)
|
||||
- `yc-bench sim resume` — advance simulation to the next checkpoint event and return wake events
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Check company status first** to understand your financial position and runway.
|
||||
2. **Browse the market** for tasks you can accept (check prestige requirements).
|
||||
3. **Accept tasks** that match your capabilities and offer good reward-to-risk ratio.
|
||||
4. **Assign employees strategically** — employees split throughput across active tasks. Focus employees on fewer tasks for faster completion.
|
||||
5. **Dispatch tasks** once assigned, then continue monitoring progress/events via status and reports.
|
||||
6. **Monitor deadlines** — completing after deadline causes failure (0.8x prestige penalty). Cancel hopeless tasks early (1.2x penalty, but stops bleeding time).
|
||||
7. **Watch payroll** — monthly salaries are deducted automatically. Don't let runway drop to zero.
|
||||
8. **Use status checks** to track critical milestones and risks.
|
||||
9. **Successful tasks** award funds + prestige + employee skill boosts. Build momentum.
|
||||
|
||||
## Key Rules
|
||||
|
||||
- Task completion at or before deadline = success (reward funds + prestige + skill boost)
|
||||
- Task completion after deadline = failure (0.8x prestige penalty, no reward)
|
||||
- Task cancellation = 1.2x prestige penalty per domain
|
||||
- Employee throughput = base_rate / number_of_active_tasks_assigned
|
||||
- Time advances only when you run `yc-bench sim resume`
|
||||
- Prestige is clamped [1, 10]. Funds are in cents.
|
||||
"""
|
||||
|
||||
|
||||
def build_turn_context(
|
||||
turn_number: int,
|
||||
sim_time: str,
|
||||
horizon_end: str,
|
||||
funds_cents: int,
|
||||
active_tasks: int,
|
||||
planned_tasks: int,
|
||||
employee_count: int,
|
||||
monthly_payroll_cents: int,
|
||||
bankrupt: bool,
|
||||
last_wake_events: list | None = None,
|
||||
) -> str:
|
||||
"""Build per-turn context message injected as user input."""
|
||||
runway_months = (
|
||||
round(funds_cents / monthly_payroll_cents, 1)
|
||||
if monthly_payroll_cents > 0
|
||||
else None
|
||||
)
|
||||
runway_str = f"~{runway_months} months" if runway_months is not None else "∞ (no payroll)"
|
||||
|
||||
parts = [
|
||||
f"## Turn {turn_number} — Simulation State",
|
||||
f"- **Current time**: {sim_time}",
|
||||
f"- **Horizon end**: {horizon_end}",
|
||||
f"- **Funds**: ${funds_cents / 100:,.2f} ({funds_cents} cents)",
|
||||
f"- **Monthly payroll**: ${monthly_payroll_cents / 100:,.2f}",
|
||||
f"- **Runway**: {runway_str}",
|
||||
f"- **Employees**: {employee_count}",
|
||||
f"- **Active tasks**: {active_tasks}",
|
||||
f"- **Planned tasks**: {planned_tasks}",
|
||||
]
|
||||
|
||||
if bankrupt:
|
||||
parts.append("\n**WARNING: Company is bankrupt. Run will terminate.**")
|
||||
|
||||
if last_wake_events:
|
||||
parts.append("\n### Events since last turn:")
|
||||
for ev in last_wake_events:
|
||||
ev_type = ev.get("type", "unknown")
|
||||
if ev_type == "task_completed":
|
||||
success = ev.get("success", False)
|
||||
tid = ev.get("task_id", "?")
|
||||
parts.append(f"- Task {tid}: {'SUCCESS' if success else 'FAILED'}")
|
||||
elif ev_type == "task_half":
|
||||
parts.append(f"- Task {ev.get('task_id', '?')}: 50% progress reached")
|
||||
elif ev_type == "horizon_end":
|
||||
parts.append("- **Horizon end reached. Simulation complete.**")
|
||||
elif ev_type == "bankruptcy":
|
||||
parts.append("- **BANKRUPTCY. Simulation terminated.**")
|
||||
else:
|
||||
parts.append(f"- Event: {ev_type}")
|
||||
|
||||
if active_tasks == 0 and planned_tasks == 0:
|
||||
parts.append(
|
||||
"\n**ACTION REQUIRED**: No tasks are running. "
|
||||
"Accept a task, assign employees to it, dispatch it, then call `yc-bench sim resume`. "
|
||||
"Do this now — every turn without active tasks burns runway."
|
||||
)
|
||||
elif planned_tasks > 0 and active_tasks == 0:
|
||||
parts.append(
|
||||
"\n**ACTION REQUIRED**: You have planned tasks but none are dispatched. "
|
||||
"Assign employees and dispatch now, then call `yc-bench sim resume`."
|
||||
)
|
||||
else:
|
||||
parts.append("\nDecide your next actions. Use `run_command` to execute CLI commands.")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_initial_user_prompt(
|
||||
sim_time: str,
|
||||
horizon_end: str,
|
||||
funds_cents: int,
|
||||
active_tasks: int,
|
||||
planned_tasks: int,
|
||||
employee_count: int,
|
||||
monthly_payroll_cents: int,
|
||||
bankrupt: bool,
|
||||
) -> str:
|
||||
"""Build the one-time initial user message at run start."""
|
||||
runway_months = (
|
||||
round(funds_cents / monthly_payroll_cents, 1)
|
||||
if monthly_payroll_cents > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
runway_months = (
|
||||
round(funds_cents / monthly_payroll_cents, 1)
|
||||
if monthly_payroll_cents > 0
|
||||
else None
|
||||
)
|
||||
runway_str = f"~{runway_months} months" if runway_months is not None else "∞"
|
||||
|
||||
lines = [
|
||||
"## Simulation Start — Take Immediate Action",
|
||||
f"- current_time: {sim_time}",
|
||||
f"- horizon_end: {horizon_end}",
|
||||
f"- funds: ${funds_cents / 100:,.2f}",
|
||||
f"- monthly_payroll: ${monthly_payroll_cents / 100:,.2f}",
|
||||
f"- runway: {runway_str}",
|
||||
f"- employees: {employee_count}",
|
||||
f"- active_tasks: {active_tasks}",
|
||||
f"- planned_tasks: {planned_tasks}",
|
||||
"",
|
||||
"**Your immediate priority**: generate revenue before payroll drains your runway.",
|
||||
"You MUST complete these steps now (multiple commands per turn are fine):",
|
||||
"1. `yc-bench market browse --required-prestige-lte 1` — find tasks you can accept",
|
||||
"2. `yc-bench task accept --task-id <UUID>` — accept 2-3 suitable tasks",
|
||||
"3. `yc-bench employee list` — get employee IDs",
|
||||
"4. `yc-bench task assign --task-id <UUID> --employee-id <UUID>` — assign employees",
|
||||
"5. `yc-bench task dispatch --task-id <UUID>` — start work on each assigned task",
|
||||
"6. `yc-bench sim resume` — advance time to collect the first task completion event",
|
||||
"",
|
||||
"Do not spend multiple turns just browsing. Accept and dispatch tasks immediately.",
|
||||
]
|
||||
if bankrupt:
|
||||
lines.append("WARNING: company is already bankrupt at initialization.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
__all__ = ["SYSTEM_PROMPT", "build_turn_context", "build_initial_user_prompt"]
|
||||
102
src/yc_bench/agent/run_state.py
Normal file
102
src/yc_bench/agent/run_state.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Run state: tracks the progress and terminal status of a benchmark run."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class TerminalReason(str, Enum):
|
||||
BANKRUPTCY = "bankruptcy"
|
||||
HORIZON_END = "horizon_end"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptEntry:
|
||||
turn: int
|
||||
timestamp: str
|
||||
user_input: str
|
||||
agent_output: str
|
||||
commands_executed: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunState:
|
||||
"""Mutable state for a single benchmark run."""
|
||||
|
||||
session_id: str
|
||||
seed: int
|
||||
model: str
|
||||
horizon_years: int
|
||||
|
||||
turn_count: int = 0
|
||||
terminal: bool = False
|
||||
terminal_reason: Optional[TerminalReason] = None
|
||||
terminal_detail: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
ended_at: Optional[str] = None
|
||||
transcript: List[TranscriptEntry] = field(default_factory=list)
|
||||
next_user_input: Optional[str] = None
|
||||
total_cost_usd: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def record_turn(self, user_input: str, agent_output: str, commands_executed: List[str] | None = None, turn_cost_usd: float = 0.0) -> None:
|
||||
self.turn_count += 1
|
||||
self.total_cost_usd += turn_cost_usd
|
||||
self.transcript.append(TranscriptEntry(
|
||||
turn=self.turn_count,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
user_input=user_input,
|
||||
agent_output=agent_output,
|
||||
commands_executed=commands_executed or [],
|
||||
))
|
||||
|
||||
def mark_terminal(self, reason: TerminalReason, detail: str = "") -> None:
|
||||
self.terminal = True
|
||||
self.terminal_reason = reason
|
||||
self.terminal_detail = detail
|
||||
self.ended_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.terminal:
|
||||
return True
|
||||
return False
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
"turns_completed": self.turn_count,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"total_cost_usd": round(self.total_cost_usd, 6),
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"transcript": [
|
||||
{
|
||||
"turn": t.turn,
|
||||
"timestamp": t.timestamp,
|
||||
"user_input": t.user_input,
|
||||
"agent_output": t.agent_output,
|
||||
"commands_executed": t.commands_executed,
|
||||
}
|
||||
for t in self.transcript
|
||||
],
|
||||
}
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
"""Summary without transcript for logging."""
|
||||
rollout = self.full_rollout()
|
||||
rollout.pop("transcript", None)
|
||||
return rollout
|
||||
|
||||
|
||||
__all__ = ["TerminalReason", "TranscriptEntry", "RunState"]
|
||||
14
src/yc_bench/agent/runtime/base.py
Normal file
14
src/yc_bench/agent/runtime/base.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class AgentRuntime(ABC):
|
||||
@abstractmethod
|
||||
def run_turn(self, request):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_session(self, session_id):
|
||||
raise NotImplementedError
|
||||
|
||||
__all__ = ["AgentRuntime"]
|
||||
13
src/yc_bench/agent/runtime/factory.py
Normal file
13
src/yc_bench/agent/runtime/factory.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .base import AgentRuntime
|
||||
from .litellm_runtime import LiteLLMRuntime
|
||||
from .schemas import RuntimeSettings
|
||||
|
||||
def build_runtime(settings, command_executor):
|
||||
return LiteLLMRuntime(
|
||||
settings=settings,
|
||||
command_executor=command_executor,
|
||||
)
|
||||
|
||||
__all__ = ["build_runtime"]
|
||||
294
src/yc_bench/agent/runtime/litellm_runtime.py
Normal file
294
src/yc_bench/agent/runtime/litellm_runtime.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import litellm
|
||||
|
||||
from .base import AgentRuntime
|
||||
from .schemas import RuntimeTurnResult
|
||||
from ..prompt import SYSTEM_PROMPT
|
||||
from ..tools.run_command_schema import normalize_result
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
# Tool schema passed to the LLM on every call
|
||||
_RUN_COMMAND_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_command",
|
||||
"description": (
|
||||
"Execute one benchmark CLI command inside the sandbox "
|
||||
"and return structured execution output."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The full yc-bench CLI command to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Session:
|
||||
messages: list = field(default_factory=list)
|
||||
|
||||
|
||||
class LiteLLMRuntime(AgentRuntime):
|
||||
|
||||
def __init__(self, settings, command_executor):
|
||||
self._settings = settings
|
||||
self._command_executor = command_executor
|
||||
self._sessions: dict[str, _Session] = {}
|
||||
|
||||
self._request_timeout_seconds = settings.request_timeout_seconds
|
||||
self._retry_max_attempts = settings.retry_max_attempts
|
||||
self._retry_backoff_seconds = settings.retry_backoff_seconds
|
||||
|
||||
if self._request_timeout_seconds <= 0:
|
||||
raise ValueError("request_timeout_seconds must be > 0")
|
||||
if self._retry_max_attempts <= 0:
|
||||
raise ValueError("retry_max_attempts must be > 0")
|
||||
if self._retry_backoff_seconds <= 0:
|
||||
raise ValueError("retry_backoff_seconds must be > 0")
|
||||
|
||||
# API key: prefer OPENAI_API_KEY, fall back to OPENROUTER_API_KEY.
|
||||
# For openrouter/ prefixed models LiteLLM also reads OPENROUTER_API_KEY
|
||||
# automatically, so either approach works.
|
||||
self._api_key = (
|
||||
os.environ.get("OPENAI_API_KEY")
|
||||
or os.environ.get("OPENROUTER_API_KEY")
|
||||
or None
|
||||
)
|
||||
# Base URL: only needed for raw OpenAI-compatible endpoints.
|
||||
# openrouter/ model prefix is handled natively by LiteLLM without this.
|
||||
self._api_base = os.environ.get("OPENAI_BASE_URL") or None
|
||||
|
||||
self._history_keep_rounds = settings.history_keep_rounds
|
||||
|
||||
logger.info(
|
||||
"LiteLLMRuntime configured: model=%s api_base=%s history_keep_rounds=%d",
|
||||
self._settings.model,
|
||||
self._api_base or "default",
|
||||
self._history_keep_rounds,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AgentRuntime interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_turn(self, request):
|
||||
session = self._get_or_create_session(request.session_id)
|
||||
# Proactively drop old rounds before appending new input.
|
||||
self._proactive_truncate(session)
|
||||
session.messages.append({"role": "user", "content": request.user_input})
|
||||
|
||||
result = None
|
||||
last_err = None
|
||||
|
||||
for attempt in range(1, self._retry_max_attempts + 1):
|
||||
try:
|
||||
result = self._run_with_timeout(session)
|
||||
break
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
if self._is_context_length_error(e):
|
||||
logger.warning(
|
||||
"Context-length error on attempt %d despite proactive truncation "
|
||||
"(history_keep_rounds=%d). Consider reducing YC_BENCH_HISTORY_KEEP_ROUNDS.",
|
||||
attempt,
|
||||
self._history_keep_rounds,
|
||||
)
|
||||
continue
|
||||
logger.warning("Turn attempt %d failed: %s", attempt, e)
|
||||
if attempt >= self._retry_max_attempts:
|
||||
raise RuntimeError(
|
||||
f"Failed to run turn after {self._retry_max_attempts} attempts"
|
||||
) from last_err
|
||||
time.sleep(self._retry_backoff_seconds * (2 ** (attempt - 1)))
|
||||
|
||||
if result is None:
|
||||
raise RuntimeError("run_turn failed without result") from last_err
|
||||
|
||||
final_output, tool_calls_made, resume_payload, turn_cost = result
|
||||
return RuntimeTurnResult(
|
||||
final_output=final_output,
|
||||
raw_result={"tool_calls": tool_calls_made},
|
||||
checkpoint_advanced=resume_payload is not None,
|
||||
resume_payload=resume_payload,
|
||||
turn_cost_usd=turn_cost,
|
||||
)
|
||||
|
||||
def clear_session(self, session_id):
|
||||
self._sessions.pop(session_id, None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_with_timeout(self, session):
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(self._do_turn, session)
|
||||
try:
|
||||
return future.result(timeout=self._request_timeout_seconds)
|
||||
except FuturesTimeoutError as exc:
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"LiteLLM call timed out after {self._request_timeout_seconds}s"
|
||||
) from exc
|
||||
|
||||
def _do_turn(self, session):
|
||||
"""One LLM call + tool execution. Returns (final_output, tool_calls_made, resume_payload, cost_usd)."""
|
||||
system_prompt = self._settings.system_prompt or SYSTEM_PROMPT
|
||||
messages = [{"role": "system", "content": system_prompt}] + session.messages
|
||||
|
||||
kwargs = dict(
|
||||
model=self._settings.model,
|
||||
messages=messages,
|
||||
tools=[_RUN_COMMAND_TOOL],
|
||||
tool_choice="auto",
|
||||
temperature=self._settings.temperature,
|
||||
top_p=self._settings.top_p,
|
||||
timeout=self._request_timeout_seconds,
|
||||
)
|
||||
if self._api_base:
|
||||
kwargs["api_base"] = self._api_base
|
||||
if self._api_key:
|
||||
kwargs["api_key"] = self._api_key
|
||||
|
||||
response = litellm.completion(**kwargs)
|
||||
|
||||
# Log token usage and cost per call
|
||||
turn_cost = 0.0
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
cost = getattr(response, "_hidden_params", {}).get("response_cost") or 0
|
||||
turn_cost = float(cost)
|
||||
logger.info(
|
||||
"LLM call: prompt_tokens=%s completion_tokens=%s cost=$%.6f",
|
||||
getattr(usage, "prompt_tokens", "?"),
|
||||
getattr(usage, "completion_tokens", "?"),
|
||||
turn_cost,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
|
||||
tool_calls_made = []
|
||||
resume_payload = None
|
||||
|
||||
if tool_calls:
|
||||
# Persist assistant message with tool calls
|
||||
session.messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in tool_calls
|
||||
],
|
||||
})
|
||||
|
||||
for tc in tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
command = args.get("command", "")
|
||||
except Exception:
|
||||
command = ""
|
||||
|
||||
raw = self._command_executor(command)
|
||||
normalized = normalize_result(raw)
|
||||
tool_result_str = json.dumps(normalized.__dict__)
|
||||
|
||||
tool_calls_made.append({"command": command, "result": tool_result_str})
|
||||
|
||||
# Extract resume payload when the agent advances simulation time
|
||||
if command.startswith("yc-bench sim resume"):
|
||||
try:
|
||||
stdout = normalized.__dict__.get("stdout", "")
|
||||
if isinstance(stdout, str) and stdout.strip():
|
||||
payload = json.loads(stdout)
|
||||
if isinstance(payload, dict):
|
||||
resume_payload = payload
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
session.messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result_str,
|
||||
})
|
||||
|
||||
cmds = [tc["command"] for tc in tool_calls_made]
|
||||
final_output = f"Executed {len(tool_calls)} tool call(s): {', '.join(cmds)}"
|
||||
else:
|
||||
content = message.content or ""
|
||||
session.messages.append({"role": "assistant", "content": content})
|
||||
final_output = content
|
||||
|
||||
return final_output, tool_calls_made, resume_payload, turn_cost
|
||||
|
||||
def _get_or_create_session(self, session_id: str) -> _Session:
|
||||
if session_id not in self._sessions:
|
||||
self._sessions[session_id] = _Session()
|
||||
return self._sessions[session_id]
|
||||
|
||||
def _is_context_length_error(self, err: Exception) -> bool:
|
||||
text = str(err).lower()
|
||||
patterns = (
|
||||
"context length",
|
||||
"maximum context",
|
||||
"max context",
|
||||
"too many tokens",
|
||||
"prompt is too long",
|
||||
"token limit",
|
||||
"context window",
|
||||
)
|
||||
return any(p in text for p in patterns)
|
||||
|
||||
def _round_start_indices(self, messages: list) -> list[int]:
|
||||
"""Return indices of user messages — each marks the start of a round."""
|
||||
return [i for i, m in enumerate(messages) if m.get("role") == "user"]
|
||||
|
||||
def _proactive_truncate(self, session: _Session) -> None:
|
||||
"""Drop oldest rounds before each turn so at most history_keep_rounds remain."""
|
||||
messages = session.messages
|
||||
user_indices = self._round_start_indices(messages)
|
||||
if len(user_indices) <= self._history_keep_rounds:
|
||||
return
|
||||
cutoff = user_indices[-self._history_keep_rounds]
|
||||
marker = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"[Earlier turns removed. Only the last {self._history_keep_rounds} "
|
||||
"turns are retained in this context window.]"
|
||||
),
|
||||
}
|
||||
session.messages = [marker] + messages[cutoff:]
|
||||
logger.info(
|
||||
"Proactive truncation: kept last %d rounds (%d → %d messages).",
|
||||
self._history_keep_rounds,
|
||||
len(messages),
|
||||
len(session.messages),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["LiteLLMRuntime"]
|
||||
40
src/yc_bench/agent/runtime/schemas.py
Normal file
40
src/yc_bench/agent/runtime/schemas.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeSettings:
|
||||
model: str
|
||||
temperature: float
|
||||
top_p: float
|
||||
request_timeout_seconds: float = 300.0
|
||||
retry_max_attempts: int = 3
|
||||
retry_backoff_seconds: float = 1.0
|
||||
history_keep_rounds: int = 20
|
||||
# Optional system prompt override; None = use default from agent/prompt.py
|
||||
system_prompt: str | None = None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeTurnRequest:
|
||||
session_id: str
|
||||
user_input: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeTurnResult:
|
||||
final_output: str
|
||||
raw_result: Any
|
||||
checkpoint_advanced: bool = False
|
||||
resume_payload: dict | None = None
|
||||
turn_cost_usd: float = 0.0
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandResult:
|
||||
ok: bool
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
sim_time: str | None
|
||||
command: str
|
||||
|
||||
__all__ = ["RuntimeSettings", "RuntimeTurnRequest", "RuntimeTurnResult", "CommandResult"]
|
||||
22
src/yc_bench/agent/tools/run_command_schema.py
Normal file
22
src/yc_bench/agent/tools/run_command_schema.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCommandResult:
|
||||
ok: bool
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
sim_time: str | None
|
||||
command: str
|
||||
|
||||
def normalize_result(payload):
|
||||
return RunCommandResult(
|
||||
ok=bool(payload.get("ok", False)),
|
||||
exit_code=int(payload.get("exit_code", 1)),
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
sim_time=payload.get("sim_time"),
|
||||
command=str(payload.get("command", "")),
|
||||
)
|
||||
110
src/yc_bench/cli/__init__.py
Normal file
110
src/yc_bench/cli/__init__.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
import typer
|
||||
|
||||
from ..db.session import build_engine, build_session_factory, session_scope
|
||||
|
||||
app = typer.Typer(name="yc-bench", add_completion=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers shared across command modules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
"""Yield a transactional SQLAlchemy session, commit on success."""
|
||||
engine = build_engine()
|
||||
factory = build_session_factory(engine)
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
|
||||
class _JSONEncoder(json.JSONEncoder):
|
||||
"""Handle UUID, Decimal, datetime serialisation."""
|
||||
|
||||
def default(self, o):
|
||||
if isinstance(o, UUID):
|
||||
return str(o)
|
||||
if isinstance(o, Decimal):
|
||||
return float(o)
|
||||
from datetime import datetime, date
|
||||
if isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
if isinstance(o, date):
|
||||
return o.isoformat()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def json_output(data: dict | list) -> None:
|
||||
"""Print JSON to stdout (captured by run_command executor)."""
|
||||
typer.echo(json.dumps(data, cls=_JSONEncoder, indent=2))
|
||||
|
||||
|
||||
def error_output(message: str, code: int = 1) -> None:
|
||||
"""Print JSON error and exit."""
|
||||
typer.echo(json.dumps({"error": message}), err=False)
|
||||
raise typer.Exit(code=code)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Register sub-command groups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from .sim_commands import sim_app # noqa: E402
|
||||
from .company_commands import company_app # noqa: E402
|
||||
from .market_commands import market_app # noqa: E402
|
||||
from .task_commands import task_app # noqa: E402
|
||||
from .finance_commands import finance_app # noqa: E402
|
||||
from .report_commands import report_app # noqa: E402
|
||||
from .employee_commands import employee_app # noqa: E402
|
||||
from .scratchpad_commands import scratchpad_app # noqa: E402
|
||||
|
||||
app.add_typer(sim_app, name="sim")
|
||||
app.add_typer(company_app, name="company")
|
||||
app.add_typer(employee_app, name="employee")
|
||||
app.add_typer(market_app, name="market")
|
||||
app.add_typer(task_app, name="task")
|
||||
app.add_typer(finance_app, name="finance")
|
||||
app.add_typer(report_app, name="report")
|
||||
app.add_typer(scratchpad_app, name="scratchpad")
|
||||
|
||||
|
||||
@app.command("run")
|
||||
def run_command_cli(
|
||||
model: str = typer.Option(..., help="LiteLLM model string (e.g. openrouter/z-ai/glm-5)"),
|
||||
seed: int = typer.Option(..., help="Random seed for deterministic world generation"),
|
||||
horizon_years: int = typer.Option(3, help="Simulation horizon in years"),
|
||||
company_name: str = typer.Option("BenchCo", help="Name of the simulated company"),
|
||||
start_date: str = typer.Option("2025-01-01", help="Simulation start date (YYYY-MM-DD)"),
|
||||
config_name: str = typer.Option(
|
||||
"default", "--config",
|
||||
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
|
||||
),
|
||||
):
|
||||
"""Run a full benchmark: migrate DB, seed world, run agent loop to completion."""
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
load_dotenv(find_dotenv(usecwd=True), override=False)
|
||||
|
||||
from ..runner.main import run_benchmark
|
||||
from ..runner.args import RunArgs
|
||||
args = RunArgs(
|
||||
model=model,
|
||||
seed=seed,
|
||||
horizon_years=horizon_years,
|
||||
company_name=company_name,
|
||||
start_date=start_date,
|
||||
config_name=config_name,
|
||||
)
|
||||
raise SystemExit(run_benchmark(args))
|
||||
|
||||
|
||||
def app_main():
|
||||
"""Entry point for `yc-bench` console_script."""
|
||||
app()
|
||||
102
src/yc_bench/cli/company_commands.py
Normal file
102
src/yc_bench/cli/company_commands.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
import typer
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..core.business_time import first_business_of_month
|
||||
from ..db.models.company import Company, CompanyPrestige
|
||||
from ..db.models.employee import Employee
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.task import Task, TaskStatus
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
company_app = typer.Typer(help="Company status commands.")
|
||||
|
||||
|
||||
def _next_payroll_date(sim_time: datetime) -> datetime:
|
||||
"""Compute next first-business-day-of-month at 09:00 from sim_time."""
|
||||
if sim_time.month == 12:
|
||||
next_month = sim_time.replace(year=sim_time.year + 1, month=1, day=1)
|
||||
else:
|
||||
next_month = sim_time.replace(month=sim_time.month + 1, day=1)
|
||||
return first_business_of_month(next_month)
|
||||
|
||||
|
||||
@company_app.command("status")
|
||||
def company_status():
|
||||
"""Show company status: funds, prestige, tasks, payroll, risk."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
|
||||
company = db.query(Company).filter(Company.id == sim_state.company_id).one_or_none()
|
||||
if company is None:
|
||||
error_output("Company not found.")
|
||||
|
||||
# Prestige by domain
|
||||
prestige_rows = db.query(CompanyPrestige).filter(
|
||||
CompanyPrestige.company_id == company.id
|
||||
).all()
|
||||
prestige_map = {row.domain.value: float(row.prestige_level) for row in prestige_rows}
|
||||
|
||||
# Task counts
|
||||
active_count = db.query(func.count(Task.id)).filter(
|
||||
Task.company_id == company.id,
|
||||
Task.status == TaskStatus.ACTIVE,
|
||||
).scalar() or 0
|
||||
|
||||
planned_count = db.query(func.count(Task.id)).filter(
|
||||
Task.company_id == company.id,
|
||||
Task.status == TaskStatus.PLANNED,
|
||||
).scalar() or 0
|
||||
|
||||
completed_count = db.query(func.count(Task.id)).filter(
|
||||
Task.company_id == company.id,
|
||||
Task.status.in_([TaskStatus.COMPLETED_SUCCESS, TaskStatus.COMPLETED_FAIL]),
|
||||
).scalar() or 0
|
||||
|
||||
cancelled_count = db.query(func.count(Task.id)).filter(
|
||||
Task.company_id == company.id,
|
||||
Task.status == TaskStatus.CANCELLED,
|
||||
).scalar() or 0
|
||||
|
||||
# Employee count
|
||||
employee_count = db.query(func.count(Employee.id)).filter(
|
||||
Employee.company_id == company.id
|
||||
).scalar() or 0
|
||||
|
||||
# Monthly payroll estimate
|
||||
total_salary = db.query(func.sum(Employee.salary_cents)).filter(
|
||||
Employee.company_id == company.id
|
||||
).scalar() or 0
|
||||
|
||||
next_payroll = _next_payroll_date(sim_state.sim_time)
|
||||
|
||||
# Bankruptcy risk
|
||||
months_runway = round(float(company.funds_cents) / float(total_salary), 2) if total_salary > 0 else None
|
||||
|
||||
json_output({
|
||||
"company_id": str(company.id),
|
||||
"company_name": company.name,
|
||||
"funds_cents": company.funds_cents,
|
||||
"prestige": prestige_map,
|
||||
"sim_time": sim_state.sim_time.isoformat(),
|
||||
"horizon_end": sim_state.horizon_end.isoformat(),
|
||||
"tasks": {
|
||||
"active": active_count,
|
||||
"planned": planned_count,
|
||||
"completed": completed_count,
|
||||
"cancelled": cancelled_count,
|
||||
},
|
||||
"employees": employee_count,
|
||||
"next_payroll": next_payroll.isoformat(),
|
||||
"monthly_payroll_cents": total_salary,
|
||||
"risk": {
|
||||
"months_runway": months_runway,
|
||||
"bankrupt": company.funds_cents < 0,
|
||||
},
|
||||
})
|
||||
62
src/yc_bench/cli/employee_commands.py
Normal file
62
src/yc_bench/cli/employee_commands.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typer
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..db.models.employee import Employee, EmployeeSkillRate
|
||||
from ..db.models.task import Task, TaskAssignment, TaskStatus
|
||||
from ..db.models.sim_state import SimState
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
employee_app = typer.Typer(help="Employee management commands.")
|
||||
|
||||
|
||||
@employee_app.command("list")
|
||||
def employee_list():
|
||||
"""List all employees with their skills and current assignments."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
employees = db.query(Employee).filter(
|
||||
Employee.company_id == sim_state.company_id
|
||||
).all()
|
||||
|
||||
results = []
|
||||
for emp in employees:
|
||||
# Skills
|
||||
skills = db.query(EmployeeSkillRate).filter(
|
||||
EmployeeSkillRate.employee_id == emp.id
|
||||
).all()
|
||||
skill_map = {
|
||||
s.domain.value: float(s.rate_domain_per_hour)
|
||||
for s in skills
|
||||
}
|
||||
|
||||
# Current active assignments
|
||||
active_assignments = (
|
||||
db.query(TaskAssignment.task_id)
|
||||
.join(Task, Task.id == TaskAssignment.task_id)
|
||||
.filter(
|
||||
TaskAssignment.employee_id == emp.id,
|
||||
Task.status == TaskStatus.ACTIVE,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
active_task_ids = [str(a.task_id) for a in active_assignments]
|
||||
|
||||
results.append({
|
||||
"employee_id": str(emp.id),
|
||||
"name": emp.name,
|
||||
"salary_cents": emp.salary_cents,
|
||||
"work_hours_per_day": float(emp.work_hours_per_day),
|
||||
"skills": skill_map,
|
||||
"active_task_count": len(active_task_ids),
|
||||
"active_task_ids": active_task_ids,
|
||||
})
|
||||
|
||||
json_output({
|
||||
"count": len(results),
|
||||
"employees": results,
|
||||
})
|
||||
75
src/yc_bench/cli/finance_commands.py
Normal file
75
src/yc_bench/cli/finance_commands.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from sqlalchemy import and_
|
||||
|
||||
from ..db.models.ledger import LedgerCategory, LedgerEntry
|
||||
from ..db.models.sim_state import SimState
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
finance_app = typer.Typer(help="Finance and ledger commands.")
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
try:
|
||||
dt = datetime.strptime(date_str, "%m/%d/%Y")
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
raise typer.BadParameter(f"Invalid date format: {date_str}. Use MM/DD/YYYY.")
|
||||
|
||||
|
||||
@finance_app.command("ledger")
|
||||
def finance_ledger(
|
||||
from_date: Optional[str] = typer.Option(None, "--from", help="Start date MM/DD/YYYY"),
|
||||
to_date: Optional[str] = typer.Option(None, "--to", help="End date MM/DD/YYYY"),
|
||||
category: Optional[str] = typer.Option(None, "--category", help="Filter by ledger category"),
|
||||
):
|
||||
"""View ledger entries with optional date and category filters."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
query = db.query(LedgerEntry).filter(
|
||||
LedgerEntry.company_id == sim_state.company_id
|
||||
)
|
||||
|
||||
if from_date is not None:
|
||||
from_dt = _parse_date(from_date)
|
||||
query = query.filter(LedgerEntry.occurred_at >= from_dt)
|
||||
|
||||
if to_date is not None:
|
||||
to_dt = _parse_date(to_date)
|
||||
query = query.filter(LedgerEntry.occurred_at <= to_dt)
|
||||
|
||||
if category is not None:
|
||||
try:
|
||||
cat = LedgerCategory(category)
|
||||
except ValueError:
|
||||
error_output(f"Invalid category: {category}. Valid: {[c.value for c in LedgerCategory]}")
|
||||
query = query.filter(LedgerEntry.category == cat)
|
||||
|
||||
entries = query.order_by(LedgerEntry.occurred_at.asc()).all()
|
||||
|
||||
total_amount = sum(e.amount_cents for e in entries)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(e.id),
|
||||
"occurred_at": e.occurred_at.isoformat(),
|
||||
"category": e.category.value,
|
||||
"amount_cents": e.amount_cents,
|
||||
"ref_type": e.ref_type,
|
||||
"ref_id": str(e.ref_id) if e.ref_id else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
|
||||
json_output({
|
||||
"count": len(results),
|
||||
"total_amount_cents": total_amount,
|
||||
"entries": results,
|
||||
})
|
||||
74
src/yc_bench/cli/market_commands.py
Normal file
74
src/yc_bench/cli/market_commands.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from ..db.models.company import Domain
|
||||
from ..db.models.task import Task, TaskRequirement, TaskStatus
|
||||
from ..config import get_world_config
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
market_app = typer.Typer(help="Market browsing commands.")
|
||||
|
||||
|
||||
@market_app.command("browse")
|
||||
def market_browse(
|
||||
domain: Optional[Domain] = typer.Option(None, "--domain", help="Filter by requirement domain"),
|
||||
required_prestige_lte: Optional[int] = typer.Option(None, "--required-prestige-lte", help="Max required prestige"),
|
||||
reward_min_cents: Optional[int] = typer.Option(None, "--reward-min-cents", help="Min reward in cents"),
|
||||
limit: Optional[int] = typer.Option(None, "--limit", help="Max results (default from experiment config)"),
|
||||
offset: int = typer.Option(0, "--offset", help="Offset for pagination"),
|
||||
):
|
||||
"""Browse available tasks on the market."""
|
||||
if limit is None:
|
||||
limit = get_world_config().market_browse_default_limit
|
||||
with get_db() as db:
|
||||
query = db.query(Task).filter(Task.status == TaskStatus.MARKET)
|
||||
|
||||
if required_prestige_lte is not None:
|
||||
query = query.filter(Task.required_prestige <= required_prestige_lte)
|
||||
|
||||
if reward_min_cents is not None:
|
||||
query = query.filter(Task.reward_funds_cents >= reward_min_cents)
|
||||
|
||||
if domain is not None:
|
||||
# Filter tasks that have a requirement in the given domain
|
||||
query = query.filter(
|
||||
Task.id.in_(
|
||||
db.query(TaskRequirement.task_id).filter(
|
||||
TaskRequirement.domain == domain
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
tasks = query.order_by(Task.reward_funds_cents.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
results = []
|
||||
for task in tasks:
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task.id).all()
|
||||
requirements = [
|
||||
{
|
||||
"domain": r.domain.value,
|
||||
"required_qty": float(r.required_qty),
|
||||
}
|
||||
for r in reqs
|
||||
]
|
||||
results.append({
|
||||
"task_id": str(task.id),
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"required_prestige": task.required_prestige,
|
||||
"reward_funds_cents": task.reward_funds_cents,
|
||||
"reward_prestige_delta": float(task.reward_prestige_delta),
|
||||
"skill_boost_pct": float(task.skill_boost_pct),
|
||||
"requirements": requirements,
|
||||
})
|
||||
|
||||
json_output({
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"tasks": results,
|
||||
})
|
||||
64
src/yc_bench/cli/report_commands.py
Normal file
64
src/yc_bench/cli/report_commands.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from sqlalchemy import and_
|
||||
|
||||
from ..db.models.session import MonthlyMetric
|
||||
from ..db.models.sim_state import SimState
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
report_app = typer.Typer(help="Reporting commands.")
|
||||
|
||||
|
||||
def _parse_month(month_str: str) -> date:
|
||||
"""Parse YYYY-MM into a date (first of month)."""
|
||||
try:
|
||||
dt = datetime.strptime(month_str, "%Y-%m")
|
||||
return dt.date().replace(day=1)
|
||||
except ValueError:
|
||||
raise typer.BadParameter(f"Invalid month format: {month_str}. Use YYYY-MM.")
|
||||
|
||||
|
||||
@report_app.command("monthly")
|
||||
def report_monthly(
|
||||
from_month: Optional[str] = typer.Option(None, "--from-month", help="Start month YYYY-MM"),
|
||||
to_month: Optional[str] = typer.Option(None, "--to-month", help="End month YYYY-MM"),
|
||||
):
|
||||
"""View monthly metrics (revenue, cost, return, ending funds)."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
query = db.query(MonthlyMetric).filter(
|
||||
MonthlyMetric.company_id == sim_state.company_id
|
||||
)
|
||||
|
||||
if from_month is not None:
|
||||
from_dt = _parse_month(from_month)
|
||||
query = query.filter(MonthlyMetric.month_start >= from_dt)
|
||||
|
||||
if to_month is not None:
|
||||
to_dt = _parse_month(to_month)
|
||||
query = query.filter(MonthlyMetric.month_start <= to_dt)
|
||||
|
||||
metrics = query.order_by(MonthlyMetric.month_start.asc()).all()
|
||||
|
||||
results = [
|
||||
{
|
||||
"month_start": m.month_start.isoformat(),
|
||||
"revenue_cents": m.revenue_cents,
|
||||
"cost_cents": m.cost_cents,
|
||||
"return_cents": m.return_cents,
|
||||
"ending_funds_cents": m.ending_funds_cents,
|
||||
}
|
||||
for m in metrics
|
||||
]
|
||||
|
||||
json_output({
|
||||
"count": len(results),
|
||||
"months": results,
|
||||
})
|
||||
72
src/yc_bench/cli/scratchpad_commands.py
Normal file
72
src/yc_bench/cli/scratchpad_commands.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typer
|
||||
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
from ..db.models.sim_state import SimState
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
scratchpad_app = typer.Typer(help="Agent scratchpad for persistent notes.")
|
||||
|
||||
|
||||
def _get_or_create(db, company_id) -> Scratchpad:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).one_or_none()
|
||||
if row is None:
|
||||
row = Scratchpad(company_id=company_id, content="")
|
||||
db.add(row)
|
||||
db.flush()
|
||||
return row
|
||||
|
||||
|
||||
@scratchpad_app.command("read")
|
||||
def scratchpad_read():
|
||||
"""Read the current scratchpad content."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
row = _get_or_create(db, sim_state.company_id)
|
||||
json_output({"content": row.content})
|
||||
|
||||
|
||||
@scratchpad_app.command("write")
|
||||
def scratchpad_write(
|
||||
content: str = typer.Option(..., help="Text to write (replaces existing content)."),
|
||||
):
|
||||
"""Overwrite the scratchpad with new content."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
row = _get_or_create(db, sim_state.company_id)
|
||||
row.content = content
|
||||
json_output({"ok": True, "content": row.content})
|
||||
|
||||
|
||||
@scratchpad_app.command("append")
|
||||
def scratchpad_append(
|
||||
content: str = typer.Option(..., help="Text to append to existing content."),
|
||||
):
|
||||
"""Append text to the scratchpad (adds a newline separator)."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
row = _get_or_create(db, sim_state.company_id)
|
||||
if row.content:
|
||||
row.content = row.content + "\n" + content
|
||||
else:
|
||||
row.content = content
|
||||
json_output({"ok": True, "content": row.content})
|
||||
|
||||
|
||||
@scratchpad_app.command("clear")
|
||||
def scratchpad_clear():
|
||||
"""Clear the scratchpad."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
row = _get_or_create(db, sim_state.company_id)
|
||||
row.content = ""
|
||||
json_output({"ok": True})
|
||||
151
src/yc_bench/cli/sim_commands.py
Normal file
151
src/yc_bench/cli/sim_commands.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from ..core.engine import advance_time
|
||||
from ..core.events import fetch_next_event, insert_event
|
||||
from ..db.models.company import Company
|
||||
from ..db.models.event import EventType
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..config import get_world_config
|
||||
from ..services.seed_world import SeedWorldRequest, seed_world_transactional
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
sim_app = typer.Typer(help="Simulation initialization commands.")
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
"""Parse MM/DD/YYYY into a timezone-aware datetime at 09:00 UTC."""
|
||||
try:
|
||||
dt = datetime.strptime(date_str, "%m/%d/%Y")
|
||||
return dt.replace(hour=9, minute=0, second=0, tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
raise typer.BadParameter(f"Invalid date format: {date_str}. Use MM/DD/YYYY.")
|
||||
|
||||
|
||||
@sim_app.command("init")
|
||||
def sim_init(
|
||||
seed: int = typer.Option(..., help="RNG seed for deterministic generation"),
|
||||
start_date: str = typer.Option(..., "--start-date", help="Start date MM/DD/YYYY"),
|
||||
horizon_years: int = typer.Option(3, "--horizon-years", help="Simulation horizon in years"),
|
||||
company_name: str = typer.Option(..., "--company-name", help="Company name"),
|
||||
employee_count: Optional[int] = typer.Option(None, "--employee-count", help="Number of employees (default from experiment config)"),
|
||||
market_task_count: Optional[int] = typer.Option(None, "--market-task-count", help="Number of market tasks (default from experiment config)"),
|
||||
):
|
||||
"""Initialize a new simulation: seed world, create company, schedule horizon."""
|
||||
_wc = get_world_config()
|
||||
if employee_count is None:
|
||||
employee_count = _wc.num_employees
|
||||
if market_task_count is None:
|
||||
market_task_count = _wc.num_market_tasks
|
||||
|
||||
start_dt = _parse_date(start_date)
|
||||
horizon_end = start_dt.replace(year=start_dt.year + horizon_years)
|
||||
|
||||
with get_db() as db:
|
||||
# Check if a simulation already exists
|
||||
existing = db.query(SimState).first()
|
||||
if existing is not None:
|
||||
error_output("A simulation already exists. Only one simulation per database is supported.")
|
||||
|
||||
req = SeedWorldRequest(
|
||||
run_seed=seed,
|
||||
company_name=company_name,
|
||||
horizon_years=horizon_years,
|
||||
employee_count=employee_count,
|
||||
market_task_count=market_task_count,
|
||||
start_date=start_dt,
|
||||
)
|
||||
result = seed_world_transactional(db, req)
|
||||
|
||||
# Schedule horizon_end event (deterministic id via insert_event helper)
|
||||
insert_event(
|
||||
db=db,
|
||||
company_id=result.company_id,
|
||||
event_type=EventType.HORIZON_END,
|
||||
scheduled_at=horizon_end,
|
||||
payload={"reason": "horizon_end"},
|
||||
dedupe_key="horizon_end",
|
||||
)
|
||||
|
||||
# Create sim_state row
|
||||
sim_state = SimState(
|
||||
company_id=result.company_id,
|
||||
sim_time=start_dt,
|
||||
run_seed=seed,
|
||||
horizon_end=horizon_end,
|
||||
replenish_counter=0,
|
||||
)
|
||||
db.add(sim_state)
|
||||
db.flush()
|
||||
|
||||
json_output({
|
||||
"simulation_id": str(result.company_id),
|
||||
"company_id": str(result.company_id),
|
||||
"sim_time": start_dt.isoformat(),
|
||||
"horizon_end": horizon_end.isoformat(),
|
||||
"company_name": company_name,
|
||||
"seed": seed,
|
||||
})
|
||||
|
||||
|
||||
@sim_app.command("resume")
|
||||
def sim_resume():
|
||||
"""Advance simulation to the next queued event checkpoint and return wake results."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
company = db.query(Company).filter(Company.id == sim_state.company_id).one()
|
||||
|
||||
next_event = fetch_next_event(
|
||||
db=db,
|
||||
company_id=sim_state.company_id,
|
||||
up_to=sim_state.horizon_end,
|
||||
)
|
||||
|
||||
if next_event is None:
|
||||
terminal_reason = None
|
||||
bankrupt = company.funds_cents < 0
|
||||
horizon_reached = sim_state.sim_time >= sim_state.horizon_end
|
||||
if bankrupt:
|
||||
terminal_reason = "bankruptcy"
|
||||
elif horizon_reached:
|
||||
terminal_reason = "horizon_end"
|
||||
|
||||
json_output({
|
||||
"ok": True,
|
||||
"message": "no_pending_events",
|
||||
"old_sim_time": sim_state.sim_time.isoformat(),
|
||||
"new_sim_time": sim_state.sim_time.isoformat(),
|
||||
"events_processed": 0,
|
||||
"payrolls_applied": 0,
|
||||
"balance_delta": 0,
|
||||
"wake_events": [],
|
||||
"bankrupt": bankrupt,
|
||||
"horizon_reached": horizon_reached,
|
||||
"terminal_reason": terminal_reason,
|
||||
})
|
||||
return
|
||||
|
||||
checkpoint_type = next_event.event_type.value
|
||||
result = advance_time(
|
||||
db=db,
|
||||
company_id=sim_state.company_id,
|
||||
target_time=next_event.scheduled_at,
|
||||
)
|
||||
|
||||
terminal_reason = None
|
||||
if result.bankrupt:
|
||||
terminal_reason = "bankruptcy"
|
||||
elif result.horizon_reached:
|
||||
terminal_reason = "horizon_end"
|
||||
|
||||
payload = result.__dict__.copy()
|
||||
payload["ok"] = True
|
||||
payload["checkpoint_event_type"] = checkpoint_type
|
||||
payload["terminal_reason"] = terminal_reason
|
||||
json_output(payload)
|
||||
460
src/yc_bench/cli/task_commands.py
Normal file
460
src/yc_bench/cli/task_commands.py
Normal file
|
|
@ -0,0 +1,460 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import typer
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..core.business_time import add_business_hours
|
||||
from ..core.eta import recalculate_etas
|
||||
from ..db.models.company import Company, CompanyPrestige
|
||||
from ..db.models.employee import Employee
|
||||
from ..db.models.event import SimEvent
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.task import Task, TaskAssignment, TaskRequirement, TaskStatus
|
||||
from ..services.generate_tasks import generate_replacement_task
|
||||
from . import get_db, json_output, error_output
|
||||
|
||||
|
||||
def _get_world_cfg():
|
||||
"""Load WorldConfig from the active experiment (YC_BENCH_EXPERIMENT env var)."""
|
||||
from yc_bench.config import get_world_config
|
||||
return get_world_config()
|
||||
|
||||
task_app = typer.Typer(help="Task management commands.")
|
||||
|
||||
|
||||
def _compute_deadline(accepted_at: datetime, total_required_qty: float, cfg=None) -> datetime:
|
||||
"""Deadline: cfg.deadline_qty_per_day units per business day, minimum cfg.deadline_min_biz_days days."""
|
||||
if cfg is None:
|
||||
cfg = _get_world_cfg()
|
||||
work_hours = cfg.workday_end_hour - cfg.workday_start_hour
|
||||
biz_days = max(cfg.deadline_min_biz_days, int(total_required_qty / cfg.deadline_qty_per_day))
|
||||
return add_business_hours(accepted_at, Decimal(str(biz_days)) * Decimal(str(work_hours)))
|
||||
|
||||
|
||||
@task_app.command("accept")
|
||||
def task_accept(
|
||||
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to accept"),
|
||||
):
|
||||
"""Accept a market task: transition to planned, assign to company, generate replacement."""
|
||||
try:
|
||||
tid = UUID(task_id)
|
||||
except ValueError:
|
||||
error_output(f"Invalid UUID: {task_id}")
|
||||
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found. Run `yc-bench sim init` first.")
|
||||
|
||||
task = db.query(Task).filter(Task.id == tid).one_or_none()
|
||||
if task is None:
|
||||
error_output(f"Task {task_id} not found.")
|
||||
if task.status != TaskStatus.MARKET:
|
||||
error_output(f"Task {task_id} is not in market status (current: {task.status.value}).")
|
||||
|
||||
# Validate prestige requirement
|
||||
company_id = sim_state.company_id
|
||||
prestige_rows = db.query(CompanyPrestige).filter(
|
||||
CompanyPrestige.company_id == company_id
|
||||
).all()
|
||||
max_prestige = max((float(p.prestige_level) for p in prestige_rows), default=1.0)
|
||||
|
||||
if task.required_prestige > max_prestige:
|
||||
error_output(
|
||||
f"Company max prestige ({max_prestige}) does not meet task requirement ({task.required_prestige})."
|
||||
)
|
||||
|
||||
# Compute deadline
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == tid).all()
|
||||
total_qty = sum(float(r.required_qty) for r in reqs)
|
||||
accepted_at = sim_state.sim_time
|
||||
deadline = _compute_deadline(accepted_at, total_qty)
|
||||
|
||||
# Transition task
|
||||
task.status = TaskStatus.PLANNED
|
||||
task.company_id = company_id
|
||||
task.accepted_at = accepted_at
|
||||
task.deadline = deadline
|
||||
|
||||
# Generate replacement task
|
||||
counter = sim_state.replenish_counter
|
||||
sim_state.replenish_counter = counter + 1
|
||||
replacement = generate_replacement_task(
|
||||
run_seed=sim_state.run_seed,
|
||||
replenish_counter=counter,
|
||||
cfg=_get_world_cfg(),
|
||||
)
|
||||
|
||||
replacement_row = Task(
|
||||
id=uuid4(),
|
||||
company_id=None,
|
||||
status=TaskStatus.MARKET,
|
||||
title=replacement.title,
|
||||
description=replacement.description,
|
||||
required_prestige=replacement.required_prestige,
|
||||
reward_funds_cents=replacement.reward_funds_cents,
|
||||
reward_prestige_delta=replacement.reward_prestige_delta,
|
||||
skill_boost_pct=replacement.skill_boost_pct,
|
||||
accepted_at=None,
|
||||
deadline=None,
|
||||
completed_at=None,
|
||||
success=None,
|
||||
halfway_event_emitted=False,
|
||||
)
|
||||
db.add(replacement_row)
|
||||
|
||||
for domain, qty in replacement.requirements.items():
|
||||
db.add(TaskRequirement(
|
||||
task_id=replacement_row.id,
|
||||
domain=domain,
|
||||
required_qty=qty,
|
||||
completed_qty=0,
|
||||
))
|
||||
|
||||
db.flush()
|
||||
|
||||
json_output({
|
||||
"task_id": str(task.id),
|
||||
"status": task.status.value,
|
||||
"accepted_at": accepted_at.isoformat(),
|
||||
"deadline": deadline.isoformat(),
|
||||
"replacement_task_id": str(replacement_row.id),
|
||||
})
|
||||
|
||||
|
||||
@task_app.command("assign")
|
||||
def task_assign(
|
||||
task_id: str = typer.Option(..., "--task-id", help="UUID of the task"),
|
||||
employee_id: str = typer.Option(..., "--employee-id", help="UUID of the employee"),
|
||||
):
|
||||
"""Assign an employee to a task."""
|
||||
try:
|
||||
tid = UUID(task_id)
|
||||
eid = UUID(employee_id)
|
||||
except ValueError:
|
||||
error_output("Invalid UUID provided.")
|
||||
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
task = db.query(Task).filter(Task.id == tid).one_or_none()
|
||||
if task is None:
|
||||
error_output(f"Task {task_id} not found.")
|
||||
if task.status not in (TaskStatus.PLANNED, TaskStatus.ACTIVE):
|
||||
error_output(f"Task {task_id} must be planned or active to assign (current: {task.status.value}).")
|
||||
if task.company_id != sim_state.company_id:
|
||||
error_output(f"Task {task_id} does not belong to your company.")
|
||||
|
||||
employee = db.query(Employee).filter(Employee.id == eid).one_or_none()
|
||||
if employee is None:
|
||||
error_output(f"Employee {employee_id} not found.")
|
||||
if employee.company_id != sim_state.company_id:
|
||||
error_output(f"Employee {employee_id} does not belong to your company.")
|
||||
|
||||
# Check if already assigned
|
||||
existing = db.query(TaskAssignment).filter(
|
||||
TaskAssignment.task_id == tid,
|
||||
TaskAssignment.employee_id == eid,
|
||||
).one_or_none()
|
||||
if existing is not None:
|
||||
error_output(f"Employee {employee_id} is already assigned to task {task_id}.")
|
||||
|
||||
assignment = TaskAssignment(
|
||||
task_id=tid,
|
||||
employee_id=eid,
|
||||
assigned_at=sim_state.sim_time,
|
||||
)
|
||||
db.add(assignment)
|
||||
db.flush()
|
||||
|
||||
# Recalculate ETAs for all active tasks sharing this employee
|
||||
if task.status == TaskStatus.ACTIVE:
|
||||
emp_assignments = db.query(TaskAssignment).filter(
|
||||
TaskAssignment.employee_id == eid
|
||||
).all()
|
||||
impacted = set()
|
||||
for ea in emp_assignments:
|
||||
t = db.query(Task).filter(Task.id == ea.task_id).one_or_none()
|
||||
if t and t.status == TaskStatus.ACTIVE:
|
||||
impacted.add(t.id)
|
||||
if impacted:
|
||||
recalculate_etas(db, sim_state.company_id, sim_state.sim_time, impacted, half_threshold=_get_world_cfg().task_half_threshold)
|
||||
|
||||
# Return current assignment list
|
||||
assignments = db.query(TaskAssignment).filter(TaskAssignment.task_id == tid).all()
|
||||
assignment_list = [
|
||||
{
|
||||
"employee_id": str(a.employee_id),
|
||||
"assigned_at": a.assigned_at.isoformat(),
|
||||
}
|
||||
for a in assignments
|
||||
]
|
||||
|
||||
json_output({
|
||||
"task_id": str(task.id),
|
||||
"status": task.status.value,
|
||||
"assignments": assignment_list,
|
||||
})
|
||||
|
||||
|
||||
@task_app.command("dispatch")
|
||||
def task_dispatch(
|
||||
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to dispatch"),
|
||||
):
|
||||
"""Dispatch a planned task to active status."""
|
||||
try:
|
||||
tid = UUID(task_id)
|
||||
except ValueError:
|
||||
error_output(f"Invalid UUID: {task_id}")
|
||||
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
task = db.query(Task).filter(Task.id == tid).one_or_none()
|
||||
if task is None:
|
||||
error_output(f"Task {task_id} not found.")
|
||||
if task.status != TaskStatus.PLANNED:
|
||||
error_output(f"Task {task_id} must be planned to dispatch (current: {task.status.value}).")
|
||||
if task.company_id != sim_state.company_id:
|
||||
error_output(f"Task {task_id} does not belong to your company.")
|
||||
|
||||
# Validate at least one assignment
|
||||
assignment_count = db.query(func.count(TaskAssignment.employee_id)).filter(
|
||||
TaskAssignment.task_id == tid
|
||||
).scalar() or 0
|
||||
if assignment_count == 0:
|
||||
error_output(f"Task {task_id} has no assignments. Assign employees before dispatching.")
|
||||
|
||||
# Transition to active
|
||||
task.status = TaskStatus.ACTIVE
|
||||
db.flush()
|
||||
|
||||
# Recalculate ETAs for this task and other active tasks that share assigned employees.
|
||||
impacted = {tid}
|
||||
assigned = db.query(TaskAssignment).filter(TaskAssignment.task_id == tid).all()
|
||||
for a in assigned:
|
||||
peer_assignments = db.query(TaskAssignment).filter(
|
||||
TaskAssignment.employee_id == a.employee_id
|
||||
).all()
|
||||
for pa in peer_assignments:
|
||||
if pa.task_id == tid:
|
||||
continue
|
||||
peer_task = db.query(Task).filter(Task.id == pa.task_id).one_or_none()
|
||||
if peer_task and peer_task.status == TaskStatus.ACTIVE:
|
||||
impacted.add(peer_task.id)
|
||||
recalculate_etas(db, sim_state.company_id, sim_state.sim_time, impacted, half_threshold=_get_world_cfg().task_half_threshold)
|
||||
|
||||
json_output({
|
||||
"task_id": str(task.id),
|
||||
"status": task.status.value,
|
||||
"assignment_count": assignment_count,
|
||||
})
|
||||
|
||||
|
||||
@task_app.command("list")
|
||||
def task_list(
|
||||
status: Optional[str] = typer.Option(None, "--status", help="Filter by task status"),
|
||||
):
|
||||
"""List tasks owned by the company."""
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
query = db.query(Task).filter(Task.company_id == sim_state.company_id)
|
||||
|
||||
if status is not None:
|
||||
try:
|
||||
ts = TaskStatus(status)
|
||||
except ValueError:
|
||||
error_output(f"Invalid status: {status}. Valid: {[s.value for s in TaskStatus]}")
|
||||
query = query.filter(Task.status == ts)
|
||||
|
||||
tasks = query.order_by(Task.accepted_at.desc().nulls_last()).all()
|
||||
results = []
|
||||
for task in tasks:
|
||||
# Compute progress %
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task.id).all()
|
||||
total_required = sum(float(r.required_qty) for r in reqs)
|
||||
total_completed = sum(float(r.completed_qty) for r in reqs)
|
||||
progress_pct = (total_completed / total_required * 100) if total_required > 0 else 0.0
|
||||
|
||||
# Deadline risk
|
||||
at_risk = False
|
||||
if task.deadline and task.status == TaskStatus.ACTIVE:
|
||||
if sim_state.sim_time > task.deadline:
|
||||
at_risk = True
|
||||
|
||||
results.append({
|
||||
"task_id": str(task.id),
|
||||
"title": task.title,
|
||||
"status": task.status.value,
|
||||
"progress_pct": round(progress_pct, 2),
|
||||
"deadline": task.deadline.isoformat() if task.deadline else None,
|
||||
"at_risk": at_risk,
|
||||
})
|
||||
|
||||
json_output({
|
||||
"count": len(results),
|
||||
"tasks": results,
|
||||
})
|
||||
|
||||
|
||||
@task_app.command("inspect")
|
||||
def task_inspect(
|
||||
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to inspect"),
|
||||
):
|
||||
"""Inspect detailed task information."""
|
||||
try:
|
||||
tid = UUID(task_id)
|
||||
except ValueError:
|
||||
error_output(f"Invalid UUID: {task_id}")
|
||||
|
||||
with get_db() as db:
|
||||
task = db.query(Task).filter(Task.id == tid).one_or_none()
|
||||
if task is None:
|
||||
error_output(f"Task {task_id} not found.")
|
||||
|
||||
# Requirements
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == tid).all()
|
||||
requirements = []
|
||||
for r in reqs:
|
||||
requirements.append({
|
||||
"domain": r.domain.value,
|
||||
"required_qty": float(r.required_qty),
|
||||
"completed_qty": float(r.completed_qty),
|
||||
"remaining_qty": float(r.required_qty - r.completed_qty),
|
||||
})
|
||||
|
||||
# Assignments with employee info
|
||||
assignments_raw = db.query(TaskAssignment).filter(TaskAssignment.task_id == tid).all()
|
||||
assignments = []
|
||||
for a in assignments_raw:
|
||||
emp = db.query(Employee).filter(Employee.id == a.employee_id).one_or_none()
|
||||
assignments.append({
|
||||
"employee_id": str(a.employee_id),
|
||||
"employee_name": emp.name if emp else "unknown",
|
||||
"assigned_at": a.assigned_at.isoformat(),
|
||||
})
|
||||
|
||||
total_required = sum(float(r.required_qty) for r in reqs)
|
||||
total_completed = sum(float(r.completed_qty) for r in reqs)
|
||||
progress_pct = (total_completed / total_required * 100) if total_required > 0 else 0.0
|
||||
|
||||
json_output({
|
||||
"task_id": str(task.id),
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"status": task.status.value,
|
||||
"required_prestige": task.required_prestige,
|
||||
"reward_funds_cents": task.reward_funds_cents,
|
||||
"reward_prestige_delta": float(task.reward_prestige_delta),
|
||||
"skill_boost_pct": float(task.skill_boost_pct),
|
||||
"accepted_at": task.accepted_at.isoformat() if task.accepted_at else None,
|
||||
"deadline": task.deadline.isoformat() if task.deadline else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||
"success": task.success,
|
||||
"progress_pct": round(progress_pct, 2),
|
||||
"requirements": requirements,
|
||||
"assignments": assignments,
|
||||
})
|
||||
|
||||
|
||||
@task_app.command("cancel")
|
||||
def task_cancel(
|
||||
task_id: str = typer.Option(..., "--task-id", help="UUID of the task to cancel"),
|
||||
reason: str = typer.Option(..., "--reason", help="Reason for cancellation"),
|
||||
):
|
||||
"""Cancel a task and apply prestige penalty."""
|
||||
try:
|
||||
tid = UUID(task_id)
|
||||
except ValueError:
|
||||
error_output(f"Invalid UUID: {task_id}")
|
||||
|
||||
with get_db() as db:
|
||||
sim_state = db.query(SimState).first()
|
||||
if sim_state is None:
|
||||
error_output("No simulation found.")
|
||||
|
||||
task = db.query(Task).filter(Task.id == tid).one_or_none()
|
||||
if task is None:
|
||||
error_output(f"Task {task_id} not found.")
|
||||
if task.status not in (TaskStatus.PLANNED, TaskStatus.ACTIVE):
|
||||
error_output(f"Task {task_id} cannot be cancelled (current: {task.status.value}).")
|
||||
if task.company_id != sim_state.company_id:
|
||||
error_output(f"Task {task_id} does not belong to your company.")
|
||||
|
||||
# Apply prestige penalty: penalty_cancel_multiplier * reward_prestige_delta across task's required domains
|
||||
_cfg = _get_world_cfg()
|
||||
cancel_penalty = Decimal(str(_cfg.penalty_cancel_multiplier)) * task.reward_prestige_delta
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == tid).all()
|
||||
penalties_applied = {}
|
||||
|
||||
for req in reqs:
|
||||
prestige = db.query(CompanyPrestige).filter(
|
||||
CompanyPrestige.company_id == sim_state.company_id,
|
||||
CompanyPrestige.domain == req.domain,
|
||||
).one_or_none()
|
||||
if prestige is not None:
|
||||
old_val = prestige.prestige_level
|
||||
new_val = max(Decimal(str(_cfg.prestige_min)), prestige.prestige_level - cancel_penalty)
|
||||
prestige.prestige_level = new_val
|
||||
penalties_applied[req.domain.value] = {
|
||||
"old": float(old_val),
|
||||
"new": float(new_val),
|
||||
"delta": float(old_val - new_val),
|
||||
}
|
||||
|
||||
# Set status to cancelled
|
||||
task.status = TaskStatus.CANCELLED
|
||||
|
||||
# Drop pending events for this task
|
||||
pending_events = db.query(SimEvent).filter(
|
||||
SimEvent.company_id == sim_state.company_id,
|
||||
SimEvent.consumed == False,
|
||||
SimEvent.payload["task_id"].astext == str(tid),
|
||||
).all()
|
||||
for ev in pending_events:
|
||||
ev.consumed = True
|
||||
|
||||
# Recalculate ETAs for tasks sharing freed employees
|
||||
cancelled_assignments = db.query(TaskAssignment).filter(
|
||||
TaskAssignment.task_id == tid
|
||||
).all()
|
||||
freed_emp_ids = {a.employee_id for a in cancelled_assignments}
|
||||
impacted = set()
|
||||
for emp_id in freed_emp_ids:
|
||||
emp_assignments = db.query(TaskAssignment).filter(
|
||||
TaskAssignment.employee_id == emp_id
|
||||
).all()
|
||||
for ea in emp_assignments:
|
||||
if ea.task_id != tid:
|
||||
t = db.query(Task).filter(Task.id == ea.task_id).one_or_none()
|
||||
if t and t.status == TaskStatus.ACTIVE:
|
||||
impacted.add(t.id)
|
||||
if impacted:
|
||||
recalculate_etas(db, sim_state.company_id, sim_state.sim_time, impacted, half_threshold=_get_world_cfg().task_half_threshold)
|
||||
|
||||
# Bankruptcy check
|
||||
company = db.query(Company).filter(Company.id == sim_state.company_id).one()
|
||||
bankrupt = company.funds_cents < 0
|
||||
|
||||
db.flush()
|
||||
|
||||
json_output({
|
||||
"task_id": str(task.id),
|
||||
"status": task.status.value,
|
||||
"reason": reason,
|
||||
"cancel_penalty_per_domain": float(cancel_penalty),
|
||||
"prestige_changes": penalties_applied,
|
||||
"bankrupt": bankrupt,
|
||||
})
|
||||
44
src/yc_bench/config.py
Normal file
44
src/yc_bench/config.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
from .agent.runtime.schemas import RuntimeSettings
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeConfig:
|
||||
model: str
|
||||
temperature: float
|
||||
top_p: float
|
||||
history_keep_rounds: int = 20
|
||||
|
||||
def load_runtime_config():
|
||||
model = os.getenv("YC_BENCH_MODEL", "gpt-5.2")
|
||||
|
||||
temperature = float(os.getenv("YC_BENCH_TEMPERATURE", "0"))
|
||||
top_p = float(os.getenv("YC_BENCH_TOP_P", "1"))
|
||||
history_keep_rounds = int(os.getenv("YC_BENCH_HISTORY_KEEP_ROUNDS", "20"))
|
||||
|
||||
if temperature < 0:
|
||||
raise ValueError("YC_BENCH_TEMPERATURE must be >= 0")
|
||||
if top_p <= 0 or top_p > 1:
|
||||
raise ValueError("YC_BENCH_TOP_P must be in (0, 1]")
|
||||
if history_keep_rounds < 1:
|
||||
raise ValueError("YC_BENCH_HISTORY_KEEP_ROUNDS must be >= 1")
|
||||
|
||||
return RuntimeConfig(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
history_keep_rounds=history_keep_rounds,
|
||||
)
|
||||
|
||||
def to_runtime_settings(cfg):
|
||||
return RuntimeSettings(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
top_p=cfg.top_p,
|
||||
history_keep_rounds=cfg.history_keep_rounds,
|
||||
)
|
||||
|
||||
__all__ = ["RuntimeConfig", "load_runtime_config", "to_runtime_settings"]
|
||||
27
src/yc_bench/config/__init__.py
Normal file
27
src/yc_bench/config/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
|
||||
from .loader import load_config
|
||||
from .schema import ExperimentConfig, AgentConfig, LoopConfig, SimConfig, WorldConfig, SalaryTierConfig
|
||||
|
||||
|
||||
def get_world_config() -> WorldConfig:
|
||||
"""Load WorldConfig from the active experiment (YC_BENCH_EXPERIMENT env var, default: 'default').
|
||||
|
||||
Falls back to default WorldConfig if config loading fails (e.g. outside a benchmark run).
|
||||
"""
|
||||
try:
|
||||
return load_config(os.environ.get("YC_BENCH_EXPERIMENT", "default")).world
|
||||
except Exception:
|
||||
return WorldConfig()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_config",
|
||||
"get_world_config",
|
||||
"ExperimentConfig",
|
||||
"AgentConfig",
|
||||
"LoopConfig",
|
||||
"SimConfig",
|
||||
"WorldConfig",
|
||||
"SalaryTierConfig",
|
||||
]
|
||||
124
src/yc_bench/config/loader.py
Normal file
124
src/yc_bench/config/loader.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Load ExperimentConfig from a built-in preset name or a user TOML file path.
|
||||
|
||||
Usage::
|
||||
|
||||
from yc_bench.config import load_config
|
||||
|
||||
cfg = load_config("default") # built-in preset
|
||||
cfg = load_config("fast_test") # built-in preset
|
||||
cfg = load_config("./my_run.toml") # user file, absolute or relative path
|
||||
cfg = load_config("/abs/path/run.toml") # user file, absolute path
|
||||
|
||||
User TOML files can inherit from a preset with ``extends``::
|
||||
|
||||
# my_run.toml
|
||||
extends = "default"
|
||||
|
||||
[agent]
|
||||
model = "openrouter/anthropic/claude-3.5-sonnet"
|
||||
|
||||
[world]
|
||||
num_employees = 15
|
||||
|
||||
Environment variable overrides (applied last, highest priority)::
|
||||
|
||||
YC_BENCH_MODEL override agent.model
|
||||
YC_BENCH_TEMPERATURE override agent.temperature
|
||||
YC_BENCH_TOP_P override agent.top_p
|
||||
YC_BENCH_HISTORY_KEEP_ROUNDS override agent.history_keep_rounds
|
||||
YC_BENCH_AUTO_ADVANCE_TURNS override loop.auto_advance_after_turns
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tomllib
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
from .schema import ExperimentConfig
|
||||
|
||||
|
||||
def load_config(path_or_name: str = "default") -> ExperimentConfig:
|
||||
"""Load a named preset or a TOML file path, apply env-var overrides."""
|
||||
raw = _read_raw(path_or_name)
|
||||
|
||||
# Handle `extends = "preset_name"` inheritance
|
||||
if "extends" in raw:
|
||||
base_name = raw.pop("extends")
|
||||
base_raw = _read_preset(base_name)
|
||||
raw = _deep_merge(base_raw, raw)
|
||||
|
||||
cfg = ExperimentConfig.model_validate(raw)
|
||||
return _apply_env_overrides(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _read_raw(path_or_name: str) -> dict:
|
||||
p = Path(path_or_name)
|
||||
# Treat as a file path if it has a .toml suffix or contains a path separator
|
||||
if p.suffix == ".toml" or os.sep in path_or_name or "/" in path_or_name:
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {p.resolve()}")
|
||||
return _read_file(p)
|
||||
return _read_preset(path_or_name)
|
||||
|
||||
|
||||
def _read_preset(name: str) -> dict:
|
||||
try:
|
||||
resource = files("yc_bench.config.presets") / f"{name}.toml"
|
||||
with resource.open("rb") as f:
|
||||
return tomllib.load(f)
|
||||
except (FileNotFoundError, TypeError):
|
||||
available = [
|
||||
p.name.replace(".toml", "")
|
||||
for p in Path(__file__).parent.joinpath("presets").glob("*.toml")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Unknown preset '{name}'. Available presets: {available}. "
|
||||
"Pass a file path ending in .toml for a custom config."
|
||||
)
|
||||
|
||||
|
||||
def _read_file(path: Path) -> dict:
|
||||
with open(path, "rb") as f:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
"""Recursively merge override on top of base. Override wins on conflicts."""
|
||||
result = dict(base)
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _apply_env_overrides(cfg: ExperimentConfig) -> ExperimentConfig:
|
||||
"""YC_BENCH_* env vars always win, regardless of config file."""
|
||||
agent_updates: dict = {}
|
||||
loop_updates: dict = {}
|
||||
|
||||
if v := os.environ.get("YC_BENCH_MODEL"):
|
||||
agent_updates["model"] = v
|
||||
if v := os.environ.get("YC_BENCH_TEMPERATURE"):
|
||||
agent_updates["temperature"] = float(v)
|
||||
if v := os.environ.get("YC_BENCH_TOP_P"):
|
||||
agent_updates["top_p"] = float(v)
|
||||
if v := os.environ.get("YC_BENCH_HISTORY_KEEP_ROUNDS"):
|
||||
agent_updates["history_keep_rounds"] = int(v)
|
||||
if v := os.environ.get("YC_BENCH_AUTO_ADVANCE_TURNS"):
|
||||
loop_updates["auto_advance_after_turns"] = int(v)
|
||||
|
||||
if agent_updates:
|
||||
cfg = cfg.model_copy(update={"agent": cfg.agent.model_copy(update=agent_updates)})
|
||||
if loop_updates:
|
||||
cfg = cfg.model_copy(update={"loop": cfg.loop.model_copy(update=loop_updates)})
|
||||
return cfg
|
||||
|
||||
|
||||
__all__ = ["load_config"]
|
||||
0
src/yc_bench/config/presets/.gitkeep
Normal file
0
src/yc_bench/config/presets/.gitkeep
Normal file
76
src/yc_bench/config/presets/challenge.toml
Normal file
76
src/yc_bench/config/presets/challenge.toml
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# challenge — 3-year benchmark designed for differentiated model behavior
|
||||
#
|
||||
# Design goals:
|
||||
# - Models that focus employees on tasks complete on time, build prestige,
|
||||
# and can reach the 3-year horizon.
|
||||
# - Models that spread employees or forget to dispatch go bankrupt.
|
||||
# - Wide outcome variance across models: some die early, some survive.
|
||||
#
|
||||
# Economics (seed=1, 5 employees):
|
||||
# Monthly payroll ≈ $32K. Starting runway ≈ 7.8 months.
|
||||
# Mode task: 2 domains × 700 units = 1400 total.
|
||||
# Deadline: max(7, 1400/200) = 7 business days (minimum).
|
||||
# All 5 focused on one task, split ~2+3: domain1 3×5.1×9=137.7 units/day →
|
||||
# 700/137.7 = 5.1 days; domain2 2×5.1×9=91.8 → 700/91.8 = 7.6 days > 7 FAIL.
|
||||
# → Need ≥3 employees per domain: put 3 on domain1 and 3 on domain2 (count ×2)
|
||||
# or accept single-domain tasks where all 5 focus.
|
||||
# When done right: complete 3 tasks/month × avg $30K = $90K > $32K payroll ✓
|
||||
# Late game: prestige-5 tasks pay $30K × 3.2× = $96K each → strong growth.
|
||||
|
||||
extends = "default"
|
||||
|
||||
name = "challenge"
|
||||
description = "3-year benchmark calibrated for differentiated model behavior. ~7.8 months starting runway. Strategy and focus determine survival."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simulation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[sim]
|
||||
horizon_years = 3
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[loop]
|
||||
# Force a sim advance if the agent goes 5 consecutive turns without one.
|
||||
auto_advance_after_turns = 5
|
||||
|
||||
# 500 turns covers 3 years for an efficient agent:
|
||||
# advancing every 3 turns → 167 advances × ~14 avg sim days = 2338 days = 6.4 yrs.
|
||||
# forced advances only → 100 advances × ~14 days = 1400 days = 3.8 yrs.
|
||||
max_turns = 500
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World — 5 employees keeps payroll manageable (~$32K/month).
|
||||
# 200 market tasks gives a deep enough pool to find tasks in any prestige tier.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[world]
|
||||
num_employees = 5
|
||||
num_market_tasks = 200
|
||||
deadline_qty_per_day = 200.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task size: mode=700 (smaller than hardened default 1400) so tasks complete
|
||||
# in ~5-10 sim days when employees are focused. This creates frequent reward
|
||||
# events visible in the funds curve, and makes payroll crises recoverable.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[world.dist.required_qty]
|
||||
type = "triangular"
|
||||
low = 200
|
||||
high = 2000
|
||||
mode = 700
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prestige: mode=3 ensures a healthy supply of prestige-1 and prestige-2 tasks
|
||||
# at the start. Agents that specialise unlock 4–8 tasks paying 2–5× more.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[world.dist.required_prestige]
|
||||
type = "triangular"
|
||||
low = 1
|
||||
high = 8
|
||||
mode = 3
|
||||
180
src/yc_bench/config/presets/default.toml
Normal file
180
src/yc_bench/config/presets/default.toml
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
# YC-Bench default preset — hardened 3-year benchmark
|
||||
#
|
||||
# This is the canonical benchmark configuration. Every parameter is listed
|
||||
# explicitly so this file serves as complete documentation of what is tunable.
|
||||
# Copy this file and override only what you need.
|
||||
|
||||
# experiment identity (top-level fields)
|
||||
name = "default"
|
||||
description = "Hardened 3-year benchmark. Most tasks require prestige 3-5, 2-domain work, tight deadlines, costly cancellations."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[agent]
|
||||
model = "openrouter/z-ai/glm-5"
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
request_timeout_seconds = 300.0
|
||||
retry_max_attempts = 3
|
||||
retry_backoff_seconds = 1.0
|
||||
|
||||
# Conversation rounds retained in context before each API call.
|
||||
# Older rounds are dropped proactively. Scratchpad survives truncation.
|
||||
history_keep_rounds = 20
|
||||
|
||||
# Optional system prompt override. Omit to use the built-in prompt.
|
||||
# system_prompt = "..."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[loop]
|
||||
# Consecutive turns without `sim resume` before the loop forces a time-advance.
|
||||
auto_advance_after_turns = 10
|
||||
|
||||
# Hard cap on total turns. Comment out or set to a large number for unlimited.
|
||||
# max_turns = 500
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simulation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[sim]
|
||||
start_date = "2025-01-01" # ISO 8601
|
||||
horizon_years = 3
|
||||
company_name = "BenchCo"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World generation — scalar parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[world]
|
||||
num_employees = 10
|
||||
initial_funds_cents = 25_000_000 # $250,000
|
||||
initial_prestige_level = 1.0
|
||||
work_hours_per_day = 9.0
|
||||
|
||||
num_market_tasks = 500
|
||||
market_browse_default_limit = 50
|
||||
|
||||
# Salary bump per completed task — each assigned employee gets this raise.
|
||||
# Compounds over time, accelerating payroll pressure.
|
||||
salary_bump_pct = 0.01
|
||||
|
||||
# --- Prestige mechanics ---
|
||||
prestige_max = 10.0
|
||||
prestige_min = 1.0
|
||||
|
||||
# Cancelling (2x) is worse than failing (1.4x) at the margin — accepting
|
||||
# is a real commitment, not a free option.
|
||||
penalty_fail_multiplier = 1.4 # hardened: was 0.8
|
||||
penalty_cancel_multiplier = 2.0 # hardened: was 1.2
|
||||
|
||||
# Extra reward fraction per prestige level above 1.
|
||||
# At 0.55: a prestige-8 task pays ~4.85x more than a prestige-1 task.
|
||||
reward_prestige_scale = 0.55 # hardened: was 0.3
|
||||
|
||||
# --- Deadline ---
|
||||
# 320 units/day is ~60% tighter than the base 200 — need real employee-hours.
|
||||
deadline_qty_per_day = 320.0 # hardened: was 200.0
|
||||
deadline_min_biz_days = 7
|
||||
|
||||
# --- Progress milestone ---
|
||||
task_half_threshold = 0.5
|
||||
|
||||
# --- Business hours ---
|
||||
workday_start_hour = 9
|
||||
workday_end_hour = 18
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World distributions
|
||||
#
|
||||
# Each [world.dist.<name>] section specifies the distribution used for one
|
||||
# random quantity. Change `type` to switch families; tune the parameters to
|
||||
# reshape the distribution.
|
||||
#
|
||||
# Available types: triangular, beta, normal, uniform, constant
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Prestige level required to accept a task (cast to int after sampling).
|
||||
# mode=4 means most tasks need prestige 3-5 — the early market is thin,
|
||||
# forcing agents to choose which domains to specialise in.
|
||||
[world.dist.required_prestige]
|
||||
type = "triangular"
|
||||
low = 1
|
||||
high = 10
|
||||
mode = 4 # hardened: base default is mode=1
|
||||
|
||||
# Base reward paid on task completion, in cents (scaled further by prestige).
|
||||
# Higher-prestige tasks automatically pay more via reward_prestige_scale.
|
||||
[world.dist.reward_funds_cents]
|
||||
type = "triangular"
|
||||
low = 500_000 # $5,000
|
||||
high = 10_000_000 # $100,000
|
||||
mode = 3_000_000 # $30,000
|
||||
|
||||
# Number of domains each task requires work in (cast to int after sampling).
|
||||
# mode=2: most tasks need 2 domains — single-specialist dominance gone.
|
||||
[world.dist.domain_count]
|
||||
type = "triangular"
|
||||
low = 1
|
||||
high = 3
|
||||
mode = 2 # hardened: base default is mode=1
|
||||
|
||||
# Work units required per domain per task (cast to int after sampling).
|
||||
# No trivially-small tasks: every task requires sustained employee-hours.
|
||||
[world.dist.required_qty]
|
||||
type = "triangular"
|
||||
low = 500 # hardened: base default is 200
|
||||
high = 3000
|
||||
mode = 1400 # hardened: base default is 800
|
||||
|
||||
# Prestige delta awarded per required domain on task success.
|
||||
# Left-skewed beta: most completions give small gains; occasional large jumps.
|
||||
[world.dist.reward_prestige_delta]
|
||||
type = "beta"
|
||||
alpha = 1.2
|
||||
beta = 2.8
|
||||
scale = 2.0
|
||||
low = 0.0
|
||||
high = 2.0
|
||||
|
||||
# Skill rate boost applied to each assigned employee on task success.
|
||||
# Expressed as a fraction of current rate: 0.12 → average 12% boost.
|
||||
[world.dist.skill_boost]
|
||||
type = "normal"
|
||||
mean = 0.12
|
||||
stdev = 0.06
|
||||
low = 0.01
|
||||
high = 0.40
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Salary tiers — all three share values must sum to exactly 1.0
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[world.salary_junior]
|
||||
name = "junior"
|
||||
share = 0.50
|
||||
min_cents = 200_000 # $2,000/month
|
||||
max_cents = 400_000 # $4,000/month
|
||||
rate_min = 1.0 # units/hour
|
||||
rate_max = 6.5
|
||||
|
||||
[world.salary_mid]
|
||||
name = "mid"
|
||||
share = 0.35
|
||||
min_cents = 600_000 # $6,000/month
|
||||
max_cents = 800_000 # $8,000/month
|
||||
rate_min = 3.5
|
||||
rate_max = 8.5
|
||||
|
||||
[world.salary_senior]
|
||||
name = "senior"
|
||||
share = 0.15
|
||||
min_cents = 1_000_000 # $10,000/month
|
||||
max_cents = 1_500_000 # $15,000/month
|
||||
rate_min = 5.5
|
||||
rate_max = 10.0
|
||||
28
src/yc_bench/config/presets/fast_test.toml
Normal file
28
src/yc_bench/config/presets/fast_test.toml
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# fast_test — quick 1-year run for iteration and debugging
|
||||
#
|
||||
# Changes from default:
|
||||
# - 1-year horizon instead of 3
|
||||
# - 5 employees, 100 market tasks (faster seeding)
|
||||
# - auto-advance after 5 idle turns instead of 10
|
||||
# - 50-turn hard cap (prevents runaway debug runs)
|
||||
|
||||
extends = "default"
|
||||
|
||||
name = "fast_test"
|
||||
description = "1-year horizon with minimal world for fast iteration."
|
||||
|
||||
[sim]
|
||||
horizon_years = 1
|
||||
|
||||
[loop]
|
||||
auto_advance_after_turns = 5
|
||||
max_turns = 50
|
||||
|
||||
[world]
|
||||
num_employees = 5
|
||||
num_market_tasks = 100
|
||||
|
||||
# With 5 employees (~220 units/day team throughput), the hardened default of
|
||||
# 320 is physically impossible for most tasks. 200 restores parity so a
|
||||
# focused agent can actually win tasks on time.
|
||||
deadline_qty_per_day = 200
|
||||
28
src/yc_bench/config/presets/high_reward.toml
Normal file
28
src/yc_bench/config/presets/high_reward.toml
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# high_reward — inflated reward pool to test prestige ceiling behaviour
|
||||
#
|
||||
# Changes from default:
|
||||
# - Task reward funds 2x higher (mode: $60k, max: $200k)
|
||||
# - Higher prestige delta rewards (up to 3.0 instead of 2.0)
|
||||
# - Larger market pool (1000 tasks)
|
||||
|
||||
extends = "default"
|
||||
|
||||
name = "high_reward"
|
||||
description = "2x reward scaling to test prestige ceiling and late-game behaviour."
|
||||
|
||||
[world]
|
||||
num_market_tasks = 1000
|
||||
|
||||
[world.dist.reward_funds_cents]
|
||||
type = "triangular"
|
||||
low = 1_000_000 # $10,000
|
||||
high = 20_000_000 # $200,000
|
||||
mode = 6_000_000 # $60,000
|
||||
|
||||
[world.dist.reward_prestige_delta]
|
||||
type = "beta"
|
||||
alpha = 1.2
|
||||
beta = 2.8
|
||||
scale = 2.0
|
||||
low = 0.0
|
||||
high = 3.0 # raised ceiling: up to 3.0 instead of 2.0
|
||||
104
src/yc_bench/config/sampling.py
Normal file
104
src/yc_bench/config/sampling.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Distribution specs and sampling.
|
||||
|
||||
Every random draw in world generation is described by a DistSpec. Callers
|
||||
pass the spec from WorldConfig — changing distribution families or parameters
|
||||
only requires a TOML edit, no code change.
|
||||
|
||||
Supported families
|
||||
------------------
|
||||
triangular low, high, mode → random.triangular
|
||||
beta alpha, beta, scale, low, high → scale × Beta(α,β), clamped
|
||||
normal mean, stdev, low, high → gauss, clamped
|
||||
uniform low, high → random.uniform
|
||||
constant value → always returns value (useful for ablations)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Distribution spec models (one per family)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TriangularDist(BaseModel):
|
||||
type: Literal["triangular"] = "triangular"
|
||||
low: float
|
||||
high: float
|
||||
mode: float
|
||||
|
||||
|
||||
class BetaDist(BaseModel):
|
||||
type: Literal["beta"] = "beta"
|
||||
alpha: float
|
||||
beta: float
|
||||
scale: float = 1.0
|
||||
low: float = 0.0
|
||||
high: float = 1.0
|
||||
|
||||
|
||||
class NormalDist(BaseModel):
|
||||
type: Literal["normal"] = "normal"
|
||||
mean: float
|
||||
stdev: float
|
||||
low: float
|
||||
high: float
|
||||
|
||||
|
||||
class UniformDist(BaseModel):
|
||||
type: Literal["uniform"] = "uniform"
|
||||
low: float
|
||||
high: float
|
||||
|
||||
|
||||
class ConstantDist(BaseModel):
|
||||
type: Literal["constant"] = "constant"
|
||||
value: float
|
||||
|
||||
|
||||
# Discriminated union — Pydantic picks the right model from the "type" field
|
||||
DistSpec = Annotated[
|
||||
Union[TriangularDist, BetaDist, NormalDist, UniformDist, ConstantDist],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sample_from_spec(rng: random.Random, spec: DistSpec) -> float:
|
||||
"""Draw one sample from the given distribution spec."""
|
||||
if isinstance(spec, TriangularDist):
|
||||
val = rng.triangular(spec.low, spec.high, spec.mode)
|
||||
return max(spec.low, min(spec.high, val))
|
||||
|
||||
if isinstance(spec, BetaDist):
|
||||
val = spec.scale * rng.betavariate(spec.alpha, spec.beta)
|
||||
return round(max(spec.low, min(spec.high, val)), 4)
|
||||
|
||||
if isinstance(spec, NormalDist):
|
||||
val = rng.gauss(spec.mean, spec.stdev)
|
||||
return round(max(spec.low, min(spec.high, val)), 4)
|
||||
|
||||
if isinstance(spec, UniformDist):
|
||||
return rng.uniform(spec.low, spec.high)
|
||||
|
||||
if isinstance(spec, ConstantDist):
|
||||
return spec.value
|
||||
|
||||
raise TypeError(f"Unknown DistSpec type: {type(spec)}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BetaDist",
|
||||
"ConstantDist",
|
||||
"DistSpec",
|
||||
"NormalDist",
|
||||
"TriangularDist",
|
||||
"UniformDist",
|
||||
"sample_from_spec",
|
||||
]
|
||||
194
src/yc_bench/config/schema.py
Normal file
194
src/yc_bench/config/schema.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
"""Pydantic models for all experiment configuration.
|
||||
|
||||
Every tunable parameter lives here. TOML files are validated against these
|
||||
models — Pydantic catches typos and type errors at load time.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .sampling import BetaDist, ConstantDist, NormalDist, TriangularDist, UniformDist, DistSpec # noqa: F401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Salary tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SalaryTierConfig(BaseModel):
|
||||
name: str
|
||||
share: float # fraction of employees in this tier (all tiers must sum to 1.0)
|
||||
min_cents: int # minimum monthly salary in cents
|
||||
max_cents: int # maximum monthly salary in cents
|
||||
rate_min: float # minimum skill rate (units/hour)
|
||||
rate_max: float # maximum skill rate (units/hour)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World distributions
|
||||
#
|
||||
# Each field names a random quantity in world generation and specifies which
|
||||
# distribution family + parameters to use. Changing `type` switches families;
|
||||
# changing parameters tunes the shape. See config/sampling.py for all families.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WorldDists(BaseModel):
|
||||
# Prestige level required to accept a task (result cast to int).
|
||||
# Any DistSpec family works — e.g. constant for ablations, uniform for flat sampling.
|
||||
required_prestige: DistSpec = Field(
|
||||
default_factory=lambda: TriangularDist(low=1, high=10, mode=1)
|
||||
)
|
||||
# Base reward paid on task completion, in cents (result cast to int).
|
||||
reward_funds_cents: DistSpec = Field(
|
||||
default_factory=lambda: TriangularDist(low=500_000, high=10_000_000, mode=3_000_000)
|
||||
)
|
||||
# Number of domains required per task (result cast to int).
|
||||
domain_count: DistSpec = Field(
|
||||
default_factory=lambda: TriangularDist(low=1, high=3, mode=1)
|
||||
)
|
||||
# Work units required per domain (result cast to int).
|
||||
required_qty: DistSpec = Field(
|
||||
default_factory=lambda: TriangularDist(low=200, high=3000, mode=800)
|
||||
)
|
||||
# Prestige delta awarded on task success.
|
||||
reward_prestige_delta: DistSpec = Field(
|
||||
default_factory=lambda: BetaDist(alpha=1.2, beta=2.8, scale=2.0, low=0.0, high=2.0)
|
||||
)
|
||||
# Skill rate boost fraction applied to each assigned employee on task success.
|
||||
skill_boost: DistSpec = Field(
|
||||
default_factory=lambda: NormalDist(mean=0.12, stdev=0.06, low=0.01, high=0.40)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent / LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
model: str = "openrouter/z-ai/glm-5"
|
||||
temperature: float = 0.0
|
||||
top_p: float = 1.0
|
||||
request_timeout_seconds: float = 300.0
|
||||
retry_max_attempts: int = 3
|
||||
retry_backoff_seconds: float = 1.0
|
||||
# Conversation rounds kept in context before each API call; older rounds dropped.
|
||||
history_keep_rounds: int = 20
|
||||
# Optional system prompt override. None = use default from agent/prompt.py
|
||||
system_prompt: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LoopConfig(BaseModel):
|
||||
# Consecutive turns without `sim resume` before the loop forces a time-advance.
|
||||
auto_advance_after_turns: int = 10
|
||||
# Hard cap on total turns. null = unlimited.
|
||||
max_turns: int | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simulation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SimConfig(BaseModel):
|
||||
start_date: str = "2025-01-01" # ISO 8601 (YYYY-MM-DD)
|
||||
horizon_years: int = 3
|
||||
company_name: str = "BenchCo"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WorldConfig(BaseModel):
|
||||
# --- Workforce ---
|
||||
num_employees: int = 10
|
||||
initial_funds_cents: int = 25_000_000 # $250,000
|
||||
initial_prestige_level: float = 1.0
|
||||
work_hours_per_day: float = 9.0
|
||||
|
||||
# --- Market ---
|
||||
num_market_tasks: int = 500
|
||||
market_browse_default_limit: int = 50
|
||||
|
||||
# --- Salary bump on task completion ---
|
||||
salary_bump_pct: float = 0.01 # 1% raise per assigned employee per completed task
|
||||
|
||||
# --- Prestige mechanics ---
|
||||
prestige_max: float = 10.0
|
||||
prestige_min: float = 1.0
|
||||
penalty_fail_multiplier: float = 0.8
|
||||
penalty_cancel_multiplier: float = 1.2
|
||||
# Extra reward fraction per prestige level above 1.
|
||||
# At 0.55: prestige-8 tasks pay ~4.85x more than prestige-1.
|
||||
reward_prestige_scale: float = 0.3
|
||||
|
||||
# --- Deadline computation ---
|
||||
deadline_qty_per_day: float = 200.0 # work units assumed completable per business day
|
||||
deadline_min_biz_days: int = 7
|
||||
|
||||
# --- Progress milestone ---
|
||||
task_half_threshold: float = 0.5
|
||||
|
||||
# --- Business hours ---
|
||||
workday_start_hour: int = 9
|
||||
workday_end_hour: int = 18
|
||||
|
||||
# --- Distributions (shape of random draws during world generation) ---
|
||||
dist: WorldDists = Field(default_factory=WorldDists)
|
||||
|
||||
# --- Salary tiers ---
|
||||
salary_junior: SalaryTierConfig = Field(
|
||||
default_factory=lambda: SalaryTierConfig(
|
||||
name="junior", share=0.50,
|
||||
min_cents=200_000, max_cents=400_000,
|
||||
rate_min=1.0, rate_max=6.5,
|
||||
)
|
||||
)
|
||||
salary_mid: SalaryTierConfig = Field(
|
||||
default_factory=lambda: SalaryTierConfig(
|
||||
name="mid", share=0.35,
|
||||
min_cents=600_000, max_cents=800_000,
|
||||
rate_min=3.5, rate_max=8.5,
|
||||
)
|
||||
)
|
||||
salary_senior: SalaryTierConfig = Field(
|
||||
default_factory=lambda: SalaryTierConfig(
|
||||
name="senior", share=0.15,
|
||||
min_cents=1_000_000, max_cents=1_500_000,
|
||||
rate_min=5.5, rate_max=10.0,
|
||||
)
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _salary_shares_sum_to_one(self) -> WorldConfig:
|
||||
total = self.salary_junior.share + self.salary_mid.share + self.salary_senior.share
|
||||
if abs(total - 1.0) > 1e-6:
|
||||
raise ValueError(f"salary tier shares must sum to 1.0, got {total:.6f}")
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level experiment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ExperimentConfig(BaseModel):
|
||||
name: str = "default"
|
||||
description: str = ""
|
||||
agent: AgentConfig = Field(default_factory=AgentConfig)
|
||||
loop: LoopConfig = Field(default_factory=LoopConfig)
|
||||
sim: SimConfig = Field(default_factory=SimConfig)
|
||||
world: WorldConfig = Field(default_factory=WorldConfig)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentConfig",
|
||||
"DistSpec",
|
||||
"ExperimentConfig",
|
||||
"LoopConfig",
|
||||
"SalaryTierConfig",
|
||||
"SimConfig",
|
||||
"WorldConfig",
|
||||
"WorldDists",
|
||||
]
|
||||
0
src/yc_bench/core/__init__.py
Normal file
0
src/yc_bench/core/__init__.py
Normal file
159
src/yc_bench/core/business_time.py
Normal file
159
src/yc_bench/core/business_time.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, time
|
||||
from decimal import Decimal
|
||||
|
||||
WORKDAY_START = time(9, 0)
|
||||
WORKDAY_END = time(18, 0)
|
||||
WORK_HOURS_PER_DAY = Decimal("9")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BusinessCalendar:
|
||||
workday_start: time = WORKDAY_START
|
||||
workday_end: time = WORKDAY_END
|
||||
|
||||
DEFAULT_CALENDAR = BusinessCalendar()
|
||||
|
||||
def is_weekday(ts):
|
||||
return ts.weekday() < 5
|
||||
|
||||
def is_business_time(ts, cal=DEFAULT_CALENDAR):
|
||||
if not is_weekday(ts):
|
||||
return False
|
||||
t = ts.timetz().replace(tzinfo=None) if ts.tzinfo else ts.time()
|
||||
return cal.workday_start <= t < cal.workday_end
|
||||
|
||||
def _day_start(ts, cal):
|
||||
return ts.replace(
|
||||
hour=cal.workday_start.hour,
|
||||
minute=cal.workday_start.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
def _day_end(ts, cal):
|
||||
return ts.replace(
|
||||
hour=cal.workday_end.hour,
|
||||
minute=cal.workday_end.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
def _next_weekday_start(ts, cal):
|
||||
cur = _day_start(ts, cal)
|
||||
while not is_weekday(cur):
|
||||
cur += timedelta(days=1)
|
||||
cur = _day_start(cur, cal)
|
||||
return cur
|
||||
|
||||
def next_business_time(ts, cal):
|
||||
if is_business_time(ts, cal):
|
||||
return ts
|
||||
if not is_weekday(ts):
|
||||
return _next_weekday_start(ts, cal)
|
||||
|
||||
day_start = _day_start(ts, cal)
|
||||
day_end = _day_end(ts, cal)
|
||||
|
||||
if ts < day_start:
|
||||
return day_start
|
||||
|
||||
if ts >= day_end:
|
||||
return _next_weekday_start(ts + timedelta(days=1), cal)
|
||||
|
||||
raise ValueError(f"No valid business time found after {ts}")
|
||||
|
||||
def add_business_hours(ts, hours, cal=DEFAULT_CALENDAR):
|
||||
hours = Decimal(str(hours))
|
||||
if hours < 0:
|
||||
raise ValueError(f"Cannot add negative business hours: {hours}")
|
||||
if hours == 0:
|
||||
return next_business_time(ts, cal)
|
||||
|
||||
cur = next_business_time(ts, cal)
|
||||
remaining = hours
|
||||
|
||||
while remaining > 0:
|
||||
day_end = _day_end(cur, cal)
|
||||
available = Decimal(str((day_end - cur).total_seconds())) / Decimal("3600")
|
||||
|
||||
if remaining <= available:
|
||||
return cur + timedelta(seconds=float(remaining * Decimal("3600")))
|
||||
|
||||
remaining -= available
|
||||
cur = next_business_time(day_end, cal)
|
||||
|
||||
return cur
|
||||
|
||||
def _business_interval_same_day(start, end, cal):
|
||||
if end <= start:
|
||||
return Decimal("0")
|
||||
if not is_weekday(start):
|
||||
return Decimal("0")
|
||||
|
||||
day_start = _day_start(start, cal)
|
||||
day_end = _day_end(end, cal)
|
||||
|
||||
lo = max(start, day_start)
|
||||
hi = min(end, day_end)
|
||||
|
||||
if hi <= lo:
|
||||
return Decimal("0")
|
||||
return Decimal(str((hi - lo).total_seconds())) / Decimal("3600")
|
||||
|
||||
def business_hours_between(t0, t1, cal=DEFAULT_CALENDAR):
|
||||
if t1 <= t0:
|
||||
return Decimal("0")
|
||||
|
||||
cur = t0
|
||||
total = Decimal("0")
|
||||
while cur < t1:
|
||||
next_midnight = (cur + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
seg_end = min(next_midnight, t1)
|
||||
total += _business_interval_same_day(cur, seg_end, cal)
|
||||
cur = seg_end
|
||||
|
||||
return total
|
||||
|
||||
def business_seconds_between(t0, t1, cal=DEFAULT_CALENDAR):
|
||||
return int((business_hours_between(t0, t1, cal) * Decimal("3600")).to_integral_value())
|
||||
|
||||
def first_business_of_month(dt, cal=DEFAULT_CALENDAR):
|
||||
first = dt.replace(day=1, hour=cal.workday_start.hour, minute=cal.workday_start.minute, second=0, microsecond=0)
|
||||
while not is_weekday(first):
|
||||
first += timedelta(days=1)
|
||||
first = first.replace(hour=cal.workday_start.hour, minute=cal.workday_start.minute, second=0, microsecond=0)
|
||||
return first
|
||||
|
||||
def iter_monthly_payroll_boundaries(start, end, cal=DEFAULT_CALENDAR):
|
||||
if end <= start:
|
||||
return []
|
||||
|
||||
cursor = start.replace(day=1, hour=cal.workday_start.hour, minute=cal.workday_start.minute, second=0, microsecond=0)
|
||||
out = []
|
||||
|
||||
while cursor < end:
|
||||
boundary = first_business_of_month(cursor, cal)
|
||||
if start < boundary <= end:
|
||||
out.append(boundary)
|
||||
|
||||
if cursor.month == 12:
|
||||
cursor = cursor.replace(year=cursor.year + 1, month=1, day=1)
|
||||
else:
|
||||
cursor = cursor.replace(month=cursor.month + 1, day=1)
|
||||
|
||||
return out
|
||||
|
||||
__all__ = [
|
||||
"BusinessCalendar",
|
||||
"DEFAULT_CALENDAR",
|
||||
"is_weekday",
|
||||
"is_business_time",
|
||||
"next_business_time",
|
||||
"add_business_hours",
|
||||
"business_hours_between",
|
||||
"business_seconds_between",
|
||||
"first_business_of_month",
|
||||
"iter_monthly_payroll_boundaries",
|
||||
]
|
||||
210
src/yc_bench/core/engine.py
Normal file
210
src/yc_bench/core/engine.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Simulation engine: advance time with deterministic event processing.
|
||||
|
||||
Main loop:
|
||||
1. Find next action: earliest of (next_event, next_payroll_boundary, target_time)
|
||||
2. Flush progress from current_time to action_time
|
||||
3. If payroll: deduct salaries, write ledger entries, bankruptcy check
|
||||
4. If event: dispatch to handler, consume, bankruptcy check
|
||||
5. Loop until target or terminal condition
|
||||
|
||||
Payroll-event tie-breaking: payroll first at same timestamp (start-of-day obligation).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.models.company import Company
|
||||
from ..db.models.employee import Employee
|
||||
from ..db.models.event import EventType, SimEvent
|
||||
from ..db.models.ledger import LedgerCategory, LedgerEntry
|
||||
from ..db.models.sim_state import SimState
|
||||
from .business_time import iter_monthly_payroll_boundaries
|
||||
from .eta import recalculate_etas
|
||||
from .events import consume_event, fetch_next_event, insert_event
|
||||
from .handlers.bankruptcy import handle_bankruptcy
|
||||
from .handlers.horizon_end import handle_horizon_end
|
||||
from .handlers.task_complete import handle_task_complete
|
||||
from .handlers.task_half import handle_task_half
|
||||
from .progress import flush_progress
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdvanceResult:
|
||||
old_sim_time: str
|
||||
new_sim_time: str
|
||||
events_processed: int = 0
|
||||
payrolls_applied: int = 0
|
||||
balance_delta: int = 0
|
||||
bankrupt: bool = False
|
||||
horizon_reached: bool = False
|
||||
wake_events: List[Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
def apply_payroll(db: Session, company_id: UUID, time: datetime) -> bool:
|
||||
"""Deduct monthly salaries for all employees. Returns True if bankrupt after payroll."""
|
||||
company = db.query(Company).filter(Company.id == company_id).one()
|
||||
employees = db.query(Employee).filter(Employee.company_id == company_id).all()
|
||||
|
||||
total_payroll = 0
|
||||
for emp in employees:
|
||||
salary = int(emp.salary_cents)
|
||||
total_payroll += salary
|
||||
db.add(LedgerEntry(
|
||||
company_id=company_id,
|
||||
occurred_at=time,
|
||||
category=LedgerCategory.MONTHLY_PAYROLL,
|
||||
amount_cents=-salary,
|
||||
ref_type="employee",
|
||||
ref_id=emp.id,
|
||||
))
|
||||
|
||||
company.funds_cents -= total_payroll
|
||||
db.flush()
|
||||
|
||||
return company.funds_cents < 0
|
||||
|
||||
|
||||
def dispatch_event(db: Session, event: SimEvent, sim_time: datetime, company_id: UUID) -> Dict:
|
||||
"""Route event to appropriate handler. Returns result dict."""
|
||||
if event.event_type == EventType.TASK_HALF_PROGRESS:
|
||||
result = handle_task_half(db, event)
|
||||
return {"type": "task_half", "task_id": str(result.task_id), "handled": result.handled}
|
||||
|
||||
elif event.event_type == EventType.TASK_COMPLETED:
|
||||
result = handle_task_complete(db, event, sim_time)
|
||||
# Recalculate ETAs — freed employees change topology
|
||||
from ..config import get_world_config
|
||||
recalculate_etas(db, company_id, sim_time, half_threshold=get_world_config().task_half_threshold)
|
||||
return {
|
||||
"type": "task_completed",
|
||||
"task_id": str(result.task_id),
|
||||
"success": result.success,
|
||||
"funds_delta": result.funds_delta,
|
||||
"bankrupt": result.bankrupt,
|
||||
}
|
||||
|
||||
elif event.event_type == EventType.HORIZON_END:
|
||||
result = handle_horizon_end(db, event)
|
||||
return {"type": "horizon_end", "reached": result.reached}
|
||||
|
||||
elif event.event_type == EventType.BANKRUPTCY:
|
||||
result = handle_bankruptcy(db, event)
|
||||
return {"type": "bankruptcy", "bankrupt": result.bankrupt}
|
||||
|
||||
return {"type": "unknown", "event_type": event.event_type.value}
|
||||
|
||||
|
||||
def advance_time(
|
||||
db: Session,
|
||||
company_id: UUID,
|
||||
target_time: datetime,
|
||||
) -> AdvanceResult:
|
||||
"""Advance simulation from current sim_time to target_time, processing all events and payroll."""
|
||||
sim_state = db.query(SimState).filter(SimState.company_id == company_id).one()
|
||||
current_time = sim_state.sim_time
|
||||
old_time = current_time
|
||||
|
||||
company = db.query(Company).filter(Company.id == company_id).one()
|
||||
starting_funds = company.funds_cents
|
||||
|
||||
result = AdvanceResult(
|
||||
old_sim_time=old_time.isoformat(),
|
||||
new_sim_time=target_time.isoformat(),
|
||||
)
|
||||
|
||||
# Pre-compute payroll boundaries
|
||||
payroll_times = iter_monthly_payroll_boundaries(current_time, target_time)
|
||||
payroll_idx = 0
|
||||
|
||||
while True:
|
||||
# Find next payroll
|
||||
next_payroll = None
|
||||
if payroll_idx < len(payroll_times):
|
||||
next_payroll = payroll_times[payroll_idx]
|
||||
|
||||
# Find next event
|
||||
next_event = fetch_next_event(db, company_id, target_time)
|
||||
|
||||
# Determine next action time
|
||||
candidates = []
|
||||
if next_payroll is not None and next_payroll <= target_time:
|
||||
candidates.append(("payroll", next_payroll))
|
||||
if next_event is not None:
|
||||
candidates.append(("event", next_event.scheduled_at))
|
||||
candidates.append(("target", target_time))
|
||||
|
||||
# Sort: earliest time first; at same time, payroll before event before target
|
||||
action_priority = {"payroll": 0, "event": 1, "target": 2}
|
||||
candidates.sort(key=lambda c: (c[1], action_priority[c[0]]))
|
||||
|
||||
action_type, action_time = candidates[0]
|
||||
|
||||
# Flush progress from current_time to action_time
|
||||
if action_time > current_time:
|
||||
flush_progress(db, company_id, current_time, action_time)
|
||||
current_time = action_time
|
||||
|
||||
if action_type == "target":
|
||||
# Nothing due before/equal target; stop advancement.
|
||||
break
|
||||
|
||||
if action_type == "payroll":
|
||||
bankrupt = apply_payroll(db, company_id, current_time)
|
||||
result.payrolls_applied += 1
|
||||
payroll_idx += 1
|
||||
|
||||
if bankrupt:
|
||||
# Insert bankruptcy event at this time
|
||||
insert_event(
|
||||
db, company_id,
|
||||
EventType.BANKRUPTCY,
|
||||
current_time,
|
||||
{"reason": "funds_negative_after_payroll"},
|
||||
dedupe_key=f"bankruptcy:{current_time.isoformat()}",
|
||||
)
|
||||
result.bankrupt = True
|
||||
break
|
||||
|
||||
elif action_type == "event":
|
||||
event_result = dispatch_event(db, next_event, current_time, company_id)
|
||||
consume_event(db, next_event)
|
||||
result.events_processed += 1
|
||||
result.wake_events.append(event_result)
|
||||
|
||||
# Check terminal conditions
|
||||
if next_event.event_type == EventType.HORIZON_END:
|
||||
result.horizon_reached = True
|
||||
break
|
||||
if next_event.event_type == EventType.BANKRUPTCY:
|
||||
result.bankrupt = True
|
||||
break
|
||||
if event_result.get("bankrupt", False):
|
||||
result.bankrupt = True
|
||||
break
|
||||
|
||||
# Continue loop: more due actions can still exist at the same timestamp.
|
||||
|
||||
# Update sim_time
|
||||
sim_state.sim_time = current_time
|
||||
db.flush()
|
||||
|
||||
# Compute balance delta
|
||||
company = db.query(Company).filter(Company.id == company_id).one()
|
||||
result.balance_delta = company.funds_cents - starting_funds
|
||||
result.new_sim_time = current_time.isoformat()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdvanceResult",
|
||||
"advance_time",
|
||||
"apply_payroll",
|
||||
"dispatch_event",
|
||||
]
|
||||
263
src/yc_bench/core/eta.py
Normal file
263
src/yc_bench/core/eta.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
"""ETA solver: compute task completion and halfway times, manage projection events.
|
||||
|
||||
Projection events (task_completed, task_half_progress) are inserted into sim_events
|
||||
and recalculated whenever the topology changes (assign, dispatch, cancel, complete).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Set
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.models.company import Domain
|
||||
from ..db.models.event import EventType, SimEvent
|
||||
from ..db.models.task import Task, TaskRequirement, TaskStatus
|
||||
from .business_time import add_business_hours
|
||||
from .events import insert_event
|
||||
from .progress import EffectiveRate, compute_effective_rates
|
||||
|
||||
|
||||
def solve_task_completion_time(
|
||||
db: Session,
|
||||
task_id: UUID,
|
||||
now: datetime,
|
||||
rates: List[EffectiveRate],
|
||||
) -> Optional[datetime]:
|
||||
"""Solve for the business-time at which a task will complete.
|
||||
|
||||
Completion = all domains reach 100%. Time is max(remaining[d] / rate[d]) across domains.
|
||||
Returns None if any domain has remaining > 0 and rate == 0 (impossible to complete).
|
||||
"""
|
||||
reqs = db.query(TaskRequirement).filter(
|
||||
TaskRequirement.task_id == task_id
|
||||
).all()
|
||||
|
||||
if not reqs:
|
||||
return None
|
||||
|
||||
# Build rate map for this task
|
||||
rate_map: Dict[Domain, Decimal] = {}
|
||||
for r in rates:
|
||||
if r.task_id == task_id:
|
||||
rate_map[r.domain] = r.rate_per_hour
|
||||
|
||||
max_hours = Decimal("0")
|
||||
for req in reqs:
|
||||
remaining = req.required_qty - req.completed_qty
|
||||
if remaining <= 0:
|
||||
continue
|
||||
rate = rate_map.get(req.domain, Decimal("0"))
|
||||
if rate <= 0:
|
||||
return None # Can't complete this domain
|
||||
hours = remaining / rate
|
||||
if hours > max_hours:
|
||||
max_hours = hours
|
||||
|
||||
if max_hours <= 0:
|
||||
# Already complete
|
||||
return now
|
||||
|
||||
return add_business_hours(now, max_hours)
|
||||
|
||||
|
||||
def solve_task_halfway_time(
|
||||
db: Session,
|
||||
task_id: UUID,
|
||||
now: datetime,
|
||||
rates: List[EffectiveRate],
|
||||
half_threshold: float = 0.5,
|
||||
) -> Optional[datetime]:
|
||||
"""Solve for the business-time at which weighted progress ratio >= 0.5.
|
||||
|
||||
Weighted ratio is:
|
||||
sum_d(completed_d) / sum_d(required_d)
|
||||
|
||||
Each domain progresses linearly until capped at required_d.
|
||||
Returns None if reaching 50% is impossible.
|
||||
"""
|
||||
reqs = db.query(TaskRequirement).filter(
|
||||
TaskRequirement.task_id == task_id
|
||||
).all()
|
||||
|
||||
if not reqs:
|
||||
return None
|
||||
|
||||
rate_map: Dict[Domain, Decimal] = {}
|
||||
for r in rates:
|
||||
if r.task_id == task_id:
|
||||
rate_map[r.domain] = r.rate_per_hour
|
||||
|
||||
total_required = sum((req.required_qty for req in reqs), Decimal("0"))
|
||||
if total_required <= 0:
|
||||
return now
|
||||
target_completed = Decimal(str(half_threshold)) * total_required
|
||||
|
||||
# Check current weighted progress
|
||||
current_completed = Decimal("0")
|
||||
for req in reqs:
|
||||
current_completed += min(req.required_qty, req.completed_qty)
|
||||
|
||||
if current_completed >= target_completed:
|
||||
return now
|
||||
|
||||
# Build list of domain progression pieces.
|
||||
# cap_hours = hours until domain reaches required_qty
|
||||
domains = []
|
||||
for req in reqs:
|
||||
remaining = req.required_qty - req.completed_qty
|
||||
rate = rate_map.get(req.domain, Decimal("0"))
|
||||
if remaining > 0 and rate <= 0:
|
||||
# This domain can never progress further
|
||||
cap_hours = None
|
||||
elif remaining <= 0:
|
||||
cap_hours = Decimal("0")
|
||||
else:
|
||||
cap_hours = remaining / rate
|
||||
domains.append({
|
||||
"completed": req.completed_qty,
|
||||
"required": req.required_qty,
|
||||
"rate": rate,
|
||||
"cap_hours": cap_hours,
|
||||
})
|
||||
|
||||
# Piecewise-linear solver over cap breakpoints.
|
||||
# In each segment, d(current_completed)/dh = sum(rate_d) for uncapped domains.
|
||||
breakpoints = sorted(set(
|
||||
d["cap_hours"] for d in domains
|
||||
if d["cap_hours"] is not None and d["cap_hours"] > 0
|
||||
))
|
||||
|
||||
h = Decimal("0")
|
||||
completed_sum = current_completed
|
||||
|
||||
for bp in breakpoints:
|
||||
# Slope of completed_sum during [h, bp]
|
||||
slope = Decimal("0")
|
||||
for d in domains:
|
||||
if d["cap_hours"] is not None and d["cap_hours"] > h:
|
||||
slope += d["rate"]
|
||||
|
||||
if slope <= 0:
|
||||
# No further progress in this segment.
|
||||
if completed_sum >= target_completed:
|
||||
return add_business_hours(now, h)
|
||||
h = bp
|
||||
# Re-evaluate completed amount at breakpoint.
|
||||
completed_sum = Decimal("0")
|
||||
for d in domains:
|
||||
progress = min(d["required"], d["completed"] + d["rate"] * h)
|
||||
completed_sum += progress
|
||||
continue
|
||||
|
||||
needed = target_completed - completed_sum
|
||||
delta_h = needed / slope
|
||||
|
||||
if h + delta_h <= bp:
|
||||
return add_business_hours(now, h + delta_h)
|
||||
|
||||
completed_sum += slope * (bp - h)
|
||||
h = bp
|
||||
|
||||
# After all breakpoints, check remaining slope
|
||||
slope = Decimal("0")
|
||||
for d in domains:
|
||||
if d["cap_hours"] is not None and d["cap_hours"] > h:
|
||||
slope += d["rate"]
|
||||
|
||||
if slope > 0:
|
||||
needed = target_completed - completed_sum
|
||||
if needed <= 0:
|
||||
return add_business_hours(now, h)
|
||||
delta_h = needed / slope
|
||||
return add_business_hours(now, h + delta_h)
|
||||
|
||||
if completed_sum >= target_completed:
|
||||
return add_business_hours(now, h)
|
||||
|
||||
return None # Cannot reach 50%
|
||||
|
||||
|
||||
def recalculate_etas(
|
||||
db: Session,
|
||||
company_id: UUID,
|
||||
now: datetime,
|
||||
impacted_task_ids: Optional[Set[UUID]] = None,
|
||||
half_threshold: float = 0.5,
|
||||
) -> None:
|
||||
"""Recalculate projection events for active tasks.
|
||||
|
||||
1. Delete stale projection events for impacted tasks (or all if None).
|
||||
2. Compute effective rates.
|
||||
3. For each active task, solve completion and halfway times.
|
||||
4. Insert new projection events.
|
||||
"""
|
||||
# Determine which tasks to recalculate
|
||||
if impacted_task_ids is None:
|
||||
active_tasks = db.query(Task).filter(
|
||||
Task.company_id == company_id,
|
||||
Task.status == TaskStatus.ACTIVE,
|
||||
).all()
|
||||
task_ids = {t.id for t in active_tasks}
|
||||
else:
|
||||
task_ids = impacted_task_ids
|
||||
|
||||
if not task_ids:
|
||||
return
|
||||
|
||||
# Delete stale unconsumed projection events for these tasks
|
||||
for tid in task_ids:
|
||||
stale = db.query(SimEvent).filter(
|
||||
SimEvent.company_id == company_id,
|
||||
SimEvent.consumed == False,
|
||||
SimEvent.event_type.in_([EventType.TASK_COMPLETED, EventType.TASK_HALF_PROGRESS]),
|
||||
SimEvent.dedupe_key.like(f"task:{tid}:%"),
|
||||
).all()
|
||||
for ev in stale:
|
||||
db.delete(ev)
|
||||
|
||||
db.flush()
|
||||
|
||||
# Compute rates for all active tasks (topology-wide, since employee sharing matters)
|
||||
rates = compute_effective_rates(db, company_id)
|
||||
|
||||
for tid in task_ids:
|
||||
task = db.query(Task).filter(Task.id == tid).one_or_none()
|
||||
if task is None or task.status != TaskStatus.ACTIVE:
|
||||
continue
|
||||
|
||||
# Completion ETA
|
||||
completion_time = solve_task_completion_time(db, tid, now, rates)
|
||||
if completion_time is not None:
|
||||
insert_event(
|
||||
db,
|
||||
company_id=company_id,
|
||||
event_type=EventType.TASK_COMPLETED,
|
||||
scheduled_at=completion_time,
|
||||
payload={"task_id": str(tid)},
|
||||
dedupe_key=f"task:{tid}:completed",
|
||||
)
|
||||
|
||||
# Halfway ETA (only if not already emitted)
|
||||
if not task.halfway_event_emitted:
|
||||
halfway_time = solve_task_halfway_time(db, tid, now, rates, half_threshold=half_threshold)
|
||||
if halfway_time is not None:
|
||||
insert_event(
|
||||
db,
|
||||
company_id=company_id,
|
||||
event_type=EventType.TASK_HALF_PROGRESS,
|
||||
scheduled_at=halfway_time,
|
||||
payload={"task_id": str(tid)},
|
||||
dedupe_key=f"task:{tid}:half",
|
||||
)
|
||||
|
||||
db.flush()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"solve_task_completion_time",
|
||||
"solve_task_halfway_time",
|
||||
"recalculate_etas",
|
||||
]
|
||||
128
src/yc_bench/core/events.py
Normal file
128
src/yc_bench/core/events.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""Event infrastructure: fetch, consume, and insert simulation events.
|
||||
|
||||
Events are processed in deterministic order: (scheduled_at, priority, id).
|
||||
Priority by event_type: task_completed=0, bankruptcy=1, task_half=2, horizon_end=3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import UUID
|
||||
from uuid import uuid5, NAMESPACE_URL
|
||||
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.models.event import EventType, SimEvent
|
||||
|
||||
# Priority ordering — lower number = higher priority
|
||||
EVENT_PRIORITY: Dict[EventType, int] = {
|
||||
EventType.TASK_COMPLETED: 0,
|
||||
EventType.BANKRUPTCY: 1,
|
||||
EventType.TASK_HALF_PROGRESS: 2,
|
||||
EventType.HORIZON_END: 3,
|
||||
}
|
||||
|
||||
|
||||
def _deterministic_event_id(
|
||||
company_id: UUID,
|
||||
event_type: EventType,
|
||||
scheduled_at: datetime,
|
||||
dedupe_key: Optional[str],
|
||||
payload: Dict[str, Any],
|
||||
) -> UUID:
|
||||
"""Generate deterministic event UUID to stabilize same-seed replay ordering."""
|
||||
payload_key = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
||||
base = "|".join(
|
||||
[
|
||||
str(company_id),
|
||||
event_type.value,
|
||||
scheduled_at.isoformat(),
|
||||
dedupe_key or "",
|
||||
payload_key,
|
||||
]
|
||||
)
|
||||
return uuid5(NAMESPACE_URL, base)
|
||||
|
||||
|
||||
def fetch_next_event(
|
||||
db: Session,
|
||||
company_id: UUID,
|
||||
up_to: datetime,
|
||||
) -> Optional[SimEvent]:
|
||||
"""Fetch the next unconsumed event scheduled at or before up_to.
|
||||
|
||||
Deterministic order: (scheduled_at ASC, priority ASC, id ASC).
|
||||
"""
|
||||
priority_expr = case(
|
||||
{et: p for et, p in EVENT_PRIORITY.items()},
|
||||
value=SimEvent.event_type,
|
||||
else_=99,
|
||||
)
|
||||
|
||||
event = (
|
||||
db.query(SimEvent)
|
||||
.filter(
|
||||
SimEvent.company_id == company_id,
|
||||
SimEvent.consumed == False,
|
||||
SimEvent.scheduled_at <= up_to,
|
||||
)
|
||||
.order_by(
|
||||
SimEvent.scheduled_at.asc(),
|
||||
priority_expr.asc(),
|
||||
SimEvent.id.asc(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def consume_event(db: Session, event: SimEvent) -> None:
|
||||
"""Mark an event as consumed."""
|
||||
event.consumed = True
|
||||
db.flush()
|
||||
|
||||
|
||||
def insert_event(
|
||||
db: Session,
|
||||
company_id: UUID,
|
||||
event_type: EventType,
|
||||
scheduled_at: datetime,
|
||||
payload: Dict[str, Any],
|
||||
dedupe_key: Optional[str] = None,
|
||||
) -> SimEvent:
|
||||
"""Insert a new event, with optional idempotent deduplication.
|
||||
|
||||
If dedupe_key is provided and an unconsumed event with the same key exists,
|
||||
the existing event is returned unchanged.
|
||||
"""
|
||||
if dedupe_key is not None:
|
||||
existing = db.query(SimEvent).filter(
|
||||
SimEvent.company_id == company_id,
|
||||
SimEvent.dedupe_key == dedupe_key,
|
||||
SimEvent.consumed == False,
|
||||
).first()
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
event = SimEvent(
|
||||
id=_deterministic_event_id(company_id, event_type, scheduled_at, dedupe_key, payload),
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
scheduled_at=scheduled_at,
|
||||
payload=payload,
|
||||
dedupe_key=dedupe_key,
|
||||
)
|
||||
db.add(event)
|
||||
db.flush()
|
||||
return event
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EVENT_PRIORITY",
|
||||
"fetch_next_event",
|
||||
"consume_event",
|
||||
"insert_event",
|
||||
]
|
||||
0
src/yc_bench/core/handlers/__init__.py
Normal file
0
src/yc_bench/core/handlers/__init__.py
Normal file
18
src/yc_bench/core/handlers/bankruptcy.py
Normal file
18
src/yc_bench/core/handlers/bankruptcy.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Handler for bankruptcy events."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.models.event import SimEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class BankruptcyResult:
|
||||
bankrupt: bool = True
|
||||
|
||||
|
||||
def handle_bankruptcy(db: Session, event: SimEvent) -> BankruptcyResult:
|
||||
"""Signal that the company has gone bankrupt."""
|
||||
return BankruptcyResult(bankrupt=True)
|
||||
18
src/yc_bench/core/handlers/horizon_end.py
Normal file
18
src/yc_bench/core/handlers/horizon_end.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Handler for horizon_end events."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.models.event import SimEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class HorizonEndResult:
|
||||
reached: bool = True
|
||||
|
||||
|
||||
def handle_horizon_end(db: Session, event: SimEvent) -> HorizonEndResult:
|
||||
"""Signal that the simulation horizon has been reached."""
|
||||
return HorizonEndResult(reached=True)
|
||||
141
src/yc_bench/core/handlers/task_complete.py
Normal file
141
src/yc_bench/core/handlers/task_complete.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
"""Handler for task_completed events.
|
||||
|
||||
On completion:
|
||||
- If completion_time <= deadline: success → add reward funds, add prestige, skill-boost employees.
|
||||
- If completion_time > deadline: fail → set completed_fail, apply 0.8 * delta prestige penalty.
|
||||
After either outcome, recalculate ETAs (freed employees change topology).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.models.company import Company, CompanyPrestige, Domain
|
||||
from ...db.models.employee import Employee, EmployeeSkillRate
|
||||
from ...config import get_world_config
|
||||
from ...db.models.event import SimEvent
|
||||
from ...db.models.ledger import LedgerCategory, LedgerEntry
|
||||
from ...db.models.task import Task, TaskAssignment, TaskRequirement, TaskStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskCompleteResult:
|
||||
task_id: UUID
|
||||
success: bool
|
||||
funds_delta: int = 0
|
||||
prestige_changes: Dict[str, float] = field(default_factory=dict)
|
||||
bankrupt: bool = False
|
||||
|
||||
|
||||
def handle_task_complete(db: Session, event: SimEvent, sim_time) -> TaskCompleteResult:
|
||||
"""Process task completion: finalize progress, determine success/fail, apply rewards/penalties."""
|
||||
task_id = UUID(event.payload["task_id"])
|
||||
task = db.query(Task).filter(Task.id == task_id).one()
|
||||
company_id = task.company_id
|
||||
|
||||
# Finalize all domain progress to 100%
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task_id).all()
|
||||
for req in reqs:
|
||||
req.completed_qty = req.required_qty
|
||||
db.flush()
|
||||
|
||||
task.completed_at = sim_time
|
||||
success = sim_time <= task.deadline
|
||||
|
||||
wc = get_world_config()
|
||||
prestige_changes: Dict[str, float] = {}
|
||||
funds_delta = 0
|
||||
|
||||
if success:
|
||||
task.status = TaskStatus.COMPLETED_SUCCESS
|
||||
task.success = True
|
||||
|
||||
# Add reward funds
|
||||
company = db.query(Company).filter(Company.id == company_id).one()
|
||||
company.funds_cents += task.reward_funds_cents
|
||||
funds_delta = task.reward_funds_cents
|
||||
|
||||
# Ledger entry
|
||||
db.add(LedgerEntry(
|
||||
company_id=company_id,
|
||||
occurred_at=sim_time,
|
||||
category=LedgerCategory.TASK_REWARD,
|
||||
amount_cents=task.reward_funds_cents,
|
||||
ref_type="task",
|
||||
ref_id=task_id,
|
||||
))
|
||||
|
||||
# Add prestige to each domain
|
||||
for req in reqs:
|
||||
prestige = db.query(CompanyPrestige).filter(
|
||||
CompanyPrestige.company_id == company_id,
|
||||
CompanyPrestige.domain == req.domain,
|
||||
).one_or_none()
|
||||
if prestige is not None:
|
||||
old = float(prestige.prestige_level)
|
||||
prestige.prestige_level = min(
|
||||
Decimal(str(wc.prestige_max)),
|
||||
prestige.prestige_level + task.reward_prestige_delta,
|
||||
)
|
||||
prestige_changes[req.domain.value] = float(prestige.prestige_level) - old
|
||||
|
||||
# Skill boost assigned employees
|
||||
assignments = db.query(TaskAssignment).filter(
|
||||
TaskAssignment.task_id == task_id
|
||||
).all()
|
||||
if task.skill_boost_pct > 0:
|
||||
task_domains = {req.domain for req in reqs}
|
||||
for a in assignments:
|
||||
for domain in task_domains:
|
||||
skill = db.query(EmployeeSkillRate).filter(
|
||||
EmployeeSkillRate.employee_id == a.employee_id,
|
||||
EmployeeSkillRate.domain == domain,
|
||||
).one_or_none()
|
||||
if skill is not None:
|
||||
boost = skill.rate_domain_per_hour * task.skill_boost_pct
|
||||
skill.rate_domain_per_hour += boost
|
||||
|
||||
# Salary bump: small raise for each employee who contributed to this task
|
||||
if wc.salary_bump_pct > 0:
|
||||
for a in assignments:
|
||||
employee = db.query(Employee).filter(Employee.id == a.employee_id).one_or_none()
|
||||
if employee is not None:
|
||||
bump = int(employee.salary_cents * wc.salary_bump_pct)
|
||||
employee.salary_cents += bump
|
||||
|
||||
else:
|
||||
task.status = TaskStatus.COMPLETED_FAIL
|
||||
task.success = False
|
||||
|
||||
# Apply penalty_fail_multiplier * reward_prestige_delta penalty
|
||||
penalty = Decimal(str(wc.penalty_fail_multiplier)) * task.reward_prestige_delta
|
||||
for req in reqs:
|
||||
prestige = db.query(CompanyPrestige).filter(
|
||||
CompanyPrestige.company_id == company_id,
|
||||
CompanyPrestige.domain == req.domain,
|
||||
).one_or_none()
|
||||
if prestige is not None:
|
||||
old = float(prestige.prestige_level)
|
||||
prestige.prestige_level = max(
|
||||
Decimal(str(wc.prestige_min)),
|
||||
prestige.prestige_level - penalty,
|
||||
)
|
||||
prestige_changes[req.domain.value] = float(prestige.prestige_level) - old
|
||||
|
||||
db.flush()
|
||||
|
||||
# Check bankruptcy
|
||||
company = db.query(Company).filter(Company.id == company_id).one()
|
||||
bankrupt = company.funds_cents < 0
|
||||
|
||||
return TaskCompleteResult(
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
funds_delta=funds_delta,
|
||||
prestige_changes=prestige_changes,
|
||||
bankrupt=bankrupt,
|
||||
)
|
||||
30
src/yc_bench/core/handlers/task_half.py
Normal file
30
src/yc_bench/core/handlers/task_half.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""Handler for task_half_progress events."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.models.event import SimEvent
|
||||
from ...db.models.task import Task
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskHalfResult:
|
||||
task_id: UUID
|
||||
handled: bool
|
||||
|
||||
|
||||
def handle_task_half(db: Session, event: SimEvent) -> TaskHalfResult:
|
||||
"""Mark the task's halfway_event_emitted flag as True."""
|
||||
task_id = UUID(event.payload["task_id"])
|
||||
task = db.query(Task).filter(Task.id == task_id).one_or_none()
|
||||
|
||||
if task is None:
|
||||
return TaskHalfResult(task_id=task_id, handled=False)
|
||||
|
||||
task.halfway_event_emitted = True
|
||||
db.flush()
|
||||
|
||||
return TaskHalfResult(task_id=task_id, handled=True)
|
||||
308
src/yc_bench/core/progress.py
Normal file
308
src/yc_bench/core/progress.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from .business_time import business_hours_between
|
||||
from ..db.models.company import Domain
|
||||
from ..db.models.employee import EmployeeSkillRate
|
||||
from ..db.models.task import TaskRequirement, Task, TaskAssignment, TaskStatus
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequirementState:
|
||||
domain: str
|
||||
required_qty: Decimal
|
||||
completed_qty: Decimal
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskProgressState:
|
||||
task_id: str
|
||||
status: str
|
||||
requirements: tuple[RequirementState, ...]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssignmentState:
|
||||
task_id: str
|
||||
employee_id: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmployeeRateState:
|
||||
employee_id: str
|
||||
domain: str
|
||||
rate_domain_per_hour: Decimal
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgressDelta:
|
||||
task_id: str
|
||||
domain: str
|
||||
delta_qty: Decimal
|
||||
before_qty: Decimal
|
||||
after_qty: Decimal
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskProgressSummary:
|
||||
task_id: str
|
||||
ratio_before: Decimal
|
||||
ratio_after: Decimal
|
||||
completed: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EffectiveRate:
|
||||
task_id: UUID
|
||||
domain: Domain
|
||||
rate_per_hour: Decimal
|
||||
|
||||
def _active_assignment_count(assignments):
|
||||
counts = {}
|
||||
for a in assignments:
|
||||
counts[a.employee_id] = counts.get(a.employee_id, 0) + 1
|
||||
return counts
|
||||
|
||||
def _rates_by_employee_domain(rates):
|
||||
m = {}
|
||||
for r in rates:
|
||||
m[(r.employee_id, r.domain)] = r.rate_domain_per_hour
|
||||
return m
|
||||
|
||||
def _effective_rate_for_task_domain(*, task_id, domain, assignments, assignment_counts, base_rates):
|
||||
total = Decimal("0")
|
||||
for a in assignments:
|
||||
if a.task_id != task_id:
|
||||
continue
|
||||
k = assignment_counts.get(a.employee_id, 0)
|
||||
if k <= 0:
|
||||
continue
|
||||
base = base_rates.get((a.employee_id, domain), Decimal("0"))
|
||||
total += base / Decimal(k)
|
||||
return total
|
||||
|
||||
def _weighted_ratio_from_rows(rows, *, task_id_label):
|
||||
total_completed = Decimal("0")
|
||||
total_required = Decimal("0")
|
||||
for req in rows:
|
||||
if req.required_qty <= 0:
|
||||
raise ValueError(
|
||||
f"Task {task_id_label} requirement {req.domain} has quantity {req.required_qty}"
|
||||
)
|
||||
completed = req.completed_qty
|
||||
if completed < 0:
|
||||
raise ValueError(
|
||||
f"Task {task_id_label} requirement {req.domain} has completed quantity {req.completed_qty} which is less than 0"
|
||||
)
|
||||
if completed > req.required_qty:
|
||||
completed = req.required_qty
|
||||
|
||||
total_completed += completed
|
||||
total_required += req.required_qty
|
||||
|
||||
if total_required == 0:
|
||||
return Decimal("0")
|
||||
return total_completed / total_required
|
||||
|
||||
def task_progress_ratio(task):
|
||||
if not task.requirements:
|
||||
raise ValueError(f"Task {task.task_id} has no requirements")
|
||||
return _weighted_ratio_from_rows(task.requirements, task_id_label=task.task_id)
|
||||
|
||||
def apply_progress_window(*, tasks, assignments, rates, t0, t1):
|
||||
hours = Decimal(str(business_hours_between(t0, t1)))
|
||||
if hours <= 0:
|
||||
unchanged = list(tasks)
|
||||
summaries = [
|
||||
TaskProgressSummary(
|
||||
task_id=t.task_id,
|
||||
ratio_before=task_progress_ratio(t),
|
||||
ratio_after=task_progress_ratio(t),
|
||||
completed=all(r.completed_qty >= r.required_qty for r in t.requirements),
|
||||
)
|
||||
for t in unchanged
|
||||
]
|
||||
return unchanged, [], summaries
|
||||
|
||||
assignment_list = list(assignments)
|
||||
assignment_counts = _active_assignment_count(assignment_list)
|
||||
base_rates = _rates_by_employee_domain(rates)
|
||||
|
||||
updated_tasks = []
|
||||
deltas = []
|
||||
summaries = []
|
||||
|
||||
for task in tasks:
|
||||
ratio_before = task_progress_ratio(task)
|
||||
|
||||
if task.status not in {"planned", "active"}:
|
||||
updated_tasks.append(task)
|
||||
summaries.append(
|
||||
TaskProgressSummary(
|
||||
task_id=task.task_id,
|
||||
ratio_before=ratio_before,
|
||||
ratio_after=ratio_before,
|
||||
completed=all(r.completed_qty >= r.required_qty for r in task.requirements),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
next_requirements = []
|
||||
for req in task.requirements:
|
||||
before = Decimal(req.completed_qty)
|
||||
required = Decimal(req.required_qty)
|
||||
|
||||
eff_rate = _effective_rate_for_task_domain(
|
||||
task_id=task.task_id,
|
||||
domain=req.domain,
|
||||
assignments=assignment_list,
|
||||
assignment_counts=assignment_counts,
|
||||
base_rates=base_rates,
|
||||
)
|
||||
delta = eff_rate * hours
|
||||
after = before + delta
|
||||
# Progress cap is independent from deadline/failure logic.
|
||||
if after > required:
|
||||
after = required
|
||||
# Guardrail in case of inconsistent negative upstream data.
|
||||
if after < 0:
|
||||
after = Decimal("0")
|
||||
|
||||
next_requirements.append(RequirementState(
|
||||
domain=req.domain,
|
||||
required_qty=required,
|
||||
completed_qty=after,
|
||||
))
|
||||
|
||||
deltas.append(ProgressDelta(
|
||||
task_id=task.task_id,
|
||||
domain=req.domain,
|
||||
delta_qty=after - before,
|
||||
before_qty=before,
|
||||
after_qty=after,
|
||||
))
|
||||
|
||||
next_task = TaskProgressState(
|
||||
task_id=task.task_id,
|
||||
status=task.status,
|
||||
requirements=tuple(next_requirements),
|
||||
)
|
||||
ratio_after = task_progress_ratio(next_task)
|
||||
completed = all(r.completed_qty >= r.required_qty for r in next_requirements)
|
||||
updated_tasks.append(next_task)
|
||||
summaries.append(
|
||||
TaskProgressSummary(
|
||||
task_id=task.task_id,
|
||||
ratio_before=ratio_before,
|
||||
ratio_after=ratio_after,
|
||||
completed=completed,
|
||||
)
|
||||
)
|
||||
return updated_tasks, deltas, summaries
|
||||
|
||||
def compute_task_progress_ratio(db, task_id):
|
||||
reqs = db.query(TaskRequirement).filter(TaskRequirement.task_id == task_id).all()
|
||||
if not reqs:
|
||||
return Decimal("0")
|
||||
return _weighted_ratio_from_rows(reqs, task_id_label=task_id)
|
||||
|
||||
def compute_effective_rates(db, company_id):
|
||||
active_tasks = db.query(Task).filter(Task.company_id == company_id, Task.status == TaskStatus.ACTIVE).all()
|
||||
if not active_tasks:
|
||||
return []
|
||||
|
||||
task_ids = [t.id for t in active_tasks]
|
||||
requirements = db.query(TaskRequirement).filter(TaskRequirement.task_id.in_(task_ids)).all()
|
||||
assignments = db.query(TaskAssignment).filter(TaskAssignment.task_id.in_(task_ids)).all()
|
||||
|
||||
if not assignments:
|
||||
out = []
|
||||
for req in requirements:
|
||||
out.append(EffectiveRate(
|
||||
task_id=req.task_id,
|
||||
domain=req.domain,
|
||||
rate_per_hour=Decimal("0"),
|
||||
))
|
||||
return out
|
||||
|
||||
assignment_counts = {}
|
||||
assignments_by_task = {}
|
||||
for a in assignments:
|
||||
assignments_by_task.setdefault(a.task_id, []).append(a)
|
||||
assignment_counts[a.employee_id] = assignment_counts.get(a.employee_id, 0) + 1
|
||||
|
||||
employee_ids = list(assignment_counts.keys())
|
||||
skill_rows = db.query(EmployeeSkillRate).filter(EmployeeSkillRate.employee_id.in_(employee_ids)).all()
|
||||
|
||||
base_rates = {}
|
||||
for s in skill_rows:
|
||||
base_rates[(s.employee_id, s.domain)] = Decimal(s.rate_domain_per_hour)
|
||||
|
||||
out = []
|
||||
for req in requirements:
|
||||
total = Decimal("0")
|
||||
for a in assignments_by_task.get(req.task_id, []):
|
||||
k = assignment_counts.get(a.employee_id, 0)
|
||||
if k <= 0:
|
||||
continue
|
||||
base = base_rates.get((a.employee_id, req.domain), Decimal("0"))
|
||||
total += base / Decimal(k)
|
||||
|
||||
out.append(EffectiveRate(
|
||||
task_id=req.task_id,
|
||||
domain=req.domain,
|
||||
rate_per_hour=total,
|
||||
))
|
||||
return out
|
||||
|
||||
def flush_progress(db, company_id, t0, t1):
|
||||
active_tasks = db.query(Task).filter(Task.company_id == company_id, Task.status == TaskStatus.ACTIVE).all()
|
||||
if not active_tasks:
|
||||
return
|
||||
task_ids = [t.id for t in active_tasks]
|
||||
req_rows = db.query(TaskRequirement).filter(TaskRequirement.task_id.in_(task_ids)).all()
|
||||
asg_rows = db.query(TaskAssignment).filter(TaskAssignment.task_id.in_(task_ids)).all()
|
||||
emp_ids = list({a.employee_id for a in asg_rows})
|
||||
rate_rows = db.query(EmployeeSkillRate).filter(EmployeeSkillRate.employee_id.in_(emp_ids)).all()
|
||||
|
||||
reqs_by_task = {}
|
||||
req_index = {}
|
||||
for r in req_rows:
|
||||
req_index[(r.task_id, r.domain)] = r
|
||||
reqs_by_task.setdefault(r.task_id, []).append(
|
||||
RequirementState(
|
||||
domain=r.domain,
|
||||
required_qty=Decimal(r.required_qty),
|
||||
completed_qty=Decimal(r.completed_qty),
|
||||
)
|
||||
)
|
||||
|
||||
task_states = [
|
||||
TaskProgressState(
|
||||
task_id=t.id,
|
||||
status=t.status.value if hasattr(t.status, "value") else str(t.status),
|
||||
requirements=tuple(reqs_by_task.get(t.id, [])),
|
||||
)
|
||||
for t in active_tasks
|
||||
]
|
||||
assignments_states = [
|
||||
AssignmentState(
|
||||
task_id=a.task_id,
|
||||
employee_id=a.employee_id,
|
||||
)
|
||||
for a in asg_rows
|
||||
]
|
||||
employee_rate_states = [
|
||||
EmployeeRateState(
|
||||
employee_id=s.employee_id,
|
||||
domain=s.domain,
|
||||
rate_domain_per_hour=Decimal(s.rate_domain_per_hour),
|
||||
)
|
||||
for s in rate_rows
|
||||
]
|
||||
|
||||
updated_tasks, _, _ = apply_progress_window(tasks=task_states, assignments=assignments_states, rates=employee_rate_states, t0=t0, t1=t1)
|
||||
|
||||
for ut in updated_tasks:
|
||||
for req in ut.requirements:
|
||||
row = req_index[(ut.task_id, req.domain)]
|
||||
if row is not None:
|
||||
row.completed_qty = req.completed_qty
|
||||
|
||||
db.flush()
|
||||
49
src/yc_bench/runner/args.py
Normal file
49
src/yc_bench/runner/args.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunArgs:
|
||||
model: str
|
||||
seed: int
|
||||
horizon_years: int | None # None = defer to sim.horizon_years in config
|
||||
company_name: str
|
||||
start_date: str
|
||||
config_name: str = "default"
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="yc-bench run",
|
||||
)
|
||||
parser.add_argument("--model", required=True)
|
||||
parser.add_argument("--seed", type=int, required=True)
|
||||
parser.add_argument("--horizon-years", type=int, default=None,
|
||||
help="Simulation horizon in years (default: read from config's sim.horizon_years)")
|
||||
parser.add_argument("--company-name", default="BenchCo")
|
||||
parser.add_argument("--start-date", default="2025-01-01", help="Simulation start date (YYYY-MM-DD)")
|
||||
parser.add_argument(
|
||||
"--config", dest="config_name", default="default",
|
||||
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
|
||||
)
|
||||
return parser
|
||||
|
||||
def parse_run_args(argv):
|
||||
parser = build_parser()
|
||||
ns = parser.parse_args(argv)
|
||||
_validate(ns, parser)
|
||||
return RunArgs(
|
||||
model=ns.model,
|
||||
seed=ns.seed,
|
||||
horizon_years=ns.horizon_years,
|
||||
company_name=ns.company_name,
|
||||
start_date=ns.start_date,
|
||||
config_name=ns.config_name,
|
||||
)
|
||||
|
||||
def _validate(ns, parser):
|
||||
if ns.horizon_years is not None and ns.horizon_years <= 0:
|
||||
parser.error("--horizon-years must be int > 0")
|
||||
|
||||
__all__ = ["RunArgs", "build_parser", "parse_run_args"]
|
||||
226
src/yc_bench/runner/main.py
Normal file
226
src/yc_bench/runner/main.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Benchmark entrypoint: provisions DB, seeds world, runs agent loop to completion."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from ..agent.commands.executor import run_command
|
||||
from ..agent.loop import run_agent_loop
|
||||
from ..agent.run_state import RunState
|
||||
from ..agent.runtime.factory import build_runtime
|
||||
from ..agent.runtime.schemas import RuntimeSettings
|
||||
from ..db.session import build_engine, build_session_factory, session_scope, init_db
|
||||
from .args import parse_run_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_date(date_str: str) -> datetime:
|
||||
"""Accept ISO (2025-01-01) or legacy MM/DD/YYYY format."""
|
||||
fmt = "%Y-%m-%d" if "-" in date_str else "%m/%d/%Y"
|
||||
dt = datetime.strptime(date_str, fmt)
|
||||
return dt.replace(hour=9, minute=0, second=0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _wipe_simulation(db) -> None:
|
||||
"""Delete all simulation rows so the DB can be reseeded cleanly."""
|
||||
from ..db.models.ledger import LedgerEntry
|
||||
from ..db.models.task import Task, TaskAssignment, TaskRequirement
|
||||
from ..db.models.employee import Employee, EmployeeSkillRate
|
||||
from ..db.models.company import Company, CompanyPrestige
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.event import SimEvent
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
|
||||
db.query(Scratchpad).delete(synchronize_session=False)
|
||||
db.query(LedgerEntry).delete(synchronize_session=False)
|
||||
db.query(TaskAssignment).delete(synchronize_session=False)
|
||||
db.query(TaskRequirement).delete(synchronize_session=False)
|
||||
db.query(Task).delete(synchronize_session=False)
|
||||
db.query(SimEvent).delete(synchronize_session=False)
|
||||
db.query(EmployeeSkillRate).delete(synchronize_session=False)
|
||||
db.query(Employee).delete(synchronize_session=False)
|
||||
db.query(CompanyPrestige).delete(synchronize_session=False)
|
||||
db.query(Company).delete(synchronize_session=False)
|
||||
db.query(SimState).delete(synchronize_session=False)
|
||||
db.flush()
|
||||
|
||||
|
||||
def _init_simulation(db_factory, args, experiment_cfg, horizon_years):
|
||||
"""Seed world directly (no CLI round-trip) using the experiment WorldConfig."""
|
||||
from ..db.models.event import EventType
|
||||
from ..db.models.sim_state import SimState
|
||||
from ..db.models.company import Company
|
||||
from ..core.events import insert_event
|
||||
from ..services.seed_world import SeedWorldRequest, seed_world_transactional
|
||||
|
||||
with db_factory() as db:
|
||||
existing = db.query(SimState).first()
|
||||
if existing is not None:
|
||||
company = db.query(Company).filter(Company.id == existing.company_id).first()
|
||||
bankrupt = company is not None and company.funds_cents < 0
|
||||
horizon_reached = existing.sim_time >= existing.horizon_end
|
||||
if bankrupt or horizon_reached:
|
||||
logger.info(
|
||||
"Existing simulation is terminal (bankrupt=%s horizon_reached=%s) — reseeding.",
|
||||
bankrupt, horizon_reached,
|
||||
)
|
||||
_wipe_simulation(db)
|
||||
else:
|
||||
logger.info("Resuming non-terminal simulation (company_id=%s, sim_time=%s).",
|
||||
existing.company_id, existing.sim_time)
|
||||
return existing.company_id
|
||||
|
||||
start_dt = _parse_date(args.start_date)
|
||||
horizon_end = start_dt.replace(year=start_dt.year + horizon_years)
|
||||
world = experiment_cfg.world
|
||||
|
||||
req = SeedWorldRequest(
|
||||
run_seed=args.seed,
|
||||
company_name=args.company_name,
|
||||
horizon_years=horizon_years,
|
||||
employee_count=world.num_employees,
|
||||
market_task_count=world.num_market_tasks,
|
||||
start_date=start_dt,
|
||||
cfg=world,
|
||||
)
|
||||
logger.info(
|
||||
"Initializing simulation: seed=%d employees=%d tasks=%d horizon=%dy",
|
||||
args.seed, world.num_employees, world.num_market_tasks, horizon_years,
|
||||
)
|
||||
result = seed_world_transactional(db, req)
|
||||
|
||||
insert_event(
|
||||
db=db,
|
||||
company_id=result.company_id,
|
||||
event_type=EventType.HORIZON_END,
|
||||
scheduled_at=horizon_end,
|
||||
payload={"reason": "horizon_end"},
|
||||
dedupe_key="horizon_end",
|
||||
)
|
||||
db.add(SimState(
|
||||
company_id=result.company_id,
|
||||
sim_time=start_dt,
|
||||
run_seed=args.seed,
|
||||
horizon_end=horizon_end,
|
||||
replenish_counter=0,
|
||||
))
|
||||
db.flush()
|
||||
|
||||
logger.info("Simulation initialized: company_id=%s", result.company_id)
|
||||
return result.company_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_benchmark(args):
|
||||
"""Run a full benchmark: migrate, seed, loop until terminal."""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(name)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Load experiment config — preset name or path to a .toml file
|
||||
from yc_bench.config import load_config
|
||||
experiment_cfg = load_config(args.config_name)
|
||||
# Propagate config name to CLI subprocess calls (e.g. task accept → generate_replacement_task)
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = args.config_name
|
||||
# CLI --model always overrides the experiment's agent model
|
||||
experiment_cfg = experiment_cfg.model_copy(
|
||||
update={"agent": experiment_cfg.agent.model_copy(update={"model": args.model})}
|
||||
)
|
||||
# --horizon-years CLI flag overrides config; fall back to sim.horizon_years from config
|
||||
horizon_years = args.horizon_years if args.horizon_years is not None else experiment_cfg.sim.horizon_years
|
||||
|
||||
logger.info(
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years,
|
||||
)
|
||||
|
||||
# 1. Build engine and create all tables
|
||||
# If DATABASE_URL is not explicitly set, default to db/<seed>_<slug>.db
|
||||
if not os.environ.get("DATABASE_URL"):
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.seed}_{slug}.db"
|
||||
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory():
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
# 3. Init simulation using experiment world config
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 4. Build runtime settings from experiment AgentConfig
|
||||
agent_cfg = experiment_cfg.agent
|
||||
settings = RuntimeSettings(
|
||||
model=agent_cfg.model,
|
||||
temperature=agent_cfg.temperature,
|
||||
top_p=agent_cfg.top_p,
|
||||
request_timeout_seconds=agent_cfg.request_timeout_seconds,
|
||||
retry_max_attempts=agent_cfg.retry_max_attempts,
|
||||
retry_backoff_seconds=agent_cfg.retry_backoff_seconds,
|
||||
history_keep_rounds=agent_cfg.history_keep_rounds,
|
||||
system_prompt=agent_cfg.system_prompt,
|
||||
)
|
||||
runtime = build_runtime(settings, command_executor=run_command)
|
||||
|
||||
# 5. Build run state
|
||||
session_id = f"run-{args.seed}-{args.model}"
|
||||
run_state = RunState(
|
||||
session_id=session_id,
|
||||
seed=args.seed,
|
||||
model=args.model,
|
||||
horizon_years=horizon_years,
|
||||
)
|
||||
|
||||
# 6. Run agent loop
|
||||
loop_cfg = experiment_cfg.loop
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
)
|
||||
|
||||
# 7. Save full rollout (with transcript) and print summary
|
||||
rollout = final_state.full_rollout()
|
||||
summary = final_state.summary()
|
||||
logger.info("Run complete: %s", json.dumps(summary, indent=2))
|
||||
|
||||
# Write full rollout (includes transcript with commands)
|
||||
slug = args.model.replace("/", "_")
|
||||
results_dir = Path("results")
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
results_path = results_dir / f"yc_bench_result_{args.seed}_{slug}.json"
|
||||
results_path.write_text(json.dumps(rollout, indent=2))
|
||||
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
|
||||
|
||||
return 0 if final_state.terminal_reason != "error" else 1
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
load_dotenv(find_dotenv(usecwd=True), override=False) # searches cwd upward for .env
|
||||
args = parse_run_args(argv)
|
||||
return run_benchmark(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
33
src/yc_bench/runner/session.py
Normal file
33
src/yc_bench/runner/session.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Session utilities for the benchmark runner."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.models.event import EventType
|
||||
from ..db.models.session import Session as SessionModel
|
||||
|
||||
|
||||
def open_session(db: Session, company_id, wake_reason: EventType) -> SessionModel:
|
||||
"""Create and persist a new session record."""
|
||||
session = SessionModel(
|
||||
id=uuid4(),
|
||||
company_id=company_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
ended_at=None,
|
||||
wake_reason=wake_reason,
|
||||
)
|
||||
db.add(session)
|
||||
db.flush()
|
||||
return session
|
||||
|
||||
|
||||
def close_session(db: Session, session: SessionModel) -> None:
|
||||
"""Close an open session record."""
|
||||
session.ended_at = datetime.now(timezone.utc)
|
||||
db.flush()
|
||||
|
||||
|
||||
__all__ = ["open_session", "close_session"]
|
||||
107
src/yc_bench/services/generate_employees.py
Normal file
107
src/yc_bench/services/generate_employees.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config.schema import WorldConfig
|
||||
from ..db.models.company import Domain
|
||||
from .rng import RngStreams, sample_right_skew_triangular_int
|
||||
|
||||
_ALL_DOMAINS = list(Domain)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratedEmployee:
|
||||
name: str
|
||||
work_hours_per_day: float
|
||||
salary_cents: int
|
||||
tier: str
|
||||
rates_by_domain: dict[str, float]
|
||||
|
||||
|
||||
def _salary_tiers(cfg):
|
||||
return (cfg.salary_junior, cfg.salary_mid, cfg.salary_senior)
|
||||
|
||||
|
||||
def _pick_tier_name(rng, cfg):
|
||||
x = rng.random()
|
||||
acc = 0.0
|
||||
for tier in _salary_tiers(cfg):
|
||||
acc += tier.share
|
||||
if acc >= x:
|
||||
return tier.name
|
||||
return _salary_tiers(cfg)[-1].name
|
||||
|
||||
|
||||
def _tier_by_name(cfg, tier_name):
|
||||
for tier in _salary_tiers(cfg):
|
||||
if tier.name == tier_name:
|
||||
return tier
|
||||
raise ValueError(f"Tier {tier_name} not found")
|
||||
|
||||
|
||||
def _sample_salary_cents(rng, cfg, tier_name):
|
||||
tier = _tier_by_name(cfg, tier_name)
|
||||
return sample_right_skew_triangular_int(rng, tier.min_cents, tier.max_cents)
|
||||
|
||||
|
||||
def _sample_rates_by_domain(rng, cfg, tier_name):
|
||||
tier = _tier_by_name(cfg, tier_name)
|
||||
lo, hi = tier.rate_min, tier.rate_max
|
||||
return {domain: round(rng.uniform(lo, hi), 4) for domain in _ALL_DOMAINS}
|
||||
|
||||
|
||||
def generate_employees(*, run_seed, count, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = WorldConfig()
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
employees = []
|
||||
streams = RngStreams(run_seed)
|
||||
|
||||
for idx in range(1, count + 1):
|
||||
rng = streams.stream(f"employee_{idx}")
|
||||
tier_name = _pick_tier_name(rng, cfg)
|
||||
|
||||
employees.append(
|
||||
GeneratedEmployee(
|
||||
name=f"Emp_{idx}",
|
||||
work_hours_per_day=cfg.work_hours_per_day,
|
||||
salary_cents=_sample_salary_cents(rng, cfg, tier_name),
|
||||
tier=tier_name,
|
||||
rates_by_domain=_sample_rates_by_domain(rng, cfg, tier_name),
|
||||
)
|
||||
)
|
||||
return employees
|
||||
|
||||
|
||||
def build_employee_rows(*, run_seed, company_id, count, cfg=None):
|
||||
generated = generate_employees(run_seed=run_seed, count=count, cfg=cfg)
|
||||
employee_rows = []
|
||||
skill_rows = []
|
||||
|
||||
for emp in generated:
|
||||
employee_rows.append(
|
||||
{
|
||||
"company_id": company_id,
|
||||
"name": emp.name,
|
||||
"work_hours_per_day": emp.work_hours_per_day,
|
||||
"salary_cents": emp.salary_cents,
|
||||
}
|
||||
)
|
||||
for domain, rate in emp.rates_by_domain.items():
|
||||
skill_rows.append(
|
||||
{
|
||||
"_employee_name": emp.name,
|
||||
"domain": domain,
|
||||
"rate_domain_per_hour": rate,
|
||||
}
|
||||
)
|
||||
return employee_rows, skill_rows
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_employee_rows",
|
||||
"GeneratedEmployee",
|
||||
"generate_employees",
|
||||
]
|
||||
163
src/yc_bench/services/generate_tasks.py
Normal file
163
src/yc_bench/services/generate_tasks.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from ..config.sampling import sample_from_spec
|
||||
from ..config.schema import WorldConfig
|
||||
from ..db.models.company import Domain
|
||||
from .rng import RngStreams, sample_without_replacement
|
||||
from .task_catalog import pick_task_text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratedTask:
|
||||
title: str
|
||||
description: str
|
||||
required_prestige: int
|
||||
reward_funds_cents: int
|
||||
reward_prestige_delta: float
|
||||
skill_boost_pct: float
|
||||
status: str
|
||||
company_id: Any | None
|
||||
accepted_at: datetime | None
|
||||
deadline: datetime | None
|
||||
completed_at: datetime | None
|
||||
success: bool | None
|
||||
halfway_event_emitted: bool
|
||||
requirements: dict[str, int]
|
||||
|
||||
|
||||
# First 10 market tasks are given explicit prestige values to guarantee a
|
||||
# climbable ladder from the start (avoids runs where all early tasks need
|
||||
# prestige 4+ before any are completable).
|
||||
_STRATIFIED_PRESTIGE = [1, 1, 1, 1, 2, 2, 2, 3, 3, 4]
|
||||
|
||||
_ALL_DOMAINS = list(Domain)
|
||||
|
||||
|
||||
def _sample_required_prestige(rng, cfg, index=None):
|
||||
if index is not None and index < len(_STRATIFIED_PRESTIGE):
|
||||
return _STRATIFIED_PRESTIGE[index]
|
||||
return int(sample_from_spec(rng, cfg.dist.required_prestige))
|
||||
|
||||
|
||||
def _sample_reward_funds_cents(rng, cfg, prestige=1):
|
||||
base = int(sample_from_spec(rng, cfg.dist.reward_funds_cents))
|
||||
# Scale reward by prestige: higher-prestige tasks pay proportionally more
|
||||
return int(base * (1 + cfg.reward_prestige_scale * (prestige - 1)))
|
||||
|
||||
|
||||
def _sample_reward_prestige_delta(rng, cfg):
|
||||
return sample_from_spec(rng, cfg.dist.reward_prestige_delta)
|
||||
|
||||
|
||||
def _sample_skill_boost_pct(rng, cfg):
|
||||
return sample_from_spec(rng, cfg.dist.skill_boost)
|
||||
|
||||
|
||||
def _sample_domain_count(rng, cfg):
|
||||
return int(sample_from_spec(rng, cfg.dist.domain_count))
|
||||
|
||||
|
||||
def _sample_required_qty(rng, cfg):
|
||||
return int(sample_from_spec(rng, cfg.dist.required_qty))
|
||||
|
||||
|
||||
def _sample_requirements(rng, cfg):
|
||||
k = _sample_domain_count(rng, cfg)
|
||||
picked_domains = sample_without_replacement(rng, _ALL_DOMAINS, k)
|
||||
return {domain: _sample_required_qty(rng, cfg) for domain in picked_domains}
|
||||
|
||||
|
||||
def _pick_title_desc(rng, primary_domain, serial):
|
||||
title, description = pick_task_text(rng, primary_domain)
|
||||
domain_str = primary_domain.value if hasattr(primary_domain, "value") else str(primary_domain)
|
||||
title = f"{title} [{domain_str.upper()}-{serial}]"
|
||||
return title, description
|
||||
|
||||
|
||||
def _make_task(rng, cfg, prestige, serial, requirements):
|
||||
title, description = _pick_title_desc(rng, next(iter(requirements)), serial)
|
||||
return GeneratedTask(
|
||||
title=title,
|
||||
description=description,
|
||||
required_prestige=prestige,
|
||||
reward_funds_cents=_sample_reward_funds_cents(rng, cfg, prestige=prestige),
|
||||
reward_prestige_delta=_sample_reward_prestige_delta(rng, cfg),
|
||||
skill_boost_pct=_sample_skill_boost_pct(rng, cfg),
|
||||
status="market",
|
||||
company_id=None,
|
||||
accepted_at=None,
|
||||
deadline=None,
|
||||
completed_at=None,
|
||||
success=None,
|
||||
halfway_event_emitted=False,
|
||||
requirements=requirements,
|
||||
)
|
||||
|
||||
|
||||
def generate_tasks(*, run_seed, count, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = WorldConfig()
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
streams = RngStreams(run_seed)
|
||||
out = []
|
||||
for idx in range(1, count + 1):
|
||||
rng = streams.stream(f"task_{idx}")
|
||||
requirements = _sample_requirements(rng, cfg)
|
||||
prestige = _sample_required_prestige(rng, cfg, index=idx - 1)
|
||||
out.append(_make_task(rng, cfg, prestige, serial=idx, requirements=requirements))
|
||||
return out
|
||||
|
||||
|
||||
def build_task_rows(*, run_seed, count, cfg=None):
|
||||
generated = generate_tasks(run_seed=run_seed, count=count, cfg=cfg)
|
||||
task_rows = []
|
||||
requirement_rows = []
|
||||
|
||||
for task in generated:
|
||||
task_rows.append({
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"required_prestige": task.required_prestige,
|
||||
"reward_funds_cents": task.reward_funds_cents,
|
||||
"reward_prestige_delta": task.reward_prestige_delta,
|
||||
"skill_boost_pct": task.skill_boost_pct,
|
||||
"status": task.status,
|
||||
"company_id": task.company_id,
|
||||
"accepted_at": task.accepted_at,
|
||||
"deadline": task.deadline,
|
||||
"completed_at": task.completed_at,
|
||||
"success": task.success,
|
||||
"halfway_event_emitted": task.halfway_event_emitted,
|
||||
})
|
||||
for domain, qty in task.requirements.items():
|
||||
requirement_rows.append({
|
||||
"_task_title": task.title,
|
||||
"domain": domain,
|
||||
"required_qty": qty,
|
||||
"completed_qty": 0,
|
||||
})
|
||||
return task_rows, requirement_rows
|
||||
|
||||
|
||||
def generate_replacement_task(*, run_seed, replenish_counter, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = WorldConfig()
|
||||
streams = RngStreams(run_seed)
|
||||
rng = streams.stream(f"replenish_{replenish_counter}")
|
||||
requirements = _sample_requirements(rng, cfg)
|
||||
prestige = _sample_required_prestige(rng, cfg)
|
||||
return _make_task(rng, cfg, prestige, serial=replenish_counter, requirements=requirements)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_task_rows",
|
||||
"generate_replacement_task",
|
||||
"generate_tasks",
|
||||
"GeneratedTask",
|
||||
]
|
||||
42
src/yc_bench/services/rng.py
Normal file
42
src/yc_bench/services/rng.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
def _stable_seed(run_seed, stream_key):
|
||||
raw = f"{run_seed}:{stream_key}".encode("utf-8")
|
||||
digest = hashlib.sha256(raw).digest()
|
||||
return int.from_bytes(digest[:8], byteorder="big", signed=False)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RngStreams:
|
||||
run_seed: int
|
||||
|
||||
def stream(self, stream_key):
|
||||
return random.Random(_stable_seed(self.run_seed, stream_key))
|
||||
|
||||
def clamp_numeric(value, low, high):
|
||||
return max(low, min(high, value))
|
||||
|
||||
def sample_triangular_int(rng, low, high, mode):
|
||||
return int(clamp_numeric((round(rng.triangular(low, high, mode))), low, high))
|
||||
|
||||
def sample_right_skew_triangular_int(rng, low, high):
|
||||
return sample_triangular_int(rng, low, high, high)
|
||||
|
||||
def sample_normal_clamped_float(rng, mean, stdev, low, high):
|
||||
val = clamp_numeric(rng.gauss(mean, stdev), low, high)
|
||||
return round(val, 4)
|
||||
|
||||
def sample_beta_scaled(rng, alpha: float, beta: float, scale: float) -> float:
|
||||
"""Sample from Beta(alpha, beta) multiplied by scale. Used for prestige reward deltas."""
|
||||
return round(scale * rng.betavariate(alpha, beta), 4)
|
||||
|
||||
|
||||
def sample_left_skew_0_2(rng):
|
||||
"""Backward-compat alias using default Beta(1.2, 2.8) * 2.0 params."""
|
||||
return sample_beta_scaled(rng, 1.2, 2.8, 2.0)
|
||||
|
||||
def sample_without_replacement(rng, population, k):
|
||||
return rng.sample(list(population), k)
|
||||
143
src/yc_bench/services/seed_world.py
Normal file
143
src/yc_bench/services/seed_world.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..config.schema import WorldConfig
|
||||
from ..db.models.company import Company, CompanyPrestige, Domain
|
||||
from ..db.models.employee import Employee, EmployeeSkillRate
|
||||
from ..db.models.task import Task, TaskRequirement, TaskStatus
|
||||
|
||||
from .generate_employees import generate_employees
|
||||
from .generate_tasks import generate_tasks
|
||||
|
||||
_ALL_DOMAINS = list(Domain)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SeedWorldRequest:
|
||||
run_seed: int
|
||||
company_name: str
|
||||
horizon_years: int
|
||||
employee_count: int
|
||||
market_task_count: int
|
||||
start_date: datetime | None = None
|
||||
cfg: WorldConfig = field(default_factory=WorldConfig)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SeedWorldResult:
|
||||
company_id: str
|
||||
seeded_at: datetime
|
||||
|
||||
|
||||
def _seed_company(db, req):
|
||||
company = Company(
|
||||
id=uuid4(),
|
||||
name=req.company_name,
|
||||
funds_cents=req.cfg.initial_funds_cents,
|
||||
)
|
||||
db.add(company)
|
||||
db.flush()
|
||||
return company
|
||||
|
||||
|
||||
def _seed_company_prestige(db, company, cfg):
|
||||
for domain in _ALL_DOMAINS:
|
||||
db.add(
|
||||
CompanyPrestige(
|
||||
company_id=company.id,
|
||||
domain=domain,
|
||||
prestige_level=cfg.initial_prestige_level,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _seed_employees(db, company, req):
|
||||
generated = generate_employees(run_seed=req.run_seed, count=req.employee_count, cfg=req.cfg)
|
||||
for emp in generated:
|
||||
employee = Employee(
|
||||
id=uuid4(),
|
||||
company_id=company.id,
|
||||
name=emp.name,
|
||||
work_hours_per_day=emp.work_hours_per_day,
|
||||
salary_cents=emp.salary_cents,
|
||||
)
|
||||
db.add(employee)
|
||||
|
||||
for domain, rate in emp.rates_by_domain.items():
|
||||
db.add(
|
||||
EmployeeSkillRate(
|
||||
employee_id=employee.id,
|
||||
domain=domain,
|
||||
rate_domain_per_hour=rate,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _seed_market_tasks(db, company, req):
|
||||
generated = generate_tasks(run_seed=req.run_seed, count=req.market_task_count, cfg=req.cfg)
|
||||
for task in generated:
|
||||
task_row = Task(
|
||||
id=uuid4(),
|
||||
company_id=None,
|
||||
status=TaskStatus.MARKET,
|
||||
title=task.title,
|
||||
description=task.description,
|
||||
required_prestige=task.required_prestige,
|
||||
reward_funds_cents=task.reward_funds_cents,
|
||||
reward_prestige_delta=task.reward_prestige_delta,
|
||||
skill_boost_pct=task.skill_boost_pct,
|
||||
accepted_at=None,
|
||||
deadline=None,
|
||||
completed_at=None,
|
||||
success=None,
|
||||
halfway_event_emitted=False,
|
||||
)
|
||||
db.add(task_row)
|
||||
|
||||
for domain, qty in task.requirements.items():
|
||||
db.add(
|
||||
TaskRequirement(
|
||||
task_id=task_row.id,
|
||||
domain=domain,
|
||||
required_qty=qty,
|
||||
completed_qty=0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def seed_world(db, req):
|
||||
if req.employee_count <= 0:
|
||||
raise ValueError("employee_count must be positive")
|
||||
if req.market_task_count <= 0:
|
||||
raise ValueError("market_task_count must be positive")
|
||||
|
||||
seeded_at = req.start_date
|
||||
|
||||
company = _seed_company(db, req)
|
||||
_seed_company_prestige(db, company, req.cfg)
|
||||
_seed_employees(db, company, req)
|
||||
_seed_market_tasks(db, company, req)
|
||||
|
||||
return SeedWorldResult(
|
||||
company_id=company.id,
|
||||
seeded_at=seeded_at,
|
||||
)
|
||||
|
||||
|
||||
def seed_world_transactional(db, req):
|
||||
result = seed_world(db, req)
|
||||
db.flush()
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SeedWorldRequest",
|
||||
"SeedWorldResult",
|
||||
"seed_world",
|
||||
"seed_world_transactional",
|
||||
]
|
||||
365
src/yc_bench/services/task_catalog.py
Normal file
365
src/yc_bench/services/task_catalog.py
Normal file
|
|
@ -0,0 +1,365 @@
|
|||
"""Realistic AI-startup task titles and descriptions, keyed by domain.
|
||||
|
||||
Each domain has a pool of (title, description) tuples. The generator picks
|
||||
from these deterministically using the seeded RNG, cycling if the pool is
|
||||
exhausted.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from ..db.models.company import Domain
|
||||
|
||||
TASK_POOL: dict[Domain, list[tuple[str, str]]] = {
|
||||
Domain.SYSTEM: [
|
||||
(
|
||||
"Set Up GPU-Aware K8s Cluster with Auto-Scaling",
|
||||
"Deploy a Kubernetes cluster with NVIDIA GPU operator, node auto-scaling based on inference queue depth, and spot instance fallback for training workloads.",
|
||||
),
|
||||
(
|
||||
"Build CI/CD Pipeline for ML Model Registry",
|
||||
"Create a CI pipeline that runs training validation, pushes versioned model artifacts to a registry, and auto-deploys to a staging inference endpoint.",
|
||||
),
|
||||
(
|
||||
"Implement Blue-Green Deployment for LLM Serving",
|
||||
"Set up zero-downtime model swaps for a vLLM serving cluster with automated rollback triggered by latency and error-rate thresholds.",
|
||||
),
|
||||
(
|
||||
"Deploy Observability Stack for AI Workloads",
|
||||
"Stand up Grafana, Prometheus, and OpenTelemetry with custom dashboards tracking GPU utilization, token throughput, time-to-first-token, and per-request cost.",
|
||||
),
|
||||
(
|
||||
"Terraform Multi-Region Inference Infrastructure",
|
||||
"Write IaC modules to provision inference endpoints across 3+ regions with global load balancing, failover routing, and centralized logging.",
|
||||
),
|
||||
(
|
||||
"Container Image Optimization for ML Serving",
|
||||
"Reduce Docker image sizes for PyTorch/CUDA serving containers from 15 GB to under 4 GB using multi-stage builds and distroless bases to cut cold-start times.",
|
||||
),
|
||||
(
|
||||
"Implement Secret Rotation and API Key Management",
|
||||
"Build an automated secret rotation system for API keys, database credentials, and model provider tokens across staging and production environments.",
|
||||
),
|
||||
(
|
||||
"Set Up Cost Monitoring and GPU Budget Alerts",
|
||||
"Integrate cloud billing APIs with a dashboard showing per-team GPU spend, cost-per-inference breakdowns, and automated alerts when daily spend exceeds thresholds.",
|
||||
),
|
||||
(
|
||||
"Build Canary Release Pipeline for Embedding Models",
|
||||
"Implement a canary deployment system that gradually shifts traffic to new embedding model versions, comparing retrieval quality metrics in real time.",
|
||||
),
|
||||
(
|
||||
"Migrate Inference Workloads to Serverless GPU",
|
||||
"Evaluate and migrate bursty inference workloads to serverless GPU providers, benchmarking cold-start latency against always-on instances.",
|
||||
),
|
||||
(
|
||||
"Implement Disaster Recovery for Training Checkpoints",
|
||||
"Design a cross-region checkpoint backup system with automated integrity verification, ensuring training runs can resume within 15 minutes of any single-region failure.",
|
||||
),
|
||||
(
|
||||
"Build Internal Developer Platform for ML Engineers",
|
||||
"Create a self-service portal where ML engineers can request GPU instances, spin up Jupyter environments, and launch training jobs without touching infrastructure.",
|
||||
),
|
||||
],
|
||||
Domain.RESEARCH: [
|
||||
(
|
||||
"Design Benchmark for Legal Document QA",
|
||||
"Create a benchmark suite of 2,000+ annotated legal questions across contract law and compliance, with human-expert baselines and an automated evaluation harness.",
|
||||
),
|
||||
(
|
||||
"Investigate MoE Routing for Multilingual Models",
|
||||
"Research and prototype alternative Mixture-of-Experts routing strategies that improve expert utilization for low-resource languages without degrading high-resource performance.",
|
||||
),
|
||||
(
|
||||
"Reproduce and Extend Speculative Decoding Results",
|
||||
"Replicate speculative decoding paper results on Llama-3 class models, then test novel draft model architectures that improve acceptance rates on code generation.",
|
||||
),
|
||||
(
|
||||
"Develop RAG Hallucination Detection Framework",
|
||||
"Build a systematic evaluation pipeline measuring faithfulness, relevance, and attribution accuracy for retrieval-augmented generation systems.",
|
||||
),
|
||||
(
|
||||
"Prototype LoRA Merging for Multi-Tenant Serving",
|
||||
"Research methods for dynamically composing multiple LoRA adapters at inference time, measuring quality degradation versus serving separate fine-tuned models.",
|
||||
),
|
||||
(
|
||||
"Benchmark Long-Context Retrieval Across 128K Models",
|
||||
"Systematically evaluate needle-in-a-haystack and multi-hop reasoning performance across frontier models at various context lengths with reproducible results.",
|
||||
),
|
||||
(
|
||||
"Investigate Synthetic Data Quality for Code Generation",
|
||||
"Develop automated quality scoring methods for synthetically generated code training data, correlating filter thresholds with downstream model performance.",
|
||||
),
|
||||
(
|
||||
"Research KV-Cache Compression Techniques",
|
||||
"Prototype and benchmark KV-cache eviction and quantization strategies for long-running conversational agents under fixed memory budgets.",
|
||||
),
|
||||
(
|
||||
"Build Ablation Study Framework for Prompt Engineering",
|
||||
"Create an experimentation harness for testing prompt variations across multiple models and tasks with statistical significance testing and cost tracking.",
|
||||
),
|
||||
(
|
||||
"Explore Constitutional AI for Domain-Specific Safety",
|
||||
"Adapt constitutional AI methods to create a self-improving safety filter for a healthcare chatbot, defining domain-specific principles and measuring accuracy.",
|
||||
),
|
||||
(
|
||||
"Develop Novel Chunking Strategies for Technical RAG",
|
||||
"Research and benchmark alternative document chunking methods—semantic, AST-aware, sliding window—specifically for API documentation and code repositories.",
|
||||
),
|
||||
(
|
||||
"Prototype Test-Time Compute Scaling for Math Reasoning",
|
||||
"Implement best-of-N sampling, tree search, and self-verification approaches for math reasoning, measuring the compute-accuracy Pareto frontier.",
|
||||
),
|
||||
],
|
||||
Domain.DATA: [
|
||||
(
|
||||
"Build Web Scraping Pipeline for Industry News Corpus",
|
||||
"Design a pipeline that crawls 50+ AI/tech news sources daily, deduplicates articles, extracts structured metadata, and loads clean text into a vector store.",
|
||||
),
|
||||
(
|
||||
"Create Annotation Platform for Dialogue Quality",
|
||||
"Build an annotation workflow where human raters score LLM conversation logs on helpfulness, accuracy, and safety, with inter-rater agreement tracking.",
|
||||
),
|
||||
(
|
||||
"Implement PII Detection and Redaction Pipeline",
|
||||
"Deploy a pipeline to detect and redact personally identifiable information from training data, with audit logging and configurable redaction strategies.",
|
||||
),
|
||||
(
|
||||
"Curate Instruction-Tuning Dataset from Internal Docs",
|
||||
"Extract, clean, and convert 10,000+ pages of internal documentation into high-quality instruction-response pairs suitable for fine-tuning.",
|
||||
),
|
||||
(
|
||||
"Build Data Quality Monitoring for Feature Store",
|
||||
"Implement data validation checks on streaming feature pipelines, alerting on schema drift, null-rate spikes, and distribution shifts before they affect models.",
|
||||
),
|
||||
(
|
||||
"Design ETL Pipeline for Multi-Modal Training Data",
|
||||
"Build a DAG pipeline that ingests images, PDFs, and structured data, applies OCR and layout detection, and produces unified records for vision-language training.",
|
||||
),
|
||||
(
|
||||
"Implement Deduplication for Large Text Corpora",
|
||||
"Deploy MinHash LSH-based near-deduplication at scale for 100M+ documents with configurable similarity thresholds and a review UI for borderline cases.",
|
||||
),
|
||||
(
|
||||
"Build Synthetic Data Pipeline for Rare Edge Cases",
|
||||
"Create a system that uses frontier LLMs to generate realistic synthetic examples for underrepresented categories in a classification dataset.",
|
||||
),
|
||||
(
|
||||
"Create Data Versioning and Lineage Tracking System",
|
||||
"Set up data versioning integrated with the ML training pipeline so every model checkpoint can be traced back to the exact dataset snapshot used.",
|
||||
),
|
||||
(
|
||||
"Build Customer Feedback Loop into Training Pipeline",
|
||||
"Implement a system where end-user thumbs-up/down signals are routed, reviewed, and selectively incorporated into fine-tuning datasets with human approval.",
|
||||
),
|
||||
(
|
||||
"Migrate Legacy Warehouse to ML-Ready Lakehouse",
|
||||
"Transform and migrate 5 years of product analytics data from a legacy SQL warehouse into a Parquet-based lakehouse optimized for feature engineering.",
|
||||
),
|
||||
],
|
||||
Domain.FRONTEND: [
|
||||
(
|
||||
"Build Interactive LLM Playground with Streaming",
|
||||
"Create a web app where users test multiple LLM providers side-by-side with streaming output, adjustable parameters, and conversation history persistence.",
|
||||
),
|
||||
(
|
||||
"Design Admin Dashboard for AI Agent Monitoring",
|
||||
"Build a dashboard showing real-time agent execution traces, tool call sequences, token usage graphs, and cost breakdowns with drill-down filtering.",
|
||||
),
|
||||
(
|
||||
"Create Document Chat Interface for RAG Product",
|
||||
"Implement a drag-and-drop document upload UI with a conversational interface showing source citations, confidence indicators, and reference highlighting.",
|
||||
),
|
||||
(
|
||||
"Build Annotation Review and Approval Interface",
|
||||
"Design a UI for data team leads to review annotator work, resolve disagreements, view agreement stats, and approve batches for training inclusion.",
|
||||
),
|
||||
(
|
||||
"Implement Prompt Management Studio",
|
||||
"Build a collaborative app where teams version, test, and A/B deploy prompt templates with visual diffs, rollback, and per-version performance analytics.",
|
||||
),
|
||||
(
|
||||
"Create Customer-Facing AI Usage Analytics Dashboard",
|
||||
"Build an embeddable dashboard showing API call volumes, latency percentiles, token consumption, and cost trends for enterprise customers.",
|
||||
),
|
||||
(
|
||||
"Build Visual Pipeline Editor for No-Code AI Workflows",
|
||||
"Create a node-based drag-and-drop editor where non-technical users chain data sources, LLM calls, and output actions into automated AI workflows.",
|
||||
),
|
||||
(
|
||||
"Design Chat Widget for Website Embedding",
|
||||
"Build a lightweight, brandable chat widget under 50 KB that customers embed on their sites, with streaming responses and escalation-to-human capability.",
|
||||
),
|
||||
(
|
||||
"Build Model Comparison Results Viewer",
|
||||
"Create a web interface displaying benchmark results across models in interactive tables and charts with filtering by task type and model size.",
|
||||
),
|
||||
(
|
||||
"Implement Real-Time Collaboration for AI Writing Tool",
|
||||
"Add multiplayer editing to an AI writing tool using CRDTs, with per-user cursors, AI suggestion tracking, and version history.",
|
||||
),
|
||||
(
|
||||
"Create Enterprise RAG Onboarding Wizard",
|
||||
"Build a step-by-step setup wizard guiding enterprise customers through connecting data sources, configuring chunking, testing retrieval, and deploying their endpoint.",
|
||||
),
|
||||
],
|
||||
Domain.BACKEND: [
|
||||
(
|
||||
"Build Multi-Tenant LLM Gateway with Rate Limiting",
|
||||
"Implement an API gateway that proxies requests to multiple LLM providers, enforces per-tenant rate limits, tracks usage, and handles automatic failover.",
|
||||
),
|
||||
(
|
||||
"Implement OAuth2 + SAML SSO for Enterprise Platform",
|
||||
"Add enterprise authentication supporting SAML 2.0, OIDC, and SCIM provisioning for customers integrating with their identity provider.",
|
||||
),
|
||||
(
|
||||
"Design Webhook System for Async AI Job Completion",
|
||||
"Build a reliable webhook delivery system with exponential backoff, signature verification, dead letter queue, and a webhook management API.",
|
||||
),
|
||||
(
|
||||
"Create Unified Embedding API with Caching Layer",
|
||||
"Build a microservice abstracting over multiple embedding providers with a Redis-backed cache, batch processing, and automatic model version migration.",
|
||||
),
|
||||
(
|
||||
"Build Conversation Memory Service for Multi-Session Agents",
|
||||
"Implement a service that stores, summarizes, and retrieves conversation history across sessions using structured storage and semantic vector search.",
|
||||
),
|
||||
(
|
||||
"Implement Usage-Based Billing with Stripe Integration",
|
||||
"Build a metering system that tracks token consumption per customer, aggregates monthly invoices, and syncs with Stripe for automated usage-based charging.",
|
||||
),
|
||||
(
|
||||
"Create Plugin Marketplace Backend",
|
||||
"Design the API and data model for a marketplace where third-party developers register, version, and distribute plugins for the AI platform.",
|
||||
),
|
||||
(
|
||||
"Build RAG Ingestion Service with Chunking and Indexing",
|
||||
"Implement an async document processing service that accepts PDFs, DOCX, and HTML, chunks them, generates embeddings, and upserts into a vector store.",
|
||||
),
|
||||
(
|
||||
"Implement Audit Logging and Compliance API",
|
||||
"Build a tamper-evident audit log system recording all AI interactions and admin actions, with an API for compliance queries and SOC 2 / HIPAA exports.",
|
||||
),
|
||||
(
|
||||
"Design Multi-Model Routing and Fallback Service",
|
||||
"Create a smart routing layer directing requests to the optimal model based on task complexity, latency requirements, and cost, with provider failover.",
|
||||
),
|
||||
(
|
||||
"Build File Processing Service for Vision-Language Models",
|
||||
"Implement an async service that accepts images and documents, runs them through vision-language models for extraction, and returns structured JSON output.",
|
||||
),
|
||||
(
|
||||
"Implement Streaming API with Server-Sent Events",
|
||||
"Build an SSE-based streaming endpoint for LLM responses with connection resumption, partial response caching, and graceful degradation.",
|
||||
),
|
||||
],
|
||||
Domain.TRAINING: [
|
||||
(
|
||||
"Fine-Tune Llama-3 8B for Domain-Specific Support",
|
||||
"Run supervised fine-tuning on 50K curated customer support conversations using QLoRA, targeting 15% accuracy improvement over the base model.",
|
||||
),
|
||||
(
|
||||
"Implement RLHF Pipeline for Code Generation Model",
|
||||
"Build an end-to-end RLHF pipeline with a reward model trained on human preference data and PPO training loop evaluated against HumanEval.",
|
||||
),
|
||||
(
|
||||
"Distill GPT-4 Class Model into Efficient 3B Model",
|
||||
"Use knowledge distillation with synthetic data to create a compact model retaining 90%+ teacher performance on targeted tasks at 10x lower inference cost.",
|
||||
),
|
||||
(
|
||||
"Train Custom Embedding Model for Vertical Search",
|
||||
"Fine-tune a sentence-transformers model on domain-specific query-document pairs with contrastive learning, hard negative mining, and retrieval benchmarks.",
|
||||
),
|
||||
(
|
||||
"Build Hyperparameter Search for Fine-Tuning Jobs",
|
||||
"Implement an Optuna-based HPO system searching over learning rate, LoRA rank, batch size, and data mixing ratios with early stopping.",
|
||||
),
|
||||
(
|
||||
"Run Continued Pre-Training on Proprietary Corpus",
|
||||
"Execute continued pre-training of a 7B base model on 10B tokens of domain-specific text with careful learning rate scheduling to avoid catastrophic forgetting.",
|
||||
),
|
||||
(
|
||||
"Train Reward Model from Preference Annotations",
|
||||
"Collect and process 20K pairwise preference annotations, train a Bradley-Terry reward model, and validate calibration against held-out human judgments.",
|
||||
),
|
||||
(
|
||||
"Build Multi-GPU Training Infra with DeepSpeed",
|
||||
"Set up distributed training using DeepSpeed ZeRO Stage 3 across an 8-node GPU cluster with checkpoint sharding and fault-tolerant resumption.",
|
||||
),
|
||||
(
|
||||
"Implement DPO Fine-Tuning Pipeline",
|
||||
"Build a Direct Preference Optimization pipeline as a simpler RLHF alternative, comparing quality and training stability on the same preference dataset.",
|
||||
),
|
||||
(
|
||||
"Train Vision-Language Adapter for Document Understanding",
|
||||
"Fine-tune a LoRA adapter on a VLM for extracting structured data from invoices, receipts, and forms with 95%+ field-level accuracy.",
|
||||
),
|
||||
(
|
||||
"Build Eval-Driven Training Loop with Auto Checkpointing",
|
||||
"Implement a training harness that runs benchmarks every N steps, auto-saves the best checkpoint, detects instability, and alerts on loss spikes.",
|
||||
),
|
||||
(
|
||||
"Fine-Tune Whisper for Industry-Specific Transcription",
|
||||
"Adapt Whisper-large for medical dictation using 500 hours of labeled audio, targeting 30% WER reduction on domain-specific terminology.",
|
||||
),
|
||||
],
|
||||
Domain.HARDWARE: [
|
||||
(
|
||||
"Optimize LLM Inference Latency with TensorRT-LLM",
|
||||
"Convert a 70B model to TensorRT-LLM with INT8/FP8 quantization, continuous batching, and paged attention, targeting sub-200ms time-to-first-token.",
|
||||
),
|
||||
(
|
||||
"Deploy On-Device ML Model for Mobile Classification",
|
||||
"Convert a PyTorch vision model to Core ML and TFLite, optimize with quantization-aware training, and benchmark on iPhone and Pixel hardware.",
|
||||
),
|
||||
(
|
||||
"Build GPU Cluster Scheduling with Fair-Share Queuing",
|
||||
"Implement a scheduler for a shared GPU cluster enforcing per-team quotas, priority queuing, preemption policies, and utilization-based chargeback.",
|
||||
),
|
||||
(
|
||||
"Implement Quantization Pipeline (GPTQ/AWQ/GGUF)",
|
||||
"Build an automated pipeline that takes any model, produces GPTQ, AWQ, and GGUF quantized variants, runs quality regression, and publishes passing models.",
|
||||
),
|
||||
(
|
||||
"Deploy Edge Inference for Real-Time Video Analytics",
|
||||
"Set up an NVIDIA Jetson-based inference node running YOLO and a lightweight LLM for on-premises real-time camera analysis with local data processing.",
|
||||
),
|
||||
(
|
||||
"Optimize vLLM Serving for Production Workload",
|
||||
"Profile and tune vLLM parameters—max batch size, KV cache, swap space, tensor parallelism—for target throughput at P99 latency SLA.",
|
||||
),
|
||||
(
|
||||
"Build Multi-GPU Inference with Tensor Parallelism",
|
||||
"Configure and benchmark a 70B+ model serving across 4-8 GPUs with tensor and pipeline parallelism, optimizing throughput versus latency tradeoffs.",
|
||||
),
|
||||
(
|
||||
"Implement Dynamic Batching for Inference Requests",
|
||||
"Build a request batching layer that groups incoming requests by sequence length and priority, maximizing GPU utilization within per-request latency SLAs.",
|
||||
),
|
||||
(
|
||||
"Design Hybrid CPU/GPU Inference Architecture",
|
||||
"Architect a system routing lightweight requests to CPU inference and complex requests to GPU instances, reducing overall compute cost by 40%.",
|
||||
),
|
||||
(
|
||||
"Set Up Triton Inference Server for Multi-Model Serving",
|
||||
"Deploy NVIDIA Triton to serve embedding, reranking, and generation models on shared GPU infrastructure with dynamic batching and concurrency control.",
|
||||
),
|
||||
(
|
||||
"Build GPU Health Monitoring and Failover System",
|
||||
"Implement a daemon detecting GPU memory errors, thermal throttling, and NVLink degradation, automatically draining affected nodes and redistributing workloads.",
|
||||
),
|
||||
(
|
||||
"Benchmark Specialized AI Accelerators vs H100",
|
||||
"Evaluate Groq, Cerebras, and custom ASICs against H100 GPUs, producing a cost-per-token and latency comparison with a migration recommendation.",
|
||||
),
|
||||
(
|
||||
"Implement Speculative Decoding in Production Stack",
|
||||
"Integrate speculative decoding with a small draft model into the existing serving infrastructure, measuring real-world throughput improvement.",
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def pick_task_text(rng, domain: Domain) -> tuple[str, str]:
|
||||
"""Deterministically pick a (title, description) for *domain* using *rng*."""
|
||||
pool = TASK_POOL[domain]
|
||||
idx = rng.randint(0, len(pool) - 1)
|
||||
return pool[idx]
|
||||
Loading…
Add table
Add a link
Reference in a new issue