build source
This commit is contained in:
commit
ee1fec43ed
4171 changed files with 1351288 additions and 0 deletions
BIN
internal/store/._runners.go
Normal file
BIN
internal/store/._runners.go
Normal file
Binary file not shown.
BIN
internal/store/._runners_test.go
Normal file
BIN
internal/store/._runners_test.go
Normal file
Binary file not shown.
BIN
internal/store/._store.go
Normal file
BIN
internal/store/._store.go
Normal file
Binary file not shown.
BIN
internal/store/._testhelper_test.go
Normal file
BIN
internal/store/._testhelper_test.go
Normal file
Binary file not shown.
BIN
internal/store/._usage.go
Normal file
BIN
internal/store/._usage.go
Normal file
Binary file not shown.
BIN
internal/store/._usage_test.go
Normal file
BIN
internal/store/._usage_test.go
Normal file
Binary file not shown.
BIN
internal/store/._users.go
Normal file
BIN
internal/store/._users.go
Normal file
Binary file not shown.
BIN
internal/store/._users_test.go
Normal file
BIN
internal/store/._users_test.go
Normal file
Binary file not shown.
227
internal/store/runners.go
Normal file
227
internal/store/runners.go
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
// GenerateRunnerID creates a unique runner identifier.
|
||||
func GenerateRunnerID() (string, error) {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
return "runner-" + hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// CreateRunner inserts a new runner record.
|
||||
func (s *Store) CreateRunner(ctx context.Context, r *model.Runner) error {
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO runners (id, user_id, repo_url, branch, tools, task, status, webhook_delivery_id, pod_name, cpu_req, mem_req, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)`,
|
||||
r.ID, r.User, r.RepoURL, r.Branch, r.Tools, r.Task, string(r.Status),
|
||||
r.WebhookDeliveryID, r.PodName, r.CPUReq, r.MemReq, r.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if isDuplicateError(err) {
|
||||
return fmt.Errorf("runner %q: %w", r.ID, ErrDuplicate)
|
||||
}
|
||||
return fmt.Errorf("insert runner: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRunner retrieves a runner by ID.
|
||||
func (s *Store) GetRunner(ctx context.Context, id string) (*model.Runner, error) {
|
||||
row := s.db.QueryRow(ctx,
|
||||
`SELECT id, user_id, repo_url, branch, tools, task, status,
|
||||
forgejo_runner_id, webhook_delivery_id, pod_name, cpu_req, mem_req,
|
||||
created_at, claimed_at, completed_at
|
||||
FROM runners WHERE id = $1`, id)
|
||||
|
||||
r, err := scanRunner(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("runner %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil, fmt.Errorf("query runner: %w", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// ListRunners returns runners, optionally filtered by user and/or status.
|
||||
func (s *Store) ListRunners(ctx context.Context, userFilter string, statusFilter string) ([]model.Runner, error) {
|
||||
query := `SELECT id, user_id, repo_url, branch, tools, task, status,
|
||||
forgejo_runner_id, webhook_delivery_id, pod_name, cpu_req, mem_req,
|
||||
created_at, claimed_at, completed_at
|
||||
FROM runners WHERE 1=1`
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if userFilter != "" {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", argIdx)
|
||||
args = append(args, userFilter)
|
||||
argIdx++
|
||||
}
|
||||
if statusFilter != "" {
|
||||
query += fmt.Sprintf(" AND status = $%d", argIdx)
|
||||
args = append(args, statusFilter)
|
||||
argIdx++
|
||||
}
|
||||
query += " ORDER BY created_at DESC"
|
||||
|
||||
rows, err := s.db.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query runners: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var runners []model.Runner
|
||||
for rows.Next() {
|
||||
var r model.Runner
|
||||
var claimedAt, completedAt *time.Time
|
||||
if err := rows.Scan(&r.ID, &r.User, &r.RepoURL, &r.Branch, &r.Tools, &r.Task,
|
||||
&r.Status, &r.ForgejoRunnerID, &r.WebhookDeliveryID, &r.PodName, &r.CPUReq, &r.MemReq,
|
||||
&r.CreatedAt, &claimedAt, &completedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan runner: %w", err)
|
||||
}
|
||||
r.ClaimedAt = claimedAt
|
||||
r.CompletedAt = completedAt
|
||||
runners = append(runners, r)
|
||||
}
|
||||
return runners, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateRunnerStatus transitions a runner to a new status with state machine validation.
|
||||
func (s *Store) UpdateRunnerStatus(ctx context.Context, id string, newStatus model.RunnerStatus, forgejoRunnerID string) error {
|
||||
current, err := s.GetRunner(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !current.Status.CanTransitionTo(newStatus) {
|
||||
return fmt.Errorf("invalid transition from %s to %s", current.Status, newStatus)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
var claimedAt, completedAt *time.Time
|
||||
if newStatus == model.RunnerStatusJobClaimed {
|
||||
claimedAt = &now
|
||||
}
|
||||
if newStatus.IsTerminal() {
|
||||
completedAt = &now
|
||||
}
|
||||
|
||||
query := `UPDATE runners SET status = $1, forgejo_runner_id = CASE WHEN $2 = '' THEN forgejo_runner_id ELSE $2 END`
|
||||
args := []any{string(newStatus), forgejoRunnerID}
|
||||
argIdx := 3
|
||||
|
||||
if claimedAt != nil {
|
||||
query += fmt.Sprintf(", claimed_at = $%d", argIdx)
|
||||
args = append(args, *claimedAt)
|
||||
argIdx++
|
||||
}
|
||||
if completedAt != nil {
|
||||
query += fmt.Sprintf(", completed_at = $%d", argIdx)
|
||||
args = append(args, *completedAt)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" WHERE id = $%d", argIdx)
|
||||
args = append(args, id)
|
||||
|
||||
tag, err := s.db.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update runner status: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("runner %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRunner removes a runner record by ID.
|
||||
func (s *Store) DeleteRunner(ctx context.Context, id string) error {
|
||||
tag, err := s.db.Exec(ctx, `DELETE FROM runners WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete runner: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("runner %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDeliveryProcessed checks if a webhook delivery ID has already been processed.
|
||||
func (s *Store) IsDeliveryProcessed(ctx context.Context, deliveryID string) (bool, error) {
|
||||
if deliveryID == "" {
|
||||
return false, nil
|
||||
}
|
||||
var exists bool
|
||||
err := s.db.QueryRow(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM runners WHERE webhook_delivery_id = $1)`,
|
||||
deliveryID).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("check delivery: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetRunnersForCleanup returns runners in terminal states (completed/failed).
|
||||
func (s *Store) GetRunnersForCleanup(ctx context.Context) ([]model.Runner, error) {
|
||||
return s.ListRunners(ctx, "", "")
|
||||
}
|
||||
|
||||
// GetStaleRunners returns runners older than the given TTL that aren't in terminal/cleanup states.
|
||||
func (s *Store) GetStaleRunners(ctx context.Context, ttl time.Duration) ([]model.Runner, error) {
|
||||
cutoff := time.Now().UTC().Add(-ttl)
|
||||
rows, err := s.db.Query(ctx,
|
||||
`SELECT id, user_id, repo_url, branch, tools, task, status,
|
||||
forgejo_runner_id, webhook_delivery_id, pod_name, cpu_req, mem_req,
|
||||
created_at, claimed_at, completed_at
|
||||
FROM runners
|
||||
WHERE created_at < $1
|
||||
AND status NOT IN ('completed', 'failed', 'cleanup_pending')
|
||||
ORDER BY created_at`, cutoff)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query stale runners: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var runners []model.Runner
|
||||
for rows.Next() {
|
||||
var r model.Runner
|
||||
var claimedAt, completedAt *time.Time
|
||||
if err := rows.Scan(&r.ID, &r.User, &r.RepoURL, &r.Branch, &r.Tools, &r.Task,
|
||||
&r.Status, &r.ForgejoRunnerID, &r.WebhookDeliveryID, &r.PodName, &r.CPUReq, &r.MemReq,
|
||||
&r.CreatedAt, &claimedAt, &completedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan stale runner: %w", err)
|
||||
}
|
||||
r.ClaimedAt = claimedAt
|
||||
r.CompletedAt = completedAt
|
||||
runners = append(runners, r)
|
||||
}
|
||||
return runners, rows.Err()
|
||||
}
|
||||
|
||||
func scanRunner(row pgx.Row) (*model.Runner, error) {
|
||||
var r model.Runner
|
||||
var claimedAt, completedAt *time.Time
|
||||
err := row.Scan(&r.ID, &r.User, &r.RepoURL, &r.Branch, &r.Tools, &r.Task,
|
||||
&r.Status, &r.ForgejoRunnerID, &r.WebhookDeliveryID, &r.PodName, &r.CPUReq, &r.MemReq,
|
||||
&r.CreatedAt, &claimedAt, &completedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.ClaimedAt = claimedAt
|
||||
r.CompletedAt = completedAt
|
||||
return &r, nil
|
||||
}
|
||||
385
internal/store/runners_test.go
Normal file
385
internal/store/runners_test.go
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
func createTestUser(t *testing.T, s *Store, id string) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
_, err := s.CreateUser(ctx, id, model.DefaultQuota())
|
||||
if err != nil {
|
||||
t.Fatalf("create test user %q: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRunner(user, id string) *model.Runner {
|
||||
return &model.Runner{
|
||||
ID: id,
|
||||
User: user,
|
||||
RepoURL: "ilia/test-repo",
|
||||
Branch: "main",
|
||||
Tools: "go",
|
||||
Task: "implement feature",
|
||||
Status: model.RunnerStatusReceived,
|
||||
CPUReq: "2",
|
||||
MemReq: "4Gi",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func cleanRunners(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
if _, err := testPool.Exec(ctx, "DELETE FROM runners"); err != nil {
|
||||
t.Fatalf("clean runners: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRunner(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-abc123")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create runner: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetRunner(ctx, "runner-abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("get runner: %v", err)
|
||||
}
|
||||
if got.ID != "runner-abc123" {
|
||||
t.Errorf("got id %q, want %q", got.ID, "runner-abc123")
|
||||
}
|
||||
if got.User != "alice" {
|
||||
t.Errorf("got user %q, want %q", got.User, "alice")
|
||||
}
|
||||
if got.Status != model.RunnerStatusReceived {
|
||||
t.Errorf("got status %q, want %q", got.Status, model.RunnerStatusReceived)
|
||||
}
|
||||
if got.RepoURL != "ilia/test-repo" {
|
||||
t.Errorf("got repo %q, want %q", got.RepoURL, "ilia/test-repo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRunner_Duplicate(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-dup1")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("first create: %v", err)
|
||||
}
|
||||
err := s.CreateRunner(ctx, r)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate runner")
|
||||
}
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Errorf("got %v, want ErrDuplicate", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRunner_NotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.GetRunner(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent runner")
|
||||
}
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRunners(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
createTestUser(t, s, "bob")
|
||||
ctx := context.Background()
|
||||
|
||||
r1 := newTestRunner("alice", "runner-list1")
|
||||
r2 := newTestRunner("alice", "runner-list2")
|
||||
r3 := newTestRunner("bob", "runner-list3")
|
||||
for _, r := range []*model.Runner{r1, r2, r3} {
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create runner %s: %v", r.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
all, err := s.ListRunners(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("list all: %v", err)
|
||||
}
|
||||
if len(all) != 3 {
|
||||
t.Fatalf("expected 3 runners, got %d", len(all))
|
||||
}
|
||||
|
||||
aliceRunners, err := s.ListRunners(ctx, "alice", "")
|
||||
if err != nil {
|
||||
t.Fatalf("list alice: %v", err)
|
||||
}
|
||||
if len(aliceRunners) != 2 {
|
||||
t.Fatalf("expected 2 runners for alice, got %d", len(aliceRunners))
|
||||
}
|
||||
|
||||
byStatus, err := s.ListRunners(ctx, "", "received")
|
||||
if err != nil {
|
||||
t.Fatalf("list by status: %v", err)
|
||||
}
|
||||
if len(byStatus) != 3 {
|
||||
t.Fatalf("expected 3 received runners, got %d", len(byStatus))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRunnerStatus_ValidTransition(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-trans1")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
if err := s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusPodCreating, ""); err != nil {
|
||||
t.Fatalf("transition to pod_creating: %v", err)
|
||||
}
|
||||
|
||||
got, _ := s.GetRunner(ctx, r.ID)
|
||||
if got.Status != model.RunnerStatusPodCreating {
|
||||
t.Errorf("expected pod_creating, got %s", got.Status)
|
||||
}
|
||||
|
||||
if err := s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusRunnerRegistered, "forgejo-42"); err != nil {
|
||||
t.Fatalf("transition to runner_registered: %v", err)
|
||||
}
|
||||
|
||||
got, _ = s.GetRunner(ctx, r.ID)
|
||||
if got.Status != model.RunnerStatusRunnerRegistered {
|
||||
t.Errorf("expected runner_registered, got %s", got.Status)
|
||||
}
|
||||
if got.ForgejoRunnerID != "forgejo-42" {
|
||||
t.Errorf("expected forgejo_runner_id 'forgejo-42', got %q", got.ForgejoRunnerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRunnerStatus_InvalidTransition(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-invalid1")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
err := s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusCompleted, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid transition received -> completed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRunnerStatus_SetsClaimedAt(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-claimed1")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusPodCreating, "")
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusRunnerRegistered, "")
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusJobClaimed, "")
|
||||
|
||||
got, _ := s.GetRunner(ctx, r.ID)
|
||||
if got.ClaimedAt == nil {
|
||||
t.Error("expected claimed_at to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRunnerStatus_SetsCompletedAt(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-complete1")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusPodCreating, "")
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusRunnerRegistered, "")
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusJobClaimed, "")
|
||||
_ = s.UpdateRunnerStatus(ctx, r.ID, model.RunnerStatusCompleted, "")
|
||||
|
||||
got, _ := s.GetRunner(ctx, r.ID)
|
||||
if got.CompletedAt == nil {
|
||||
t.Error("expected completed_at to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRunner(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r := newTestRunner("alice", "runner-del1")
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
if err := s.DeleteRunner(ctx, r.ID); err != nil {
|
||||
t.Fatalf("delete: %v", err)
|
||||
}
|
||||
|
||||
_, err := s.GetRunner(ctx, r.ID)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRunner_NotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.DeleteRunner(ctx, "nonexistent")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDeliveryProcessed(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
isDupe, err := s.IsDeliveryProcessed(ctx, "delivery-1")
|
||||
if err != nil {
|
||||
t.Fatalf("check delivery: %v", err)
|
||||
}
|
||||
if isDupe {
|
||||
t.Error("expected delivery-1 to not be processed yet")
|
||||
}
|
||||
|
||||
r := newTestRunner("alice", "runner-dedupe1")
|
||||
r.WebhookDeliveryID = "delivery-1"
|
||||
if err := s.CreateRunner(ctx, r); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
isDupe, err = s.IsDeliveryProcessed(ctx, "delivery-1")
|
||||
if err != nil {
|
||||
t.Fatalf("check delivery after create: %v", err)
|
||||
}
|
||||
if !isDupe {
|
||||
t.Error("expected delivery-1 to be processed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDeliveryProcessed_EmptyID(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
ctx := context.Background()
|
||||
|
||||
isDupe, err := s.IsDeliveryProcessed(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("check empty delivery: %v", err)
|
||||
}
|
||||
if isDupe {
|
||||
t.Error("expected empty delivery ID to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookDeliveryDedupe_UniqueConstraint(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
r1 := newTestRunner("alice", "runner-uniq1")
|
||||
r1.WebhookDeliveryID = "webhook-abc"
|
||||
if err := s.CreateRunner(ctx, r1); err != nil {
|
||||
t.Fatalf("first insert: %v", err)
|
||||
}
|
||||
|
||||
r2 := newTestRunner("alice", "runner-uniq2")
|
||||
r2.WebhookDeliveryID = "webhook-abc"
|
||||
err := s.CreateRunner(ctx, r2)
|
||||
if err == nil {
|
||||
t.Fatal("expected duplicate error for same webhook_delivery_id")
|
||||
}
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Errorf("expected ErrDuplicate, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStaleRunners(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
cleanRunners(t)
|
||||
createTestUser(t, s, "alice")
|
||||
ctx := context.Background()
|
||||
|
||||
old := newTestRunner("alice", "runner-stale1")
|
||||
old.CreatedAt = time.Now().UTC().Add(-3 * time.Hour)
|
||||
if err := s.CreateRunner(ctx, old); err != nil {
|
||||
t.Fatalf("create old runner: %v", err)
|
||||
}
|
||||
|
||||
fresh := newTestRunner("alice", "runner-fresh1")
|
||||
fresh.CreatedAt = time.Now().UTC()
|
||||
if err := s.CreateRunner(ctx, fresh); err != nil {
|
||||
t.Fatalf("create fresh runner: %v", err)
|
||||
}
|
||||
|
||||
stale, err := s.GetStaleRunners(ctx, 2*time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("get stale: %v", err)
|
||||
}
|
||||
if len(stale) != 1 {
|
||||
t.Fatalf("expected 1 stale runner, got %d", len(stale))
|
||||
}
|
||||
if stale[0].ID != "runner-stale1" {
|
||||
t.Errorf("expected stale runner runner-stale1, got %s", stale[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRunnerID(t *testing.T) {
|
||||
id1, err := GenerateRunnerID()
|
||||
if err != nil {
|
||||
t.Fatalf("generate: %v", err)
|
||||
}
|
||||
if len(id1) == 0 {
|
||||
t.Error("expected non-empty ID")
|
||||
}
|
||||
|
||||
id2, _ := GenerateRunnerID()
|
||||
if id1 == id2 {
|
||||
t.Error("expected unique IDs")
|
||||
}
|
||||
|
||||
if id1[:7] != "runner-" {
|
||||
t.Errorf("expected runner- prefix, got %s", id1[:7])
|
||||
}
|
||||
}
|
||||
100
internal/store/store.go
Normal file
100
internal/store/store.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// DBTX is the interface for database query operations.
|
||||
// Satisfied by *pgxpool.Pool, *pgx.Conn, and pgx.Tx.
|
||||
type DBTX interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
|
||||
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
}
|
||||
|
||||
// Store provides database operations for the dev-pod API.
|
||||
type Store struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
// New creates a Store with the given database connection.
|
||||
func New(db DBTX) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// NewPool creates a new pgx connection pool.
|
||||
func NewPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
|
||||
pool, err := pgxpool.New(ctx, databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to database: %w", err)
|
||||
}
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Migrate runs database migrations to create required tables.
|
||||
func (s *Store) Migrate(ctx context.Context) error {
|
||||
migrations := []string{
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
max_concurrent_pods INTEGER NOT NULL DEFAULT 3,
|
||||
max_cpu_per_pod INTEGER NOT NULL DEFAULT 8,
|
||||
max_ram_gb_per_pod INTEGER NOT NULL DEFAULT 16,
|
||||
monthly_pod_hours INTEGER NOT NULL DEFAULT 500,
|
||||
monthly_ai_requests INTEGER NOT NULL DEFAULT 10000
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS api_keys (
|
||||
key_hash TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL DEFAULT 'user',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_used_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS usage_records (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
pod_name TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
value DOUBLE PRECISION,
|
||||
recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_usage_user_month ON usage_records(user_id, recorded_at)`,
|
||||
`ALTER TABLE users ADD COLUMN IF NOT EXISTS forgejo_token TEXT NOT NULL DEFAULT ''`,
|
||||
`CREATE TABLE IF NOT EXISTS runners (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id),
|
||||
repo_url TEXT NOT NULL,
|
||||
branch TEXT NOT NULL DEFAULT 'main',
|
||||
tools TEXT NOT NULL DEFAULT '',
|
||||
task TEXT NOT NULL DEFAULT '',
|
||||
status TEXT NOT NULL DEFAULT 'received',
|
||||
forgejo_runner_id TEXT NOT NULL DEFAULT '',
|
||||
webhook_delivery_id TEXT NOT NULL DEFAULT '',
|
||||
pod_name TEXT NOT NULL DEFAULT '',
|
||||
cpu_req TEXT NOT NULL DEFAULT '2',
|
||||
mem_req TEXT NOT NULL DEFAULT '4Gi',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
claimed_at TIMESTAMPTZ,
|
||||
completed_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE UNIQUE INDEX IF NOT EXISTS idx_runners_webhook_delivery
|
||||
ON runners(webhook_delivery_id) WHERE webhook_delivery_id != ''`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_runners_status ON runners(status)`,
|
||||
`ALTER TABLE users ADD COLUMN IF NOT EXISTS tailscale_key TEXT NOT NULL DEFAULT ''`,
|
||||
}
|
||||
for _, m := range migrations {
|
||||
if _, err := s.db.Exec(ctx, m); err != nil {
|
||||
return fmt.Errorf("run migration: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
internal/store/testhelper_test.go
Normal file
57
internal/store/testhelper_test.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var testPool *pgxpool.Pool
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
postgres := embeddedpostgres.NewDatabase(
|
||||
embeddedpostgres.DefaultConfig().
|
||||
Port(15432).
|
||||
Database("devpod_test"),
|
||||
)
|
||||
if err := postgres.Start(); err != nil {
|
||||
log.Fatalf("start embedded postgres: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, "postgres://postgres:postgres@localhost:15432/devpod_test?sslmode=disable")
|
||||
if err != nil {
|
||||
postgres.Stop()
|
||||
log.Fatalf("connect to test db: %v", err)
|
||||
}
|
||||
testPool = pool
|
||||
|
||||
code := m.Run()
|
||||
|
||||
pool.Close()
|
||||
if err := postgres.Stop(); err != nil {
|
||||
log.Printf("stop embedded postgres: %v", err)
|
||||
}
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func newTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
s := New(testPool)
|
||||
if err := s.Migrate(ctx); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
// Clean tables for test isolation (order matters: foreign keys)
|
||||
for _, table := range []string{"runners", "usage_records", "api_keys", "users"} {
|
||||
if _, err := testPool.Exec(ctx, fmt.Sprintf("DELETE FROM %s", table)); err != nil {
|
||||
t.Fatalf("clean %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
178
internal/store/usage.go
Normal file
178
internal/store/usage.go
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
// RecordPodStart records a pod start event for usage tracking.
|
||||
func (s *Store) RecordPodStart(ctx context.Context, userID, podName string) error {
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ($1, $2, $3, 0, NOW())`,
|
||||
userID, podName, model.EventPodStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record pod start: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordPodStop records a pod stop event for usage tracking.
|
||||
func (s *Store) RecordPodStop(ctx context.Context, userID, podName string) error {
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ($1, $2, $3, 0, NOW())`,
|
||||
userID, podName, model.EventPodStop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record pod stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordResourceSample records periodic CPU and memory usage samples.
|
||||
// cpuMillicores is CPU usage in millicores, memBytes is memory in bytes.
|
||||
func (s *Store) RecordResourceSample(ctx context.Context, userID, podName string, cpuMillicores, memBytes float64) error {
|
||||
now := time.Now().UTC()
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ($1, $2, $3, $4, $5), ($1, $2, $6, $7, $5)`,
|
||||
userID, podName, model.EventCPUSample, cpuMillicores, now,
|
||||
model.EventMemSample, memBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record resource sample: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordAIRequest records an AI proxy request event.
|
||||
func (s *Store) RecordAIRequest(ctx context.Context, userID string) error {
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ($1, '', $2, 1, NOW())`,
|
||||
userID, model.EventAIRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record ai request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUsage returns aggregated usage for a user for a given month.
|
||||
func (s *Store) GetUsage(ctx context.Context, userID string, year int, month time.Month, monthlyBudget int) (*model.UsageSummary, error) {
|
||||
return s.getUsage(ctx, userID, year, month, monthlyBudget, time.Now().UTC())
|
||||
}
|
||||
|
||||
// getUsage is the internal implementation that accepts a "now" parameter for testing.
|
||||
func (s *Store) getUsage(ctx context.Context, userID string, year int, month time.Month, monthlyBudget int, now time.Time) (*model.UsageSummary, error) {
|
||||
monthStart := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC)
|
||||
monthEnd := monthStart.AddDate(0, 1, 0)
|
||||
|
||||
// Calculate pod-hours from pod_start/pod_stop pairs.
|
||||
// For each pod_start, find the nearest subsequent pod_stop for the same user/pod.
|
||||
// If no stop found, use "now" (pod is still running).
|
||||
// Sessions that cross month boundaries are clamped to [monthStart, monthEnd].
|
||||
var podHours float64
|
||||
err := s.db.QueryRow(ctx, `
|
||||
WITH sessions AS (
|
||||
SELECT
|
||||
GREATEST(recorded_at, $2) AS start_time,
|
||||
(
|
||||
SELECT MIN(r2.recorded_at)
|
||||
FROM usage_records r2
|
||||
WHERE r2.user_id = r.user_id
|
||||
AND r2.pod_name = r.pod_name
|
||||
AND r2.event_type = 'pod_stop'
|
||||
AND r2.recorded_at > r.recorded_at
|
||||
) AS stop_time
|
||||
FROM usage_records r
|
||||
WHERE r.user_id = $1
|
||||
AND r.event_type = 'pod_start'
|
||||
AND r.recorded_at < $3
|
||||
)
|
||||
SELECT COALESCE(SUM(
|
||||
EXTRACT(EPOCH FROM LEAST(COALESCE(stop_time, $4), $3) - start_time) / 3600.0
|
||||
), 0)
|
||||
FROM sessions
|
||||
WHERE COALESCE(stop_time, $4) > $2`,
|
||||
userID, monthStart, monthEnd, now).Scan(&podHours)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("calculate pod hours: %w", err)
|
||||
}
|
||||
|
||||
// CPU-hours from resource samples.
|
||||
// Each sample represents 60s of CPU at the sampled rate (millicores).
|
||||
// CPU-hours = sum(millicores) * 60s / 3600s / 1000m = sum(millicores) / 60000
|
||||
var cpuHours float64
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(SUM(value) / 60000.0, 0)
|
||||
FROM usage_records
|
||||
WHERE user_id = $1 AND event_type = $2
|
||||
AND recorded_at >= $3 AND recorded_at < $4`,
|
||||
userID, model.EventCPUSample, monthStart, monthEnd).Scan(&cpuHours)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("calculate cpu hours: %w", err)
|
||||
}
|
||||
|
||||
// Count AI requests.
|
||||
var aiRequests int64
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM usage_records
|
||||
WHERE user_id = $1 AND event_type = $2
|
||||
AND recorded_at >= $3 AND recorded_at < $4`,
|
||||
userID, model.EventAIRequest, monthStart, monthEnd).Scan(&aiRequests)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count ai requests: %w", err)
|
||||
}
|
||||
|
||||
var budgetUsedPct float64
|
||||
if monthlyBudget > 0 {
|
||||
budgetUsedPct = (podHours / float64(monthlyBudget)) * 100
|
||||
}
|
||||
|
||||
return &model.UsageSummary{
|
||||
PodHours: podHours,
|
||||
CPUHours: cpuHours,
|
||||
AIRequests: aiRequests,
|
||||
BudgetUsedPct: budgetUsedPct,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDailyUsage returns daily usage breakdown for a user for a given month.
|
||||
// Pod-hours per day are estimated from CPU sample count (each sample = 1/60 hour of pod time).
|
||||
func (s *Store) GetDailyUsage(ctx context.Context, userID string, year int, month time.Month) ([]model.DailyUsage, error) {
|
||||
monthStart := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC)
|
||||
monthEnd := monthStart.AddDate(0, 1, 0)
|
||||
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT
|
||||
DATE(recorded_at) AS day,
|
||||
COALESCE(SUM(CASE WHEN event_type = 'cpu_sample' THEN 1.0/60.0 ELSE 0 END), 0) AS pod_hours,
|
||||
COALESCE(SUM(CASE WHEN event_type = 'cpu_sample' THEN value / 60000.0 ELSE 0 END), 0) AS cpu_hours,
|
||||
COUNT(*) FILTER (WHERE event_type = 'ai_request') AS ai_requests
|
||||
FROM usage_records
|
||||
WHERE user_id = $1
|
||||
AND recorded_at >= $2 AND recorded_at < $3
|
||||
AND event_type IN ('cpu_sample', 'ai_request')
|
||||
GROUP BY DATE(recorded_at)
|
||||
ORDER BY DATE(recorded_at)`,
|
||||
userID, monthStart, monthEnd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query daily usage: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []model.DailyUsage
|
||||
for rows.Next() {
|
||||
var d model.DailyUsage
|
||||
var day time.Time
|
||||
if err := rows.Scan(&day, &d.PodHours, &d.CPUHours, &d.AIRequests); err != nil {
|
||||
return nil, fmt.Errorf("scan daily usage: %w", err)
|
||||
}
|
||||
d.Date = day.Format("2006-01-02")
|
||||
result = append(result, d)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
399
internal/store/usage_test.go
Normal file
399
internal/store/usage_test.go
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
func TestRecordPodStart(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
if err := s.RecordPodStart(ctx, "alice", "main"); err != nil {
|
||||
t.Fatalf("record pod start: %v", err)
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := testPool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'pod_start'`).Scan(&count); err != nil {
|
||||
t.Fatalf("query: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 pod_start record, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordPodStop(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
if err := s.RecordPodStop(ctx, "alice", "main"); err != nil {
|
||||
t.Fatalf("record pod stop: %v", err)
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := testPool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'pod_stop'`).Scan(&count); err != nil {
|
||||
t.Fatalf("query: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 pod_stop record, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordResourceSample(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
if err := s.RecordResourceSample(ctx, "alice", "main", 2500, 1073741824); err != nil {
|
||||
t.Fatalf("record resource sample: %v", err)
|
||||
}
|
||||
|
||||
// Should have both cpu and mem records
|
||||
var cpuCount, memCount int
|
||||
testPool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'cpu_sample'`).Scan(&cpuCount)
|
||||
testPool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'mem_sample'`).Scan(&memCount)
|
||||
|
||||
if cpuCount != 1 {
|
||||
t.Errorf("expected 1 cpu_sample, got %d", cpuCount)
|
||||
}
|
||||
if memCount != 1 {
|
||||
t.Errorf("expected 1 mem_sample, got %d", memCount)
|
||||
}
|
||||
|
||||
// Verify CPU value
|
||||
var cpuValue float64
|
||||
testPool.QueryRow(ctx,
|
||||
`SELECT value FROM usage_records WHERE user_id = 'alice' AND event_type = 'cpu_sample'`).Scan(&cpuValue)
|
||||
if cpuValue != 2500 {
|
||||
t.Errorf("expected cpu value 2500, got %f", cpuValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAIRequest(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := s.RecordAIRequest(ctx, "alice"); err != nil {
|
||||
t.Fatalf("record ai request %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
var count int
|
||||
testPool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM usage_records WHERE user_id = 'alice' AND event_type = 'ai_request'`).Scan(&count)
|
||||
if count != 5 {
|
||||
t.Errorf("expected 5 ai_request records, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_Empty(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500,
|
||||
time.Date(2026, time.March, 15, 12, 0, 0, 0, time.UTC))
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
if usage.PodHours != 0 {
|
||||
t.Errorf("expected 0 pod hours, got %f", usage.PodHours)
|
||||
}
|
||||
if usage.CPUHours != 0 {
|
||||
t.Errorf("expected 0 cpu hours, got %f", usage.CPUHours)
|
||||
}
|
||||
if usage.AIRequests != 0 {
|
||||
t.Errorf("expected 0 ai requests, got %d", usage.AIRequests)
|
||||
}
|
||||
if usage.BudgetUsedPct != 0 {
|
||||
t.Errorf("expected 0 budget used, got %f", usage.BudgetUsedPct)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_PodHours_CompletedSession(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
start := time.Date(2026, time.March, 15, 10, 0, 0, 0, time.UTC)
|
||||
stop := time.Date(2026, time.March, 15, 12, 0, 0, 0, time.UTC) // 2 hours later
|
||||
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_start', 0, $1)`, start)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_stop', 0, $1)`, stop)
|
||||
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, stop)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
if math.Abs(usage.PodHours-2.0) > 0.01 {
|
||||
t.Errorf("expected ~2.0 pod hours, got %f", usage.PodHours)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_PodHours_RunningPod(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
start := time.Date(2026, time.March, 15, 10, 0, 0, 0, time.UTC)
|
||||
now := time.Date(2026, time.March, 15, 13, 0, 0, 0, time.UTC) // 3 hours later, no stop
|
||||
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_start', 0, $1)`, start)
|
||||
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
if math.Abs(usage.PodHours-3.0) > 0.01 {
|
||||
t.Errorf("expected ~3.0 pod hours (running pod), got %f", usage.PodHours)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_PodHours_MultiplePods(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
// Pod 1: 2 hours
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'pod1', 'pod_start', 0, $1)`,
|
||||
time.Date(2026, time.March, 10, 8, 0, 0, 0, time.UTC))
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'pod1', 'pod_stop', 0, $1)`,
|
||||
time.Date(2026, time.March, 10, 10, 0, 0, 0, time.UTC))
|
||||
|
||||
// Pod 2: 5 hours
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'pod2', 'pod_start', 0, $1)`,
|
||||
time.Date(2026, time.March, 12, 14, 0, 0, 0, time.UTC))
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'pod2', 'pod_stop', 0, $1)`,
|
||||
time.Date(2026, time.March, 12, 19, 0, 0, 0, time.UTC))
|
||||
|
||||
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
if math.Abs(usage.PodHours-7.0) > 0.01 {
|
||||
t.Errorf("expected ~7.0 pod hours (2+5), got %f", usage.PodHours)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_CPUHours(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
// 10 samples of 3000 millicores
|
||||
// CPU-hours = 10 * 3000 / 60000 = 0.5
|
||||
for i := 0; i < 10; i++ {
|
||||
ts := time.Date(2026, time.March, 15, 10, i, 0, 0, time.UTC)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'cpu_sample', 3000, $1)`, ts)
|
||||
}
|
||||
|
||||
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
if math.Abs(usage.CPUHours-0.5) > 0.001 {
|
||||
t.Errorf("expected ~0.5 cpu hours, got %f", usage.CPUHours)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_AIRequests(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
for i := 0; i < 42; i++ {
|
||||
ts := time.Date(2026, time.March, 15, 10, 0, i, 0, time.UTC)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', '', 'ai_request', 1, $1)`, ts)
|
||||
}
|
||||
|
||||
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
if usage.AIRequests != 42 {
|
||||
t.Errorf("expected 42 ai requests, got %d", usage.AIRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_BudgetPercent(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
// 100 hours used, budget = 500 → 20%
|
||||
start := time.Date(2026, time.March, 1, 0, 0, 0, 0, time.UTC)
|
||||
stop := time.Date(2026, time.March, 5, 4, 0, 0, 0, time.UTC) // 100 hours
|
||||
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_start', 0, $1)`, start)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_stop', 0, $1)`, stop)
|
||||
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, stop)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
if math.Abs(usage.BudgetUsedPct-20.0) > 0.1 {
|
||||
t.Errorf("expected ~20%% budget used, got %f", usage.BudgetUsedPct)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsage_DifferentMonths(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
// March data
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_start', 0, $1)`,
|
||||
time.Date(2026, time.March, 15, 10, 0, 0, 0, time.UTC))
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'pod_stop', 0, $1)`,
|
||||
time.Date(2026, time.March, 15, 12, 0, 0, 0, time.UTC))
|
||||
|
||||
// February data should not be counted
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'other', 'pod_start', 0, $1)`,
|
||||
time.Date(2026, time.February, 15, 10, 0, 0, 0, time.UTC))
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'other', 'pod_stop', 0, $1)`,
|
||||
time.Date(2026, time.February, 15, 20, 0, 0, 0, time.UTC))
|
||||
|
||||
now := time.Date(2026, time.March, 20, 0, 0, 0, 0, time.UTC)
|
||||
usage, err := s.getUsage(ctx, "alice", 2026, time.March, 500, now)
|
||||
if err != nil {
|
||||
t.Fatalf("get usage: %v", err)
|
||||
}
|
||||
|
||||
// Should only count March: 2 hours
|
||||
if math.Abs(usage.PodHours-2.0) > 0.01 {
|
||||
t.Errorf("expected ~2.0 pod hours (March only), got %f", usage.PodHours)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDailyUsage(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
// Day 1: 5 CPU samples of 2000m + 3 AI requests
|
||||
for i := 0; i < 5; i++ {
|
||||
ts := time.Date(2026, time.March, 10, 10, i, 0, 0, time.UTC)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'cpu_sample', 2000, $1)`, ts)
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
ts := time.Date(2026, time.March, 10, 11, i, 0, 0, time.UTC)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', '', 'ai_request', 1, $1)`, ts)
|
||||
}
|
||||
|
||||
// Day 2: 10 CPU samples of 4000m + 7 AI requests
|
||||
for i := 0; i < 10; i++ {
|
||||
ts := time.Date(2026, time.March, 11, 14, i, 0, 0, time.UTC)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', 'main', 'cpu_sample', 4000, $1)`, ts)
|
||||
}
|
||||
for i := 0; i < 7; i++ {
|
||||
ts := time.Date(2026, time.March, 11, 15, i, 0, 0, time.UTC)
|
||||
testPool.Exec(ctx,
|
||||
`INSERT INTO usage_records (user_id, pod_name, event_type, value, recorded_at)
|
||||
VALUES ('alice', '', 'ai_request', 1, $1)`, ts)
|
||||
}
|
||||
|
||||
daily, err := s.GetDailyUsage(ctx, "alice", 2026, time.March)
|
||||
if err != nil {
|
||||
t.Fatalf("get daily usage: %v", err)
|
||||
}
|
||||
|
||||
if len(daily) != 2 {
|
||||
t.Fatalf("expected 2 days, got %d", len(daily))
|
||||
}
|
||||
|
||||
// Day 1
|
||||
if daily[0].Date != "2026-03-10" {
|
||||
t.Errorf("day 1 date: got %q, want 2026-03-10", daily[0].Date)
|
||||
}
|
||||
// pod_hours = 5 samples * (1/60) = 5/60
|
||||
if math.Abs(daily[0].PodHours-5.0/60.0) > 0.001 {
|
||||
t.Errorf("day 1 pod hours: got %f, want %f", daily[0].PodHours, 5.0/60.0)
|
||||
}
|
||||
// cpu_hours = 5 * 2000 / 60000 = 10000/60000
|
||||
if math.Abs(daily[0].CPUHours-10000.0/60000.0) > 0.001 {
|
||||
t.Errorf("day 1 cpu hours: got %f, want %f", daily[0].CPUHours, 10000.0/60000.0)
|
||||
}
|
||||
if daily[0].AIRequests != 3 {
|
||||
t.Errorf("day 1 ai requests: got %d, want 3", daily[0].AIRequests)
|
||||
}
|
||||
|
||||
// Day 2
|
||||
if daily[1].Date != "2026-03-11" {
|
||||
t.Errorf("day 2 date: got %q, want 2026-03-11", daily[1].Date)
|
||||
}
|
||||
if daily[1].AIRequests != 7 {
|
||||
t.Errorf("day 2 ai requests: got %d, want 7", daily[1].AIRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDailyUsage_Empty(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
s.CreateUser(ctx, "alice", model.DefaultQuota())
|
||||
|
||||
daily, err := s.GetDailyUsage(ctx, "alice", 2026, time.March)
|
||||
if err != nil {
|
||||
t.Fatalf("get daily usage: %v", err)
|
||||
}
|
||||
|
||||
if len(daily) != 0 {
|
||||
t.Errorf("expected 0 days, got %d", len(daily))
|
||||
}
|
||||
}
|
||||
280
internal/store/users.go
Normal file
280
internal/store/users.go
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when a requested resource does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// ErrDuplicate is returned when a resource already exists.
|
||||
var ErrDuplicate = errors.New("already exists")
|
||||
|
||||
// CreateUser inserts a new user with the given quota.
|
||||
func (s *Store) CreateUser(ctx context.Context, id string, quota model.Quota) (*model.User, error) {
|
||||
now := time.Now().UTC()
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO users (id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
id, now, quota.MaxConcurrentPods, quota.MaxCPUPerPod, quota.MaxRAMGBPerPod,
|
||||
quota.MonthlyPodHours, quota.MonthlyAIRequests,
|
||||
)
|
||||
if err != nil {
|
||||
if isDuplicateError(err) {
|
||||
return nil, fmt.Errorf("user %q: %w", id, ErrDuplicate)
|
||||
}
|
||||
return nil, fmt.Errorf("insert user: %w", err)
|
||||
}
|
||||
return &model.User{
|
||||
ID: id,
|
||||
CreatedAt: now,
|
||||
Quota: quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUser retrieves a user by ID.
|
||||
func (s *Store) GetUser(ctx context.Context, id string) (*model.User, error) {
|
||||
row := s.db.QueryRow(ctx,
|
||||
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
||||
monthly_pod_hours, monthly_ai_requests
|
||||
FROM users WHERE id = $1`, id)
|
||||
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil, fmt.Errorf("query user: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// ListUsers returns all users ordered by creation time.
|
||||
func (s *Store) ListUsers(ctx context.Context) ([]model.User, error) {
|
||||
rows, err := s.db.Query(ctx,
|
||||
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
||||
monthly_pod_hours, monthly_ai_requests
|
||||
FROM users ORDER BY created_at`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query users: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []model.User
|
||||
for rows.Next() {
|
||||
var u model.User
|
||||
if err := rows.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
||||
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests); err != nil {
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteUser removes a user by ID. Cascades to api_keys and usage_records.
|
||||
func (s *Store) DeleteUser(ctx context.Context, id string) error {
|
||||
tag, err := s.db.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("user %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateQuotas updates specific quota fields for a user.
|
||||
func (s *Store) UpdateQuotas(ctx context.Context, id string, req model.UpdateQuotasRequest) (*model.User, error) {
|
||||
var sets []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if req.MaxConcurrentPods != nil {
|
||||
sets = append(sets, fmt.Sprintf("max_concurrent_pods = $%d", argIdx))
|
||||
args = append(args, *req.MaxConcurrentPods)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxCPUPerPod != nil {
|
||||
sets = append(sets, fmt.Sprintf("max_cpu_per_pod = $%d", argIdx))
|
||||
args = append(args, *req.MaxCPUPerPod)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxRAMGBPerPod != nil {
|
||||
sets = append(sets, fmt.Sprintf("max_ram_gb_per_pod = $%d", argIdx))
|
||||
args = append(args, *req.MaxRAMGBPerPod)
|
||||
argIdx++
|
||||
}
|
||||
if req.MonthlyPodHours != nil {
|
||||
sets = append(sets, fmt.Sprintf("monthly_pod_hours = $%d", argIdx))
|
||||
args = append(args, *req.MonthlyPodHours)
|
||||
argIdx++
|
||||
}
|
||||
if req.MonthlyAIRequests != nil {
|
||||
sets = append(sets, fmt.Sprintf("monthly_ai_requests = $%d", argIdx))
|
||||
args = append(args, *req.MonthlyAIRequests)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
return s.GetUser(ctx, id)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`UPDATE users SET %s WHERE id = $%d
|
||||
RETURNING id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
||||
monthly_pod_hours, monthly_ai_requests`,
|
||||
strings.Join(sets, ", "), argIdx)
|
||||
args = append(args, id)
|
||||
|
||||
row := s.db.QueryRow(ctx, query, args...)
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil, fmt.Errorf("update quotas: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GenerateAPIKey creates a new random API key.
|
||||
// Returns the plaintext key (dpk_ + 24 hex chars) and its SHA-256 hash.
|
||||
func GenerateAPIKey() (plaintext string, hash string, err error) {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
plain := model.APIKeyPrefix + hex.EncodeToString(b)
|
||||
return plain, HashKey(plain), nil
|
||||
}
|
||||
|
||||
// HashKey returns the hex-encoded SHA-256 hash of a key.
|
||||
func HashKey(key string) string {
|
||||
h := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// CreateAPIKey stores a hashed API key for a user.
|
||||
func (s *Store) CreateAPIKey(ctx context.Context, userID, role, keyHash string) error {
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO api_keys (key_hash, user_id, role, created_at) VALUES ($1, $2, $3, $4)`,
|
||||
keyHash, userID, role, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert api key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey checks an API key and returns the associated user and role.
|
||||
func (s *Store) ValidateKey(ctx context.Context, key string) (*model.User, string, error) {
|
||||
hash := HashKey(key)
|
||||
row := s.db.QueryRow(ctx,
|
||||
`SELECT u.id, u.created_at, u.max_concurrent_pods, u.max_cpu_per_pod, u.max_ram_gb_per_pod,
|
||||
u.monthly_pod_hours, u.monthly_ai_requests, k.role
|
||||
FROM api_keys k JOIN users u ON k.user_id = u.id
|
||||
WHERE k.key_hash = $1`, hash)
|
||||
|
||||
var u model.User
|
||||
var role string
|
||||
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
||||
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests, &role)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, "", fmt.Errorf("invalid api key: %w", ErrNotFound)
|
||||
}
|
||||
return nil, "", fmt.Errorf("validate key: %w", err)
|
||||
}
|
||||
|
||||
// Update last_used_at (best-effort, don't fail auth on update error)
|
||||
if _, err := s.db.Exec(ctx,
|
||||
`UPDATE api_keys SET last_used_at = $1 WHERE key_hash = $2`,
|
||||
time.Now().UTC(), hash); err != nil {
|
||||
slog.Warn("failed to update api key last_used_at", "error", err)
|
||||
}
|
||||
|
||||
return &u, role, nil
|
||||
}
|
||||
|
||||
// scanUser scans a user row into a model.User.
|
||||
func scanUser(row pgx.Row) (*model.User, error) {
|
||||
var u model.User
|
||||
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
||||
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// SaveForgejoToken stores a Forgejo API token for a user.
|
||||
func (s *Store) SaveForgejoToken(ctx context.Context, userID, token string) error {
|
||||
tag, err := s.db.Exec(ctx, `UPDATE users SET forgejo_token = $1 WHERE id = $2`, token, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update forgejo token: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForgejoToken retrieves the Forgejo API token for a user.
|
||||
func (s *Store) GetForgejoToken(ctx context.Context, userID string) (string, error) {
|
||||
var token string
|
||||
err := s.db.QueryRow(ctx, `SELECT forgejo_token FROM users WHERE id = $1`, userID).Scan(&token)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return "", fmt.Errorf("query forgejo token: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SaveTailscaleKey stores a Tailscale auth key for a user.
|
||||
func (s *Store) SaveTailscaleKey(ctx context.Context, userID, key string) error {
|
||||
tag, err := s.db.Exec(ctx, `UPDATE users SET tailscale_key = $1 WHERE id = $2`, key, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update tailscale key: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTailscaleKey retrieves the Tailscale auth key for a user.
|
||||
func (s *Store) GetTailscaleKey(ctx context.Context, userID string) (string, error) {
|
||||
var key string
|
||||
err := s.db.QueryRow(ctx, `SELECT tailscale_key FROM users WHERE id = $1`, userID).Scan(&key)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return "", fmt.Errorf("query tailscale key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func isDuplicateError(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505" // unique_violation
|
||||
}
|
||||
return false
|
||||
}
|
||||
442
internal/store/users_test.go
Normal file
442
internal/store/users_test.go
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
// --- User CRUD tests ---
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
quota := model.DefaultQuota()
|
||||
|
||||
u, err := s.CreateUser(ctx, "alice", quota)
|
||||
if err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
if u.ID != "alice" {
|
||||
t.Errorf("got id %q, want %q", u.ID, "alice")
|
||||
}
|
||||
if u.CreatedAt.IsZero() {
|
||||
t.Error("created_at should not be zero")
|
||||
}
|
||||
if u.Quota != quota {
|
||||
t.Errorf("got quota %+v, want %+v", u.Quota, quota)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUserDuplicate(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
quota := model.DefaultQuota()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", quota); err != nil {
|
||||
t.Fatalf("first create: %v", err)
|
||||
}
|
||||
_, err := s.CreateUser(ctx, "alice", quota)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate user")
|
||||
}
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Errorf("got %v, want ErrDuplicate", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
quota := model.DefaultQuota()
|
||||
|
||||
created, err := s.CreateUser(ctx, "bob", quota)
|
||||
if err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetUser(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if got.ID != created.ID {
|
||||
t.Errorf("got id %q, want %q", got.ID, created.ID)
|
||||
}
|
||||
if got.Quota != quota {
|
||||
t.Errorf("got quota %+v, want %+v", got.Quota, quota)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserNotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.GetUser(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUsers(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
quota := model.DefaultQuota()
|
||||
|
||||
// Empty list
|
||||
users, err := s.ListUsers(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list empty: %v", err)
|
||||
}
|
||||
if len(users) != 0 {
|
||||
t.Errorf("got %d users, want 0", len(users))
|
||||
}
|
||||
|
||||
// Create users
|
||||
if _, err := s.CreateUser(ctx, "alice", quota); err != nil {
|
||||
t.Fatalf("create alice: %v", err)
|
||||
}
|
||||
if _, err := s.CreateUser(ctx, "bob", quota); err != nil {
|
||||
t.Fatalf("create bob: %v", err)
|
||||
}
|
||||
|
||||
users, err = s.ListUsers(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list: %v", err)
|
||||
}
|
||||
if len(users) != 2 {
|
||||
t.Fatalf("got %d users, want 2", len(users))
|
||||
}
|
||||
// Ordered by created_at
|
||||
if users[0].ID != "alice" {
|
||||
t.Errorf("first user: got %q, want %q", users[0].ID, "alice")
|
||||
}
|
||||
if users[1].ID != "bob" {
|
||||
t.Errorf("second user: got %q, want %q", users[1].ID, "bob")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
if err := s.DeleteUser(ctx, "alice"); err != nil {
|
||||
t.Fatalf("delete: %v", err)
|
||||
}
|
||||
|
||||
// Should be gone
|
||||
_, err := s.GetUser(ctx, "alice")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v after delete, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUserNotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.DeleteUser(ctx, "ghost")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUserCascadesAPIKeys(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
plain, hash, err := GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash); err != nil {
|
||||
t.Fatalf("create key: %v", err)
|
||||
}
|
||||
|
||||
// Delete user should cascade to api_keys
|
||||
if err := s.DeleteUser(ctx, "alice"); err != nil {
|
||||
t.Fatalf("delete: %v", err)
|
||||
}
|
||||
|
||||
// Key should no longer validate
|
||||
_, _, err = s.ValidateKey(ctx, plain)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v after cascade delete, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Quota tests ---
|
||||
|
||||
func TestUpdateQuotas(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
newPods := 5
|
||||
newCPU := 16
|
||||
updated, err := s.UpdateQuotas(ctx, "alice", model.UpdateQuotasRequest{
|
||||
MaxConcurrentPods: &newPods,
|
||||
MaxCPUPerPod: &newCPU,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("update quotas: %v", err)
|
||||
}
|
||||
if updated.Quota.MaxConcurrentPods != 5 {
|
||||
t.Errorf("max_concurrent_pods: got %d, want 5", updated.Quota.MaxConcurrentPods)
|
||||
}
|
||||
if updated.Quota.MaxCPUPerPod != 16 {
|
||||
t.Errorf("max_cpu_per_pod: got %d, want 16", updated.Quota.MaxCPUPerPod)
|
||||
}
|
||||
// Unchanged fields should keep defaults
|
||||
if updated.Quota.MaxRAMGBPerPod != 16 {
|
||||
t.Errorf("max_ram_gb_per_pod: got %d, want 16", updated.Quota.MaxRAMGBPerPod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateQuotasNoChanges(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
// Empty update should return current user
|
||||
u, err := s.UpdateQuotas(ctx, "alice", model.UpdateQuotasRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("update no changes: %v", err)
|
||||
}
|
||||
if u.Quota != model.DefaultQuota() {
|
||||
t.Errorf("quota changed unexpectedly: %+v", u.Quota)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateQuotasNotFound(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
newPods := 5
|
||||
_, err := s.UpdateQuotas(ctx, "ghost", model.UpdateQuotasRequest{
|
||||
MaxConcurrentPods: &newPods,
|
||||
})
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuotaStorageAndRetrieval(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
customQuota := model.Quota{
|
||||
MaxConcurrentPods: 10,
|
||||
MaxCPUPerPod: 32,
|
||||
MaxRAMGBPerPod: 64,
|
||||
MonthlyPodHours: 1000,
|
||||
MonthlyAIRequests: 50000,
|
||||
}
|
||||
if _, err := s.CreateUser(ctx, "power-user", customQuota); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetUser(ctx, "power-user")
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if got.Quota != customQuota {
|
||||
t.Errorf("quota mismatch:\ngot: %+v\nwant: %+v", got.Quota, customQuota)
|
||||
}
|
||||
}
|
||||
|
||||
// --- API key tests ---
|
||||
|
||||
func TestGenerateAPIKey(t *testing.T) {
|
||||
plain, hash, err := GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate: %v", err)
|
||||
}
|
||||
|
||||
// Check prefix
|
||||
if !strings.HasPrefix(plain, model.APIKeyPrefix) {
|
||||
t.Errorf("key %q missing prefix %q", plain, model.APIKeyPrefix)
|
||||
}
|
||||
|
||||
// Check length: dpk_ (4) + 24 hex chars = 28
|
||||
if len(plain) != 28 {
|
||||
t.Errorf("key length: got %d, want 28", len(plain))
|
||||
}
|
||||
|
||||
// Hash should be consistent
|
||||
if HashKey(plain) != hash {
|
||||
t.Error("hash mismatch")
|
||||
}
|
||||
|
||||
// Two keys should be different
|
||||
plain2, _, err := GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate second: %v", err)
|
||||
}
|
||||
if plain == plain2 {
|
||||
t.Error("two generated keys should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashKeyConsistency(t *testing.T) {
|
||||
key := "dpk_abcdef1234567890abcdef12"
|
||||
h1 := HashKey(key)
|
||||
h2 := HashKey(key)
|
||||
if h1 != h2 {
|
||||
t.Error("hash should be deterministic")
|
||||
}
|
||||
// SHA-256 produces 64 hex chars
|
||||
if len(h1) != 64 {
|
||||
t.Errorf("hash length: got %d, want 64", len(h1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndValidateAPIKey(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
plain, hash, err := GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash); err != nil {
|
||||
t.Fatalf("create key: %v", err)
|
||||
}
|
||||
|
||||
// Validate returns correct user and role
|
||||
u, role, err := s.ValidateKey(ctx, plain)
|
||||
if err != nil {
|
||||
t.Fatalf("validate: %v", err)
|
||||
}
|
||||
if u.ID != "alice" {
|
||||
t.Errorf("user: got %q, want %q", u.ID, "alice")
|
||||
}
|
||||
if role != model.RoleUser {
|
||||
t.Errorf("role: got %q, want %q", role, model.RoleUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAdminKey(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "admin1", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
plain, hash, err := GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
if err := s.CreateAPIKey(ctx, "admin1", model.RoleAdmin, hash); err != nil {
|
||||
t.Fatalf("create key: %v", err)
|
||||
}
|
||||
|
||||
_, role, err := s.ValidateKey(ctx, plain)
|
||||
if err != nil {
|
||||
t.Fatalf("validate: %v", err)
|
||||
}
|
||||
if role != model.RoleAdmin {
|
||||
t.Errorf("role: got %q, want %q", role, model.RoleAdmin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInvalidKey(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.ValidateKey(ctx, "dpk_invalid_key_here_nope")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid key")
|
||||
}
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("got %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleKeysPerUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
// Create two keys
|
||||
plain1, hash1, _ := GenerateAPIKey()
|
||||
plain2, hash2, _ := GenerateAPIKey()
|
||||
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash1); err != nil {
|
||||
t.Fatalf("create key 1: %v", err)
|
||||
}
|
||||
if err := s.CreateAPIKey(ctx, "alice", model.RoleAdmin, hash2); err != nil {
|
||||
t.Fatalf("create key 2: %v", err)
|
||||
}
|
||||
|
||||
// Both should validate
|
||||
u1, role1, err := s.ValidateKey(ctx, plain1)
|
||||
if err != nil {
|
||||
t.Fatalf("validate key 1: %v", err)
|
||||
}
|
||||
if u1.ID != "alice" || role1 != model.RoleUser {
|
||||
t.Errorf("key1: got user=%q role=%q, want alice/user", u1.ID, role1)
|
||||
}
|
||||
|
||||
u2, role2, err := s.ValidateKey(ctx, plain2)
|
||||
if err != nil {
|
||||
t.Fatalf("validate key 2: %v", err)
|
||||
}
|
||||
if u2.ID != "alice" || role2 != model.RoleAdmin {
|
||||
t.Errorf("key2: got user=%q role=%q, want alice/admin", u2.ID, role2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateKeyUpdatesLastUsed(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := s.CreateUser(ctx, "alice", model.DefaultQuota()); err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
plain, hash, _ := GenerateAPIKey()
|
||||
if err := s.CreateAPIKey(ctx, "alice", model.RoleUser, hash); err != nil {
|
||||
t.Fatalf("create key: %v", err)
|
||||
}
|
||||
|
||||
// Validate should update last_used_at
|
||||
if _, _, err := s.ValidateKey(ctx, plain); err != nil {
|
||||
t.Fatalf("validate: %v", err)
|
||||
}
|
||||
|
||||
// Check last_used_at is set
|
||||
var lastUsed *string
|
||||
err := testPool.QueryRow(ctx,
|
||||
`SELECT last_used_at::text FROM api_keys WHERE key_hash = $1`, hash).Scan(&lastUsed)
|
||||
if err != nil {
|
||||
t.Fatalf("query last_used_at: %v", err)
|
||||
}
|
||||
if lastUsed == nil {
|
||||
t.Error("last_used_at should be set after validation")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue