diff --git a/src/yc_bench/db/base.py b/src/yc_bench/db/base.py new file mode 100644 index 0000000..e493d0b --- /dev/null +++ b/src/yc_bench/db/base.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from sqlalchemy.orm import DeclarativeBase + +class Base(DeclarativeBase): + pass + +class TimestampMixin: + pass + +__all__ = ["Base", "TimestampMixin"] \ No newline at end of file diff --git a/src/yc_bench/db/models/company.py b/src/yc_bench/db/models/company.py new file mode 100644 index 0000000..d6a6db1 --- /dev/null +++ b/src/yc_bench/db/models/company.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from decimal import Decimal +from uuid import uuid4 +from enum import Enum + +from sqlalchemy import BigInteger, CheckConstraint, Enum as SAEnum, ForeignKey, Numeric, String, Uuid +from sqlalchemy.orm import mapped_column + +from ..base import Base + +class Domain(str, Enum): + SYSTEM = "system" + RESEARCH = "research" + DATA = "data" + FRONTEND = "frontend" + BACKEND = "backend" + TRAINING = "training" + HARDWARE = "hardware" + +class Company(Base): + __tablename__ = "companies" + + id = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid4, + ) + name = mapped_column( + String(255), + nullable=False, + ) + funds_cents = mapped_column( + BigInteger, + nullable=False, + ) + +class CompanyPrestige(Base): + __tablename__ = "company_prestige" + __table_args__ = ( + CheckConstraint("prestige_level >= 1 AND prestige_level <= 10", name="ck_company_prestige_prestige_level_range"), + ) + + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + domain = mapped_column( + SAEnum(Domain, name="domain", values_callable=lambda e: [x.value for x in e]), + primary_key=True, + nullable=False, + ) + prestige_level = mapped_column( + Numeric(6, 3), + nullable=False, + default=Decimal("1.000"), + ) + +__all__ = ["Domain", "Company", "CompanyPrestige"] diff --git a/src/yc_bench/db/models/employee.py b/src/yc_bench/db/models/employee.py new file mode 100644 index 0000000..fb663f2 --- /dev/null +++ b/src/yc_bench/db/models/employee.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from uuid import uuid4 + +from decimal import Decimal +from sqlalchemy import BigInteger, CheckConstraint, Enum as SAEnum, ForeignKey, Numeric, String, Uuid +from sqlalchemy.orm import mapped_column + +from ..base import Base +from .company import Domain + +class Employee(Base): + __tablename__ = "employees" + __table_args__ = ( + CheckConstraint("work_hours_per_day > 0", name="ck_employees_work_hours_per_day_gt_0"), + CheckConstraint("salary_cents >= 0", name="ck_employees_salary_cents_gte_0"), + ) + + id = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid4, + ) + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + nullable=False, + ) + name = mapped_column( + String(255), + nullable=False, + ) + work_hours_per_day = mapped_column( + Numeric(5, 2), + nullable=False, + default=Decimal("9.00"), + ) + salary_cents = mapped_column( + BigInteger, + nullable=False, + ) + +class EmployeeSkillRate(Base): + __tablename__ = "employee_skill_rates" + __table_args__ = ( + CheckConstraint("rate_domain_per_hour >= 0", name="ck_employee_skill_rates_rate_domain_per_hour_gte_0"), + ) + + employee_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("employees.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + domain = mapped_column( + SAEnum(Domain, name="domain", values_callable=lambda e: [x.value for x in e]), + primary_key=True, + nullable=False, + ) + rate_domain_per_hour = mapped_column( + Numeric(12, 4), + nullable=False, + default=Decimal("1.0000"), + ) + +__all__ = ["Employee", "EmployeeSkillRate"] diff --git a/src/yc_bench/db/models/event.py b/src/yc_bench/db/models/event.py new file mode 100644 index 0000000..1fba9e4 --- /dev/null +++ b/src/yc_bench/db/models/event.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from enum import Enum +from uuid import uuid4 + +from sqlalchemy import Boolean, DateTime, Enum as SAEnum, ForeignKey, JSON, String, Uuid, text +from sqlalchemy.orm import mapped_column + +from ..base import Base + +class EventType(str, Enum): + TASK_HALF_PROGRESS = "task_half_progress" + TASK_COMPLETED = "task_completed" + BANKRUPTCY = "bankruptcy" + HORIZON_END = "horizon_end" + +class SimEvent(Base): + __tablename__ = "sim_events" + + id = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid4, + ) + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + nullable=False, + ) + event_type = mapped_column( + SAEnum(EventType, name="event_type", values_callable=lambda e: [x.value for x in e]), + nullable=False, + ) + scheduled_at = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + payload = mapped_column( + JSON, + nullable=False, + ) + dedupe_key = mapped_column( # guardrail against duplicates in event recomputation + String(255), + nullable=True, + ) + consumed = mapped_column( + Boolean, + nullable=False, + default=False, + server_default=text("false"), + ) + +__all__ = ["EventType", "SimEvent"] diff --git a/src/yc_bench/db/models/ledger.py b/src/yc_bench/db/models/ledger.py new file mode 100644 index 0000000..7103f56 --- /dev/null +++ b/src/yc_bench/db/models/ledger.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from enum import Enum +from uuid import uuid4 + +from sqlalchemy import BigInteger, DateTime, Enum as SAEnum, ForeignKey, String, Uuid +from sqlalchemy.orm import mapped_column + +from ..base import Base + +class LedgerCategory(str, Enum): + MONTHLY_PAYROLL = "monthly_payroll" + TASK_REWARD = "task_reward" + TASK_FAIL_PENALTY = "task_fail_penalty" + TASK_CANCEL_PENALTY = "task_cancel_penalty" + +class LedgerEntry(Base): + __tablename__ = "ledger_entries" + + id = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid4, + ) + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + nullable=False, + ) + occurred_at = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + category = mapped_column( + SAEnum(LedgerCategory, name="ledger_category", values_callable=lambda e: [x.value for x in e]), + nullable=False, + ) + amount_cents = mapped_column( + BigInteger, + nullable=False, + ) + ref_type = mapped_column( + String(64), + nullable=True, + ) + ref_id = mapped_column( # can refer multiple tables, no foreign key + Uuid(as_uuid=True), + nullable=True, + ) + +__all__ = ["LedgerCategory", "LedgerEntry"] \ No newline at end of file diff --git a/src/yc_bench/db/models/scratchpad.py b/src/yc_bench/db/models/scratchpad.py new file mode 100644 index 0000000..39e6697 --- /dev/null +++ b/src/yc_bench/db/models/scratchpad.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from uuid import uuid4 + +from sqlalchemy import ForeignKey, Text, Uuid +from sqlalchemy.orm import mapped_column + +from ..base import Base + + +class Scratchpad(Base): + __tablename__ = "scratchpads" + + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + content = mapped_column( + Text, + nullable=False, + default="", + ) + + +__all__ = ["Scratchpad"] diff --git a/src/yc_bench/db/models/session.py b/src/yc_bench/db/models/session.py new file mode 100644 index 0000000..497f3d4 --- /dev/null +++ b/src/yc_bench/db/models/session.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from uuid import uuid4 + +from sqlalchemy import BigInteger, DateTime, ForeignKey, String, Uuid, Date, Enum as SAEnum +from sqlalchemy.orm import mapped_column + +from ..base import Base +from .event import EventType + +class Session(Base): + __tablename__ = "sessions" + + id = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid4, + ) + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + nullable=False, + ) + started_at = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + ended_at = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + wake_reason = mapped_column( + SAEnum(EventType, name="event_type", values_callable=lambda e: [x.value for x in e]), + nullable=False, + ) + +class MonthlyMetric(Base): + __tablename__ = "monthly_metrics" + + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + month_start = mapped_column( + Date, + primary_key=True, + nullable=False, + ) + revenue_cents = mapped_column( + BigInteger, + nullable=False, + ) + cost_cents = mapped_column( + BigInteger, + nullable=False, + ) + return_cents = mapped_column( + BigInteger, + nullable=False, + ) + ending_funds_cents = mapped_column( + BigInteger, + nullable=False, + ) + +__all__ = ["Session", "MonthlyMetric"] \ No newline at end of file diff --git a/src/yc_bench/db/models/sim_state.py b/src/yc_bench/db/models/sim_state.py new file mode 100644 index 0000000..d6e9e54 --- /dev/null +++ b/src/yc_bench/db/models/sim_state.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, Uuid +from sqlalchemy.orm import mapped_column + +from ..base import Base + + +class SimState(Base): + __tablename__ = "sim_state" + + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + sim_time = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + run_seed = mapped_column( + Integer, + nullable=False, + ) + horizon_end = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + replenish_counter = mapped_column( + Integer, + nullable=False, + default=0, + server_default="0", + ) + +__all__ = ["SimState"] diff --git a/src/yc_bench/db/models/task.py b/src/yc_bench/db/models/task.py new file mode 100644 index 0000000..26deda2 --- /dev/null +++ b/src/yc_bench/db/models/task.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from enum import Enum +from uuid import uuid4 + +from sqlalchemy import BigInteger, Boolean, CheckConstraint, DateTime, Enum as SAEnum, ForeignKey, Integer, Numeric, String, Uuid, text +from sqlalchemy.orm import mapped_column + +from ..base import Base +from .company import Domain + +class TaskStatus(str, Enum): + MARKET = "market" + PLANNED = "planned" + ACTIVE = "active" + COMPLETED_SUCCESS = "completed_success" + COMPLETED_FAIL = "completed_fail" + CANCELLED = "cancelled" + +class Task(Base): + __tablename__ = "tasks" + __table_args__ = ( + CheckConstraint("required_prestige >= 1 AND required_prestige <= 10", name="ck_tasks_required_prestige_range"), + CheckConstraint("skill_boost_pct >= 0", name="ck_tasks_skill_boost_pct_range"), + CheckConstraint("reward_funds_cents >= 0", name="ck_tasks_reward_funds_cents_gte_0"), + CheckConstraint("reward_prestige_delta >= 0 AND reward_prestige_delta <= 5", name="ck_tasks_reward_prestige_delta_range"), + ) + + id = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid4, + ) + company_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("companies.id", ondelete="CASCADE"), + nullable=True, + ) + status = mapped_column( + SAEnum(TaskStatus, name="task_status", values_callable=lambda e: [x.value for x in e]), + nullable=False, + default=TaskStatus.MARKET, + ) + title = mapped_column( + String(255), + nullable=False, + ) + description = mapped_column( + String, + nullable=False, + ) + required_prestige = mapped_column( + Integer, + nullable=False, + ) + reward_funds_cents = mapped_column( + BigInteger, + nullable=False, + ) + reward_prestige_delta = mapped_column( + Numeric(6, 3), + nullable=False, + ) + skill_boost_pct = mapped_column( + Numeric(6, 4), + nullable=False, + ) + accepted_at = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + deadline = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + completed_at = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + success = mapped_column( + Boolean, + nullable=True, + ) + halfway_event_emitted = mapped_column( + Boolean, + nullable=False, + default=False, + server_default=text("false"), + ) + +class TaskRequirement(Base): + __tablename__ = "task_requirements" + __table_args__ = ( + CheckConstraint("required_qty >= 200 AND required_qty <= 3000", name="ck_task_requirements_required_qty_range"), + CheckConstraint("completed_qty >= 0", name="ck_task_requirements_completed_qty_gte_0"), + CheckConstraint("completed_qty <= required_qty", name="ck_task_requirements_completed_qty_lte_required_qty"), + ) + + task_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("tasks.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + domain = mapped_column( + SAEnum(Domain, name="domain", values_callable=lambda e: [x.value for x in e]), + primary_key=True, + nullable=False, + ) + required_qty = mapped_column( + Numeric(14, 4), + nullable=False, + ) + completed_qty = mapped_column( + Numeric(14, 4), + nullable=False, + default=0, + ) + +class TaskAssignment(Base): + __tablename__ = "task_assignments" + + task_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("tasks.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + employee_id = mapped_column( + Uuid(as_uuid=True), + ForeignKey("employees.id", ondelete="CASCADE"), + primary_key=True, + nullable=False, + ) + assigned_at = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + +__all__ = ["TaskStatus", "Task", "TaskRequirement", "TaskAssignment"] diff --git a/src/yc_bench/db/session.py b/src/yc_bench/db/session.py new file mode 100644 index 0000000..bffb061 --- /dev/null +++ b/src/yc_bench/db/session.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import os +from contextlib import contextmanager + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + + +def _get_database_url() -> str: + default = "sqlite:///db/yc_bench.db" + url = os.environ.get("DATABASE_URL", default) + # Auto-create parent directory for sqlite:/// relative paths + if url.startswith("sqlite:///") and not url.startswith("sqlite:////"): + import pathlib + db_path = pathlib.Path(url[len("sqlite:///"):]) + db_path.parent.mkdir(parents=True, exist_ok=True) + return url + + +def _maybe_register_psycopg_enum_dumpers(url: str) -> None: + """Register psycopg3 enum value dumpers — only needed for PostgreSQL. + + psycopg3 sends enum .name (uppercase) by default; the DB stores .value + (lowercase). This adapter fixes that mismatch. SQLite doesn't need it + because SQLAlchemy stores enums as plain VARCHAR using .value directly. + """ + if not url.startswith("postgresql"): + return + + import enum + import psycopg + from psycopg.adapt import Dumper + from psycopg.pq import Format + + class _EnumValueDumper(Dumper): + format = Format.TEXT + + def dump(self, obj): + return obj.value.encode("utf-8") if isinstance(obj, enum.Enum) else str(obj).encode("utf-8") + + from .models.company import Domain + from .models.event import EventType + from .models.ledger import LedgerCategory + from .models.task import TaskStatus + + for cls in (Domain, EventType, LedgerCategory, TaskStatus): + psycopg.adapters.register_dumper(cls, _EnumValueDumper) + + +def build_engine(url: str | None = None): + db_url = url or _get_database_url() + _maybe_register_psycopg_enum_dumpers(db_url) + + kwargs: dict = {"echo": False, "future": True} + if db_url.startswith("sqlite"): + # Required for SQLAlchemy's connection pool when used across threads + kwargs["connect_args"] = {"check_same_thread": False} + + return create_engine(db_url, **kwargs) + + +def build_session_factory(engine): + return sessionmaker( + bind=engine, + autoflush=False, + autocommit=False, + expire_on_commit=False, + class_=Session, + ) + + +@contextmanager +def session_scope(session_factory): + session = session_factory() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + +def init_db(engine) -> None: + """Create all tables that do not yet exist. Safe to call on every startup.""" + from .base import Base + # Import all models so SQLAlchemy registers them with Base.metadata before create_all. + from .models import company, employee, event, ledger, scratchpad, session, sim_state, task # noqa: F401 + Base.metadata.create_all(engine) + + +__all__ = ["build_engine", "build_session_factory", "session_scope", "init_db"]