diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index 4f31eb6..91533d6 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -24,28 +24,19 @@ func (m *Middleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := extractBearerToken(r) if token == "" { - http.Error(w, "missing authorization token", http.StatusUnauthorized) + http.Error(w, `{"error":"missing authorization token"}`, http.StatusUnauthorized) return } userID, err := m.verifyJWT(token) if err != nil { - http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) + http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized) return } ctx := ContextWithUserID(r.Context(), userID) - - // Resolve tenant from user_tenants - var tenantID uuid.UUID - err = m.db.GetContext(r.Context(), &tenantID, - "SELECT tenant_id FROM user_tenants WHERE user_id = $1 LIMIT 1", userID) - if err != nil { - http.Error(w, "no tenant found for user", http.StatusForbidden) - return - } - ctx = ContextWithTenantID(ctx, tenantID) - + // Tenant resolution is handled by TenantResolver middleware for scoped routes. + // Tenant management routes handle their own access control. next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/backend/internal/auth/tenant_resolver.go b/backend/internal/auth/tenant_resolver.go index 6358d4d..24688d0 100644 --- a/backend/internal/auth/tenant_resolver.go +++ b/backend/internal/auth/tenant_resolver.go @@ -2,20 +2,21 @@ package auth import ( "context" - "fmt" + "log/slog" "net/http" "github.com/google/uuid" ) -// TenantLookup resolves the default tenant for a user. +// TenantLookup resolves and verifies tenant access for a user. // Defined as an interface to avoid circular dependency with services. type TenantLookup interface { FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) + VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) } // TenantResolver is middleware that resolves the tenant from X-Tenant-ID header -// or defaults to the user's first tenant. +// or defaults to the user's first tenant. Always verifies user has access. type TenantResolver struct { lookup TenantLookup } @@ -28,7 +29,7 @@ func (tr *TenantResolver) Resolve(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID, ok := UserFromContext(r.Context()) if !ok { - http.Error(w, "unauthorized", http.StatusUnauthorized) + http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) return } @@ -37,19 +38,33 @@ func (tr *TenantResolver) Resolve(next http.Handler) http.Handler { if header := r.Header.Get("X-Tenant-ID"); header != "" { parsed, err := uuid.Parse(header) if err != nil { - http.Error(w, fmt.Sprintf("invalid X-Tenant-ID: %v", err), http.StatusBadRequest) + http.Error(w, `{"error":"invalid X-Tenant-ID"}`, http.StatusBadRequest) return } + + // Verify user has access to this tenant + hasAccess, err := tr.lookup.VerifyAccess(r.Context(), userID, parsed) + if err != nil { + slog.Error("tenant access check failed", "error", err, "user_id", userID, "tenant_id", parsed) + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return + } + if !hasAccess { + http.Error(w, `{"error":"no access to tenant"}`, http.StatusForbidden) + return + } + tenantID = parsed } else { // Default to user's first tenant first, err := tr.lookup.FirstTenantForUser(r.Context(), userID) if err != nil { - http.Error(w, fmt.Sprintf("resolving tenant: %v", err), http.StatusInternalServerError) + slog.Error("failed to resolve default tenant", "error", err, "user_id", userID) + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) return } if first == nil { - http.Error(w, "no tenant found for user", http.StatusBadRequest) + http.Error(w, `{"error":"no tenant found for user"}`, http.StatusBadRequest) return } tenantID = *first diff --git a/backend/internal/auth/tenant_resolver_test.go b/backend/internal/auth/tenant_resolver_test.go index dfb8e2d..a542bdf 100644 --- a/backend/internal/auth/tenant_resolver_test.go +++ b/backend/internal/auth/tenant_resolver_test.go @@ -10,17 +10,23 @@ import ( ) type mockTenantLookup struct { - tenantID *uuid.UUID - err error + tenantID *uuid.UUID + err error + hasAccess bool + accessErr error } func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { return m.tenantID, m.err } +func (m *mockTenantLookup) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) { + return m.hasAccess, m.accessErr +} + func TestTenantResolver_FromHeader(t *testing.T) { tenantID := uuid.New() - tr := NewTenantResolver(&mockTenantLookup{}) + tr := NewTenantResolver(&mockTenantLookup{hasAccess: true}) var gotTenantID uuid.UUID next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -47,6 +53,26 @@ func TestTenantResolver_FromHeader(t *testing.T) { } } +func TestTenantResolver_FromHeader_NoAccess(t *testing.T) { + tenantID := uuid.New() + tr := NewTenantResolver(&mockTenantLookup{hasAccess: false}) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next should not be called") + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + r.Header.Set("X-Tenant-ID", tenantID.String()) + r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) + } +} + func TestTenantResolver_DefaultsToFirst(t *testing.T) { tenantID := uuid.New() tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID}) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3b78b3a..620a7f4 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -13,6 +13,7 @@ type Config struct { SupabaseServiceKey string SupabaseJWTSecret string AnthropicAPIKey string + FrontendOrigin string } func Load() (*Config, error) { @@ -24,6 +25,7 @@ func Load() (*Config, error) { SupabaseServiceKey: os.Getenv("SUPABASE_SERVICE_KEY"), SupabaseJWTSecret: os.Getenv("SUPABASE_JWT_SECRET"), AnthropicAPIKey: os.Getenv("ANTHROPIC_API_KEY"), + FrontendOrigin: getEnv("FRONTEND_ORIGIN", "https://kanzlai.msbls.de"), } if cfg.DatabaseURL == "" { diff --git a/backend/internal/handlers/ai.go b/backend/internal/handlers/ai.go index 806c9b3..1a03c37 100644 --- a/backend/internal/handlers/ai.go +++ b/backend/internal/handlers/ai.go @@ -5,18 +5,16 @@ import ( "io" "net/http" - "github.com/jmoiron/sqlx" - + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" ) type AIHandler struct { ai *services.AIService - db *sqlx.DB } -func NewAIHandler(ai *services.AIService, db *sqlx.DB) *AIHandler { - return &AIHandler{ai: ai, db: db} +func NewAIHandler(ai *services.AIService) *AIHandler { + return &AIHandler{ai: ai} } // ExtractDeadlines handles POST /api/ai/extract-deadlines @@ -61,10 +59,14 @@ func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "provide either a PDF file or text") return } + if len(text) > maxDescriptionLen { + writeError(w, http.StatusBadRequest, "text exceeds maximum length") + return + } deadlines, err := h.ai.ExtractDeadlines(r.Context(), pdfData, text) if err != nil { - writeError(w, http.StatusInternalServerError, "AI extraction failed: "+err.Error()) + internalError(w, "AI deadline extraction failed", err) return } @@ -77,9 +79,9 @@ func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) { // SummarizeCase handles POST /api/ai/summarize-case // Accepts JSON {"case_id": "uuid"}. func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -104,7 +106,7 @@ func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) { summary, err := h.ai.SummarizeCase(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, "AI summarization failed: "+err.Error()) + internalError(w, "AI case summarization failed", err) return } diff --git a/backend/internal/handlers/ai_handler_test.go b/backend/internal/handlers/ai_handler_test.go index 7fb3f67..9622413 100644 --- a/backend/internal/handlers/ai_handler_test.go +++ b/backend/internal/handlers/ai_handler_test.go @@ -42,7 +42,7 @@ func TestAIExtractDeadlines_InvalidJSON(t *testing.T) { } } -func TestAISummarizeCase_MissingCaseID(t *testing.T) { +func TestAISummarizeCase_MissingTenant(t *testing.T) { h := &AIHandler{} body := `{"case_id":""}` @@ -52,9 +52,9 @@ func TestAISummarizeCase_MissingCaseID(t *testing.T) { h.SummarizeCase(w, r) - // Without auth context, the resolveTenant will fail first - if w.Code != http.StatusUnauthorized { - t.Errorf("expected 401, got %d", w.Code) + // Without tenant context, TenantFromContext returns !ok → 403 + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) } } @@ -67,8 +67,8 @@ func TestAISummarizeCase_InvalidJSON(t *testing.T) { h.SummarizeCase(w, r) - // Without auth context, the resolveTenant will fail first - if w.Code != http.StatusUnauthorized { - t.Errorf("expected 401, got %d", w.Code) + // Without tenant context, TenantFromContext returns !ok → 403 + if w.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d", w.Code) } } diff --git a/backend/internal/handlers/appointments.go b/backend/internal/handlers/appointments.go index d8acaab..188fc4e 100644 --- a/backend/internal/handlers/appointments.go +++ b/backend/internal/handlers/appointments.go @@ -121,6 +121,10 @@ func (h *AppointmentHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "title is required") return } + if msg := validateStringLength("title", req.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } if req.StartAt.IsZero() { writeError(w, http.StatusBadRequest, "start_at is required") return @@ -188,6 +192,10 @@ func (h *AppointmentHandler) Update(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "title is required") return } + if msg := validateStringLength("title", req.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } if req.StartAt.IsZero() { writeError(w, http.StatusBadRequest, "start_at is required") return diff --git a/backend/internal/handlers/caldav.go b/backend/internal/handlers/caldav.go index cb38e72..8f93e4a 100644 --- a/backend/internal/handlers/caldav.go +++ b/backend/internal/handlers/caldav.go @@ -27,7 +27,7 @@ func (h *CalDAVHandler) TriggerSync(w http.ResponseWriter, r *http.Request) { cfg, err := h.svc.LoadTenantConfig(tenantID) if err != nil { - writeError(w, http.StatusBadRequest, err.Error()) + writeError(w, http.StatusBadRequest, "CalDAV not configured for this tenant") return } diff --git a/backend/internal/handlers/cases.go b/backend/internal/handlers/cases.go index a10d9d5..8cd9d33 100644 --- a/backend/internal/handlers/cases.go +++ b/backend/internal/handlers/cases.go @@ -28,18 +28,25 @@ func (h *CaseHandler) List(w http.ResponseWriter, r *http.Request) { limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) offset, _ := strconv.Atoi(r.URL.Query().Get("offset")) + limit, offset = clampPagination(limit, offset) + + search := r.URL.Query().Get("search") + if msg := validateStringLength("search", search, maxSearchLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } filter := services.CaseFilter{ Status: r.URL.Query().Get("status"), Type: r.URL.Query().Get("type"), - Search: r.URL.Query().Get("search"), + Search: search, Limit: limit, Offset: offset, } cases, total, err := h.svc.List(r.Context(), tenantID, filter) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to list cases", err) return } @@ -66,10 +73,18 @@ func (h *CaseHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "case_number and title are required") return } + if msg := validateStringLength("case_number", input.CaseNumber, maxCaseNumberLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + if msg := validateStringLength("title", input.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } c, err := h.svc.Create(r.Context(), tenantID, userID, input) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to create case", err) return } @@ -91,7 +106,7 @@ func (h *CaseHandler) Get(w http.ResponseWriter, r *http.Request) { detail, err := h.svc.GetByID(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to get case", err) return } if detail == nil { @@ -121,10 +136,22 @@ func (h *CaseHandler) Update(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "invalid JSON body") return } + if input.Title != nil { + if msg := validateStringLength("title", *input.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + } + if input.CaseNumber != nil { + if msg := validateStringLength("case_number", *input.CaseNumber, maxCaseNumberLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + } updated, err := h.svc.Update(r.Context(), tenantID, caseID, userID, input) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to update case", err) return } if updated == nil { diff --git a/backend/internal/handlers/dashboard.go b/backend/internal/handlers/dashboard.go index 09b5dbe..806800b 100644 --- a/backend/internal/handlers/dashboard.go +++ b/backend/internal/handlers/dashboard.go @@ -24,7 +24,7 @@ func (h *DashboardHandler) Get(w http.ResponseWriter, r *http.Request) { data, err := h.svc.Get(r.Context(), tenantID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to load dashboard", err) return } diff --git a/backend/internal/handlers/deadlines.go b/backend/internal/handlers/deadlines.go index 5b2b265..f8bdac2 100644 --- a/backend/internal/handlers/deadlines.go +++ b/backend/internal/handlers/deadlines.go @@ -4,27 +4,25 @@ import ( "encoding/json" "net/http" - "github.com/jmoiron/sqlx" - + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" ) // DeadlineHandlers holds handlers for deadline CRUD endpoints type DeadlineHandlers struct { deadlines *services.DeadlineService - db *sqlx.DB } // NewDeadlineHandlers creates deadline handlers -func NewDeadlineHandlers(ds *services.DeadlineService, db *sqlx.DB) *DeadlineHandlers { - return &DeadlineHandlers{deadlines: ds, db: db} +func NewDeadlineHandlers(ds *services.DeadlineService) *DeadlineHandlers { + return &DeadlineHandlers{deadlines: ds} } // Get handles GET /api/deadlines/{deadlineID} func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -36,7 +34,7 @@ func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { deadline, err := h.deadlines.GetByID(tenantID, deadlineID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to fetch deadline") + internalError(w, "failed to fetch deadline", err) return } if deadline == nil { @@ -49,15 +47,15 @@ func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { // ListAll handles GET /api/deadlines func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } deadlines, err := h.deadlines.ListAll(tenantID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list deadlines") + internalError(w, "failed to list deadlines", err) return } @@ -66,9 +64,9 @@ func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) { // ListForCase handles GET /api/cases/{caseID}/deadlines func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -80,7 +78,7 @@ func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { deadlines, err := h.deadlines.ListForCase(tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list deadlines") + internalError(w, "failed to list deadlines for case", err) return } @@ -89,9 +87,9 @@ func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { // Create handles POST /api/cases/{caseID}/deadlines func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -112,10 +110,14 @@ func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "title and due_date are required") return } + if msg := validateStringLength("title", input.Title, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } deadline, err := h.deadlines.Create(tenantID, input) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to create deadline") + internalError(w, "failed to create deadline", err) return } @@ -124,9 +126,9 @@ func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { // Update handles PUT /api/deadlines/{deadlineID} func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -144,7 +146,7 @@ func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { deadline, err := h.deadlines.Update(tenantID, deadlineID, input) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to update deadline") + internalError(w, "failed to update deadline", err) return } if deadline == nil { @@ -157,9 +159,9 @@ func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { // Complete handles PATCH /api/deadlines/{deadlineID}/complete func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -171,7 +173,7 @@ func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { deadline, err := h.deadlines.Complete(tenantID, deadlineID) if err != nil { - writeError(w, http.StatusInternalServerError, "failed to complete deadline") + internalError(w, "failed to complete deadline", err) return } if deadline == nil { @@ -184,9 +186,9 @@ func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { // Delete handles DELETE /api/deadlines/{deadlineID} func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) { - tenantID, err := resolveTenant(r, h.db) - if err != nil { - handleTenantError(w, err) + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") return } @@ -196,9 +198,8 @@ func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) { return } - err = h.deadlines.Delete(tenantID, deadlineID) - if err != nil { - writeError(w, http.StatusNotFound, err.Error()) + if err := h.deadlines.Delete(tenantID, deadlineID); err != nil { + writeError(w, http.StatusNotFound, "deadline not found") return } diff --git a/backend/internal/handlers/documents.go b/backend/internal/handlers/documents.go index c15c0cb..3a9098c 100644 --- a/backend/internal/handlers/documents.go +++ b/backend/internal/handlers/documents.go @@ -36,7 +36,7 @@ func (h *DocumentHandler) ListByCase(w http.ResponseWriter, r *http.Request) { docs, err := h.svc.ListByCase(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to list documents", err) return } @@ -98,7 +98,7 @@ func (h *DocumentHandler) Upload(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusNotFound, "case not found") return } - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to upload document", err) return } @@ -121,16 +121,16 @@ func (h *DocumentHandler) Download(w http.ResponseWriter, r *http.Request) { body, contentType, title, err := h.svc.Download(r.Context(), tenantID, docID) if err != nil { if err.Error() == "document not found" || err.Error() == "document has no file" { - writeError(w, http.StatusNotFound, err.Error()) + writeError(w, http.StatusNotFound, "document not found") return } - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to download document", err) return } defer body.Close() w.Header().Set("Content-Type", contentType) - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, title)) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, sanitizeFilename(title))) io.Copy(w, body) } @@ -149,7 +149,7 @@ func (h *DocumentHandler) GetMeta(w http.ResponseWriter, r *http.Request) { doc, err := h.svc.GetByID(r.Context(), tenantID, docID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to get document metadata", err) return } if doc == nil { diff --git a/backend/internal/handlers/helpers.go b/backend/internal/handlers/helpers.go index 785768e..88d0d79 100644 --- a/backend/internal/handlers/helpers.go +++ b/backend/internal/handlers/helpers.go @@ -2,12 +2,12 @@ package handlers import ( "encoding/json" + "log/slog" "net/http" + "strings" + "unicode/utf8" "github.com/google/uuid" - "github.com/jmoiron/sqlx" - - "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" ) func writeJSON(w http.ResponseWriter, status int, v any) { @@ -20,62 +20,9 @@ func writeError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } -// resolveTenant gets the tenant ID for the authenticated user. -// Checks X-Tenant-ID header first, then falls back to user's first tenant. -func resolveTenant(r *http.Request, db *sqlx.DB) (uuid.UUID, error) { - userID, ok := auth.UserFromContext(r.Context()) - if !ok { - return uuid.Nil, errUnauthorized - } - - // Check header first - if headerVal := r.Header.Get("X-Tenant-ID"); headerVal != "" { - tenantID, err := uuid.Parse(headerVal) - if err != nil { - return uuid.Nil, errInvalidTenant - } - // Verify user has access to this tenant - var count int - err = db.Get(&count, - `SELECT COUNT(*) FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`, - userID, tenantID) - if err != nil || count == 0 { - return uuid.Nil, errTenantAccess - } - return tenantID, nil - } - - // Fall back to user's first tenant - var tenantID uuid.UUID - err := db.Get(&tenantID, - `SELECT tenant_id FROM user_tenants WHERE user_id = $1 ORDER BY created_at LIMIT 1`, - userID) - if err != nil { - return uuid.Nil, errNoTenant - } - return tenantID, nil -} - -type apiError struct { - msg string - status int -} - -func (e *apiError) Error() string { return e.msg } - -var ( - errUnauthorized = &apiError{msg: "unauthorized", status: http.StatusUnauthorized} - errInvalidTenant = &apiError{msg: "invalid tenant ID", status: http.StatusBadRequest} - errTenantAccess = &apiError{msg: "no access to tenant", status: http.StatusForbidden} - errNoTenant = &apiError{msg: "no tenant found for user", status: http.StatusBadRequest} -) - -// handleTenantError writes the appropriate error response for tenant resolution errors -func handleTenantError(w http.ResponseWriter, err error) { - if ae, ok := err.(*apiError); ok { - writeError(w, ae.status, ae.msg) - return - } +// internalError logs the real error and returns a generic message to the client. +func internalError(w http.ResponseWriter, msg string, err error) { + slog.Error(msg, "error", err) writeError(w, http.StatusInternalServerError, "internal error") } @@ -88,3 +35,74 @@ func parsePathUUID(r *http.Request, key string) (uuid.UUID, error) { func parseUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) } + +// --- Input validation helpers --- + +const ( + maxTitleLen = 500 + maxDescriptionLen = 10000 + maxCaseNumberLen = 100 + maxSearchLen = 200 + maxPaginationLimit = 100 +) + +// validateStringLength checks if a string exceeds the given max length. +func validateStringLength(field, value string, maxLen int) string { + if utf8.RuneCountInString(value) > maxLen { + return field + " exceeds maximum length" + } + return "" +} + +// clampPagination enforces sane pagination defaults and limits. +func clampPagination(limit, offset int) (int, int) { + if limit <= 0 { + limit = 20 + } + if limit > maxPaginationLimit { + limit = maxPaginationLimit + } + if offset < 0 { + offset = 0 + } + return limit, offset +} + +// sanitizeFilename removes characters unsafe for Content-Disposition headers. +func sanitizeFilename(name string) string { + // Remove control characters, quotes, and backslashes + var b strings.Builder + for _, r := range name { + if r < 32 || r == '"' || r == '\\' || r == '/' { + b.WriteRune('_') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// maskSettingsPassword masks the CalDAV password in tenant settings JSON before returning to clients. +func maskSettingsPassword(settings json.RawMessage) json.RawMessage { + if len(settings) == 0 { + return settings + } + var m map[string]json.RawMessage + if err := json.Unmarshal(settings, &m); err != nil { + return settings + } + caldavRaw, ok := m["caldav"] + if !ok { + return settings + } + var caldav map[string]json.RawMessage + if err := json.Unmarshal(caldavRaw, &caldav); err != nil { + return settings + } + if _, ok := caldav["password"]; ok { + caldav["password"], _ = json.Marshal("********") + } + m["caldav"], _ = json.Marshal(caldav) + result, _ := json.Marshal(m) + return result +} diff --git a/backend/internal/handlers/notes.go b/backend/internal/handlers/notes.go index 68c9173..e935007 100644 --- a/backend/internal/handlers/notes.go +++ b/backend/internal/handlers/notes.go @@ -60,6 +60,10 @@ func (h *NoteHandler) Create(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "content is required") return } + if msg := validateStringLength("content", input.Content, maxDescriptionLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } var createdBy *uuid.UUID if userID != uuid.Nil { @@ -100,6 +104,10 @@ func (h *NoteHandler) Update(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "content is required") return } + if msg := validateStringLength("content", req.Content, maxDescriptionLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } note, err := h.svc.Update(r.Context(), tenantID, noteID, req.Content) if err != nil { diff --git a/backend/internal/handlers/parties.go b/backend/internal/handlers/parties.go index 178e06b..0b220cd 100644 --- a/backend/internal/handlers/parties.go +++ b/backend/internal/handlers/parties.go @@ -34,7 +34,7 @@ func (h *PartyHandler) List(w http.ResponseWriter, r *http.Request) { parties, err := h.svc.ListByCase(r.Context(), tenantID, caseID) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to list parties", err) return } @@ -67,13 +67,18 @@ func (h *PartyHandler) Create(w http.ResponseWriter, r *http.Request) { return } + if msg := validateStringLength("name", input.Name, maxTitleLen); msg != "" { + writeError(w, http.StatusBadRequest, msg) + return + } + party, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input) if err != nil { if err == sql.ErrNoRows { writeError(w, http.StatusNotFound, "case not found") return } - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to create party", err) return } @@ -101,7 +106,7 @@ func (h *PartyHandler) Update(w http.ResponseWriter, r *http.Request) { updated, err := h.svc.Update(r.Context(), tenantID, partyID, input) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + internalError(w, "failed to update party", err) return } if updated == nil { diff --git a/backend/internal/handlers/tenant_handler.go b/backend/internal/handlers/tenant_handler.go index 1db14f7..52bb29f 100644 --- a/backend/internal/handlers/tenant_handler.go +++ b/backend/internal/handlers/tenant_handler.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "log/slog" "net/http" "github.com/google/uuid" @@ -41,7 +42,8 @@ func (h *TenantHandler) CreateTenant(w http.ResponseWriter, r *http.Request) { tenant, err := h.svc.Create(r.Context(), userID, req.Name, req.Slug) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to create tenant", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } @@ -58,10 +60,16 @@ func (h *TenantHandler) ListTenants(w http.ResponseWriter, r *http.Request) { tenants, err := h.svc.ListForUser(r.Context(), userID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to list tenants", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } + // Mask CalDAV passwords in tenant settings + for i := range tenants { + tenants[i].Settings = maskSettingsPassword(tenants[i].Settings) + } + jsonResponse(w, tenants, http.StatusOK) } @@ -82,7 +90,8 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { // Verify user has access to this tenant role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role == "" { @@ -92,7 +101,8 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { tenant, err := h.svc.GetByID(r.Context(), tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get tenant", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if tenant == nil { @@ -100,6 +110,9 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { return } + // Mask CalDAV password before returning + tenant.Settings = maskSettingsPassword(tenant.Settings) + jsonResponse(w, tenant, http.StatusOK) } @@ -120,7 +133,8 @@ func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) { // Only owners and admins can invite role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role != "owner" && role != "admin" { @@ -150,7 +164,8 @@ func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) { ut, err := h.svc.InviteByEmail(r.Context(), tenantID, req.Email, req.Role) if err != nil { - jsonError(w, err.Error(), http.StatusBadRequest) + // These are user-facing validation errors (user not found, already member) + jsonError(w, "failed to invite user", http.StatusBadRequest) return } @@ -180,7 +195,8 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { // Only owners and admins can remove members (or user removing themselves) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role != "owner" && role != "admin" && userID != memberID { @@ -189,7 +205,8 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { } if err := h.svc.RemoveMember(r.Context(), tenantID, memberID); err != nil { - jsonError(w, err.Error(), http.StatusBadRequest) + // These are user-facing validation errors (not a member, last owner, etc.) + jsonError(w, "failed to remove member", http.StatusBadRequest) return } @@ -213,7 +230,8 @@ func (h *TenantHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) { // Only owners and admins can update settings role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role != "owner" && role != "admin" { @@ -229,10 +247,14 @@ func (h *TenantHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) { tenant, err := h.svc.UpdateSettings(r.Context(), tenantID, settings) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to update settings", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } + // Mask CalDAV password before returning + tenant.Settings = maskSettingsPassword(tenant.Settings) + jsonResponse(w, tenant, http.StatusOK) } @@ -253,7 +275,8 @@ func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) { // Verify user has access role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to get user role", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } if role == "" { @@ -263,7 +286,8 @@ func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) { members, err := h.svc.ListMembers(r.Context(), tenantID) if err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) + slog.Error("failed to list members", "error", err) + jsonError(w, "internal error", http.StatusInternalServerError) return } diff --git a/backend/internal/middleware/security.go b/backend/internal/middleware/security.go new file mode 100644 index 0000000..eccc6ea --- /dev/null +++ b/backend/internal/middleware/security.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// SecurityHeaders adds standard security headers to all responses. +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-XSS-Protection", "1; mode=block") + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + next.ServeHTTP(w, r) + }) +} + +// CORS returns middleware that restricts cross-origin requests to the given origin. +// If allowedOrigin is empty, CORS headers are not set (same-origin only). +func CORS(allowedOrigin string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + if allowedOrigin != "" && origin != "" && matchOrigin(origin, allowedOrigin) { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID") + w.Header().Set("Access-Control-Max-Age", "86400") + w.Header().Set("Vary", "Origin") + } + + // Handle preflight + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// matchOrigin checks if the request origin matches the allowed origin. +func matchOrigin(origin, allowed string) bool { + return strings.EqualFold(strings.TrimRight(origin, "/"), strings.TrimRight(allowed, "/")) +} diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go index 3c6f8a5..02b9dd9 100644 --- a/backend/internal/router/router.go +++ b/backend/internal/router/router.go @@ -34,7 +34,7 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se var aiH *handlers.AIHandler if cfg.AnthropicAPIKey != "" { aiSvc := services.NewAIService(cfg.AnthropicAPIKey, db) - aiH = handlers.NewAIHandler(aiSvc, db) + aiH = handlers.NewAIHandler(aiSvc) } // Middleware @@ -48,7 +48,7 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se caseH := handlers.NewCaseHandler(caseSvc) partyH := handlers.NewPartyHandler(partySvc) apptH := handlers.NewAppointmentHandler(appointmentSvc) - deadlineH := handlers.NewDeadlineHandlers(deadlineSvc, db) + deadlineH := handlers.NewDeadlineHandlers(deadlineSvc) ruleH := handlers.NewDeadlineRuleHandlers(deadlineRuleSvc) calcH := handlers.NewCalculateHandlers(calculator, deadlineRuleSvc) dashboardH := handlers.NewDashboardHandler(dashboardSvc) @@ -149,14 +149,20 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se mux.Handle("/api/", authMW.RequireAuth(api)) - return requestLogger(mux) + // Apply security middleware stack: CORS -> Security Headers -> Request Logger -> Routes + var handler http.Handler = mux + handler = requestLogger(handler) + handler = middleware.SecurityHeaders(handler) + handler = middleware.CORS(cfg.FrontendOrigin)(handler) + + return handler } func handleHealth(db *sqlx.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := db.Ping(); err != nil { w.WriteHeader(http.StatusServiceUnavailable) - json.NewEncoder(w).Encode(map[string]string{"status": "error", "error": err.Error()}) + json.NewEncoder(w).Encode(map[string]string{"status": "error"}) return } w.Header().Set("Content-Type", "application/json") @@ -194,4 +200,3 @@ func requestLogger(next http.Handler) http.Handler { ) }) } - diff --git a/backend/internal/services/tenant_service.go b/backend/internal/services/tenant_service.go index 7ed5614..0c0f52e 100644 --- a/backend/internal/services/tenant_service.go +++ b/backend/internal/services/tenant_service.go @@ -101,6 +101,19 @@ func (s *TenantService) GetUserRole(ctx context.Context, userID, tenantID uuid.U return role, nil } +// VerifyAccess checks if a user has access to a given tenant. +func (s *TenantService) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) { + var exists bool + err := s.db.GetContext(ctx, &exists, + `SELECT EXISTS(SELECT 1 FROM user_tenants WHERE user_id = $1 AND tenant_id = $2)`, + userID, tenantID, + ) + if err != nil { + return false, fmt.Errorf("verify tenant access: %w", err) + } + return exists, nil +} + // FirstTenantForUser returns the user's first tenant (by name), used as default. func (s *TenantService) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { var tenantID uuid.UUID