package store import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "fmt" "log/slog" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model" ) // ErrNotFound is returned when a requested resource does not exist. var ErrNotFound = errors.New("not found") // ErrDuplicate is returned when a resource already exists. var ErrDuplicate = errors.New("already exists") // CreateUser inserts a new user with the given quota. func (s *Store) CreateUser(ctx context.Context, id string, quota model.Quota) (*model.User, error) { now := time.Now().UTC() _, err := s.db.Exec(ctx, `INSERT INTO users (id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests) VALUES ($1, $2, $3, $4, $5, $6, $7)`, id, now, quota.MaxConcurrentPods, quota.MaxCPUPerPod, quota.MaxRAMGBPerPod, quota.MonthlyPodHours, quota.MonthlyAIRequests, ) if err != nil { if isDuplicateError(err) { return nil, fmt.Errorf("user %q: %w", id, ErrDuplicate) } return nil, fmt.Errorf("insert user: %w", err) } return &model.User{ ID: id, CreatedAt: now, Quota: quota, }, nil } // GetUser retrieves a user by ID. func (s *Store) GetUser(ctx context.Context, id string) (*model.User, error) { row := s.db.QueryRow(ctx, `SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests FROM users WHERE id = $1`, id) u, err := scanUser(row) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, fmt.Errorf("user %q: %w", id, ErrNotFound) } return nil, fmt.Errorf("query user: %w", err) } return u, nil } // ListUsers returns all users ordered by creation time. func (s *Store) ListUsers(ctx context.Context) ([]model.User, error) { rows, err := s.db.Query(ctx, `SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests FROM users ORDER BY created_at`) if err != nil { return nil, fmt.Errorf("query users: %w", err) } defer rows.Close() var users []model.User for rows.Next() { var u model.User if err := rows.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod, &u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests); err != nil { return nil, fmt.Errorf("scan user: %w", err) } users = append(users, u) } return users, rows.Err() } // DeleteUser removes a user by ID. Cascades to api_keys and usage_records. func (s *Store) DeleteUser(ctx context.Context, id string) error { tag, err := s.db.Exec(ctx, `DELETE FROM users WHERE id = $1`, id) if err != nil { return fmt.Errorf("delete user: %w", err) } if tag.RowsAffected() == 0 { return fmt.Errorf("user %q: %w", id, ErrNotFound) } return nil } // UpdateQuotas updates specific quota fields for a user. func (s *Store) UpdateQuotas(ctx context.Context, id string, req model.UpdateQuotasRequest) (*model.User, error) { var sets []string var args []any argIdx := 1 if req.MaxConcurrentPods != nil { sets = append(sets, fmt.Sprintf("max_concurrent_pods = $%d", argIdx)) args = append(args, *req.MaxConcurrentPods) argIdx++ } if req.MaxCPUPerPod != nil { sets = append(sets, fmt.Sprintf("max_cpu_per_pod = $%d", argIdx)) args = append(args, *req.MaxCPUPerPod) argIdx++ } if req.MaxRAMGBPerPod != nil { sets = append(sets, fmt.Sprintf("max_ram_gb_per_pod = $%d", argIdx)) args = append(args, *req.MaxRAMGBPerPod) argIdx++ } if req.MonthlyPodHours != nil { sets = append(sets, fmt.Sprintf("monthly_pod_hours = $%d", argIdx)) args = append(args, *req.MonthlyPodHours) argIdx++ } if req.MonthlyAIRequests != nil { sets = append(sets, fmt.Sprintf("monthly_ai_requests = $%d", argIdx)) args = append(args, *req.MonthlyAIRequests) argIdx++ } if len(sets) == 0 { return s.GetUser(ctx, id) } query := fmt.Sprintf( `UPDATE users SET %s WHERE id = $%d RETURNING id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests`, strings.Join(sets, ", "), argIdx) args = append(args, id) row := s.db.QueryRow(ctx, query, args...) u, err := scanUser(row) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, fmt.Errorf("user %q: %w", id, ErrNotFound) } return nil, fmt.Errorf("update quotas: %w", err) } return u, nil } // GenerateAPIKey creates a new random API key. // Returns the plaintext key (dpk_ + 24 hex chars) and its SHA-256 hash. func GenerateAPIKey() (plaintext string, hash string, err error) { b := make([]byte, 12) if _, err := rand.Read(b); err != nil { return "", "", fmt.Errorf("generate random bytes: %w", err) } plain := model.APIKeyPrefix + hex.EncodeToString(b) return plain, HashKey(plain), nil } // HashKey returns the hex-encoded SHA-256 hash of a key. func HashKey(key string) string { h := sha256.Sum256([]byte(key)) return hex.EncodeToString(h[:]) } // CreateAPIKey stores a hashed API key for a user. func (s *Store) CreateAPIKey(ctx context.Context, userID, role, keyHash string) error { _, err := s.db.Exec(ctx, `INSERT INTO api_keys (key_hash, user_id, role, created_at) VALUES ($1, $2, $3, $4)`, keyHash, userID, role, time.Now().UTC(), ) if err != nil { return fmt.Errorf("insert api key: %w", err) } return nil } // ValidateKey checks an API key and returns the associated user and role. func (s *Store) ValidateKey(ctx context.Context, key string) (*model.User, string, error) { hash := HashKey(key) row := s.db.QueryRow(ctx, `SELECT u.id, u.created_at, u.max_concurrent_pods, u.max_cpu_per_pod, u.max_ram_gb_per_pod, u.monthly_pod_hours, u.monthly_ai_requests, k.role FROM api_keys k JOIN users u ON k.user_id = u.id WHERE k.key_hash = $1`, hash) var u model.User var role string err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod, &u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests, &role) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, "", fmt.Errorf("invalid api key: %w", ErrNotFound) } return nil, "", fmt.Errorf("validate key: %w", err) } // Update last_used_at (best-effort, don't fail auth on update error) if _, err := s.db.Exec(ctx, `UPDATE api_keys SET last_used_at = $1 WHERE key_hash = $2`, time.Now().UTC(), hash); err != nil { slog.Warn("failed to update api key last_used_at", "error", err) } return &u, role, nil } // scanUser scans a user row into a model.User. func scanUser(row pgx.Row) (*model.User, error) { var u model.User err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod, &u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests) if err != nil { return nil, err } return &u, nil } // SaveForgejoToken stores a Forgejo API token for a user. func (s *Store) SaveForgejoToken(ctx context.Context, userID, token string) error { tag, err := s.db.Exec(ctx, `UPDATE users SET forgejo_token = $1 WHERE id = $2`, token, userID) if err != nil { return fmt.Errorf("update forgejo token: %w", err) } if tag.RowsAffected() == 0 { return fmt.Errorf("user %q: %w", userID, ErrNotFound) } return nil } // GetForgejoToken retrieves the Forgejo API token for a user. func (s *Store) GetForgejoToken(ctx context.Context, userID string) (string, error) { var token string err := s.db.QueryRow(ctx, `SELECT forgejo_token FROM users WHERE id = $1`, userID).Scan(&token) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", fmt.Errorf("user %q: %w", userID, ErrNotFound) } return "", fmt.Errorf("query forgejo token: %w", err) } return token, nil } // SaveTailscaleKey stores a Tailscale auth key for a user. func (s *Store) SaveTailscaleKey(ctx context.Context, userID, key string) error { tag, err := s.db.Exec(ctx, `UPDATE users SET tailscale_key = $1 WHERE id = $2`, key, userID) if err != nil { return fmt.Errorf("update tailscale key: %w", err) } if tag.RowsAffected() == 0 { return fmt.Errorf("user %q: %w", userID, ErrNotFound) } return nil } // GetTailscaleKey retrieves the Tailscale auth key for a user. func (s *Store) GetTailscaleKey(ctx context.Context, userID string) (string, error) { var key string err := s.db.QueryRow(ctx, `SELECT tailscale_key FROM users WHERE id = $1`, userID).Scan(&key) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", fmt.Errorf("user %q: %w", userID, ErrNotFound) } return "", fmt.Errorf("query tailscale key: %w", err) } return key, nil } func isDuplicateError(err error) bool { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { return pgErr.Code == "23505" // unique_violation } return false }