Add db/ source files that were blocked by overly broad gitignore

The old `db/` pattern in .gitignore matched src/yc_bench/db/ too,
preventing all ORM models and session.py from being committed.
Previous commit fixed .gitignore to `/db/`; this adds the 10 missing files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
adit jain 2026-02-26 21:19:45 -08:00
parent a11b2828a9
commit db7d9f218a
10 changed files with 608 additions and 0 deletions

11
src/yc_bench/db/base.py Normal file
View file

@ -0,0 +1,11 @@
from __future__ import annotations
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
class TimestampMixin:
pass
__all__ = ["Base", "TimestampMixin"]

View file

@ -0,0 +1,61 @@
from __future__ import annotations
from decimal import Decimal
from uuid import uuid4
from enum import Enum
from sqlalchemy import BigInteger, CheckConstraint, Enum as SAEnum, ForeignKey, Numeric, String, Uuid
from sqlalchemy.orm import mapped_column
from ..base import Base
class Domain(str, Enum):
SYSTEM = "system"
RESEARCH = "research"
DATA = "data"
FRONTEND = "frontend"
BACKEND = "backend"
TRAINING = "training"
HARDWARE = "hardware"
class Company(Base):
__tablename__ = "companies"
id = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid4,
)
name = mapped_column(
String(255),
nullable=False,
)
funds_cents = mapped_column(
BigInteger,
nullable=False,
)
class CompanyPrestige(Base):
__tablename__ = "company_prestige"
__table_args__ = (
CheckConstraint("prestige_level >= 1 AND prestige_level <= 10", name="ck_company_prestige_prestige_level_range"),
)
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
domain = mapped_column(
SAEnum(Domain, name="domain", values_callable=lambda e: [x.value for x in e]),
primary_key=True,
nullable=False,
)
prestige_level = mapped_column(
Numeric(6, 3),
nullable=False,
default=Decimal("1.000"),
)
__all__ = ["Domain", "Company", "CompanyPrestige"]

View file

@ -0,0 +1,66 @@
from __future__ import annotations
from uuid import uuid4
from decimal import Decimal
from sqlalchemy import BigInteger, CheckConstraint, Enum as SAEnum, ForeignKey, Numeric, String, Uuid
from sqlalchemy.orm import mapped_column
from ..base import Base
from .company import Domain
class Employee(Base):
__tablename__ = "employees"
__table_args__ = (
CheckConstraint("work_hours_per_day > 0", name="ck_employees_work_hours_per_day_gt_0"),
CheckConstraint("salary_cents >= 0", name="ck_employees_salary_cents_gte_0"),
)
id = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid4,
)
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
nullable=False,
)
name = mapped_column(
String(255),
nullable=False,
)
work_hours_per_day = mapped_column(
Numeric(5, 2),
nullable=False,
default=Decimal("9.00"),
)
salary_cents = mapped_column(
BigInteger,
nullable=False,
)
class EmployeeSkillRate(Base):
__tablename__ = "employee_skill_rates"
__table_args__ = (
CheckConstraint("rate_domain_per_hour >= 0", name="ck_employee_skill_rates_rate_domain_per_hour_gte_0"),
)
employee_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("employees.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
domain = mapped_column(
SAEnum(Domain, name="domain", values_callable=lambda e: [x.value for x in e]),
primary_key=True,
nullable=False,
)
rate_domain_per_hour = mapped_column(
Numeric(12, 4),
nullable=False,
default=Decimal("1.0000"),
)
__all__ = ["Employee", "EmployeeSkillRate"]

View file

@ -0,0 +1,53 @@
from __future__ import annotations
from enum import Enum
from uuid import uuid4
from sqlalchemy import Boolean, DateTime, Enum as SAEnum, ForeignKey, JSON, String, Uuid, text
from sqlalchemy.orm import mapped_column
from ..base import Base
class EventType(str, Enum):
TASK_HALF_PROGRESS = "task_half_progress"
TASK_COMPLETED = "task_completed"
BANKRUPTCY = "bankruptcy"
HORIZON_END = "horizon_end"
class SimEvent(Base):
__tablename__ = "sim_events"
id = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid4,
)
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
nullable=False,
)
event_type = mapped_column(
SAEnum(EventType, name="event_type", values_callable=lambda e: [x.value for x in e]),
nullable=False,
)
scheduled_at = mapped_column(
DateTime(timezone=True),
nullable=False,
)
payload = mapped_column(
JSON,
nullable=False,
)
dedupe_key = mapped_column( # guardrail against duplicates in event recomputation
String(255),
nullable=True,
)
consumed = mapped_column(
Boolean,
nullable=False,
default=False,
server_default=text("false"),
)
__all__ = ["EventType", "SimEvent"]

View file

@ -0,0 +1,51 @@
from __future__ import annotations
from enum import Enum
from uuid import uuid4
from sqlalchemy import BigInteger, DateTime, Enum as SAEnum, ForeignKey, String, Uuid
from sqlalchemy.orm import mapped_column
from ..base import Base
class LedgerCategory(str, Enum):
MONTHLY_PAYROLL = "monthly_payroll"
TASK_REWARD = "task_reward"
TASK_FAIL_PENALTY = "task_fail_penalty"
TASK_CANCEL_PENALTY = "task_cancel_penalty"
class LedgerEntry(Base):
__tablename__ = "ledger_entries"
id = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid4,
)
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
nullable=False,
)
occurred_at = mapped_column(
DateTime(timezone=True),
nullable=False,
)
category = mapped_column(
SAEnum(LedgerCategory, name="ledger_category", values_callable=lambda e: [x.value for x in e]),
nullable=False,
)
amount_cents = mapped_column(
BigInteger,
nullable=False,
)
ref_type = mapped_column(
String(64),
nullable=True,
)
ref_id = mapped_column( # can refer multiple tables, no foreign key
Uuid(as_uuid=True),
nullable=True,
)
__all__ = ["LedgerCategory", "LedgerEntry"]

View file

@ -0,0 +1,27 @@
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import ForeignKey, Text, Uuid
from sqlalchemy.orm import mapped_column
from ..base import Base
class Scratchpad(Base):
__tablename__ = "scratchpads"
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
content = mapped_column(
Text,
nullable=False,
default="",
)
__all__ = ["Scratchpad"]

View file

@ -0,0 +1,68 @@
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import BigInteger, DateTime, ForeignKey, String, Uuid, Date, Enum as SAEnum
from sqlalchemy.orm import mapped_column
from ..base import Base
from .event import EventType
class Session(Base):
__tablename__ = "sessions"
id = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid4,
)
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
nullable=False,
)
started_at = mapped_column(
DateTime(timezone=True),
nullable=False,
)
ended_at = mapped_column(
DateTime(timezone=True),
nullable=True,
)
wake_reason = mapped_column(
SAEnum(EventType, name="event_type", values_callable=lambda e: [x.value for x in e]),
nullable=False,
)
class MonthlyMetric(Base):
__tablename__ = "monthly_metrics"
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
month_start = mapped_column(
Date,
primary_key=True,
nullable=False,
)
revenue_cents = mapped_column(
BigInteger,
nullable=False,
)
cost_cents = mapped_column(
BigInteger,
nullable=False,
)
return_cents = mapped_column(
BigInteger,
nullable=False,
)
ending_funds_cents = mapped_column(
BigInteger,
nullable=False,
)
__all__ = ["Session", "MonthlyMetric"]

View file

@ -0,0 +1,37 @@
from __future__ import annotations
from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, Uuid
from sqlalchemy.orm import mapped_column
from ..base import Base
class SimState(Base):
__tablename__ = "sim_state"
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
sim_time = mapped_column(
DateTime(timezone=True),
nullable=False,
)
run_seed = mapped_column(
Integer,
nullable=False,
)
horizon_end = mapped_column(
DateTime(timezone=True),
nullable=False,
)
replenish_counter = mapped_column(
Integer,
nullable=False,
default=0,
server_default="0",
)
__all__ = ["SimState"]

View file

@ -0,0 +1,140 @@
from __future__ import annotations
from enum import Enum
from uuid import uuid4
from sqlalchemy import BigInteger, Boolean, CheckConstraint, DateTime, Enum as SAEnum, ForeignKey, Integer, Numeric, String, Uuid, text
from sqlalchemy.orm import mapped_column
from ..base import Base
from .company import Domain
class TaskStatus(str, Enum):
MARKET = "market"
PLANNED = "planned"
ACTIVE = "active"
COMPLETED_SUCCESS = "completed_success"
COMPLETED_FAIL = "completed_fail"
CANCELLED = "cancelled"
class Task(Base):
__tablename__ = "tasks"
__table_args__ = (
CheckConstraint("required_prestige >= 1 AND required_prestige <= 10", name="ck_tasks_required_prestige_range"),
CheckConstraint("skill_boost_pct >= 0", name="ck_tasks_skill_boost_pct_range"),
CheckConstraint("reward_funds_cents >= 0", name="ck_tasks_reward_funds_cents_gte_0"),
CheckConstraint("reward_prestige_delta >= 0 AND reward_prestige_delta <= 5", name="ck_tasks_reward_prestige_delta_range"),
)
id = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid4,
)
company_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("companies.id", ondelete="CASCADE"),
nullable=True,
)
status = mapped_column(
SAEnum(TaskStatus, name="task_status", values_callable=lambda e: [x.value for x in e]),
nullable=False,
default=TaskStatus.MARKET,
)
title = mapped_column(
String(255),
nullable=False,
)
description = mapped_column(
String,
nullable=False,
)
required_prestige = mapped_column(
Integer,
nullable=False,
)
reward_funds_cents = mapped_column(
BigInteger,
nullable=False,
)
reward_prestige_delta = mapped_column(
Numeric(6, 3),
nullable=False,
)
skill_boost_pct = mapped_column(
Numeric(6, 4),
nullable=False,
)
accepted_at = mapped_column(
DateTime(timezone=True),
nullable=True,
)
deadline = mapped_column(
DateTime(timezone=True),
nullable=True,
)
completed_at = mapped_column(
DateTime(timezone=True),
nullable=True,
)
success = mapped_column(
Boolean,
nullable=True,
)
halfway_event_emitted = mapped_column(
Boolean,
nullable=False,
default=False,
server_default=text("false"),
)
class TaskRequirement(Base):
__tablename__ = "task_requirements"
__table_args__ = (
CheckConstraint("required_qty >= 200 AND required_qty <= 3000", name="ck_task_requirements_required_qty_range"),
CheckConstraint("completed_qty >= 0", name="ck_task_requirements_completed_qty_gte_0"),
CheckConstraint("completed_qty <= required_qty", name="ck_task_requirements_completed_qty_lte_required_qty"),
)
task_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("tasks.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
domain = mapped_column(
SAEnum(Domain, name="domain", values_callable=lambda e: [x.value for x in e]),
primary_key=True,
nullable=False,
)
required_qty = mapped_column(
Numeric(14, 4),
nullable=False,
)
completed_qty = mapped_column(
Numeric(14, 4),
nullable=False,
default=0,
)
class TaskAssignment(Base):
__tablename__ = "task_assignments"
task_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("tasks.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
employee_id = mapped_column(
Uuid(as_uuid=True),
ForeignKey("employees.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
assigned_at = mapped_column(
DateTime(timezone=True),
nullable=False,
)
__all__ = ["TaskStatus", "Task", "TaskRequirement", "TaskAssignment"]

View file

@ -0,0 +1,94 @@
from __future__ import annotations
import os
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
def _get_database_url() -> str:
default = "sqlite:///db/yc_bench.db"
url = os.environ.get("DATABASE_URL", default)
# Auto-create parent directory for sqlite:/// relative paths
if url.startswith("sqlite:///") and not url.startswith("sqlite:////"):
import pathlib
db_path = pathlib.Path(url[len("sqlite:///"):])
db_path.parent.mkdir(parents=True, exist_ok=True)
return url
def _maybe_register_psycopg_enum_dumpers(url: str) -> None:
"""Register psycopg3 enum value dumpers — only needed for PostgreSQL.
psycopg3 sends enum .name (uppercase) by default; the DB stores .value
(lowercase). This adapter fixes that mismatch. SQLite doesn't need it
because SQLAlchemy stores enums as plain VARCHAR using .value directly.
"""
if not url.startswith("postgresql"):
return
import enum
import psycopg
from psycopg.adapt import Dumper
from psycopg.pq import Format
class _EnumValueDumper(Dumper):
format = Format.TEXT
def dump(self, obj):
return obj.value.encode("utf-8") if isinstance(obj, enum.Enum) else str(obj).encode("utf-8")
from .models.company import Domain
from .models.event import EventType
from .models.ledger import LedgerCategory
from .models.task import TaskStatus
for cls in (Domain, EventType, LedgerCategory, TaskStatus):
psycopg.adapters.register_dumper(cls, _EnumValueDumper)
def build_engine(url: str | None = None):
db_url = url or _get_database_url()
_maybe_register_psycopg_enum_dumpers(db_url)
kwargs: dict = {"echo": False, "future": True}
if db_url.startswith("sqlite"):
# Required for SQLAlchemy's connection pool when used across threads
kwargs["connect_args"] = {"check_same_thread": False}
return create_engine(db_url, **kwargs)
def build_session_factory(engine):
return sessionmaker(
bind=engine,
autoflush=False,
autocommit=False,
expire_on_commit=False,
class_=Session,
)
@contextmanager
def session_scope(session_factory):
session = session_factory()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def init_db(engine) -> None:
"""Create all tables that do not yet exist. Safe to call on every startup."""
from .base import Base
# Import all models so SQLAlchemy registers them with Base.metadata before create_all.
from .models import company, employee, event, ledger, scratchpad, session, sim_state, task # noqa: F401
Base.metadata.create_all(engine)
__all__ = ["build_engine", "build_session_factory", "session_scope", "init_db"]