293 lines
7.4 KiB
Go
293 lines
7.4 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/iliaivanov/spec-kit-remote/cmd/dev-pod-api/internal/model"
|
|
)
|
|
|
|
// mockKeyValidator implements KeyValidator for testing.
|
|
type mockKeyValidator struct {
|
|
users map[string]struct {
|
|
user *model.User
|
|
role string
|
|
}
|
|
}
|
|
|
|
func newMockValidator() *mockKeyValidator {
|
|
return &mockKeyValidator{
|
|
users: make(map[string]struct {
|
|
user *model.User
|
|
role string
|
|
}),
|
|
}
|
|
}
|
|
|
|
func (m *mockKeyValidator) addKey(key string, user *model.User, role string) {
|
|
m.users[key] = struct {
|
|
user *model.User
|
|
role string
|
|
}{user: user, role: role}
|
|
}
|
|
|
|
func (m *mockKeyValidator) ValidateKey(_ context.Context, key string) (*model.User, string, error) {
|
|
entry, ok := m.users[key]
|
|
if !ok {
|
|
return nil, "", errors.New("invalid key")
|
|
}
|
|
return entry.user, entry.role, nil
|
|
}
|
|
|
|
func TestAuthMiddleware_ValidKey(t *testing.T) {
|
|
kv := newMockValidator()
|
|
user := &model.User{ID: "alice", Quota: model.DefaultQuota()}
|
|
kv.addKey("dpk_valid123", user, model.RoleUser)
|
|
|
|
var capturedUser *model.User
|
|
var capturedRole string
|
|
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedUser = UserFromContext(r.Context())
|
|
capturedRole = RoleFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer dpk_valid123")
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
|
}
|
|
if capturedUser == nil || capturedUser.ID != "alice" {
|
|
t.Fatalf("expected user alice, got %v", capturedUser)
|
|
}
|
|
if capturedRole != model.RoleUser {
|
|
t.Fatalf("expected role user, got %s", capturedRole)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_AdminKey(t *testing.T) {
|
|
kv := newMockValidator()
|
|
admin := &model.User{ID: "admin-user", Quota: model.DefaultQuota()}
|
|
kv.addKey("dpk_admin456", admin, model.RoleAdmin)
|
|
|
|
var capturedRole string
|
|
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedRole = RoleFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer dpk_admin456")
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
|
}
|
|
if capturedRole != model.RoleAdmin {
|
|
t.Fatalf("expected role admin, got %s", capturedRole)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
|
kv := newMockValidator()
|
|
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatal("handler should not be called")
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected 401, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
|
kv := newMockValidator()
|
|
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatal("handler should not be called")
|
|
}))
|
|
|
|
tests := []struct {
|
|
name string
|
|
value string
|
|
}{
|
|
{"no bearer prefix", "dpk_abc123"},
|
|
{"basic auth", "Basic dXNlcjpwYXNz"},
|
|
{"empty bearer", "Bearer "},
|
|
{"bearer lowercase", "bearer dpk_abc123"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", tt.value)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected 401, got %d", rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InvalidKey(t *testing.T) {
|
|
kv := newMockValidator()
|
|
handler := AuthMiddleware(kv)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatal("handler should not be called")
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer dpk_doesnotexist")
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected 401, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_UnderLimit(t *testing.T) {
|
|
rl := NewRateLimiter(1.0, 60)
|
|
for i := range 60 {
|
|
if !rl.Allow("user1") {
|
|
t.Fatalf("request %d should be allowed (under burst limit)", i+1)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_OverLimit(t *testing.T) {
|
|
rl := NewRateLimiter(1.0, 5)
|
|
|
|
// Exhaust the burst
|
|
for range 5 {
|
|
rl.Allow("user1")
|
|
}
|
|
|
|
// Next request should be rejected
|
|
if rl.Allow("user1") {
|
|
t.Fatal("request should be rejected after burst exhausted")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_DifferentKeys(t *testing.T) {
|
|
rl := NewRateLimiter(1.0, 2)
|
|
|
|
// Exhaust user1's burst
|
|
rl.Allow("user1")
|
|
rl.Allow("user1")
|
|
if rl.Allow("user1") {
|
|
t.Fatal("user1 should be rate limited")
|
|
}
|
|
|
|
// user2 should still be allowed
|
|
if !rl.Allow("user2") {
|
|
t.Fatal("user2 should not be affected by user1's rate limit")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_Refill(t *testing.T) {
|
|
rl := NewRateLimiter(1.0, 2)
|
|
|
|
// Use a controllable time function
|
|
now := time.Now()
|
|
rl.nowFunc = func() time.Time { return now }
|
|
|
|
// Exhaust burst
|
|
rl.Allow("user1")
|
|
rl.Allow("user1")
|
|
if rl.Allow("user1") {
|
|
t.Fatal("should be rate limited")
|
|
}
|
|
|
|
// Advance time by 1 second (1 token refill at 1/s)
|
|
now = now.Add(1 * time.Second)
|
|
if !rl.Allow("user1") {
|
|
t.Fatal("should be allowed after token refill")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_Middleware(t *testing.T) {
|
|
rl := NewRateLimiter(1.0, 2)
|
|
|
|
handlerCalls := 0
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalls++
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Set up authenticated context
|
|
user := &model.User{ID: "alice"}
|
|
for i := range 3 {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
ctx := context.WithValue(req.Context(), ctxKeyUser, user)
|
|
req = req.WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if i < 2 {
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code)
|
|
}
|
|
} else {
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("request %d: expected 429, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
if handlerCalls != 2 {
|
|
t.Fatalf("expected 2 handler calls, got %d", handlerCalls)
|
|
}
|
|
}
|
|
|
|
func TestRequestLogger(t *testing.T) {
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusCreated)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusCreated {
|
|
t.Fatalf("expected 201, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestStatusWriter_DefaultStatus(t *testing.T) {
|
|
rr := httptest.NewRecorder()
|
|
sw := &statusWriter{ResponseWriter: rr, status: http.StatusOK}
|
|
sw.Write([]byte("hello"))
|
|
if sw.status != http.StatusOK {
|
|
t.Fatalf("expected default 200, got %d", sw.status)
|
|
}
|
|
}
|
|
|
|
func TestStatusWriter_ExplicitStatus(t *testing.T) {
|
|
rr := httptest.NewRecorder()
|
|
sw := &statusWriter{ResponseWriter: rr, status: http.StatusOK}
|
|
sw.WriteHeader(http.StatusNotFound)
|
|
if sw.status != http.StatusNotFound {
|
|
t.Fatalf("expected 404, got %d", sw.status)
|
|
}
|
|
|
|
// Second WriteHeader should be ignored
|
|
sw.WriteHeader(http.StatusOK)
|
|
if sw.status != http.StatusNotFound {
|
|
t.Fatalf("expected 404 after double write, got %d", sw.status)
|
|
}
|
|
}
|