build source
This commit is contained in:
commit
ee1fec43ed
4171 changed files with 1351288 additions and 0 deletions
280
internal/store/users.go
Normal file
280
internal/store/users.go
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when a requested resource does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// ErrDuplicate is returned when a resource already exists.
|
||||
var ErrDuplicate = errors.New("already exists")
|
||||
|
||||
// CreateUser inserts a new user with the given quota.
|
||||
func (s *Store) CreateUser(ctx context.Context, id string, quota model.Quota) (*model.User, error) {
|
||||
now := time.Now().UTC()
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO users (id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod, monthly_pod_hours, monthly_ai_requests)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
id, now, quota.MaxConcurrentPods, quota.MaxCPUPerPod, quota.MaxRAMGBPerPod,
|
||||
quota.MonthlyPodHours, quota.MonthlyAIRequests,
|
||||
)
|
||||
if err != nil {
|
||||
if isDuplicateError(err) {
|
||||
return nil, fmt.Errorf("user %q: %w", id, ErrDuplicate)
|
||||
}
|
||||
return nil, fmt.Errorf("insert user: %w", err)
|
||||
}
|
||||
return &model.User{
|
||||
ID: id,
|
||||
CreatedAt: now,
|
||||
Quota: quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUser retrieves a user by ID.
|
||||
func (s *Store) GetUser(ctx context.Context, id string) (*model.User, error) {
|
||||
row := s.db.QueryRow(ctx,
|
||||
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
||||
monthly_pod_hours, monthly_ai_requests
|
||||
FROM users WHERE id = $1`, id)
|
||||
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil, fmt.Errorf("query user: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// ListUsers returns all users ordered by creation time.
|
||||
func (s *Store) ListUsers(ctx context.Context) ([]model.User, error) {
|
||||
rows, err := s.db.Query(ctx,
|
||||
`SELECT id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
||||
monthly_pod_hours, monthly_ai_requests
|
||||
FROM users ORDER BY created_at`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query users: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []model.User
|
||||
for rows.Next() {
|
||||
var u model.User
|
||||
if err := rows.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
||||
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests); err != nil {
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteUser removes a user by ID. Cascades to api_keys and usage_records.
|
||||
func (s *Store) DeleteUser(ctx context.Context, id string) error {
|
||||
tag, err := s.db.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("user %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateQuotas updates specific quota fields for a user.
|
||||
func (s *Store) UpdateQuotas(ctx context.Context, id string, req model.UpdateQuotasRequest) (*model.User, error) {
|
||||
var sets []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if req.MaxConcurrentPods != nil {
|
||||
sets = append(sets, fmt.Sprintf("max_concurrent_pods = $%d", argIdx))
|
||||
args = append(args, *req.MaxConcurrentPods)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxCPUPerPod != nil {
|
||||
sets = append(sets, fmt.Sprintf("max_cpu_per_pod = $%d", argIdx))
|
||||
args = append(args, *req.MaxCPUPerPod)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxRAMGBPerPod != nil {
|
||||
sets = append(sets, fmt.Sprintf("max_ram_gb_per_pod = $%d", argIdx))
|
||||
args = append(args, *req.MaxRAMGBPerPod)
|
||||
argIdx++
|
||||
}
|
||||
if req.MonthlyPodHours != nil {
|
||||
sets = append(sets, fmt.Sprintf("monthly_pod_hours = $%d", argIdx))
|
||||
args = append(args, *req.MonthlyPodHours)
|
||||
argIdx++
|
||||
}
|
||||
if req.MonthlyAIRequests != nil {
|
||||
sets = append(sets, fmt.Sprintf("monthly_ai_requests = $%d", argIdx))
|
||||
args = append(args, *req.MonthlyAIRequests)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
return s.GetUser(ctx, id)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`UPDATE users SET %s WHERE id = $%d
|
||||
RETURNING id, created_at, max_concurrent_pods, max_cpu_per_pod, max_ram_gb_per_pod,
|
||||
monthly_pod_hours, monthly_ai_requests`,
|
||||
strings.Join(sets, ", "), argIdx)
|
||||
args = append(args, id)
|
||||
|
||||
row := s.db.QueryRow(ctx, query, args...)
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("user %q: %w", id, ErrNotFound)
|
||||
}
|
||||
return nil, fmt.Errorf("update quotas: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GenerateAPIKey creates a new random API key.
|
||||
// Returns the plaintext key (dpk_ + 24 hex chars) and its SHA-256 hash.
|
||||
func GenerateAPIKey() (plaintext string, hash string, err error) {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
plain := model.APIKeyPrefix + hex.EncodeToString(b)
|
||||
return plain, HashKey(plain), nil
|
||||
}
|
||||
|
||||
// HashKey returns the hex-encoded SHA-256 hash of a key.
|
||||
func HashKey(key string) string {
|
||||
h := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// CreateAPIKey stores a hashed API key for a user.
|
||||
func (s *Store) CreateAPIKey(ctx context.Context, userID, role, keyHash string) error {
|
||||
_, err := s.db.Exec(ctx,
|
||||
`INSERT INTO api_keys (key_hash, user_id, role, created_at) VALUES ($1, $2, $3, $4)`,
|
||||
keyHash, userID, role, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert api key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey checks an API key and returns the associated user and role.
|
||||
func (s *Store) ValidateKey(ctx context.Context, key string) (*model.User, string, error) {
|
||||
hash := HashKey(key)
|
||||
row := s.db.QueryRow(ctx,
|
||||
`SELECT u.id, u.created_at, u.max_concurrent_pods, u.max_cpu_per_pod, u.max_ram_gb_per_pod,
|
||||
u.monthly_pod_hours, u.monthly_ai_requests, k.role
|
||||
FROM api_keys k JOIN users u ON k.user_id = u.id
|
||||
WHERE k.key_hash = $1`, hash)
|
||||
|
||||
var u model.User
|
||||
var role string
|
||||
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
||||
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests, &role)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, "", fmt.Errorf("invalid api key: %w", ErrNotFound)
|
||||
}
|
||||
return nil, "", fmt.Errorf("validate key: %w", err)
|
||||
}
|
||||
|
||||
// Update last_used_at (best-effort, don't fail auth on update error)
|
||||
if _, err := s.db.Exec(ctx,
|
||||
`UPDATE api_keys SET last_used_at = $1 WHERE key_hash = $2`,
|
||||
time.Now().UTC(), hash); err != nil {
|
||||
slog.Warn("failed to update api key last_used_at", "error", err)
|
||||
}
|
||||
|
||||
return &u, role, nil
|
||||
}
|
||||
|
||||
// scanUser scans a user row into a model.User.
|
||||
func scanUser(row pgx.Row) (*model.User, error) {
|
||||
var u model.User
|
||||
err := row.Scan(&u.ID, &u.CreatedAt, &u.Quota.MaxConcurrentPods, &u.Quota.MaxCPUPerPod,
|
||||
&u.Quota.MaxRAMGBPerPod, &u.Quota.MonthlyPodHours, &u.Quota.MonthlyAIRequests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// SaveForgejoToken stores a Forgejo API token for a user.
|
||||
func (s *Store) SaveForgejoToken(ctx context.Context, userID, token string) error {
|
||||
tag, err := s.db.Exec(ctx, `UPDATE users SET forgejo_token = $1 WHERE id = $2`, token, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update forgejo token: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForgejoToken retrieves the Forgejo API token for a user.
|
||||
func (s *Store) GetForgejoToken(ctx context.Context, userID string) (string, error) {
|
||||
var token string
|
||||
err := s.db.QueryRow(ctx, `SELECT forgejo_token FROM users WHERE id = $1`, userID).Scan(&token)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return "", fmt.Errorf("query forgejo token: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SaveTailscaleKey stores a Tailscale auth key for a user.
|
||||
func (s *Store) SaveTailscaleKey(ctx context.Context, userID, key string) error {
|
||||
tag, err := s.db.Exec(ctx, `UPDATE users SET tailscale_key = $1 WHERE id = $2`, key, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update tailscale key: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTailscaleKey retrieves the Tailscale auth key for a user.
|
||||
func (s *Store) GetTailscaleKey(ctx context.Context, userID string) (string, error) {
|
||||
var key string
|
||||
err := s.db.QueryRow(ctx, `SELECT tailscale_key FROM users WHERE id = $1`, userID).Scan(&key)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("user %q: %w", userID, ErrNotFound)
|
||||
}
|
||||
return "", fmt.Errorf("query tailscale key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func isDuplicateError(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505" // unique_violation
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue