198 lines
5 KiB
Go
198 lines
5 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ctxKeyUser contextKey = "user"
|
|
ctxKeyRole contextKey = "role"
|
|
)
|
|
|
|
// UserFromContext returns the authenticated user from the request context.
|
|
func UserFromContext(ctx context.Context) *model.User {
|
|
u, _ := ctx.Value(ctxKeyUser).(*model.User)
|
|
return u
|
|
}
|
|
|
|
// RoleFromContext returns the authenticated user's role from the request context.
|
|
func RoleFromContext(ctx context.Context) string {
|
|
r, _ := ctx.Value(ctxKeyRole).(string)
|
|
return r
|
|
}
|
|
|
|
// KeyValidator validates an API key and returns the associated user and role.
|
|
type KeyValidator interface {
|
|
ValidateKey(ctx context.Context, key string) (*model.User, string, error)
|
|
}
|
|
|
|
// AuthMiddleware returns middleware that validates Bearer tokens via the KeyValidator.
|
|
func AuthMiddleware(kv KeyValidator) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.Header.Get("Authorization")
|
|
if auth == "" {
|
|
writeError(w, http.StatusUnauthorized, "missing authorization header")
|
|
return
|
|
}
|
|
|
|
token, found := strings.CutPrefix(auth, "Bearer ")
|
|
if !found || token == "" {
|
|
writeError(w, http.StatusUnauthorized, "invalid authorization format, expected: Bearer <token>")
|
|
return
|
|
}
|
|
|
|
user, role, err := kv.ValidateKey(r.Context(), token)
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "invalid api key")
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), ctxKeyUser, user)
|
|
ctx = context.WithValue(ctx, ctxKeyRole, role)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// bucket tracks rate limiting state for a single key.
|
|
type bucket struct {
|
|
tokens float64
|
|
lastSeen time.Time
|
|
}
|
|
|
|
// RateLimiter implements a per-key token bucket rate limiter.
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
buckets map[string]*bucket
|
|
rate float64 // tokens per second
|
|
burst float64 // max tokens
|
|
nowFunc func() time.Time
|
|
callCount int // tracks calls for periodic cleanup
|
|
}
|
|
|
|
// NewRateLimiter creates a rate limiter with the given rate (req/s) and burst size.
|
|
func NewRateLimiter(ratePerSecond, burst float64) *RateLimiter {
|
|
return &RateLimiter{
|
|
buckets: make(map[string]*bucket),
|
|
rate: ratePerSecond,
|
|
burst: burst,
|
|
nowFunc: time.Now,
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request for the given key is allowed.
|
|
func (rl *RateLimiter) Allow(key string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := rl.nowFunc()
|
|
|
|
// Evict stale buckets every 100 calls to prevent unbounded growth.
|
|
rl.callCount++
|
|
if rl.callCount%100 == 0 {
|
|
const staleThreshold = 10 * time.Minute
|
|
for k, b := range rl.buckets {
|
|
if now.Sub(b.lastSeen) > staleThreshold {
|
|
delete(rl.buckets, k)
|
|
}
|
|
}
|
|
}
|
|
|
|
b, exists := rl.buckets[key]
|
|
if !exists {
|
|
rl.buckets[key] = &bucket{tokens: rl.burst - 1, lastSeen: now}
|
|
return true
|
|
}
|
|
|
|
elapsed := now.Sub(b.lastSeen).Seconds()
|
|
b.tokens += elapsed * rl.rate
|
|
if b.tokens > rl.burst {
|
|
b.tokens = rl.burst
|
|
}
|
|
b.lastSeen = now
|
|
|
|
if b.tokens < 1 {
|
|
return false
|
|
}
|
|
b.tokens--
|
|
return true
|
|
}
|
|
|
|
// MaxBodySize returns middleware that limits request body size to the given number of bytes.
|
|
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Body != nil {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// Middleware returns HTTP middleware that rate-limits by authenticated user ID.
|
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user := UserFromContext(r.Context())
|
|
key := "anonymous"
|
|
if user != nil {
|
|
key = user.ID
|
|
}
|
|
|
|
if !rl.Allow(key) {
|
|
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// statusWriter wraps http.ResponseWriter to capture the status code.
|
|
type statusWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
wrote bool
|
|
}
|
|
|
|
func (sw *statusWriter) WriteHeader(code int) {
|
|
if !sw.wrote {
|
|
sw.status = code
|
|
sw.wrote = true
|
|
}
|
|
sw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (sw *statusWriter) Write(b []byte) (int, error) {
|
|
if !sw.wrote {
|
|
sw.status = http.StatusOK
|
|
sw.wrote = true
|
|
}
|
|
return sw.ResponseWriter.Write(b)
|
|
}
|
|
|
|
// RequestLogger returns middleware that logs each request using slog.
|
|
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
|
next.ServeHTTP(sw, r)
|
|
logger.Info("request",
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"status", sw.status,
|
|
"duration_ms", time.Since(start).Milliseconds(),
|
|
)
|
|
})
|
|
}
|
|
}
|