From c15d5b72f2575ae109c8ce35d15340599bff1e09 Mon Sep 17 00:00:00 2001 From: m Date: Mon, 30 Mar 2026 11:01:14 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20critical=20security=20hardening=20?= =?UTF-8?q?=E2=80=94=20tenant=20isolation,=20CORS,=20error=20leaking,=20in?= =?UTF-8?q?put=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Tenant isolation bypass (CRITICAL): TenantResolver now verifies user has access to X-Tenant-ID via user_tenants lookup before setting context. Added VerifyAccess method to TenantLookup interface and TenantService. 2. Consolidated tenant resolution: Removed duplicate resolveTenant() from helpers.go and tenant resolution from auth middleware. TenantResolver is now the single source of truth. Deadlines and AI handlers use auth.TenantFromContext() instead of direct DB queries. 3. CalDAV credential masking: tenant settings responses now mask CalDAV passwords with "********" via maskSettingsPassword helper. Applied to GetTenant, ListTenants, and UpdateSettings responses. 4. CORS + security headers: New middleware/security.go with CORS (restricted to FRONTEND_ORIGIN) and security headers (X-Frame-Options, X-Content-Type-Options, HSTS, Referrer-Policy, X-XSS-Protection). 5. Internal error leaking: All writeError(w, 500, err.Error()) replaced with internalError() that logs via slog and returns generic "internal error" to client. Same for jsonError in tenant handler. 6. Input validation: Max length on title (500), description (10000), case_number (100), search (200). Pagination clamped to max 100. Content-Disposition filename sanitized against header injection. Regression test added for tenant access denial (403 on unauthorized X-Tenant-ID). All existing tests pass, go vet clean. --- backend/internal/auth/middleware.go | 17 +-- backend/internal/auth/tenant_resolver.go | 29 +++- backend/internal/auth/tenant_resolver_test.go | 32 ++++- backend/internal/config/config.go | 2 + backend/internal/handlers/ai.go | 22 +-- backend/internal/handlers/ai_handler_test.go | 14 +- backend/internal/handlers/appointments.go | 8 ++ backend/internal/handlers/caldav.go | 2 +- backend/internal/handlers/cases.go | 37 ++++- backend/internal/handlers/dashboard.go | 2 +- backend/internal/handlers/deadlines.go | 71 ++++----- backend/internal/handlers/documents.go | 12 +- backend/internal/handlers/helpers.go | 136 ++++++++++-------- backend/internal/handlers/notes.go | 8 ++ backend/internal/handlers/parties.go | 11 +- backend/internal/handlers/tenant_handler.go | 48 +++++-- backend/internal/middleware/security.go | 49 +++++++ backend/internal/router/router.go | 15 +- backend/internal/services/tenant_service.go | 13 ++ 19 files changed, 361 insertions(+), 167 deletions(-) create mode 100644 backend/internal/middleware/security.go diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index 4f31eb6..91533d6 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -24,28 +24,19 @@ func (m *Middleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := extractBearerToken(r) if token == "" { - http.Error(w, "missing authorization token", http.StatusUnauthorized) + http.Error(w, `{"error":"missing authorization token"}`, http.StatusUnauthorized) return } userID, err := m.verifyJWT(token) if err != nil { - http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) + http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized) return } ctx := ContextWithUserID(r.Context(), userID) - - // Resolve tenant from user_tenants - var tenantID uuid.UUID - err = m.db.GetContext(r.Context(), &tenantID, - "SELECT tenant_id FROM user_tenants WHERE user_id = $1 LIMIT 1", userID) - if err != nil { - http.Error(w, "no tenant found for user", http.StatusForbidden) - return - } - ctx = ContextWithTenantID(ctx, tenantID) - + // Tenant resolution is handled by TenantResolver middleware for scoped routes. + // Tenant management routes handle their own access control. next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/backend/internal/auth/tenant_resolver.go b/backend/internal/auth/tenant_resolver.go index 6358d4d..24688d0 100644 --- a/backend/internal/auth/tenant_resolver.go +++ b/backend/internal/auth/tenant_resolver.go @@ -2,20 +2,21 @@ package auth import ( "context" - "fmt" + "log/slog" "net/http" "github.com/google/uuid" ) -// TenantLookup resolves the default tenant for a user. +// TenantLookup resolves and verifies tenant access for a user. // Defined as an interface to avoid circular dependency with services. type TenantLookup interface { FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) + VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) } // TenantResolver is middleware that resolves the tenant from X-Tenant-ID header -// or defaults to the user's first tenant. +// or defaults to the user's first tenant. Always verifies user has access. type TenantResolver struct { lookup TenantLookup } @@ -28,7 +29,7 @@ func (tr *TenantResolver) Resolve(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID, ok := UserFromContext(r.Context()) if !ok { - http.Error(w, "unauthorized", http.StatusUnauthorized) + http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) return } @@ -37,19 +38,33 @@ func (tr *TenantResolver) Resolve(next http.Handler) http.Handler { if header := r.Header.Get("X-Tenant-ID"); header != "" { parsed, err := uuid.Parse(header) if err != nil { - http.Error(w, fmt.Sprintf("invalid X-Tenant-ID: %v", err), http.StatusBadRequest) + http.Error(w, `{"error":"invalid X-Tenant-ID"}`, http.StatusBadRequest) return } + + // Verify user has access to this tenant + hasAccess, err := tr.lookup.VerifyAccess(r.Context(), userID, parsed) + if err != nil { + slog.Error("tenant access check failed", "error", err, "user_id", userID, "tenant_id", parsed) + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return + } + if !hasAccess { + http.Error(w, `{"error":"no access to tenant"}`, http.StatusForbidden) + return + } + tenantID = parsed } else { // Default to user's first tenant first, err := tr.lookup.FirstTenantForUser(r.Context(), userID) if err != nil { - http.Error(w, fmt.Sprintf("resolving tenant: %v", err), http.StatusInternalServerError) + slog.Error("failed to resolve default tenant", "error", err, "user_id", userID) + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) return } if first == nil { - http.Error(w, "no tenant found for user", http.StatusBadRequest) + http.Error(w, `{"error":"no tenant found for user"}`, http.StatusBadRequest) return } tenantID = *first diff --git a/backend/internal/auth/tenant_resolver_test.go b/backend/internal/auth/tenant_resolver_test.go index dfb8e2d..a542bdf 100644 --- a/backend/internal/auth/tenant_resolver_test.go +++ b/backend/internal/auth/tenant_resolver_test.go @@ -10,17 +10,23 @@ import ( ) type mockTenantLookup struct { - tenantID *uuid.UUID - err error + tenantID *uuid.UUID + err error + hasAccess bool + accessErr error } func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { return m.tenantID, m.err } +func (m *mockTenantLookup) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) { + return m.hasAccess, m.accessErr +} + func TestTenantResolver_FromHeader(t *testing.T) { tenantID := uuid.New() - tr := NewTenantResolver(&mockTenantLookup{}) + tr := NewTenantResolver(&mockTenantLookup{hasAccess: true}) var gotTenantID uuid.UUID next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -47,6 +53,26 @@ func TestTenantResolver_FromHeader(t *testing.T) { } } +func TestTenantResolver_FromHeader_NoAccess(t *testing.T) { + tenantID := uuid.New() + tr := NewTenantResolver(&mockTenantLookup{hasAccess: false}) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next should not be called") + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + r.Header.Set("X-Tenant-ID", tenantID.String()) + r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) + } +} + func TestTenantResolver_DefaultsToFirst(t *testing.T) { tenantID := uuid.New() tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID}) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3b78b3a..620a7f4 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -13,6 +13,7 @@ type Config struct { SupabaseServiceKey string SupabaseJWTSecret string AnthropicAPIKey string + FrontendOrigin string } func Load() (*Config, error) { @@ -24,6 +25,7 @@ func Load() (*Config, error) { SupabaseServiceKey: os.Getenv("SUPABASE_SERVICE_KEY"), SupabaseJWTSecret: os.Getenv("SUPABASE_JWT_SECRET"), AnthropicAPIKey: os.Getenv("ANTHROPIC_API_KEY"), + FrontendOrigin: getEnv("FRONTEND_ORIGIN", "https://kanzlai.msbls.de"), } if cfg.DatabaseURL == "" { diff --git a/backend/internal/handlers/ai.go b/backend/internal/handlers/ai.go index 806c9b3..1a03c37 100644 --- a/backend/internal/handlers/ai.go +++ b/backend/internal/handlers/ai.go @@ -5,18 +5,16 @@ import ( "io" "net/http" - "github.com/jmoiron/sqlx" - + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" ) type AIHandler struct { ai *services.AIService - db *sqlx.DB } -func NewAIHandler(ai *services.AIService, db *sqlx.DB) *AIHandler { - return &AIHandler{ai: ai, db: db} +func NewAIHandler(ai *services.AIService) *AIHandler { + return &AIHandler{ai: ai} } // ExtractDeadlines handles POST /api/ai/extract-deadlines @@ -61,10 +59,14 @@ func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "provide either a PDF file or text") return } + if len(text) > maxDescriptionLen { + writeError(w, http.StatusBadRequest, "text exceeds maximum length") + return + } deadlines, err := h.ai.ExtractDeadlines(r.Context(), pdfData, text) if err != nil { - writeError(w, http.StatusInternalServerError, "AI extraction failed: "+err.Error()) + internalError(w, "AI deadline extraction failed", err) return } @@ -77,9 +79,9 @@ func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) { // SummarizeCase handles POST /api/ai/summarize-case // Accepts JSON {"case_id": "uuid"}. func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -104,7 +106,7 @@ func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) { summary, err := h.ai.SummarizeCase(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, "AI summarization failed: "+err.Error()) + internalError(w, "AI case summarization failed", err) return } diff --git a/backend/internal/handlers/ai_handler_test.go b/backend/internal/handlers/ai_handler_test.go index 7fb3f67..9622413 100644 --- a/backend/internal/handlers/ai_handler_test.go +++ b/backend/internal/handlers/ai_handler_test.go @@ -42,7 +42,7 @@ func TestAIExtractDeadlines_InvalidJSON(t *testing.T) { } } -func TestAISummarizeCase_MissingCaseID(t *testing.T) { +func TestAISummarizeCase_MissingTenant(t *testing.T) { h := &AIHandler{} body := `{"case_id":""}` @@ -52,9 +52,9 @@ func TestAISummarizeCase_MissingCaseID(t *testing.T) { h.SummarizeCase(w, r) - // Without auth context, the resolveTenant will fail first - if w.Code != http.StatusUnauthorized { - t.Errorf("expected 401, got %d", w.Code) + // Without tenant context, TenantFromContext returns !ok → 403 + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) } } @@ -67,8 +67,8 @@ func TestAISummarizeCase_InvalidJSON(t *testing.T) { h.SummarizeCase(w, r) - // Without auth context, the resolveTenant will fail first - if w.Code != http.StatusUnauthorized { - t.Errorf("expected 401, got %d", w.Code) + // Without tenant context, TenantFromContext returns !ok → 403 + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) } } diff --git a/backend/internal/handlers/appointments.go b/backend/internal/handlers/appointments.go index d8acaab..188fc4e 100644 --- a/backend/internal/handlers/appointments.go +++ b/backend/internal/handlers/appointments.go @@ -121,6 +121,10 @@ func (h *AppointmentHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "title is required") return } + if msg := validateStringLength("title", req.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } if req.StartAt.IsZero() { writeError(w, http.StatusBadRequest, "start_at is required") return @@ -188,6 +192,10 @@ func (h *AppointmentHandler) Update(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "title is required") return } + if msg := validateStringLength("title", req.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } if req.StartAt.IsZero() { writeError(w, http.StatusBadRequest, "start_at is required") return diff --git a/backend/internal/handlers/caldav.go b/backend/internal/handlers/caldav.go index cb38e72..8f93e4a 100644 --- a/backend/internal/handlers/caldav.go +++ b/backend/internal/handlers/caldav.go @@ -27,7 +27,7 @@ func (h *CalDAVHandler) TriggerSync(w http.ResponseWriter, r *http.Request) { cfg, err := h.svc.LoadTenantConfig(tenantID) if err != nil { - writeError(w, http.StatusBadRequest, err.Error()) + writeError(w, http.StatusBadRequest, "CalDAV not configured for this tenant") return } diff --git a/backend/internal/handlers/cases.go b/backend/internal/handlers/cases.go index a10d9d5..8cd9d33 100644 --- a/backend/internal/handlers/cases.go +++ b/backend/internal/handlers/cases.go @@ -28,18 +28,25 @@ func (h *CaseHandler) List(w http.ResponseWriter, r *http.Request) { limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) offset, _ := strconv.Atoi(r.URL.Query().Get("offset")) + limit, offset = clampPagination(limit, offset) + + search := r.URL.Query().Get("search") + if msg := validateStringLength("search", search, maxSearchLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } filter := services.CaseFilter{ Status: r.URL.Query().Get("status"), Type: r.URL.Query().Get("type"), - Search: r.URL.Query().Get("search"), + Search: search, Limit: limit, Offset: offset, } cases, total, err := h.svc.List(r.Context(), tenantID, filter) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to list cases", err) return } @@ -66,10 +73,18 @@ func (h *CaseHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "case_number and title are required") return } + if msg := validateStringLength("case_number", input.CaseNumber, maxCaseNumberLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + if msg := validateStringLength("title", input.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } c, err := h.svc.Create(r.Context(), tenantID, userID, input) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to create case", err) return } @@ -91,7 +106,7 @@ func (h *CaseHandler) Get(w http.ResponseWriter, r *http.Request) { detail, err := h.svc.GetByID(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to get case", err) return } if detail == nil { @@ -121,10 +136,22 @@ func (h *CaseHandler) Update(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "invalid JSON body") return } + if input.Title != nil { + if msg := validateStringLength("title", *input.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + } + if input.CaseNumber != nil { + if msg := validateStringLength("case_number", *input.CaseNumber, maxCaseNumberLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + } updated, err := h.svc.Update(r.Context(), tenantID, caseID, userID, input) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to update case", err) return } if updated == nil { diff --git a/backend/internal/handlers/dashboard.go b/backend/internal/handlers/dashboard.go index 09b5dbe..806800b 100644 --- a/backend/internal/handlers/dashboard.go +++ b/backend/internal/handlers/dashboard.go @@ -24,7 +24,7 @@ func (h *DashboardHandler) Get(w http.ResponseWriter, r *http.Request) { data, err := h.svc.Get(r.Context(), tenantID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to load dashboard", err) return } diff --git a/backend/internal/handlers/deadlines.go b/backend/internal/handlers/deadlines.go index 5b2b265..f8bdac2 100644 --- a/backend/internal/handlers/deadlines.go +++ b/backend/internal/handlers/deadlines.go @@ -4,27 +4,25 @@ import ( "encoding/json" "net/http" - "github.com/jmoiron/sqlx" - + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" ) // DeadlineHandlers holds handlers for deadline CRUD endpoints type DeadlineHandlers struct { deadlines *services.DeadlineService - db *sqlx.DB } // NewDeadlineHandlers creates deadline handlers -func NewDeadlineHandlers(ds *services.DeadlineService, db *sqlx.DB) *DeadlineHandlers { - return &DeadlineHandlers{deadlines: ds, db: db} +func NewDeadlineHandlers(ds *services.DeadlineService) *DeadlineHandlers { + return &DeadlineHandlers{deadlines: ds} } // Get handles GET /api/deadlines/{deadlineID} func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -36,7 +34,7 @@ func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { deadline, err := h.deadlines.GetByID(tenantID, deadlineID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to fetch deadline") + internalError(w, "failed to fetch deadline", err) return } if deadline == nil { @@ -49,15 +47,15 @@ func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { // ListAll handles GET /api/deadlines func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } deadlines, err := h.deadlines.ListAll(tenantID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list deadlines") + internalError(w, "failed to list deadlines", err) return } @@ -66,9 +64,9 @@ func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) { // ListForCase handles GET /api/cases/{caseID}/deadlines func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -80,7 +78,7 @@ func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { deadlines, err := h.deadlines.ListForCase(tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list deadlines") + internalError(w, "failed to list deadlines for case", err) return } @@ -89,9 +87,9 @@ func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { // Create handles POST /api/cases/{caseID}/deadlines func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -112,10 +110,14 @@ func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "title and due_date are required") return } + if msg := validateStringLength("title", input.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } deadline, err := h.deadlines.Create(tenantID, input) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to create deadline") + internalError(w, "failed to create deadline", err) return } @@ -124,9 +126,9 @@ func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { // Update handles PUT /api/deadlines/{deadlineID} func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -144,7 +146,7 @@ func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { deadline, err := h.deadlines.Update(tenantID, deadlineID, input) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to update deadline") + internalError(w, "failed to update deadline", err) return } if deadline == nil { @@ -157,9 +159,9 @@ func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { // Complete handles PATCH /api/deadlines/{deadlineID}/complete func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -171,7 +173,7 @@ func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { deadline, err := h.deadlines.Complete(tenantID, deadlineID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to complete deadline") + internalError(w, "failed to complete deadline", err) return } if deadline == nil { @@ -184,9 +186,9 @@ func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { // Delete handles DELETE /api/deadlines/{deadlineID} func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -196,9 +198,8 @@ func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) { return } - err = h.deadlines.Delete(tenantID, deadlineID) - if err != nil { - writeError(w, http.StatusNotFound, err.Error()) + if err := h.deadlines.Delete(tenantID, deadlineID); err != nil { + writeError(w, http.StatusNotFound, "deadline not found") return } diff --git a/backend/internal/handlers/documents.go b/backend/internal/handlers/documents.go index c15c0cb..3a9098c 100644 --- a/backend/internal/handlers/documents.go +++ b/backend/internal/handlers/documents.go @@ -36,7 +36,7 @@ func (h *DocumentHandler) ListByCase(w http.ResponseWriter, r *http.Request) { docs, err := h.svc.ListByCase(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to list documents", err) return } @@ -98,7 +98,7 @@ func (h *DocumentHandler) Upload(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusNotFound, "case not found") return } - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to upload document", err) return } @@ -121,16 +121,16 @@ func (h *DocumentHandler) Download(w http.ResponseWriter, r *http.Request) { body, contentType, title, err := h.svc.Download(r.Context(), tenantID, docID) if err != nil { if err.Error() == "document not found" || err.Error() == "document has no file" { - writeError(w, http.StatusNotFound, err.Error()) + writeError(w, http.StatusNotFound, "document not found") return } - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to download document", err) return } defer body.Close() w.Header().Set("Content-Type", contentType) - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, title)) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, sanitizeFilename(title))) io.Copy(w, body) } @@ -149,7 +149,7 @@ func (h *DocumentHandler) GetMeta(w http.ResponseWriter, r *http.Request) { doc, err := h.svc.GetByID(r.Context(), tenantID, docID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to get document metadata", err) return } if doc == nil { diff --git a/backend/internal/handlers/helpers.go b/backend/internal/handlers/helpers.go index 785768e..88d0d79 100644 --- a/backend/internal/handlers/helpers.go +++ b/backend/internal/handlers/helpers.go @@ -2,12 +2,12 @@ package handlers import ( "encoding/json" + "log/slog" "net/http" + "strings" + "unicode/utf8" "github.com/google/uuid" - "github.com/jmoiron/sqlx" - - "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" ) func writeJSON(w http.ResponseWriter, status int, v any) { @@ -20,62 +20,9 @@ func writeError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } -// resolveTenant gets the tenant ID for the authenticated user. -// Checks X-Tenant-ID header first, then falls back to user's first tenant. -func resolveTenant(r *http.Request, db *sqlx.DB) (uuid.UUID, error) { - userID, ok := auth.UserFromContext(r.Context()) - if !ok { - return uuid.Nil, errUnauthorized - } - - // Check header first - if headerVal := r.Header.Get("X-Tenant-ID"); headerVal != "" { - tenantID, err := uuid.Parse(headerVal) - if err != nil { - return uuid.Nil, errInvalidTenant - } - // Verify user has access to this tenant - var count int - err = db.Get(&count, - `SELECT COUNT(*) FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`, - userID, tenantID) - if err != nil || count == 0 { - return uuid.Nil, errTenantAccess - } - return tenantID, nil - } - - // Fall back to user's first tenant - var tenantID uuid.UUID - err := db.Get(&tenantID, - `SELECT tenant_id FROM user_tenants WHERE user_id = $1 ORDER BY created_at LIMIT 1`, - userID) - if err != nil { - return uuid.Nil, errNoTenant - } - return tenantID, nil -} - -type apiError struct { - msg string - status int -} - -func (e *apiError) Error() string { return e.msg } - -var ( - errUnauthorized = &apiError{msg: "unauthorized", status: http.StatusUnauthorized} - errInvalidTenant = &apiError{msg: "invalid tenant ID", status: http.StatusBadRequest} - errTenantAccess = &apiError{msg: "no access to tenant", status: http.StatusForbidden} - errNoTenant = &apiError{msg: "no tenant found for user", status: http.StatusBadRequest} -) - -// handleTenantError writes the appropriate error response for tenant resolution errors -func handleTenantError(w http.ResponseWriter, err error) { - if ae, ok := err.(*apiError); ok { - writeError(w, ae.status, ae.msg) - return - } +// internalError logs the real error and returns a generic message to the client. +func internalError(w http.ResponseWriter, msg string, err error) { + slog.Error(msg, "error", err) writeError(w, http.StatusInternalServerError, "internal error") } @@ -88,3 +35,74 @@ func parsePathUUID(r *http.Request, key string) (uuid.UUID, error) { func parseUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) } + +// --- Input validation helpers --- + +const ( + maxTitleLen = 500 + maxDescriptionLen = 10000 + maxCaseNumberLen = 100 + maxSearchLen = 200 + maxPaginationLimit = 100 +) + +// validateStringLength checks if a string exceeds the given max length. +func validateStringLength(field, value string, maxLen int) string { + if utf8.RuneCountInString(value) > maxLen { + return field + " exceeds maximum length" + } + return "" +} + +// clampPagination enforces sane pagination defaults and limits. +func clampPagination(limit, offset int) (int, int) { + if limit <= 0 { + limit = 20 + } + if limit > maxPaginationLimit { + limit = maxPaginationLimit + } + if offset < 0 { + offset = 0 + } + return limit, offset +} + +// sanitizeFilename removes characters unsafe for Content-Disposition headers. +func sanitizeFilename(name string) string { + // Remove control characters, quotes, and backslashes + var b strings.Builder + for _, r := range name { + if r < 32 || r == '"' || r == '\\' || r == '/' { + b.WriteRune('_') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// maskSettingsPassword masks the CalDAV password in tenant settings JSON before returning to clients. +func maskSettingsPassword(settings json.RawMessage) json.RawMessage { + if len(settings) == 0 { + return settings + } + var m map[string]json.RawMessage + if err := json.Unmarshal(settings, &m); err != nil { + return settings + } + caldavRaw, ok := m["caldav"] + if !ok { + return settings + } + var caldav map[string]json.RawMessage + if err := json.Unmarshal(caldavRaw, &caldav); err != nil { + return settings + } + if _, ok := caldav["password"]; ok { + caldav["password"], _ = json.Marshal("********") + } + m["caldav"], _ = json.Marshal(caldav) + result, _ := json.Marshal(m) + return result +} diff --git a/backend/internal/handlers/notes.go b/backend/internal/handlers/notes.go index 68c9173..e935007 100644 --- a/backend/internal/handlers/notes.go +++ b/backend/internal/handlers/notes.go @@ -60,6 +60,10 @@ func (h *NoteHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "content is required") return } + if msg := validateStringLength("content", input.Content, maxDescriptionLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } var createdBy *uuid.UUID if userID != uuid.Nil { @@ -100,6 +104,10 @@ func (h *NoteHandler) Update(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "content is required") return } + if msg := validateStringLength("content", req.Content, maxDescriptionLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } note, err := h.svc.Update(r.Context(), tenantID, noteID, req.Content) if err != nil { diff --git a/backend/internal/handlers/parties.go b/backend/internal/handlers/parties.go index 178e06b..0b220cd 100644 --- a/backend/internal/handlers/parties.go +++ b/backend/internal/handlers/parties.go @@ -34,7 +34,7 @@ func (h *PartyHandler) List(w http.ResponseWriter, r *http.Request) { parties, err := h.svc.ListByCase(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to list parties", err) return } @@ -67,13 +67,18 @@ func (h *PartyHandler) Create(w http.ResponseWriter, r *http.Request) { return } + if msg := validateStringLength("name", input.Name, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + party, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input) if err != nil { if err == sql.ErrNoRows { writeError(w, http.StatusNotFound, "case not found") return } - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to create party", err) return } @@ -101,7 +106,7 @@ func (h *PartyHandler) Update(w http.ResponseWriter, r *http.Request) { updated, err := h.svc.Update(r.Context(), tenantID, partyID, input) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to update party", err) return } if updated == nil { diff --git a/backend/internal/handlers/tenant_handler.go b/backend/internal/handlers/tenant_handler.go index 1db14f7..52bb29f 100644 --- a/backend/internal/handlers/tenant_handler.go +++ b/backend/internal/handlers/tenant_handler.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "log/slog" "net/http" "github.com/google/uuid" @@ -41,7 +42,8 @@ func (h *TenantHandler) CreateTenant(w http.ResponseWriter, r *http.Request) { tenant, err := h.svc.Create(r.Context(), userID, req.Name, req.Slug) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to create tenant", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -58,10 +60,16 @@ func (h *TenantHandler) ListTenants(w http.ResponseWriter, r *http.Request) { tenants, err := h.svc.ListForUser(r.Context(), userID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to list tenants", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } + // Mask CalDAV passwords in tenant settings + for i := range tenants { + tenants[i].Settings = maskSettingsPassword(tenants[i].Settings) + } + jsonResponse(w, tenants, http.StatusOK) } @@ -82,7 +90,8 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { // Verify user has access to this tenant role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role == "" { @@ -92,7 +101,8 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { tenant, err := h.svc.GetByID(r.Context(), tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get tenant", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if tenant == nil { @@ -100,6 +110,9 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { return } + // Mask CalDAV password before returning + tenant.Settings = maskSettingsPassword(tenant.Settings) + jsonResponse(w, tenant, http.StatusOK) } @@ -120,7 +133,8 @@ func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) { // Only owners and admins can invite role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role != "owner" && role != "admin" { @@ -150,7 +164,8 @@ func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) { ut, err := h.svc.InviteByEmail(r.Context(), tenantID, req.Email, req.Role) if err != nil { - jsonError(w, err.Error(), http.StatusBadRequest) + // These are user-facing validation errors (user not found, already member) + jsonError(w, "failed to invite user", http.StatusBadRequest) return } @@ -180,7 +195,8 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { // Only owners and admins can remove members (or user removing themselves) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role != "owner" && role != "admin" && userID != memberID { @@ -189,7 +205,8 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { } if err := h.svc.RemoveMember(r.Context(), tenantID, memberID); err != nil { - jsonError(w, err.Error(), http.StatusBadRequest) + // These are user-facing validation errors (not a member, last owner, etc.) + jsonError(w, "failed to remove member", http.StatusBadRequest) return } @@ -213,7 +230,8 @@ func (h *TenantHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) { // Only owners and admins can update settings role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role != "owner" && role != "admin" { @@ -229,10 +247,14 @@ func (h *TenantHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) { tenant, err := h.svc.UpdateSettings(r.Context(), tenantID, settings) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to update settings", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } + // Mask CalDAV password before returning + tenant.Settings = maskSettingsPassword(tenant.Settings) + jsonResponse(w, tenant, http.StatusOK) } @@ -253,7 +275,8 @@ func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) { // Verify user has access role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role == "" { @@ -263,7 +286,8 @@ func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) { members, err := h.svc.ListMembers(r.Context(), tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to list members", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } diff --git a/backend/internal/middleware/security.go b/backend/internal/middleware/security.go new file mode 100644 index 0000000..eccc6ea --- /dev/null +++ b/backend/internal/middleware/security.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// SecurityHeaders adds standard security headers to all responses. +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-XSS-Protection", "1; mode=block") + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + next.ServeHTTP(w, r) + }) +} + +// CORS returns middleware that restricts cross-origin requests to the given origin. +// If allowedOrigin is empty, CORS headers are not set (same-origin only). +func CORS(allowedOrigin string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + if allowedOrigin != "" && origin != "" && matchOrigin(origin, allowedOrigin) { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID") + w.Header().Set("Access-Control-Max-Age", "86400") + w.Header().Set("Vary", "Origin") + } + + // Handle preflight + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// matchOrigin checks if the request origin matches the allowed origin. +func matchOrigin(origin, allowed string) bool { + return strings.EqualFold(strings.TrimRight(origin, "/"), strings.TrimRight(allowed, "/")) +} diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go index 3c6f8a5..02b9dd9 100644 --- a/backend/internal/router/router.go +++ b/backend/internal/router/router.go @@ -34,7 +34,7 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se var aiH *handlers.AIHandler if cfg.AnthropicAPIKey != "" { aiSvc := services.NewAIService(cfg.AnthropicAPIKey, db) - aiH = handlers.NewAIHandler(aiSvc, db) + aiH = handlers.NewAIHandler(aiSvc) } // Middleware @@ -48,7 +48,7 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se caseH := handlers.NewCaseHandler(caseSvc) partyH := handlers.NewPartyHandler(partySvc) apptH := handlers.NewAppointmentHandler(appointmentSvc) - deadlineH := handlers.NewDeadlineHandlers(deadlineSvc, db) + deadlineH := handlers.NewDeadlineHandlers(deadlineSvc) ruleH := handlers.NewDeadlineRuleHandlers(deadlineRuleSvc) calcH := handlers.NewCalculateHandlers(calculator, deadlineRuleSvc) dashboardH := handlers.NewDashboardHandler(dashboardSvc) @@ -149,14 +149,20 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se mux.Handle("/api/", authMW.RequireAuth(api)) - return requestLogger(mux) + // Apply security middleware stack: CORS -> Security Headers -> Request Logger -> Routes + var handler http.Handler = mux + handler = requestLogger(handler) + handler = middleware.SecurityHeaders(handler) + handler = middleware.CORS(cfg.FrontendOrigin)(handler) + + return handler } func handleHealth(db *sqlx.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := db.Ping(); err != nil { w.WriteHeader(http.StatusServiceUnavailable) - json.NewEncoder(w).Encode(map[string]string{"status": "error", "error": err.Error()}) + json.NewEncoder(w).Encode(map[string]string{"status": "error"}) return } w.Header().Set("Content-Type", "application/json") @@ -194,4 +200,3 @@ func requestLogger(next http.Handler) http.Handler { ) }) } - diff --git a/backend/internal/services/tenant_service.go b/backend/internal/services/tenant_service.go index 7ed5614..0c0f52e 100644 --- a/backend/internal/services/tenant_service.go +++ b/backend/internal/services/tenant_service.go @@ -101,6 +101,19 @@ func (s *TenantService) GetUserRole(ctx context.Context, userID, tenantID uuid.U return role, nil } +// VerifyAccess checks if a user has access to a given tenant. +func (s *TenantService) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) { + var exists bool + err := s.db.GetContext(ctx, &exists, + `SELECT EXISTS(SELECT 1 FROM user_tenants WHERE user_id = $1 AND tenant_id = $2)`, + userID, tenantID, + ) + if err != nil { + return false, fmt.Errorf("verify tenant access: %w", err) + } + return exists, nil +} + // FirstTenantForUser returns the user's first tenant (by name), used as default. func (s *TenantService) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { var tenantID uuid.UUID