package api import ( "context" "errors" "io" "log/slog" "net/http" "net/http/httptest" "testing" "time" "github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model" ) // mockKeyValidator implements KeyValidator for testing. type mockKeyValidator struct { users map[string]struct { user *model.User role string } } func newMockValidator() *mockKeyValidator { return &mockKeyValidator{ users: make(map[string]struct { user *model.User role string }), } } func (m *mockKeyValidator) addKey(key string, user *model.User, role string) { m.users[key] = struct { user *model.User role string }{user: user, role: role} } func (m *mockKeyValidator) ValidateKey(_ context.Context, key string) (*model.User, string, error) { entry, ok := m.users[key] if !ok { return nil, "", errors.New("invalid key") } return entry.user, entry.role, nil } func TestAuthMiddleware_ValidKey(t *testing.T) { kv := newMockValidator() user := &model.User{ID: "alice", Quota: model.DefaultQuota()} kv.addKey("dpk_valid123", user, model.RoleUser) var capturedUser *model.User var capturedRole string handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedUser = UserFromContext(r.Context()) capturedRole = RoleFromContext(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer dpk_valid123") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } if capturedUser == nil || capturedUser.ID != "alice" { t.Fatalf("expected user alice, got %v", capturedUser) } if capturedRole != model.RoleUser { t.Fatalf("expected role user, got %s", capturedRole) } } func TestAuthMiddleware_AdminKey(t *testing.T) { kv := newMockValidator() admin := &model.User{ID: "admin-user", Quota: model.DefaultQuota()} kv.addKey("dpk_admin456", admin, model.RoleAdmin) var capturedRole string handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedRole = RoleFromContext(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer dpk_admin456") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("expected 200, got %d", rr.Code) } if capturedRole != model.RoleAdmin { t.Fatalf("expected role admin, got %s", capturedRole) } } func TestAuthMiddleware_MissingHeader(t *testing.T) { kv := newMockValidator() handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("handler should not be called") })) req := httptest.NewRequest(http.MethodGet, "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", rr.Code) } } func TestAuthMiddleware_InvalidFormat(t *testing.T) { kv := newMockValidator() handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("handler should not be called") })) tests := []struct { name string value string }{ {"no bearer prefix", "dpk_abc123"}, {"basic auth", "Basic dXNlcjpwYXNz"}, {"empty bearer", "Bearer "}, {"bearer lowercase", "bearer dpk_abc123"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", tt.value) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", rr.Code) } }) } } func TestAuthMiddleware_InvalidKey(t *testing.T) { kv := newMockValidator() handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("handler should not be called") })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer dpk_doesnotexist") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", rr.Code) } } func TestRateLimiter_UnderLimit(t *testing.T) { rl := NewRateLimiter(1.0, 60) for i := range 60 { if !rl.Allow("user1") { t.Fatalf("request %d should be allowed (under burst limit)", i+1) } } } func TestRateLimiter_OverLimit(t *testing.T) { rl := NewRateLimiter(1.0, 5) // Exhaust the burst for range 5 { rl.Allow("user1") } // Next request should be rejected if rl.Allow("user1") { t.Fatal("request should be rejected after burst exhausted") } } func TestRateLimiter_DifferentKeys(t *testing.T) { rl := NewRateLimiter(1.0, 2) // Exhaust user1's burst rl.Allow("user1") rl.Allow("user1") if rl.Allow("user1") { t.Fatal("user1 should be rate limited") } // user2 should still be allowed if !rl.Allow("user2") { t.Fatal("user2 should not be affected by user1's rate limit") } } func TestRateLimiter_Refill(t *testing.T) { rl := NewRateLimiter(1.0, 2) // Use a controllable time function now := time.Now() rl.nowFunc = func() time.Time { return now } // Exhaust burst rl.Allow("user1") rl.Allow("user1") if rl.Allow("user1") { t.Fatal("should be rate limited") } // Advance time by 1 second (1 token refill at 1/s) now = now.Add(1 * time.Second) if !rl.Allow("user1") { t.Fatal("should be allowed after token refill") } } func TestRateLimiter_Middleware(t *testing.T) { rl := NewRateLimiter(1.0, 2) handlerCalls := 0 handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerCalls++ w.WriteHeader(http.StatusOK) })) // Set up authenticated context user := &model.User{ID: "alice"} for i := range 3 { req := httptest.NewRequest(http.MethodGet, "/", nil) ctx := context.WithValue(req.Context(), ctxKeyUser, user) req = req.WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if i < 2 { if rr.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code) } } else { if rr.Code != http.StatusTooManyRequests { t.Fatalf("request %d: expected 429, got %d", i+1, rr.Code) } } } if handlerCalls != 2 { t.Fatalf("expected 2 handler calls, got %d", handlerCalls) } } func TestRequestLogger(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) })) req := httptest.NewRequest(http.MethodPost, "/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusCreated { t.Fatalf("expected 201, got %d", rr.Code) } } func TestStatusWriter_DefaultStatus(t *testing.T) { rr := httptest.NewRecorder() sw := &statusWriter{ResponseWriter: rr, status: http.StatusOK} sw.Write([]byte("hello")) if sw.status != http.StatusOK { t.Fatalf("expected default 200, got %d", sw.status) } } func TestStatusWriter_ExplicitStatus(t *testing.T) { rr := httptest.NewRecorder() sw := &statusWriter{ResponseWriter: rr, status: http.StatusOK} sw.WriteHeader(http.StatusNotFound) if sw.status != http.StatusNotFound { t.Fatalf("expected 404, got %d", sw.status) } // Second WriteHeader should be ignored sw.WriteHeader(http.StatusOK) if sw.status != http.StatusNotFound { t.Fatalf("expected 404 after double write, got %d", sw.status) } }