package api import ( "context" "log/slog" "net/http" "strings" "sync" "time" "github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model" ) type contextKey string const ( ctxKeyUser contextKey = "user" ctxKeyRole contextKey = "role" ) // UserFromContext returns the authenticated user from the request context. func UserFromContext(ctx context.Context) *model.User { u, _ := ctx.Value(ctxKeyUser).(*model.User) return u } // RoleFromContext returns the authenticated user's role from the request context. func RoleFromContext(ctx context.Context) string { r, _ := ctx.Value(ctxKeyRole).(string) return r } // KeyValidator validates an API key and returns the associated user and role. type KeyValidator interface { ValidateKey(ctx context.Context, key string) (*model.User, string, error) } // AuthMiddleware returns middleware that validates Bearer tokens via the KeyValidator. func AuthMiddleware(kv KeyValidator) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if auth == "" { writeError(w, http.StatusUnauthorized, "missing authorization header") return } token, found := strings.CutPrefix(auth, "Bearer ") if !found || token == "" { writeError(w, http.StatusUnauthorized, "invalid authorization format, expected: Bearer ") return } user, role, err := kv.ValidateKey(r.Context(), token) if err != nil { writeError(w, http.StatusUnauthorized, "invalid api key") return } ctx := context.WithValue(r.Context(), ctxKeyUser, user) ctx = context.WithValue(ctx, ctxKeyRole, role) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // bucket tracks rate limiting state for a single key. type bucket struct { tokens float64 lastSeen time.Time } // RateLimiter implements a per-key token bucket rate limiter. type RateLimiter struct { mu sync.Mutex buckets map[string]*bucket rate float64 // tokens per second burst float64 // max tokens nowFunc func() time.Time callCount int // tracks calls for periodic cleanup } // NewRateLimiter creates a rate limiter with the given rate (req/s) and burst size. func NewRateLimiter(ratePerSecond, burst float64) *RateLimiter { return &RateLimiter{ buckets: make(map[string]*bucket), rate: ratePerSecond, burst: burst, nowFunc: time.Now, } } // Allow checks if a request for the given key is allowed. func (rl *RateLimiter) Allow(key string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := rl.nowFunc() // Evict stale buckets every 100 calls to prevent unbounded growth. rl.callCount++ if rl.callCount%100 == 0 { const staleThreshold = 10 * time.Minute for k, b := range rl.buckets { if now.Sub(b.lastSeen) > staleThreshold { delete(rl.buckets, k) } } } b, exists := rl.buckets[key] if !exists { rl.buckets[key] = &bucket{tokens: rl.burst - 1, lastSeen: now} return true } elapsed := now.Sub(b.lastSeen).Seconds() b.tokens += elapsed * rl.rate if b.tokens > rl.burst { b.tokens = rl.burst } b.lastSeen = now if b.tokens < 1 { return false } b.tokens-- return true } // MaxBodySize returns middleware that limits request body size to the given number of bytes. func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Body != nil { r.Body = http.MaxBytesReader(w, r.Body, maxBytes) } next.ServeHTTP(w, r) }) } } // Middleware returns HTTP middleware that rate-limits by authenticated user ID. func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := UserFromContext(r.Context()) key := "anonymous" if user != nil { key = user.ID } if !rl.Allow(key) { writeError(w, http.StatusTooManyRequests, "rate limit exceeded") return } next.ServeHTTP(w, r) }) } // statusWriter wraps http.ResponseWriter to capture the status code. type statusWriter struct { http.ResponseWriter status int wrote bool } func (sw *statusWriter) WriteHeader(code int) { if !sw.wrote { sw.status = code sw.wrote = true } sw.ResponseWriter.WriteHeader(code) } func (sw *statusWriter) Write(b []byte) (int, error) { if !sw.wrote { sw.status = http.StatusOK sw.wrote = true } return sw.ResponseWriter.Write(b) } // RequestLogger returns middleware that logs each request using slog. func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(sw, r) logger.Info("request", "method", r.Method, "path", r.URL.Path, "status", sw.status, "duration_ms", time.Since(start).Milliseconds(), ) }) } }