dev-pod-api-build/internal/api/middleware.go
2026-04-16 04:16:36 +00:00

198 lines
5 KiB
Go

package api
import (
"context"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
type contextKey string
const (
ctxKeyUser contextKey = "user"
ctxKeyRole contextKey = "role"
)
// UserFromContext returns the authenticated user from the request context.
func UserFromContext(ctx context.Context) *model.User {
u, _ := ctx.Value(ctxKeyUser).(*model.User)
return u
}
// RoleFromContext returns the authenticated user's role from the request context.
func RoleFromContext(ctx context.Context) string {
r, _ := ctx.Value(ctxKeyRole).(string)
return r
}
// KeyValidator validates an API key and returns the associated user and role.
type KeyValidator interface {
ValidateKey(ctx context.Context, key string) (*model.User, string, error)
}
// AuthMiddleware returns middleware that validates Bearer tokens via the KeyValidator.
func AuthMiddleware(kv KeyValidator) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth == "" {
writeError(w, http.StatusUnauthorized, "missing authorization header")
return
}
token, found := strings.CutPrefix(auth, "Bearer ")
if !found || token == "" {
writeError(w, http.StatusUnauthorized, "invalid authorization format, expected: Bearer <token>")
return
}
user, role, err := kv.ValidateKey(r.Context(), token)
if err != nil {
writeError(w, http.StatusUnauthorized, "invalid api key")
return
}
ctx := context.WithValue(r.Context(), ctxKeyUser, user)
ctx = context.WithValue(ctx, ctxKeyRole, role)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// bucket tracks rate limiting state for a single key.
type bucket struct {
tokens float64
lastSeen time.Time
}
// RateLimiter implements a per-key token bucket rate limiter.
type RateLimiter struct {
mu sync.Mutex
buckets map[string]*bucket
rate float64 // tokens per second
burst float64 // max tokens
nowFunc func() time.Time
callCount int // tracks calls for periodic cleanup
}
// NewRateLimiter creates a rate limiter with the given rate (req/s) and burst size.
func NewRateLimiter(ratePerSecond, burst float64) *RateLimiter {
return &RateLimiter{
buckets: make(map[string]*bucket),
rate: ratePerSecond,
burst: burst,
nowFunc: time.Now,
}
}
// Allow checks if a request for the given key is allowed.
func (rl *RateLimiter) Allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := rl.nowFunc()
// Evict stale buckets every 100 calls to prevent unbounded growth.
rl.callCount++
if rl.callCount%100 == 0 {
const staleThreshold = 10 * time.Minute
for k, b := range rl.buckets {
if now.Sub(b.lastSeen) > staleThreshold {
delete(rl.buckets, k)
}
}
}
b, exists := rl.buckets[key]
if !exists {
rl.buckets[key] = &bucket{tokens: rl.burst - 1, lastSeen: now}
return true
}
elapsed := now.Sub(b.lastSeen).Seconds()
b.tokens += elapsed * rl.rate
if b.tokens > rl.burst {
b.tokens = rl.burst
}
b.lastSeen = now
if b.tokens < 1 {
return false
}
b.tokens--
return true
}
// MaxBodySize returns middleware that limits request body size to the given number of bytes.
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
}
next.ServeHTTP(w, r)
})
}
}
// Middleware returns HTTP middleware that rate-limits by authenticated user ID.
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := UserFromContext(r.Context())
key := "anonymous"
if user != nil {
key = user.ID
}
if !rl.Allow(key) {
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}
// statusWriter wraps http.ResponseWriter to capture the status code.
type statusWriter struct {
http.ResponseWriter
status int
wrote bool
}
func (sw *statusWriter) WriteHeader(code int) {
if !sw.wrote {
sw.status = code
sw.wrote = true
}
sw.ResponseWriter.WriteHeader(code)
}
func (sw *statusWriter) Write(b []byte) (int, error) {
if !sw.wrote {
sw.status = http.StatusOK
sw.wrote = true
}
return sw.ResponseWriter.Write(b)
}
// RequestLogger returns middleware that logs each request using slog.
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, r)
logger.Info("request",
"method", r.Method,
"path", r.URL.Path,
"status", sw.status,
"duration_ms", time.Since(start).Milliseconds(),
)
})
}
}