dev-pod-api-build/internal/api/middleware_test.go
2026-04-16 04:16:36 +00:00

293 lines
7.4 KiB
Go

package api
import (
"context"
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
)
// mockKeyValidator implements KeyValidator for testing.
type mockKeyValidator struct {
users map[string]struct {
user *model.User
role string
}
}
func newMockValidator() *mockKeyValidator {
return &mockKeyValidator{
users: make(map[string]struct {
user *model.User
role string
}),
}
}
func (m *mockKeyValidator) addKey(key string, user *model.User, role string) {
m.users[key] = struct {
user *model.User
role string
}{user: user, role: role}
}
func (m *mockKeyValidator) ValidateKey(_ context.Context, key string) (*model.User, string, error) {
entry, ok := m.users[key]
if !ok {
return nil, "", errors.New("invalid key")
}
return entry.user, entry.role, nil
}
func TestAuthMiddleware_ValidKey(t *testing.T) {
kv := newMockValidator()
user := &model.User{ID: "alice", Quota: model.DefaultQuota()}
kv.addKey("dpk_valid123", user, model.RoleUser)
var capturedUser *model.User
var capturedRole string
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedUser = UserFromContext(r.Context())
capturedRole = RoleFromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer dpk_valid123")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
if capturedUser == nil || capturedUser.ID != "alice" {
t.Fatalf("expected user alice, got %v", capturedUser)
}
if capturedRole != model.RoleUser {
t.Fatalf("expected role user, got %s", capturedRole)
}
}
func TestAuthMiddleware_AdminKey(t *testing.T) {
kv := newMockValidator()
admin := &model.User{ID: "admin-user", Quota: model.DefaultQuota()}
kv.addKey("dpk_admin456", admin, model.RoleAdmin)
var capturedRole string
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedRole = RoleFromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer dpk_admin456")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
if capturedRole != model.RoleAdmin {
t.Fatalf("expected role admin, got %s", capturedRole)
}
}
func TestAuthMiddleware_MissingHeader(t *testing.T) {
kv := newMockValidator()
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("handler should not be called")
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rr.Code)
}
}
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
kv := newMockValidator()
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("handler should not be called")
}))
tests := []struct {
name string
value string
}{
{"no bearer prefix", "dpk_abc123"},
{"basic auth", "Basic dXNlcjpwYXNz"},
{"empty bearer", "Bearer "},
{"bearer lowercase", "bearer dpk_abc123"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", tt.value)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rr.Code)
}
})
}
}
func TestAuthMiddleware_InvalidKey(t *testing.T) {
kv := newMockValidator()
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("handler should not be called")
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer dpk_doesnotexist")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rr.Code)
}
}
func TestRateLimiter_UnderLimit(t *testing.T) {
rl := NewRateLimiter(1.0, 60)
for i := range 60 {
if !rl.Allow("user1") {
t.Fatalf("request %d should be allowed (under burst limit)", i+1)
}
}
}
func TestRateLimiter_OverLimit(t *testing.T) {
rl := NewRateLimiter(1.0, 5)
// Exhaust the burst
for range 5 {
rl.Allow("user1")
}
// Next request should be rejected
if rl.Allow("user1") {
t.Fatal("request should be rejected after burst exhausted")
}
}
func TestRateLimiter_DifferentKeys(t *testing.T) {
rl := NewRateLimiter(1.0, 2)
// Exhaust user1's burst
rl.Allow("user1")
rl.Allow("user1")
if rl.Allow("user1") {
t.Fatal("user1 should be rate limited")
}
// user2 should still be allowed
if !rl.Allow("user2") {
t.Fatal("user2 should not be affected by user1's rate limit")
}
}
func TestRateLimiter_Refill(t *testing.T) {
rl := NewRateLimiter(1.0, 2)
// Use a controllable time function
now := time.Now()
rl.nowFunc = func() time.Time { return now }
// Exhaust burst
rl.Allow("user1")
rl.Allow("user1")
if rl.Allow("user1") {
t.Fatal("should be rate limited")
}
// Advance time by 1 second (1 token refill at 1/s)
now = now.Add(1 * time.Second)
if !rl.Allow("user1") {
t.Fatal("should be allowed after token refill")
}
}
func TestRateLimiter_Middleware(t *testing.T) {
rl := NewRateLimiter(1.0, 2)
handlerCalls := 0
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalls++
w.WriteHeader(http.StatusOK)
}))
// Set up authenticated context
user := &model.User{ID: "alice"}
for i := range 3 {
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyUser, user)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if i < 2 {
if rr.Code != http.StatusOK {
t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code)
}
} else {
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("request %d: expected 429, got %d", i+1, rr.Code)
}
}
}
if handlerCalls != 2 {
t.Fatalf("expected 2 handler calls, got %d", handlerCalls)
}
}
func TestRequestLogger(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
}))
req := httptest.NewRequest(http.MethodPost, "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", rr.Code)
}
}
func TestStatusWriter_DefaultStatus(t *testing.T) {
rr := httptest.NewRecorder()
sw := &statusWriter{ResponseWriter: rr, status: http.StatusOK}
sw.Write([]byte("hello"))
if sw.status != http.StatusOK {
t.Fatalf("expected default 200, got %d", sw.status)
}
}
func TestStatusWriter_ExplicitStatus(t *testing.T) {
rr := httptest.NewRecorder()
sw := &statusWriter{ResponseWriter: rr, status: http.StatusOK}
sw.WriteHeader(http.StatusNotFound)
if sw.status != http.StatusNotFound {
t.Fatalf("expected 404, got %d", sw.status)
}
// Second WriteHeader should be ignored
sw.WriteHeader(http.StatusOK)
if sw.status != http.StatusNotFound {
t.Fatalf("expected 404 after double write, got %d", sw.status)
}
}