build source

This commit is contained in:
build 2026-04-16 04:16:36 +00:00
commit ee1fec43ed
4171 changed files with 1351288 additions and 0 deletions

BIN
internal/store/._runners.go Normal file

Binary file not shown.

Binary file not shown.

BIN
internal/store/._store.go Normal file

Binary file not shown.

Binary file not shown.

BIN
internal/store/._usage.go Normal file

Binary file not shown.

Binary file not shown.

BIN
internal/store/._users.go Normal file

Binary file not shown.

Binary file not shown.

227
internal/store/runners.go Normal file
View file

@ -0,0 +1,227 @@
package store
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
// GenerateRunnerID creates a unique runner identifier.
func GenerateRunnerID() (string, error) {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
return "runner-" + hex.EncodeToString(b), nil
}
// CreateRunner inserts a new runner record.
func (s *Store) CreateRunner(ctx context.Context, r *model.Runner) error {
_, err := s.db.Exec(ctx,
`INSERT INTO runners (id, user_id, repo_url, branch, tools, task, status, webhook_delivery_id, pod_name, cpu_req, mem_req, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)`,
r.ID, r.User, r.RepoURL, r.Branch, r.Tools, r.Task, string(r.Status),
r.WebhookDeliveryID, r.PodName, r.CPUReq, r.MemReq, r.CreatedAt,
)
if err != nil {
if isDuplicateError(err) {
return fmt.Errorf("runner %q: %w", r.ID, ErrDuplicate)
}
return fmt.Errorf("insert runner: %w", err)
}
return nil
}
// GetRunner retrieves a runner by ID.
func (s *Store) GetRunner(ctx context.Context, id string) (*model.Runner, error) {
row := s.db.QueryRow(ctx,
`SELECT id, user_id, repo_url, branch, tools, task, status,
forgejo_runner_id, webhook_delivery_id, pod_name, cpu_req, mem_req,
created_at, claimed_at, completed_at
FROM runners WHERE id = $1`, id)
r, err := scanRunner(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("runner %q: %w", id, ErrNotFound)
}
return nil, fmt.Errorf("query runner: %w", err)
}
return r, nil
}
// ListRunners returns runners, optionally filtered by user and/or status.
func (s *Store) ListRunners(ctx context.Context, userFilter string, statusFilter string) ([]model.Runner, error) {
query := `SELECT id, user_id, repo_url, branch, tools, task, status,
forgejo_runner_id, webhook_delivery_id, pod_name, cpu_req, mem_req,
created_at, claimed_at, completed_at
FROM runners WHERE 1=1`
var args []any
argIdx := 1
if userFilter != "" {
query += fmt.Sprintf(" AND user_id = $%d", argIdx)
args = append(args, userFilter)
argIdx++
}
if statusFilter != "" {
query += fmt.Sprintf(" AND status = $%d", argIdx)
args = append(args, statusFilter)
argIdx++
}
query += " ORDER BY created_at DESC"
rows, err := s.db.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("query runners: %w", err)
}
defer rows.Close()
var runners []model.Runner
for rows.Next() {
var r model.Runner
var claimedAt, completedAt *time.Time
if err := rows.Scan(&r.ID, &r.User, &r.RepoURL, &r.Branch, &r.Tools, &r.Task,
&r.Status, &r.ForgejoRunnerID, &r.WebhookDeliveryID, &r.PodName, &r.CPUReq, &r.MemReq,
&r.CreatedAt, &claimedAt, &completedAt); err != nil {
return nil, fmt.Errorf("scan runner: %w", err)
}
r.ClaimedAt = claimedAt
r.CompletedAt = completedAt
runners = append(runners, r)
}
return runners, rows.Err()
}
// UpdateRunnerStatus transitions a runner to a new status with state machine validation.
func (s *Store) UpdateRunnerStatus(ctx context.Context, id string, newStatus model.RunnerStatus, forgejoRunnerID string) error {
current, err := s.GetRunner(ctx, id)
if err != nil {
return err
}
if !current.Status.CanTransitionTo(newStatus) {
return fmt.Errorf("invalid transition from %s to %s", current.Status, newStatus)
}
now := time.Now().UTC()
var claimedAt, completedAt *time.Time
if newStatus == model.RunnerStatusJobClaimed {
claimedAt = &now
}
if newStatus.IsTerminal() {
completedAt = &now
}
query := `UPDATE runners SET status = $1, forgejo_runner_id = CASE WHEN $2 = '' THEN forgejo_runner_id ELSE $2 END`
args := []any{string(newStatus), forgejoRunnerID}
argIdx := 3
if claimedAt != nil {
query += fmt.Sprintf(", claimed_at = $%d", argIdx)
args = append(args, *claimedAt)
argIdx++
}
if completedAt != nil {
query += fmt.Sprintf(", completed_at = $%d", argIdx)
args = append(args, *completedAt)
argIdx++
}
query += fmt.Sprintf(" WHERE id = $%d", argIdx)
args = append(args, id)
tag, err := s.db.Exec(ctx, query, args...)
if err != nil {
return fmt.Errorf("update runner status: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("runner %q: %w", id, ErrNotFound)
}
return nil
}
// DeleteRunner removes a runner record by ID.
func (s *Store) DeleteRunner(ctx context.Context, id string) error {
tag, err := s.db.Exec(ctx, `DELETE FROM runners WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete runner: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("runner %q: %w", id, ErrNotFound)
}
return nil
}
// IsDeliveryProcessed checks if a webhook delivery ID has already been processed.
func (s *Store) IsDeliveryProcessed(ctx context.Context, deliveryID string) (bool, error) {
if deliveryID == "" {
return false, nil
}
var exists bool
err := s.db.QueryRow(ctx,
`SELECT EXISTS(SELECT 1 FROM runners WHERE webhook_delivery_id = $1)`,
deliveryID).Scan(&exists)
if err != nil {
return false, fmt.Errorf("check delivery: %w", err)
}
return exists, nil
}
// GetRunnersForCleanup returns runners in terminal states (completed/failed).
func (s *Store) GetRunnersForCleanup(ctx context.Context) ([]model.Runner, error) {
return s.ListRunners(ctx, "", "")
}
// GetStaleRunners returns runners older than the given TTL that aren't in terminal/cleanup states.
func (s *Store) GetStaleRunners(ctx context.Context, ttl time.Duration) ([]model.Runner, error) {
cutoff := time.Now().UTC().Add(-ttl)
rows, err := s.db.Query(ctx,
`SELECT id, user_id, repo_url, branch, tools, task, status,
forgejo_runner_id, webhook_delivery_id, pod_name, cpu_req, mem_req,
created_at, claimed_at, completed_at
FROM runners
WHERE created_at < $1
AND status NOT IN ('completed', 'failed', 'cleanup_pending')
ORDER BY created_at`, cutoff)
if err != nil {
return nil, fmt.Errorf("query stale runners: %w", err)
}
defer rows.Close()
var runners []model.Runner
for rows.Next() {
var r model.Runner
var claimedAt, completedAt *time.Time
if err := rows.Scan(&r.ID, &r.User, &r.RepoURL, &r.Branch, &r.Tools, &r.Task,
&r.Status, &r.ForgejoRunnerID, &r.WebhookDeliveryID, &r.PodName, &r.CPUReq, &r.MemReq,
&r.CreatedAt, &claimedAt, &completedAt); err != nil {
return nil, fmt.Errorf("scan stale runner: %w", err)
}
r.ClaimedAt = claimedAt
r.CompletedAt = completedAt
runners = append(runners, r)
}
return runners, rows.Err()
}
func scanRunner(row pgx.Row) (*model.Runner, error) {
var r model.Runner
var claimedAt, completedAt *time.Time
err := row.Scan(&r.ID, &r.User, &r.RepoURL, &r.Branch, &r.Tools, &r.Task,
&r.Status, &r.ForgejoRunnerID, &r.WebhookDeliveryID, &r.PodName, &r.CPUReq, &r.MemReq,
&r.CreatedAt, &claimedAt, &completedAt)
if err != nil {
return nil, err
}
r.ClaimedAt = claimedAt
r.CompletedAt = completedAt
return &r, nil
}

View file

@ -0,0 +1,385 @@
package store
import (
"context"
"errors"
"testing"
"time"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
func createTestUser(t *testing.T, s *Store, id string) {
t.Helper()
ctx := context.Background()
_, err := s.CreateUser(ctx, id, model.DefaultQuota())
if err != nil {
t.Fatalf("create test user %q: %v", id, err)
}
}
func newTestRunner(user, id string) *model.Runner {
return &model.Runner{
ID: id,
User: user,
RepoURL: "ilia/test-repo",
Branch: "main",
Tools: "go",
Task: "implement feature",
Status: model.RunnerStatusReceived,
CPUReq: "2",
MemReq: "4Gi",
CreatedAt: time.Now().UTC(),
}
}
func cleanRunners(t *testing.T) {
t.Helper()
ctx := context.Background()
if _, err := testPool.Exec(ctx, "DELETE FROM runners"); err != nil {
t.Fatalf("clean runners: %v", err)
}
}
func TestCreateRunner(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-abc123")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create runner: %v", err)
}
got, err := s.GetRunner(ctx, "runner-abc123")
if err != nil {
t.Fatalf("get runner: %v", err)
}
if got.ID != "runner-abc123" {
t.Errorf("got id %q, want %q", got.ID, "runner-abc123")
}
if got.User != "alice" {
t.Errorf("got user %q, want %q", got.User, "alice")
}
if got.Status != model.RunnerStatusReceived {
t.Errorf("got status %q, want %q", got.Status, model.RunnerStatusReceived)
}
if got.RepoURL != "ilia/test-repo" {
t.Errorf("got repo %q, want %q", got.RepoURL, "ilia/test-repo")
}
}
func TestCreateRunner_Duplicate(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-dup1")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("first create: %v", err)
}
err := s.CreateRunner(ctx, r)
if err == nil {
t.Fatal("expected error for duplicate runner")
}
if !errors.Is(err, ErrDuplicate) {
t.Errorf("got %v, want ErrDuplicate", err)
}
}
func TestGetRunner_NotFound(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
ctx := context.Background()
_, err := s.GetRunner(ctx, "nonexistent")
if err == nil {
t.Fatal("expected error for nonexistent runner")
}
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v, want ErrNotFound", err)
}
}
func TestListRunners(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
createTestUser(t, s, "bob")
ctx := context.Background()
r1 := newTestRunner("alice", "runner-list1")
r2 := newTestRunner("alice", "runner-list2")
r3 := newTestRunner("bob", "runner-list3")
for _, r := range []*model.Runner{r1, r2, r3} {
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create runner %s: %v", r.ID, err)
}
}
all, err := s.ListRunners(ctx, "", "")
if err != nil {
t.Fatalf("list all: %v", err)
}
if len(all) != 3 {
t.Fatalf("expected 3 runners, got %d", len(all))
}
aliceRunners, err := s.ListRunners(ctx, "alice", "")
if err != nil {
t.Fatalf("list alice: %v", err)
}
if len(aliceRunners) != 2 {
t.Fatalf("expected 2 runners for alice, got %d", len(aliceRunners))
}
byStatus, err := s.ListRunners(ctx, "", "received")
if err != nil {
t.Fatalf("list by status: %v", err)
}
if len(byStatus) != 3 {
t.Fatalf("expected 3 received runners, got %d", len(byStatus))
}
}
func TestUpdateRunnerStatus_ValidTransition(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-trans1")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create: %v", err)
}
if err := s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusPodCreating, ""); err != nil {
t.Fatalf("transition to pod_creating: %v", err)
}
got, _ := s.GetRunner(ctx, r.ID)
if got.Status != model.RunnerStatusPodCreating {
t.Errorf("expected pod_creating, got %s", got.Status)
}
if err := s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusRunnerRegistered, "forgejo-42"); err != nil {
t.Fatalf("transition to runner_registered: %v", err)
}
got, _ = s.GetRunner(ctx, r.ID)
if got.Status != model.RunnerStatusRunnerRegistered {
t.Errorf("expected runner_registered, got %s", got.Status)
}
if got.ForgejoRunnerID != "forgejo-42" {
t.Errorf("expected forgejo_runner_id 'forgejo-42', got %q", got.ForgejoRunnerID)
}
}
func TestUpdateRunnerStatus_InvalidTransition(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-invalid1")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create: %v", err)
}
err := s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusCompleted, "")
if err == nil {
t.Fatal("expected error for invalid transition received -> completed")
}
}
func TestUpdateRunnerStatus_SetsClaimedAt(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-claimed1")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create: %v", err)
}
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusPodCreating, "")
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusRunnerRegistered, "")
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusJobClaimed, "")
got, _ := s.GetRunner(ctx, r.ID)
if got.ClaimedAt == nil {
t.Error("expected claimed_at to be set")
}
}
func TestUpdateRunnerStatus_SetsCompletedAt(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-complete1")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create: %v", err)
}
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusPodCreating, "")
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusRunnerRegistered, "")
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusJobClaimed, "")
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusCompleted, "")
got, _ := s.GetRunner(ctx, r.ID)
if got.CompletedAt == nil {
t.Error("expected completed_at to be set")
}
}
func TestDeleteRunner(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r := newTestRunner("alice", "runner-del1")
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create: %v", err)
}
if err := s.DeleteRunner(ctx, r.ID); err != nil {
t.Fatalf("delete: %v", err)
}
_, err := s.GetRunner(ctx, r.ID)
if !errors.Is(err, ErrNotFound) {
t.Errorf("expected ErrNotFound after delete, got %v", err)
}
}
func TestDeleteRunner_NotFound(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
ctx := context.Background()
err := s.DeleteRunner(ctx, "nonexistent")
if !errors.Is(err, ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestIsDeliveryProcessed(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
isDupe, err := s.IsDeliveryProcessed(ctx, "delivery-1")
if err != nil {
t.Fatalf("check delivery: %v", err)
}
if isDupe {
t.Error("expected delivery-1 to not be processed yet")
}
r := newTestRunner("alice", "runner-dedupe1")
r.WebhookDeliveryID = "delivery-1"
if err := s.CreateRunner(ctx, r); err != nil {
t.Fatalf("create: %v", err)
}
isDupe, err = s.IsDeliveryProcessed(ctx, "delivery-1")
if err != nil {
t.Fatalf("check delivery after create: %v", err)
}
if !isDupe {
t.Error("expected delivery-1 to be processed")
}
}
func TestIsDeliveryProcessed_EmptyID(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
ctx := context.Background()
isDupe, err := s.IsDeliveryProcessed(ctx, "")
if err != nil {
t.Fatalf("check empty delivery: %v", err)
}
if isDupe {
t.Error("expected empty delivery ID to return false")
}
}
func TestWebhookDeliveryDedupe_UniqueConstraint(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
r1 := newTestRunner("alice", "runner-uniq1")
r1.WebhookDeliveryID = "webhook-abc"
if err := s.CreateRunner(ctx, r1); err != nil {
t.Fatalf("first insert: %v", err)
}
r2 := newTestRunner("alice", "runner-uniq2")
r2.WebhookDeliveryID = "webhook-abc"
err := s.CreateRunner(ctx, r2)
if err == nil {
t.Fatal("expected duplicate error for same webhook_delivery_id")
}
if !errors.Is(err, ErrDuplicate) {
t.Errorf("expected ErrDuplicate, got %v", err)
}
}
func TestGetStaleRunners(t *testing.T) {
s := newTestStore(t)
cleanRunners(t)
createTestUser(t, s, "alice")
ctx := context.Background()
old := newTestRunner("alice", "runner-stale1")
old.CreatedAt = time.Now().UTC().Add(-3 * time.Hour)
if err := s.CreateRunner(ctx, old); err != nil {
t.Fatalf("create old runner: %v", err)
}
fresh := newTestRunner("alice", "runner-fresh1")
fresh.CreatedAt = time.Now().UTC()
if err := s.CreateRunner(ctx, fresh); err != nil {
t.Fatalf("create fresh runner: %v", err)
}
stale, err := s.GetStaleRunners(ctx, 2*time.Hour)
if err != nil {
t.Fatalf("get stale: %v", err)
}
if len(stale) != 1 {
t.Fatalf("expected 1 stale runner, got %d", len(stale))
}
if stale[0].ID != "runner-stale1" {
t.Errorf("expected stale runner runner-stale1, got %s", stale[0].ID)
}
}
func TestGenerateRunnerID(t *testing.T) {
id1, err := GenerateRunnerID()
if err != nil {
t.Fatalf("generate: %v", err)
}
if len(id1) == 0 {
t.Error("expected non-empty ID")
}
id2, _ := GenerateRunnerID()
if id1 == id2 {
t.Error("expected unique IDs")
}
if id1[:7] != "runner-" {
t.Errorf("expected runner- prefix, got %s", id1[:7])
}
}

100
internal/store/store.go Normal file
View file

@ -0,0 +1,100 @@
package store
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
// DBTX is the interface for database query operations.
// Satisfied by *pgxpool.Pool, *pgx.Conn, and pgx.Tx.
type DBTX interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
// Store provides database operations for the dev-pod API.
type Store struct {
db DBTX
}
// New creates a Store with the given database connection.
func New(db DBTX) *Store {
return &Store{db: db}
}
// NewPool creates a new pgx connection pool.
func NewPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
pool, err := pgxpool.New(ctx, databaseURL)
if err != nil {
return nil, fmt.Errorf("connect to database: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping database: %w", err)
}
return pool, nil
}
// Migrate runs database migrations to create required tables.
func (s *Store) Migrate(ctx context.Context) error {
migrations := []string{
`CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
max_concurrent_pods INTEGER NOT NULL DEFAULT 3,
max_cpu_per_pod INTEGER NOT NULL DEFAULT 8,
max_ram_gb_per_pod INTEGER NOT NULL DEFAULT 16,
monthly_pod_hours INTEGER NOT NULL DEFAULT 500,
monthly_ai_requests INTEGER NOT NULL DEFAULT 10000
)`,
`CREATE TABLE IF NOT EXISTS api_keys (
key_hash TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role TEXT NOT NULL DEFAULT 'user',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_used_at TIMESTAMPTZ
)`,
`CREATE TABLE IF NOT EXISTS usage_records (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
pod_name TEXT NOT NULL,
event_type TEXT NOT NULL,
value DOUBLE PRECISION,
recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_usage_user_month ON usage_records(user_id, recorded_at)`,
`ALTER TABLE users ADD COLUMN IF NOT EXISTS forgejo_token TEXT NOT NULL DEFAULT ''`,
`CREATE TABLE IF NOT EXISTS runners (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id),
repo_url TEXT NOT NULL,
branch TEXT NOT NULL DEFAULT 'main',
tools TEXT NOT NULL DEFAULT '',
task TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL DEFAULT 'received',
forgejo_runner_id TEXT NOT NULL DEFAULT '',
webhook_delivery_id TEXT NOT NULL DEFAULT '',
pod_name TEXT NOT NULL DEFAULT '',
cpu_req TEXT NOT NULL DEFAULT '2',
mem_req TEXT NOT NULL DEFAULT '4Gi',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
claimed_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ
)`,
`CREATE UNIQUE INDEX IF NOT EXISTS idx_runners_webhook_delivery
ON runners(webhook_delivery_id) WHERE webhook_delivery_id != ''`,
`CREATE INDEX IF NOT EXISTS idx_runners_status ON runners(status)`,
`ALTER TABLE users ADD COLUMN IF NOT EXISTS tailscale_key TEXT NOT NULL DEFAULT ''`,
}
for _, m := range migrations {
if _, err := s.db.Exec(ctx, m); err != nil {
return fmt.Errorf("run migration: %w", err)
}
}
return nil
}

View file

@ -0,0 +1,57 @@
package store
import (
"context"
"fmt"
"log"
"os"
"testing"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/jackc/pgx/v5/pgxpool"
)
var testPool *pgxpool.Pool
func TestMain(m *testing.M) {
postgres := embeddedpostgres.NewDatabase(
embeddedpostgres.DefaultConfig().
Port(15432).
Database("devpod_test"),
)
if err := postgres.Start(); err != nil {
log.Fatalf("start embedded postgres: %v", err)
}
ctx := context.Background()
pool, err := pgxpool.New(ctx, "postgres://postgres:postgres@localhost:15432/devpod_test?sslmode=disable")
if err != nil {
postgres.Stop()
log.Fatalf("connect to test db: %v", err)
}
testPool = pool
code := m.Run()
pool.Close()
if err := postgres.Stop(); err != nil {
log.Printf("stop embedded postgres: %v", err)
}
os.Exit(code)
}
func newTestStore(t *testing.T) *Store {
t.Helper()
ctx := context.Background()
s := New(testPool)
if err := s.Migrate(ctx); err != nil {
t.Fatalf("migrate: %v", err)
}
// Clean tables for test isolation (order matters: foreign keys)
for _, table := range []string{"runners", "usage_records", "api_keys", "users"} {
if _, err := testPool.Exec(ctx, fmt.Sprintf("DELETE FROM %s", table)); err != nil {
t.Fatalf("clean %s: %v", table, err)
}
}
return s
}

178
internal/store/usage.go Normal file
View file

@ -0,0 +1,178 @@
package store
import (
"context"
"fmt"
"time"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
// RecordPodStart records a pod start event for usage tracking.
func (s *Store) RecordPodStart(ctx context.Context, userID, podName string) error {
_, err := s.db.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ($1, $2, $3, 0, NOW())`,
userID, podName, model.EventPodStart)
if err != nil {
return fmt.Errorf("record pod start: %w", err)
}
return nil
}
// RecordPodStop records a pod stop event for usage tracking.
func (s *Store) RecordPodStop(ctx context.Context, userID, podName string) error {
_, err := s.db.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ($1, $2, $3, 0, NOW())`,
userID, podName, model.EventPodStop)
if err != nil {
return fmt.Errorf("record pod stop: %w", err)
}
return nil
}
// RecordResourceSample records periodic CPU and memory usage samples.
// cpuMillicores is CPU usage in millicores, memBytes is memory in bytes.
func (s *Store) RecordResourceSample(ctx context.Context, userID, podName string, cpuMillicores, memBytes float64) error {
now := time.Now().UTC()
_, err := s.db.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ($1, $2, $3, $4, $5), ($1, $2, $6, $7, $5)`,
userID, podName, model.EventCPUSample, cpuMillicores, now,
model.EventMemSample, memBytes)
if err != nil {
return fmt.Errorf("record resource sample: %w", err)
}
return nil
}
// RecordAIRequest records an AI proxy request event.
func (s *Store) RecordAIRequest(ctx context.Context, userID string) error {
_, err := s.db.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ($1, '', $2, 1, NOW())`,
userID, model.EventAIRequest)
if err != nil {
return fmt.Errorf("record ai request: %w", err)
}
return nil
}
// GetUsage returns aggregated usage for a user for a given month.
func (s *Store) GetUsage(ctx context.Context, userID string, year int, month time.Month, monthlyBudget int) (*model.UsageSummary, error) {
return s.getUsage(ctx, userID, year, month, monthlyBudget, time.Now().UTC())
}
// getUsage is the internal implementation that accepts a "now" parameter for testing.
func (s *Store) getUsage(ctx context.Context, userID string, year int, month time.Month, monthlyBudget int, now time.Time) (*model.UsageSummary, error) {
monthStart := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC)
monthEnd := monthStart.AddDate(0, 1, 0)
// Calculate pod-hours from pod_start/pod_stop pairs.
// For each pod_start, find the nearest subsequent pod_stop for the same user/pod.
// If no stop found, use "now" (pod is still running).
// Sessions that cross month boundaries are clamped to [monthStart, monthEnd].
var podHours float64
err := s.db.QueryRow(ctx, `
WITH sessions AS (
SELECT
GREATEST(recorded_at, $2) AS start_time,
(
SELECT MIN(r2.recorded_at)
FROM usage_records r2
WHERE r2.user_id = r.user_id
AND r2.pod_name = r.pod_name
AND r2.event_type = 'pod_stop'
AND r2.recorded_at > r.recorded_at
) AS stop_time
FROM usage_records r
WHERE r.user_id = $1
AND r.event_type = 'pod_start'
AND r.recorded_at < $3
)
SELECT COALESCE(SUM(
EXTRACT(EPOCH FROM LEAST(COALESCE(stop_time, $4), $3) - start_time) / 3600.0
), 0)
FROM sessions
WHERE COALESCE(stop_time, $4) > $2`,
userID, monthStart, monthEnd, now).Scan(&podHours)
if err != nil {
return nil, fmt.Errorf("calculate pod hours: %w", err)
}
// CPU-hours from resource samples.
// Each sample represents 60s of CPU at the sampled rate (millicores).
// CPU-hours = sum(millicores) * 60s / 3600s / 1000m = sum(millicores) / 60000
var cpuHours float64
err = s.db.QueryRow(ctx, `
SELECT COALESCE(SUM(value) / 60000.0, 0)
FROM usage_records
WHERE user_id = $1 AND event_type = $2
AND recorded_at >= $3 AND recorded_at < $4`,
userID, model.EventCPUSample, monthStart, monthEnd).Scan(&cpuHours)
if err != nil {
return nil, fmt.Errorf("calculate cpu hours: %w", err)
}
// Count AI requests.
var aiRequests int64
err = s.db.QueryRow(ctx, `
SELECT COUNT(*)
FROM usage_records
WHERE user_id = $1 AND event_type = $2
AND recorded_at >= $3 AND recorded_at < $4`,
userID, model.EventAIRequest, monthStart, monthEnd).Scan(&aiRequests)
if err != nil {
return nil, fmt.Errorf("count ai requests: %w", err)
}
var budgetUsedPct float64
if monthlyBudget > 0 {
budgetUsedPct = (podHours / float64(monthlyBudget)) * 100
}
return &model.UsageSummary{
PodHours: podHours,
CPUHours: cpuHours,
AIRequests: aiRequests,
BudgetUsedPct: budgetUsedPct,
}, nil
}
// GetDailyUsage returns daily usage breakdown for a user for a given month.
// Pod-hours per day are estimated from CPU sample count (each sample = 1/60 hour of pod time).
func (s *Store) GetDailyUsage(ctx context.Context, userID string, year int, month time.Month) ([]model.DailyUsage, error) {
monthStart := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC)
monthEnd := monthStart.AddDate(0, 1, 0)
rows, err := s.db.Query(ctx, `
SELECT
DATE(recorded_at) AS day,
COALESCE(SUM(CASE WHEN event_type = 'cpu_sample' THEN 1.0/60.0 ELSE 0 END), 0) AS pod_hours,
COALESCE(SUM(CASE WHEN event_type = 'cpu_sample' THEN value / 60000.0 ELSE 0 END), 0) AS cpu_hours,
COUNT(*) FILTER (WHERE event_type = 'ai_request') AS ai_requests
FROM usage_records
WHERE user_id = $1
AND recorded_at >= $2 AND recorded_at < $3
AND event_type IN ('cpu_sample', 'ai_request')
GROUP BY DATE(recorded_at)
ORDER BY DATE(recorded_at)`,
userID, monthStart, monthEnd)
if err != nil {
return nil, fmt.Errorf("query daily usage: %w", err)
}
defer rows.Close()
var result []model.DailyUsage
for rows.Next() {
var d model.DailyUsage
var day time.Time
if err := rows.Scan(&day, &d.PodHours, &d.CPUHours, &d.AIRequests); err != nil {
return nil, fmt.Errorf("scan daily usage: %w", err)
}
d.Date = day.Format("2006-01-02")
result = append(result, d)
}
return result, rows.Err()
}

View file

@ -0,0 +1,399 @@
package store
import (
"context"
"math"
"testing"
"time"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
func TestRecordPodStart(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
if err := s.RecordPodStart(ctx, "alice", "main"); err != nil {
t.Fatalf("record pod start: %v", err)
}
var count int
if err := testPool.QueryRow(ctx,
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'pod_start'`).Scan(&count); err != nil {
t.Fatalf("query: %v", err)
}
if count != 1 {
t.Errorf("expected 1 pod_start record, got %d", count)
}
}
func TestRecordPodStop(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
if err := s.RecordPodStop(ctx, "alice", "main"); err != nil {
t.Fatalf("record pod stop: %v", err)
}
var count int
if err := testPool.QueryRow(ctx,
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'pod_stop'`).Scan(&count); err != nil {
t.Fatalf("query: %v", err)
}
if count != 1 {
t.Errorf("expected 1 pod_stop record, got %d", count)
}
}
func TestRecordResourceSample(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
if err := s.RecordResourceSample(ctx, "alice", "main", 2500, 1073741824); err != nil {
t.Fatalf("record resource sample: %v", err)
}
// Should have both cpu and mem records
var cpuCount, memCount int
testPool.QueryRow(ctx,
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'cpu_sample'`).Scan(&cpuCount)
testPool.QueryRow(ctx,
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'mem_sample'`).Scan(&memCount)
if cpuCount != 1 {
t.Errorf("expected 1 cpu_sample, got %d", cpuCount)
}
if memCount != 1 {
t.Errorf("expected 1 mem_sample, got %d", memCount)
}
// Verify CPU value
var cpuValue float64
testPool.QueryRow(ctx,
`SELECT value FROM usage_records WHERE user_id = 'alice' AND event_type = 'cpu_sample'`).Scan(&cpuValue)
if cpuValue != 2500 {
t.Errorf("expected cpu value 2500, got %f", cpuValue)
}
}
func TestRecordAIRequest(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
for i := 0; i < 5; i++ {
if err := s.RecordAIRequest(ctx, "alice"); err != nil {
t.Fatalf("record ai request %d: %v", i, err)
}
}
var count int
testPool.QueryRow(ctx,
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'ai_request'`).Scan(&count)
if count != 5 {
t.Errorf("expected 5 ai_request records, got %d", count)
}
}
func TestGetUsage_Empty(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500,
time.Date(2026, time.March, 15, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("get usage: %v", err)
}
if usage.PodHours != 0 {
t.Errorf("expected 0 pod hours, got %f", usage.PodHours)
}
if usage.CPUHours != 0 {
t.Errorf("expected 0 cpu hours, got %f", usage.CPUHours)
}
if usage.AIRequests != 0 {
t.Errorf("expected 0 ai requests, got %d", usage.AIRequests)
}
if usage.BudgetUsedPct != 0 {
t.Errorf("expected 0 budget used, got %f", usage.BudgetUsedPct)
}
}
func TestGetUsage_PodHours_CompletedSession(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
start := time.Date(2026, time.March, 15, 10, 0, 0, 0, time.UTC)
stop := time.Date(2026, time.March, 15, 12, 0, 0, 0, time.UTC) // 2 hours later
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_start', 0, $1)`, start)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_stop', 0, $1)`, stop)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, stop)
if err != nil {
t.Fatalf("get usage: %v", err)
}
if math.Abs(usage.PodHours-2.0) > 0.01 {
t.Errorf("expected ~2.0 pod hours, got %f", usage.PodHours)
}
}
func TestGetUsage_PodHours_RunningPod(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
start := time.Date(2026, time.March, 15, 10, 0, 0, 0, time.UTC)
now := time.Date(2026, time.March, 15, 13, 0, 0, 0, time.UTC) // 3 hours later, no stop
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_start', 0, $1)`, start)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
if err != nil {
t.Fatalf("get usage: %v", err)
}
if math.Abs(usage.PodHours-3.0) > 0.01 {
t.Errorf("expected ~3.0 pod hours (running pod), got %f", usage.PodHours)
}
}
func TestGetUsage_PodHours_MultiplePods(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
// Pod 1: 2 hours
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'pod1', 'pod_start', 0, $1)`,
time.Date(2026, time.March, 10, 8, 0, 0, 0, time.UTC))
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'pod1', 'pod_stop', 0, $1)`,
time.Date(2026, time.March, 10, 10, 0, 0, 0, time.UTC))
// Pod 2: 5 hours
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'pod2', 'pod_start', 0, $1)`,
time.Date(2026, time.March, 12, 14, 0, 0, 0, time.UTC))
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'pod2', 'pod_stop', 0, $1)`,
time.Date(2026, time.March, 12, 19, 0, 0, 0, time.UTC))
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
if err != nil {
t.Fatalf("get usage: %v", err)
}
if math.Abs(usage.PodHours-7.0) > 0.01 {
t.Errorf("expected ~7.0 pod hours (2+5), got %f", usage.PodHours)
}
}
func TestGetUsage_CPUHours(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
// 10 samples of 3000 millicores
// CPU-hours = 10 * 3000 / 60000 = 0.5
for i := 0; i < 10; i++ {
ts := time.Date(2026, time.March, 15, 10, i, 0, 0, time.UTC)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'cpu_sample', 3000, $1)`, ts)
}
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
if err != nil {
t.Fatalf("get usage: %v", err)
}
if math.Abs(usage.CPUHours-0.5) > 0.001 {
t.Errorf("expected ~0.5 cpu hours, got %f", usage.CPUHours)
}
}
func TestGetUsage_AIRequests(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
for i := 0; i < 42; i++ {
ts := time.Date(2026, time.March, 15, 10, 0, i, 0, time.UTC)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', '', 'ai_request', 1, $1)`, ts)
}
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
if err != nil {
t.Fatalf("get usage: %v", err)
}
if usage.AIRequests != 42 {
t.Errorf("expected 42 ai requests, got %d", usage.AIRequests)
}
}
func TestGetUsage_BudgetPercent(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
// 100 hours used, budget = 500 → 20%
start := time.Date(2026, time.March, 1, 0, 0, 0, 0, time.UTC)
stop := time.Date(2026, time.March, 5, 4, 0, 0, 0, time.UTC) // 100 hours
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_start', 0, $1)`, start)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_stop', 0, $1)`, stop)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, stop)
if err != nil {
t.Fatalf("get usage: %v", err)
}
if math.Abs(usage.BudgetUsedPct-20.0) > 0.1 {
t.Errorf("expected ~20%% budget used, got %f", usage.BudgetUsedPct)
}
}
func TestGetUsage_DifferentMonths(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
// March data
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_start', 0, $1)`,
time.Date(2026, time.March, 15, 10, 0, 0, 0, time.UTC))
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'pod_stop', 0, $1)`,
time.Date(2026, time.March, 15, 12, 0, 0, 0, time.UTC))
// February data should not be counted
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'other', 'pod_start', 0, $1)`,
time.Date(2026, time.February, 15, 10, 0, 0, 0, time.UTC))
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'other', 'pod_stop', 0, $1)`,
time.Date(2026, time.February, 15, 20, 0, 0, 0, time.UTC))
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
if err != nil {
t.Fatalf("get usage: %v", err)
}
// Should only count March: 2 hours
if math.Abs(usage.PodHours-2.0) > 0.01 {
t.Errorf("expected ~2.0 pod hours (March only), got %f", usage.PodHours)
}
}
func TestGetDailyUsage(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
// Day 1: 5 CPU samples of 2000m + 3 AI requests
for i := 0; i < 5; i++ {
ts := time.Date(2026, time.March, 10, 10, i, 0, 0, time.UTC)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'cpu_sample', 2000, $1)`, ts)
}
for i := 0; i < 3; i++ {
ts := time.Date(2026, time.March, 10, 11, i, 0, 0, time.UTC)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', '', 'ai_request', 1, $1)`, ts)
}
// Day 2: 10 CPU samples of 4000m + 7 AI requests
for i := 0; i < 10; i++ {
ts := time.Date(2026, time.March, 11, 14, i, 0, 0, time.UTC)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', 'main', 'cpu_sample', 4000, $1)`, ts)
}
for i := 0; i < 7; i++ {
ts := time.Date(2026, time.March, 11, 15, i, 0, 0, time.UTC)
testPool.Exec(ctx,
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
VALUES ('alice', '', 'ai_request', 1, $1)`, ts)
}
daily, err := s.GetDailyUsage(ctx, "alice", 2026, time.March)
if err != nil {
t.Fatalf("get daily usage: %v", err)
}
if len(daily) != 2 {
t.Fatalf("expected 2 days, got %d", len(daily))
}
// Day 1
if daily[0].Date != "2026-03-10" {
t.Errorf("day 1 date: got %q, want 2026-03-10", daily[0].Date)
}
// pod_hours = 5 samples * (1/60) = 5/60
if math.Abs(daily[0].PodHours-5.0/60.0) > 0.001 {
t.Errorf("day 1 pod hours: got %f, want %f", daily[0].PodHours, 5.0/60.0)
}
// cpu_hours = 5 * 2000 / 60000 = 10000/60000
if math.Abs(daily[0].CPUHours-10000.0/60000.0) > 0.001 {
t.Errorf("day 1 cpu hours: got %f, want %f", daily[0].CPUHours, 10000.0/60000.0)
}
if daily[0].AIRequests != 3 {
t.Errorf("day 1 ai requests: got %d, want 3", daily[0].AIRequests)
}
// Day 2
if daily[1].Date != "2026-03-11" {
t.Errorf("day 2 date: got %q, want 2026-03-11", daily[1].Date)
}
if daily[1].AIRequests != 7 {
t.Errorf("day 2 ai requests: got %d, want 7", daily[1].AIRequests)
}
}
func TestGetDailyUsage_Empty(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
s.CreateUser(ctx, "alice", model.DefaultQuota())
daily, err := s.GetDailyUsage(ctx, "alice", 2026, time.March)
if err != nil {
t.Fatalf("get daily usage: %v", err)
}
if len(daily) != 0 {
t.Errorf("expected 0 days, got %d", len(daily))
}
}

280
internal/store/users.go Normal file
View file

@ -0,0 +1,280 @@
package store
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
// ErrNotFound is returned when a requested resource does not exist.
var ErrNotFound = errors.New("not found")
// ErrDuplicate is returned when a resource already exists.
var ErrDuplicate = errors.New("already exists")
// CreateUser inserts a new user with the given quota.
func (s *Store) CreateUser(ctx context.Context, id string, quota model.Quota) (*model.User, error) {
now := time.Now().UTC()
_, err := s.db.Exec(ctx,
`INSERT INTO users (id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
id, now, quota.MaxConcurrentPods, quota.MaxCPUPerPod, quota.MaxRAMGBPerPod,
quota.MonthlyPodHours, quota.MonthlyAIRequests,
)
if err != nil {
if isDuplicateError(err) {
return nil, fmt.Errorf("user %q: %w", id, ErrDuplicate)
}
return nil, fmt.Errorf("insert user: %w", err)
}
return &model.User{
ID: id,
CreatedAt: now,
Quota: quota,
}, nil
}
// GetUser retrieves a user by ID.
func (s *Store) GetUser(ctx context.Context, id string) (*model.User, error) {
row := s.db.QueryRow(ctx,
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
monthly_pod_hours, monthly_ai_requests
FROM users WHERE id = $1`, id)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
}
return nil, fmt.Errorf("query user: %w", err)
}
return u, nil
}
// ListUsers returns all users ordered by creation time.
func (s *Store) ListUsers(ctx context.Context) ([]model.User, error) {
rows, err := s.db.Query(ctx,
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
monthly_pod_hours, monthly_ai_requests
FROM users ORDER BY created_at`)
if err != nil {
return nil, fmt.Errorf("query users: %w", err)
}
defer rows.Close()
var users []model.User
for rows.Next() {
var u model.User
if err := rows.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests); err != nil {
return nil, fmt.Errorf("scan user: %w", err)
}
users = append(users, u)
}
return users, rows.Err()
}
// DeleteUser removes a user by ID. Cascades to api_keys and usage_records.
func (s *Store) DeleteUser(ctx context.Context, id string) error {
tag, err := s.db.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete user: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("user %q: %w", id, ErrNotFound)
}
return nil
}
// UpdateQuotas updates specific quota fields for a user.
func (s *Store) UpdateQuotas(ctx context.Context, id string, req model.UpdateQuotasRequest) (*model.User, error) {
var sets []string
var args []any
argIdx := 1
if req.MaxConcurrentPods != nil {
sets = append(sets, fmt.Sprintf("max_concurrent_pods = $%d", argIdx))
args = append(args, *req.MaxConcurrentPods)
argIdx++
}
if req.MaxCPUPerPod != nil {
sets = append(sets, fmt.Sprintf("max_cpu_per_pod = $%d", argIdx))
args = append(args, *req.MaxCPUPerPod)
argIdx++
}
if req.MaxRAMGBPerPod != nil {
sets = append(sets, fmt.Sprintf("max_ram_gb_per_pod = $%d", argIdx))
args = append(args, *req.MaxRAMGBPerPod)
argIdx++
}
if req.MonthlyPodHours != nil {
sets = append(sets, fmt.Sprintf("monthly_pod_hours = $%d", argIdx))
args = append(args, *req.MonthlyPodHours)
argIdx++
}
if req.MonthlyAIRequests != nil {
sets = append(sets, fmt.Sprintf("monthly_ai_requests = $%d", argIdx))
args = append(args, *req.MonthlyAIRequests)
argIdx++
}
if len(sets) == 0 {
return s.GetUser(ctx, id)
}
query := fmt.Sprintf(
`UPDATE users SET %s WHERE id = $%d
RETURNING id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
monthly_pod_hours, monthly_ai_requests`,
strings.Join(sets, ", "), argIdx)
args = append(args, id)
row := s.db.QueryRow(ctx, query, args...)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
}
return nil, fmt.Errorf("update quotas: %w", err)
}
return u, nil
}
// GenerateAPIKey creates a new random API key.
// Returns the plaintext key (dpk_ + 24 hex chars) and its SHA-256 hash.
func GenerateAPIKey() (plaintext string, hash string, err error) {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
return "", "", fmt.Errorf("generate random bytes: %w", err)
}
plain := model.APIKeyPrefix + hex.EncodeToString(b)
return plain, HashKey(plain), nil
}
// HashKey returns the hex-encoded SHA-256 hash of a key.
func HashKey(key string) string {
h := sha256.Sum256([]byte(key))
return hex.EncodeToString(h[:])
}
// CreateAPIKey stores a hashed API key for a user.
func (s *Store) CreateAPIKey(ctx context.Context, userID, role, keyHash string) error {
_, err := s.db.Exec(ctx,
`INSERT INTO api_keys (key_hash, user_id, role, created_at) VALUES ($1, $2, $3, $4)`,
keyHash, userID, role, time.Now().UTC(),
)
if err != nil {
return fmt.Errorf("insert api key: %w", err)
}
return nil
}
// ValidateKey checks an API key and returns the associated user and role.
func (s *Store) ValidateKey(ctx context.Context, key string) (*model.User, string, error) {
hash := HashKey(key)
row := s.db.QueryRow(ctx,
`SELECT u.id, u.created_at, u.max_concurrent_pods, u.max_cpu_per_pod, u.max_ram_gb_per_pod,
u.monthly_pod_hours, u.monthly_ai_requests, k.role
FROM api_keys k JOIN users u ON k.user_id = u.id
WHERE k.key_hash = $1`, hash)
var u model.User
var role string
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests, &role)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, "", fmt.Errorf("invalid api key: %w", ErrNotFound)
}
return nil, "", fmt.Errorf("validate key: %w", err)
}
// Update last_used_at (best-effort, don't fail auth on update error)
if _, err := s.db.Exec(ctx,
`UPDATE api_keys SET last_used_at = $1 WHERE key_hash = $2`,
time.Now().UTC(), hash); err != nil {
slog.Warn("failed to update api key last_used_at", "error", err)
}
return &u, role, nil
}
// scanUser scans a user row into a model.User.
func scanUser(row pgx.Row) (*model.User, error) {
var u model.User
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests)
if err != nil {
return nil, err
}
return &u, nil
}
// SaveForgejoToken stores a Forgejo API token for a user.
func (s *Store) SaveForgejoToken(ctx context.Context, userID, token string) error {
tag, err := s.db.Exec(ctx, `UPDATE users SET forgejo_token = $1 WHERE id = $2`, token, userID)
if err != nil {
return fmt.Errorf("update forgejo token: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return nil
}
// GetForgejoToken retrieves the Forgejo API token for a user.
func (s *Store) GetForgejoToken(ctx context.Context, userID string) (string, error) {
var token string
err := s.db.QueryRow(ctx, `SELECT forgejo_token FROM users WHERE id = $1`, userID).Scan(&token)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return "", fmt.Errorf("query forgejo token: %w", err)
}
return token, nil
}
// SaveTailscaleKey stores a Tailscale auth key for a user.
func (s *Store) SaveTailscaleKey(ctx context.Context, userID, key string) error {
tag, err := s.db.Exec(ctx, `UPDATE users SET tailscale_key = $1 WHERE id = $2`, key, userID)
if err != nil {
return fmt.Errorf("update tailscale key: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return nil
}
// GetTailscaleKey retrieves the Tailscale auth key for a user.
func (s *Store) GetTailscaleKey(ctx context.Context, userID string) (string, error) {
var key string
err := s.db.QueryRow(ctx, `SELECT tailscale_key FROM users WHERE id = $1`, userID).Scan(&key)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return "", fmt.Errorf("query tailscale key: %w", err)
}
return key, nil
}
func isDuplicateError(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "23505" // unique_violation
}
return false
}

View file

@ -0,0 +1,442 @@
package store
import (
"context"
"errors"
"strings"
"testing"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
// --- User CRUD tests ---
func TestCreateUser(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
quota := model.DefaultQuota()
u, err := s.CreateUser(ctx, "alice", quota)
if err != nil {
t.Fatalf("create user: %v", err)
}
if u.ID != "alice" {
t.Errorf("got id %q, want %q", u.ID, "alice")
}
if u.CreatedAt.IsZero() {
t.Error("created_at should not be zero")
}
if u.Quota != quota {
t.Errorf("got quota %+v, want %+v", u.Quota, quota)
}
}
func TestCreateUserDuplicate(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
quota := model.DefaultQuota()
if _, err := s.CreateUser(ctx, "alice", quota); err != nil {
t.Fatalf("first create: %v", err)
}
_, err := s.CreateUser(ctx, "alice", quota)
if err == nil {
t.Fatal("expected error for duplicate user")
}
if !errors.Is(err, ErrDuplicate) {
t.Errorf("got %v, want ErrDuplicate", err)
}
}
func TestGetUser(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
quota := model.DefaultQuota()
created, err := s.CreateUser(ctx, "bob", quota)
if err != nil {
t.Fatalf("create: %v", err)
}
got, err := s.GetUser(ctx, "bob")
if err != nil {
t.Fatalf("get: %v", err)
}
if got.ID != created.ID {
t.Errorf("got id %q, want %q", got.ID, created.ID)
}
if got.Quota != quota {
t.Errorf("got quota %+v, want %+v", got.Quota, quota)
}
}
func TestGetUserNotFound(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
_, err := s.GetUser(ctx, "nonexistent")
if err == nil {
t.Fatal("expected error")
}
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v, want ErrNotFound", err)
}
}
func TestListUsers(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
quota := model.DefaultQuota()
// Empty list
users, err := s.ListUsers(ctx)
if err != nil {
t.Fatalf("list empty: %v", err)
}
if len(users) != 0 {
t.Errorf("got %d users, want 0", len(users))
}
// Create users
if _, err := s.CreateUser(ctx, "alice", quota); err != nil {
t.Fatalf("create alice: %v", err)
}
if _, err := s.CreateUser(ctx, "bob", quota); err != nil {
t.Fatalf("create bob: %v", err)
}
users, err = s.ListUsers(ctx)
if err != nil {
t.Fatalf("list: %v", err)
}
if len(users) != 2 {
t.Fatalf("got %d users, want 2", len(users))
}
// Ordered by created_at
if users[0].ID != "alice" {
t.Errorf("first user: got %q, want %q", users[0].ID, "alice")
}
if users[1].ID != "bob" {
t.Errorf("second user: got %q, want %q", users[1].ID, "bob")
}
}
func TestDeleteUser(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create: %v", err)
}
if err := s.DeleteUser(ctx, "alice"); err != nil {
t.Fatalf("delete: %v", err)
}
// Should be gone
_, err := s.GetUser(ctx, "alice")
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v after delete, want ErrNotFound", err)
}
}
func TestDeleteUserNotFound(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
err := s.DeleteUser(ctx, "ghost")
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v, want ErrNotFound", err)
}
}
func TestDeleteUserCascadesAPIKeys(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create user: %v", err)
}
plain, hash, err := GenerateAPIKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash); err != nil {
t.Fatalf("create key: %v", err)
}
// Delete user should cascade to api_keys
if err := s.DeleteUser(ctx, "alice"); err != nil {
t.Fatalf("delete: %v", err)
}
// Key should no longer validate
_, _, err = s.ValidateKey(ctx, plain)
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v after cascade delete, want ErrNotFound", err)
}
}
// --- Quota tests ---
func TestUpdateQuotas(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create: %v", err)
}
newPods := 5
newCPU := 16
updated, err := s.UpdateQuotas(ctx, "alice", model.UpdateQuotasRequest{
MaxConcurrentPods: &newPods,
MaxCPUPerPod: &newCPU,
})
if err != nil {
t.Fatalf("update quotas: %v", err)
}
if updated.Quota.MaxConcurrentPods != 5 {
t.Errorf("max_concurrent_pods: got %d, want 5", updated.Quota.MaxConcurrentPods)
}
if updated.Quota.MaxCPUPerPod != 16 {
t.Errorf("max_cpu_per_pod: got %d, want 16", updated.Quota.MaxCPUPerPod)
}
// Unchanged fields should keep defaults
if updated.Quota.MaxRAMGBPerPod != 16 {
t.Errorf("max_ram_gb_per_pod: got %d, want 16", updated.Quota.MaxRAMGBPerPod)
}
}
func TestUpdateQuotasNoChanges(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create: %v", err)
}
// Empty update should return current user
u, err := s.UpdateQuotas(ctx, "alice", model.UpdateQuotasRequest{})
if err != nil {
t.Fatalf("update no changes: %v", err)
}
if u.Quota != model.DefaultQuota() {
t.Errorf("quota changed unexpectedly: %+v", u.Quota)
}
}
func TestUpdateQuotasNotFound(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
newPods := 5
_, err := s.UpdateQuotas(ctx, "ghost", model.UpdateQuotasRequest{
MaxConcurrentPods: &newPods,
})
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v, want ErrNotFound", err)
}
}
func TestQuotaStorageAndRetrieval(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
customQuota := model.Quota{
MaxConcurrentPods: 10,
MaxCPUPerPod: 32,
MaxRAMGBPerPod: 64,
MonthlyPodHours: 1000,
MonthlyAIRequests: 50000,
}
if _, err := s.CreateUser(ctx, "power-user", customQuota); err != nil {
t.Fatalf("create: %v", err)
}
got, err := s.GetUser(ctx, "power-user")
if err != nil {
t.Fatalf("get: %v", err)
}
if got.Quota != customQuota {
t.Errorf("quota mismatch:\ngot: %+v\nwant: %+v", got.Quota, customQuota)
}
}
// --- API key tests ---
func TestGenerateAPIKey(t *testing.T) {
plain, hash, err := GenerateAPIKey()
if err != nil {
t.Fatalf("generate: %v", err)
}
// Check prefix
if !strings.HasPrefix(plain, model.APIKeyPrefix) {
t.Errorf("key %q missing prefix %q", plain, model.APIKeyPrefix)
}
// Check length: dpk_ (4) + 24 hex chars = 28
if len(plain) != 28 {
t.Errorf("key length: got %d, want 28", len(plain))
}
// Hash should be consistent
if HashKey(plain) != hash {
t.Error("hash mismatch")
}
// Two keys should be different
plain2, _, err := GenerateAPIKey()
if err != nil {
t.Fatalf("generate second: %v", err)
}
if plain == plain2 {
t.Error("two generated keys should not be equal")
}
}
func TestHashKeyConsistency(t *testing.T) {
key := "dpk_abcdef1234567890abcdef12"
h1 := HashKey(key)
h2 := HashKey(key)
if h1 != h2 {
t.Error("hash should be deterministic")
}
// SHA-256 produces 64 hex chars
if len(h1) != 64 {
t.Errorf("hash length: got %d, want 64", len(h1))
}
}
func TestCreateAndValidateAPIKey(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create user: %v", err)
}
plain, hash, err := GenerateAPIKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash); err != nil {
t.Fatalf("create key: %v", err)
}
// Validate returns correct user and role
u, role, err := s.ValidateKey(ctx, plain)
if err != nil {
t.Fatalf("validate: %v", err)
}
if u.ID != "alice" {
t.Errorf("user: got %q, want %q", u.ID, "alice")
}
if role != model.RoleUser {
t.Errorf("role: got %q, want %q", role, model.RoleUser)
}
}
func TestValidateAdminKey(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "admin1", model.DefaultQuota()); err != nil {
t.Fatalf("create user: %v", err)
}
plain, hash, err := GenerateAPIKey()
if err != nil {
t.Fatalf("generate key: %v", err)
}
if err := s.CreateAPIKey(ctx, "admin1", model.RoleAdmin, hash); err != nil {
t.Fatalf("create key: %v", err)
}
_, role, err := s.ValidateKey(ctx, plain)
if err != nil {
t.Fatalf("validate: %v", err)
}
if role != model.RoleAdmin {
t.Errorf("role: got %q, want %q", role, model.RoleAdmin)
}
}
func TestValidateInvalidKey(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
_, _, err := s.ValidateKey(ctx, "dpk_invalid_key_here_nope")
if err == nil {
t.Fatal("expected error for invalid key")
}
if !errors.Is(err, ErrNotFound) {
t.Errorf("got %v, want ErrNotFound", err)
}
}
func TestMultipleKeysPerUser(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create user: %v", err)
}
// Create two keys
plain1, hash1, _ := GenerateAPIKey()
plain2, hash2, _ := GenerateAPIKey()
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash1); err != nil {
t.Fatalf("create key 1: %v", err)
}
if err := s.CreateAPIKey(ctx, "alice", model.RoleAdmin, hash2); err != nil {
t.Fatalf("create key 2: %v", err)
}
// Both should validate
u1, role1, err := s.ValidateKey(ctx, plain1)
if err != nil {
t.Fatalf("validate key 1: %v", err)
}
if u1.ID != "alice" || role1 != model.RoleUser {
t.Errorf("key1: got user=%q role=%q, want alice/user", u1.ID, role1)
}
u2, role2, err := s.ValidateKey(ctx, plain2)
if err != nil {
t.Fatalf("validate key 2: %v", err)
}
if u2.ID != "alice" || role2 != model.RoleAdmin {
t.Errorf("key2: got user=%q role=%q, want alice/admin", u2.ID, role2)
}
}
func TestValidateKeyUpdatesLastUsed(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
t.Fatalf("create user: %v", err)
}
plain, hash, _ := GenerateAPIKey()
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash); err != nil {
t.Fatalf("create key: %v", err)
}
// Validate should update last_used_at
if _, _, err := s.ValidateKey(ctx, plain); err != nil {
t.Fatalf("validate: %v", err)
}
// Check last_used_at is set
var lastUsed *string
err := testPool.QueryRow(ctx,
`SELECT last_used_at::text FROM api_keys WHERE key_hash = $1`, hash).Scan(&lastUsed)
if err != nil {
t.Fatalf("query last_used_at: %v", err)
}
if lastUsed == nil {
t.Error("last_used_at should be set after validation")
}
}