build source
This commit is contained in:
commit
ee1fec43ed
4171 changed files with 1351288 additions and 0 deletions
198
internal/api/middleware.go
Normal file
198
internal/api/middleware.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
ctxKeyUser contextKey = "user"
|
||||
ctxKeyRole contextKey = "role"
|
||||
)
|
||||
|
||||
// UserFromContext returns the authenticated user from the request context.
|
||||
func UserFromContext(ctx context.Context) *model.User {
|
||||
u, _ := ctx.Value(ctxKeyUser).(*model.User)
|
||||
return u
|
||||
}
|
||||
|
||||
// RoleFromContext returns the authenticated user's role from the request context.
|
||||
func RoleFromContext(ctx context.Context) string {
|
||||
r, _ := ctx.Value(ctxKeyRole).(string)
|
||||
return r
|
||||
}
|
||||
|
||||
// KeyValidator validates an API key and returns the associated user and role.
|
||||
type KeyValidator interface {
|
||||
ValidateKey(ctx context.Context, key string) (*model.User, string, error)
|
||||
}
|
||||
|
||||
// AuthMiddleware returns middleware that validates Bearer tokens via the KeyValidator.
|
||||
func AuthMiddleware(kv KeyValidator) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing authorization header")
|
||||
return
|
||||
}
|
||||
|
||||
token, found := strings.CutPrefix(auth, "Bearer ")
|
||||
if !found || token == "" {
|
||||
writeError(w, http.StatusUnauthorized, "invalid authorization format, expected: Bearer <token>")
|
||||
return
|
||||
}
|
||||
|
||||
user, role, err := kv.ValidateKey(r.Context(), token)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "invalid api key")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), ctxKeyUser, user)
|
||||
ctx = context.WithValue(ctx, ctxKeyRole, role)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// bucket tracks rate limiting state for a single key.
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// RateLimiter implements a per-key token bucket rate limiter.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*bucket
|
||||
rate float64 // tokens per second
|
||||
burst float64 // max tokens
|
||||
nowFunc func() time.Time
|
||||
callCount int // tracks calls for periodic cleanup
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a rate limiter with the given rate (req/s) and burst size.
|
||||
func NewRateLimiter(ratePerSecond, burst float64) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
buckets: make(map[string]*bucket),
|
||||
rate: ratePerSecond,
|
||||
burst: burst,
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key is allowed.
|
||||
func (rl *RateLimiter) Allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := rl.nowFunc()
|
||||
|
||||
// Evict stale buckets every 100 calls to prevent unbounded growth.
|
||||
rl.callCount++
|
||||
if rl.callCount%100 == 0 {
|
||||
const staleThreshold = 10 * time.Minute
|
||||
for k, b := range rl.buckets {
|
||||
if now.Sub(b.lastSeen) > staleThreshold {
|
||||
delete(rl.buckets, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b, exists := rl.buckets[key]
|
||||
if !exists {
|
||||
rl.buckets[key] = &bucket{tokens: rl.burst - 1, lastSeen: now}
|
||||
return true
|
||||
}
|
||||
|
||||
elapsed := now.Sub(b.lastSeen).Seconds()
|
||||
b.tokens += elapsed * rl.rate
|
||||
if b.tokens > rl.burst {
|
||||
b.tokens = rl.burst
|
||||
}
|
||||
b.lastSeen = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// MaxBodySize returns middleware that limits request body size to the given number of bytes.
|
||||
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Body != nil {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns HTTP middleware that rate-limits by authenticated user ID.
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := UserFromContext(r.Context())
|
||||
key := "anonymous"
|
||||
if user != nil {
|
||||
key = user.ID
|
||||
}
|
||||
|
||||
if !rl.Allow(key) {
|
||||
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// statusWriter wraps http.ResponseWriter to capture the status code.
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
wrote bool
|
||||
}
|
||||
|
||||
func (sw *statusWriter) WriteHeader(code int) {
|
||||
if !sw.wrote {
|
||||
sw.status = code
|
||||
sw.wrote = true
|
||||
}
|
||||
sw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (sw *statusWriter) Write(b []byte) (int, error) {
|
||||
if !sw.wrote {
|
||||
sw.status = http.StatusOK
|
||||
sw.wrote = true
|
||||
}
|
||||
return sw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// RequestLogger returns middleware that logs each request using slog.
|
||||
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(sw, r)
|
||||
logger.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sw.status,
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue