package auth import ( "context" "net/http" "net/http/httptest" "testing" "github.com/google/uuid" ) type mockTenantLookup struct { tenantID *uuid.UUID err error hasAccess bool accessErr error role string } func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { return m.tenantID, m.err } func (m *mockTenantLookup) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) { return m.hasAccess, m.accessErr } func (m *mockTenantLookup) GetUserRole(ctx context.Context, userID, tenantID uuid.UUID) (string, error) { if m.role != "" { return m.role, m.err } if m.hasAccess { return "associate", m.err } return "", m.err } func TestTenantResolver_FromHeader(t *testing.T) { tenantID := uuid.New() tr := NewTenantResolver(&mockTenantLookup{hasAccess: true, role: "partner"}) var gotTenantID uuid.UUID next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id, ok := TenantFromContext(r.Context()) if !ok { t.Fatal("tenant ID not in context") } gotTenantID = id w.WriteHeader(http.StatusOK) }) r := httptest.NewRequest("GET", "/api/cases", nil) r.Header.Set("X-Tenant-ID", tenantID.String()) r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) w := httptest.NewRecorder() tr.Resolve(next).ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } if gotTenantID != tenantID { t.Errorf("expected tenant %s, got %s", tenantID, gotTenantID) } } func TestTenantResolver_FromHeader_NoAccess(t *testing.T) { tenantID := uuid.New() tr := NewTenantResolver(&mockTenantLookup{hasAccess: false}) next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next should not be called") }) r := httptest.NewRequest("GET", "/api/cases", nil) r.Header.Set("X-Tenant-ID", tenantID.String()) r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) w := httptest.NewRecorder() tr.Resolve(next).ServeHTTP(w, r) if w.Code != http.StatusForbidden { t.Errorf("expected 403, got %d", w.Code) } } func TestTenantResolver_DefaultsToFirst(t *testing.T) { tenantID := uuid.New() tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID, role: "owner"}) var gotTenantID uuid.UUID next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id, _ := TenantFromContext(r.Context()) gotTenantID = id w.WriteHeader(http.StatusOK) }) r := httptest.NewRequest("GET", "/api/cases", nil) r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) w := httptest.NewRecorder() tr.Resolve(next).ServeHTTP(w, r) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } if gotTenantID != tenantID { t.Errorf("expected tenant %s, got %s", tenantID, gotTenantID) } } func TestTenantResolver_NoUser(t *testing.T) { tr := NewTenantResolver(&mockTenantLookup{}) next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next should not be called") }) r := httptest.NewRequest("GET", "/api/cases", nil) w := httptest.NewRecorder() tr.Resolve(next).ServeHTTP(w, r) if w.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", w.Code) } } func TestTenantResolver_InvalidHeader(t *testing.T) { tr := NewTenantResolver(&mockTenantLookup{}) next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next should not be called") }) r := httptest.NewRequest("GET", "/api/cases", nil) r.Header.Set("X-Tenant-ID", "not-a-uuid") r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) w := httptest.NewRecorder() tr.Resolve(next).ServeHTTP(w, r) if w.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", w.Code) } } func TestTenantResolver_NoTenantForUser(t *testing.T) { tr := NewTenantResolver(&mockTenantLookup{tenantID: nil}) next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next should not be called") }) r := httptest.NewRequest("GET", "/api/cases", nil) r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) w := httptest.NewRecorder() tr.Resolve(next).ServeHTTP(w, r) if w.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", w.Code) } }