280 lines
8.7 KiB
Go
280 lines
8.7 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
|
|
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
|
)
|
|
|
|
// ErrNotFound is returned when a requested resource does not exist.
|
|
var ErrNotFound = errors.New("not found")
|
|
|
|
// ErrDuplicate is returned when a resource already exists.
|
|
var ErrDuplicate = errors.New("already exists")
|
|
|
|
// CreateUser inserts a new user with the given quota.
|
|
func (s *Store) CreateUser(ctx context.Context, id string, quota model.Quota) (*model.User, error) {
|
|
now := time.Now().UTC()
|
|
_, err := s.db.Exec(ctx,
|
|
`INSERT INTO users (id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
|
id, now, quota.MaxConcurrentPods, quota.MaxCPUPerPod, quota.MaxRAMGBPerPod,
|
|
quota.MonthlyPodHours, quota.MonthlyAIRequests,
|
|
)
|
|
if err != nil {
|
|
if isDuplicateError(err) {
|
|
return nil, fmt.Errorf("user %q: %w", id, ErrDuplicate)
|
|
}
|
|
return nil, fmt.Errorf("insert user: %w", err)
|
|
}
|
|
return &model.User{
|
|
ID: id,
|
|
CreatedAt: now,
|
|
Quota: quota,
|
|
}, nil
|
|
}
|
|
|
|
// GetUser retrieves a user by ID.
|
|
func (s *Store) GetUser(ctx context.Context, id string) (*model.User, error) {
|
|
row := s.db.QueryRow(ctx,
|
|
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
|
monthly_pod_hours, monthly_ai_requests
|
|
FROM users WHERE id = $1`, id)
|
|
|
|
u, err := scanUser(row)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
|
|
}
|
|
return nil, fmt.Errorf("query user: %w", err)
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// ListUsers returns all users ordered by creation time.
|
|
func (s *Store) ListUsers(ctx context.Context) ([]model.User, error) {
|
|
rows, err := s.db.Query(ctx,
|
|
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
|
monthly_pod_hours, monthly_ai_requests
|
|
FROM users ORDER BY created_at`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query users: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []model.User
|
|
for rows.Next() {
|
|
var u model.User
|
|
if err := rows.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
|
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests); err != nil {
|
|
return nil, fmt.Errorf("scan user: %w", err)
|
|
}
|
|
users = append(users, u)
|
|
}
|
|
return users, rows.Err()
|
|
}
|
|
|
|
// DeleteUser removes a user by ID. Cascades to api_keys and usage_records.
|
|
func (s *Store) DeleteUser(ctx context.Context, id string) error {
|
|
tag, err := s.db.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("delete user: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return fmt.Errorf("user %q: %w", id, ErrNotFound)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateQuotas updates specific quota fields for a user.
|
|
func (s *Store) UpdateQuotas(ctx context.Context, id string, req model.UpdateQuotasRequest) (*model.User, error) {
|
|
var sets []string
|
|
var args []any
|
|
argIdx := 1
|
|
|
|
if req.MaxConcurrentPods != nil {
|
|
sets = append(sets, fmt.Sprintf("max_concurrent_pods = $%d", argIdx))
|
|
args = append(args, *req.MaxConcurrentPods)
|
|
argIdx++
|
|
}
|
|
if req.MaxCPUPerPod != nil {
|
|
sets = append(sets, fmt.Sprintf("max_cpu_per_pod = $%d", argIdx))
|
|
args = append(args, *req.MaxCPUPerPod)
|
|
argIdx++
|
|
}
|
|
if req.MaxRAMGBPerPod != nil {
|
|
sets = append(sets, fmt.Sprintf("max_ram_gb_per_pod = $%d", argIdx))
|
|
args = append(args, *req.MaxRAMGBPerPod)
|
|
argIdx++
|
|
}
|
|
if req.MonthlyPodHours != nil {
|
|
sets = append(sets, fmt.Sprintf("monthly_pod_hours = $%d", argIdx))
|
|
args = append(args, *req.MonthlyPodHours)
|
|
argIdx++
|
|
}
|
|
if req.MonthlyAIRequests != nil {
|
|
sets = append(sets, fmt.Sprintf("monthly_ai_requests = $%d", argIdx))
|
|
args = append(args, *req.MonthlyAIRequests)
|
|
argIdx++
|
|
}
|
|
|
|
if len(sets) == 0 {
|
|
return s.GetUser(ctx, id)
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
`UPDATE users SET %s WHERE id = $%d
|
|
RETURNING id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
|
monthly_pod_hours, monthly_ai_requests`,
|
|
strings.Join(sets, ", "), argIdx)
|
|
args = append(args, id)
|
|
|
|
row := s.db.QueryRow(ctx, query, args...)
|
|
u, err := scanUser(row)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
|
|
}
|
|
return nil, fmt.Errorf("update quotas: %w", err)
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// GenerateAPIKey creates a new random API key.
|
|
// Returns the plaintext key (dpk_ + 24 hex chars) and its SHA-256 hash.
|
|
func GenerateAPIKey() (plaintext string, hash string, err error) {
|
|
b := make([]byte, 12)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", "", fmt.Errorf("generate random bytes: %w", err)
|
|
}
|
|
plain := model.APIKeyPrefix + hex.EncodeToString(b)
|
|
return plain, HashKey(plain), nil
|
|
}
|
|
|
|
// HashKey returns the hex-encoded SHA-256 hash of a key.
|
|
func HashKey(key string) string {
|
|
h := sha256.Sum256([]byte(key))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// CreateAPIKey stores a hashed API key for a user.
|
|
func (s *Store) CreateAPIKey(ctx context.Context, userID, role, keyHash string) error {
|
|
_, err := s.db.Exec(ctx,
|
|
`INSERT INTO api_keys (key_hash, user_id, role, created_at) VALUES ($1, $2, $3, $4)`,
|
|
keyHash, userID, role, time.Now().UTC(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert api key: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ValidateKey checks an API key and returns the associated user and role.
|
|
func (s *Store) ValidateKey(ctx context.Context, key string) (*model.User, string, error) {
|
|
hash := HashKey(key)
|
|
row := s.db.QueryRow(ctx,
|
|
`SELECT u.id, u.created_at, u.max_concurrent_pods, u.max_cpu_per_pod, u.max_ram_gb_per_pod,
|
|
u.monthly_pod_hours, u.monthly_ai_requests, k.role
|
|
FROM api_keys k JOIN users u ON k.user_id = u.id
|
|
WHERE k.key_hash = $1`, hash)
|
|
|
|
var u model.User
|
|
var role string
|
|
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
|
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests, &role)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, "", fmt.Errorf("invalid api key: %w", ErrNotFound)
|
|
}
|
|
return nil, "", fmt.Errorf("validate key: %w", err)
|
|
}
|
|
|
|
// Update last_used_at (best-effort, don't fail auth on update error)
|
|
if _, err := s.db.Exec(ctx,
|
|
`UPDATE api_keys SET last_used_at = $1 WHERE key_hash = $2`,
|
|
time.Now().UTC(), hash); err != nil {
|
|
slog.Warn("failed to update api key last_used_at", "error", err)
|
|
}
|
|
|
|
return &u, role, nil
|
|
}
|
|
|
|
// scanUser scans a user row into a model.User.
|
|
func scanUser(row pgx.Row) (*model.User, error) {
|
|
var u model.User
|
|
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
|
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
// SaveForgejoToken stores a Forgejo API token for a user.
|
|
func (s *Store) SaveForgejoToken(ctx context.Context, userID, token string) error {
|
|
tag, err := s.db.Exec(ctx, `UPDATE users SET forgejo_token = $1 WHERE id = $2`, token, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("update forgejo token: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetForgejoToken retrieves the Forgejo API token for a user.
|
|
func (s *Store) GetForgejoToken(ctx context.Context, userID string) (string, error) {
|
|
var token string
|
|
err := s.db.QueryRow(ctx, `SELECT forgejo_token FROM users WHERE id = $1`, userID).Scan(&token)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
|
}
|
|
return "", fmt.Errorf("query forgejo token: %w", err)
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// SaveTailscaleKey stores a Tailscale auth key for a user.
|
|
func (s *Store) SaveTailscaleKey(ctx context.Context, userID, key string) error {
|
|
tag, err := s.db.Exec(ctx, `UPDATE users SET tailscale_key = $1 WHERE id = $2`, key, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("update tailscale key: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTailscaleKey retrieves the Tailscale auth key for a user.
|
|
func (s *Store) GetTailscaleKey(ctx context.Context, userID string) (string, error) {
|
|
var key string
|
|
err := s.db.QueryRow(ctx, `SELECT tailscale_key FROM users WHERE id = $1`, userID).Scan(&key)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
|
}
|
|
return "", fmt.Errorf("query tailscale key: %w", err)
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func isDuplicateError(err error) bool {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
return pgErr.Code == "23505" // unique_violation
|
|
}
|
|
return false
|
|
}
|