dev-pod-api-build/internal/store/users.go
2026-04-16 04:16:36 +00:00

280 lines
8.7 KiB
Go

package store
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
// ErrNotFound is returned when a requested resource does not exist.
var ErrNotFound = errors.New("not found")
// ErrDuplicate is returned when a resource already exists.
var ErrDuplicate = errors.New("already exists")
// CreateUser inserts a new user with the given quota.
func (s *Store) CreateUser(ctx context.Context, id string, quota model.Quota) (*model.User, error) {
now := time.Now().UTC()
_, err := s.db.Exec(ctx,
`INSERT INTO users (id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
id, now, quota.MaxConcurrentPods, quota.MaxCPUPerPod, quota.MaxRAMGBPerPod,
quota.MonthlyPodHours, quota.MonthlyAIRequests,
)
if err != nil {
if isDuplicateError(err) {
return nil, fmt.Errorf("user %q: %w", id, ErrDuplicate)
}
return nil, fmt.Errorf("insert user: %w", err)
}
return &model.User{
ID: id,
CreatedAt: now,
Quota: quota,
}, nil
}
// GetUser retrieves a user by ID.
func (s *Store) GetUser(ctx context.Context, id string) (*model.User, error) {
row := s.db.QueryRow(ctx,
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
monthly_pod_hours, monthly_ai_requests
FROM users WHERE id = $1`, id)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
}
return nil, fmt.Errorf("query user: %w", err)
}
return u, nil
}
// ListUsers returns all users ordered by creation time.
func (s *Store) ListUsers(ctx context.Context) ([]model.User, error) {
rows, err := s.db.Query(ctx,
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
monthly_pod_hours, monthly_ai_requests
FROM users ORDER BY created_at`)
if err != nil {
return nil, fmt.Errorf("query users: %w", err)
}
defer rows.Close()
var users []model.User
for rows.Next() {
var u model.User
if err := rows.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests); err != nil {
return nil, fmt.Errorf("scan user: %w", err)
}
users = append(users, u)
}
return users, rows.Err()
}
// DeleteUser removes a user by ID. Cascades to api_keys and usage_records.
func (s *Store) DeleteUser(ctx context.Context, id string) error {
tag, err := s.db.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete user: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("user %q: %w", id, ErrNotFound)
}
return nil
}
// UpdateQuotas updates specific quota fields for a user.
func (s *Store) UpdateQuotas(ctx context.Context, id string, req model.UpdateQuotasRequest) (*model.User, error) {
var sets []string
var args []any
argIdx := 1
if req.MaxConcurrentPods != nil {
sets = append(sets, fmt.Sprintf("max_concurrent_pods = $%d", argIdx))
args = append(args, *req.MaxConcurrentPods)
argIdx++
}
if req.MaxCPUPerPod != nil {
sets = append(sets, fmt.Sprintf("max_cpu_per_pod = $%d", argIdx))
args = append(args, *req.MaxCPUPerPod)
argIdx++
}
if req.MaxRAMGBPerPod != nil {
sets = append(sets, fmt.Sprintf("max_ram_gb_per_pod = $%d", argIdx))
args = append(args, *req.MaxRAMGBPerPod)
argIdx++
}
if req.MonthlyPodHours != nil {
sets = append(sets, fmt.Sprintf("monthly_pod_hours = $%d", argIdx))
args = append(args, *req.MonthlyPodHours)
argIdx++
}
if req.MonthlyAIRequests != nil {
sets = append(sets, fmt.Sprintf("monthly_ai_requests = $%d", argIdx))
args = append(args, *req.MonthlyAIRequests)
argIdx++
}
if len(sets) == 0 {
return s.GetUser(ctx, id)
}
query := fmt.Sprintf(
`UPDATE users SET %s WHERE id = $%d
RETURNING id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
monthly_pod_hours, monthly_ai_requests`,
strings.Join(sets, ", "), argIdx)
args = append(args, id)
row := s.db.QueryRow(ctx, query, args...)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
}
return nil, fmt.Errorf("update quotas: %w", err)
}
return u, nil
}
// GenerateAPIKey creates a new random API key.
// Returns the plaintext key (dpk_ + 24 hex chars) and its SHA-256 hash.
func GenerateAPIKey() (plaintext string, hash string, err error) {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
return "", "", fmt.Errorf("generate random bytes: %w", err)
}
plain := model.APIKeyPrefix + hex.EncodeToString(b)
return plain, HashKey(plain), nil
}
// HashKey returns the hex-encoded SHA-256 hash of a key.
func HashKey(key string) string {
h := sha256.Sum256([]byte(key))
return hex.EncodeToString(h[:])
}
// CreateAPIKey stores a hashed API key for a user.
func (s *Store) CreateAPIKey(ctx context.Context, userID, role, keyHash string) error {
_, err := s.db.Exec(ctx,
`INSERT INTO api_keys (key_hash, user_id, role, created_at) VALUES ($1, $2, $3, $4)`,
keyHash, userID, role, time.Now().UTC(),
)
if err != nil {
return fmt.Errorf("insert api key: %w", err)
}
return nil
}
// ValidateKey checks an API key and returns the associated user and role.
func (s *Store) ValidateKey(ctx context.Context, key string) (*model.User, string, error) {
hash := HashKey(key)
row := s.db.QueryRow(ctx,
`SELECT u.id, u.created_at, u.max_concurrent_pods, u.max_cpu_per_pod, u.max_ram_gb_per_pod,
u.monthly_pod_hours, u.monthly_ai_requests, k.role
FROM api_keys k JOIN users u ON k.user_id = u.id
WHERE k.key_hash = $1`, hash)
var u model.User
var role string
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests, &role)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, "", fmt.Errorf("invalid api key: %w", ErrNotFound)
}
return nil, "", fmt.Errorf("validate key: %w", err)
}
// Update last_used_at (best-effort, don't fail auth on update error)
if _, err := s.db.Exec(ctx,
`UPDATE api_keys SET last_used_at = $1 WHERE key_hash = $2`,
time.Now().UTC(), hash); err != nil {
slog.Warn("failed to update api key last_used_at", "error", err)
}
return &u, role, nil
}
// scanUser scans a user row into a model.User.
func scanUser(row pgx.Row) (*model.User, error) {
var u model.User
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests)
if err != nil {
return nil, err
}
return &u, nil
}
// SaveForgejoToken stores a Forgejo API token for a user.
func (s *Store) SaveForgejoToken(ctx context.Context, userID, token string) error {
tag, err := s.db.Exec(ctx, `UPDATE users SET forgejo_token = $1 WHERE id = $2`, token, userID)
if err != nil {
return fmt.Errorf("update forgejo token: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return nil
}
// GetForgejoToken retrieves the Forgejo API token for a user.
func (s *Store) GetForgejoToken(ctx context.Context, userID string) (string, error) {
var token string
err := s.db.QueryRow(ctx, `SELECT forgejo_token FROM users WHERE id = $1`, userID).Scan(&token)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return "", fmt.Errorf("query forgejo token: %w", err)
}
return token, nil
}
// SaveTailscaleKey stores a Tailscale auth key for a user.
func (s *Store) SaveTailscaleKey(ctx context.Context, userID, key string) error {
tag, err := s.db.Exec(ctx, `UPDATE users SET tailscale_key = $1 WHERE id = $2`, key, userID)
if err != nil {
return fmt.Errorf("update tailscale key: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return nil
}
// GetTailscaleKey retrieves the Tailscale auth key for a user.
func (s *Store) GetTailscaleKey(ctx context.Context, userID string) (string, error) {
var key string
err := s.db.QueryRow(ctx, `SELECT tailscale_key FROM users WHERE id = $1`, userID).Scan(&key)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
}
return "", fmt.Errorf("query tailscale key: %w", err)
}
return key, nil
}
func isDuplicateError(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "23505" // unique_violation
}
return false
}