Compare commits

..

2 Commits

Author SHA1 Message Date
m
c324a2b5c7 fix: critical security hardening — tenant isolation, CORS, error masking, input validation 2026-03-30 11:02:52 +02:00
m
c15d5b72f2 fix: critical security hardening — tenant isolation, CORS, error leaking, input validation
1. Tenant isolation bypass (CRITICAL): TenantResolver now verifies user
   has access to X-Tenant-ID via user_tenants lookup before setting context.
   Added VerifyAccess method to TenantLookup interface and TenantService.

2. Consolidated tenant resolution: Removed duplicate resolveTenant() from
   helpers.go and tenant resolution from auth middleware. TenantResolver is
   now the single source of truth. Deadlines and AI handlers use
   auth.TenantFromContext() instead of direct DB queries.

3. CalDAV credential masking: tenant settings responses now mask CalDAV
   passwords with "********" via maskSettingsPassword helper. Applied to
   GetTenant, ListTenants, and UpdateSettings responses.

4. CORS + security headers: New middleware/security.go with CORS
   (restricted to FRONTEND_ORIGIN) and security headers (X-Frame-Options,
   X-Content-Type-Options, HSTS, Referrer-Policy, X-XSS-Protection).

5. Internal error leaking: All writeError(w, 500, err.Error()) replaced
   with internalError() that logs via slog and returns generic "internal
   error" to client. Same for jsonError in tenant handler.

6. Input validation: Max length on title (500), description (10000),
   case_number (100), search (200). Pagination clamped to max 100.
   Content-Disposition filename sanitized against header injection.

Regression test added for tenant access denial (403 on unauthorized
X-Tenant-ID). All existing tests pass, go vet clean.
2026-03-30 11:01:14 +02:00
32 changed files with 397 additions and 699 deletions

View File

@@ -11,8 +11,6 @@ type contextKey string
const ( const (
userIDKey contextKey = "user_id" userIDKey contextKey = "user_id"
tenantIDKey contextKey = "tenant_id" tenantIDKey contextKey = "tenant_id"
ipKey contextKey = "ip_address"
userAgentKey contextKey = "user_agent"
) )
func ContextWithUserID(ctx context.Context, userID uuid.UUID) context.Context { func ContextWithUserID(ctx context.Context, userID uuid.UUID) context.Context {
@@ -32,23 +30,3 @@ func TenantFromContext(ctx context.Context) (uuid.UUID, bool) {
id, ok := ctx.Value(tenantIDKey).(uuid.UUID) id, ok := ctx.Value(tenantIDKey).(uuid.UUID)
return id, ok return id, ok
} }
func ContextWithRequestInfo(ctx context.Context, ip, userAgent string) context.Context {
ctx = context.WithValue(ctx, ipKey, ip)
ctx = context.WithValue(ctx, userAgentKey, userAgent)
return ctx
}
func IPFromContext(ctx context.Context) *string {
if v, ok := ctx.Value(ipKey).(string); ok && v != "" {
return &v
}
return nil
}
func UserAgentFromContext(ctx context.Context) *string {
if v, ok := ctx.Value(userAgentKey).(string); ok && v != "" {
return &v
}
return nil
}

View File

@@ -24,35 +24,19 @@ func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := extractBearerToken(r) token := extractBearerToken(r)
if token == "" { if token == "" {
http.Error(w, "missing authorization token", http.StatusUnauthorized) http.Error(w, `{"error":"missing authorization token"}`, http.StatusUnauthorized)
return return
} }
userID, err := m.verifyJWT(token) userID, err := m.verifyJWT(token)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized) http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
return return
} }
ctx := ContextWithUserID(r.Context(), userID) ctx := ContextWithUserID(r.Context(), userID)
// Tenant resolution is handled by TenantResolver middleware for scoped routes.
// Resolve tenant from user_tenants // Tenant management routes handle their own access control.
var tenantID uuid.UUID
err = m.db.GetContext(r.Context(), &tenantID,
"SELECT tenant_id FROM user_tenants WHERE user_id = $1 LIMIT 1", userID)
if err != nil {
http.Error(w, "no tenant found for user", http.StatusForbidden)
return
}
ctx = ContextWithTenantID(ctx, tenantID)
// Capture IP and user-agent for audit logging
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.RemoteAddr
}
ctx = ContextWithRequestInfo(ctx, ip, r.UserAgent())
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }

View File

@@ -2,20 +2,21 @@ package auth
import ( import (
"context" "context"
"fmt" "log/slog"
"net/http" "net/http"
"github.com/google/uuid" "github.com/google/uuid"
) )
// TenantLookup resolves the default tenant for a user. // TenantLookup resolves and verifies tenant access for a user.
// Defined as an interface to avoid circular dependency with services. // Defined as an interface to avoid circular dependency with services.
type TenantLookup interface { type TenantLookup interface {
FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error)
VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error)
} }
// TenantResolver is middleware that resolves the tenant from X-Tenant-ID header // TenantResolver is middleware that resolves the tenant from X-Tenant-ID header
// or defaults to the user's first tenant. // or defaults to the user's first tenant. Always verifies user has access.
type TenantResolver struct { type TenantResolver struct {
lookup TenantLookup lookup TenantLookup
} }
@@ -28,7 +29,7 @@ func (tr *TenantResolver) Resolve(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID, ok := UserFromContext(r.Context()) userID, ok := UserFromContext(r.Context())
if !ok { if !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized) http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
return return
} }
@@ -37,19 +38,33 @@ func (tr *TenantResolver) Resolve(next http.Handler) http.Handler {
if header := r.Header.Get("X-Tenant-ID"); header != "" { if header := r.Header.Get("X-Tenant-ID"); header != "" {
parsed, err := uuid.Parse(header) parsed, err := uuid.Parse(header)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("invalid X-Tenant-ID: %v", err), http.StatusBadRequest) http.Error(w, `{"error":"invalid X-Tenant-ID"}`, http.StatusBadRequest)
return return
} }
// Verify user has access to this tenant
hasAccess, err := tr.lookup.VerifyAccess(r.Context(), userID, parsed)
if err != nil {
slog.Error("tenant access check failed", "error", err, "user_id", userID, "tenant_id", parsed)
http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError)
return
}
if !hasAccess {
http.Error(w, `{"error":"no access to tenant"}`, http.StatusForbidden)
return
}
tenantID = parsed tenantID = parsed
} else { } else {
// Default to user's first tenant // Default to user's first tenant
first, err := tr.lookup.FirstTenantForUser(r.Context(), userID) first, err := tr.lookup.FirstTenantForUser(r.Context(), userID)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("resolving tenant: %v", err), http.StatusInternalServerError) slog.Error("failed to resolve default tenant", "error", err, "user_id", userID)
http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError)
return return
} }
if first == nil { if first == nil {
http.Error(w, "no tenant found for user", http.StatusBadRequest) http.Error(w, `{"error":"no tenant found for user"}`, http.StatusBadRequest)
return return
} }
tenantID = *first tenantID = *first

View File

@@ -12,15 +12,21 @@ import (
type mockTenantLookup struct { type mockTenantLookup struct {
tenantID *uuid.UUID tenantID *uuid.UUID
err error err error
hasAccess bool
accessErr error
} }
func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) {
return m.tenantID, m.err return m.tenantID, m.err
} }
func (m *mockTenantLookup) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) {
return m.hasAccess, m.accessErr
}
func TestTenantResolver_FromHeader(t *testing.T) { func TestTenantResolver_FromHeader(t *testing.T) {
tenantID := uuid.New() tenantID := uuid.New()
tr := NewTenantResolver(&mockTenantLookup{}) tr := NewTenantResolver(&mockTenantLookup{hasAccess: true})
var gotTenantID uuid.UUID var gotTenantID uuid.UUID
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -47,6 +53,26 @@ func TestTenantResolver_FromHeader(t *testing.T) {
} }
} }
func TestTenantResolver_FromHeader_NoAccess(t *testing.T) {
tenantID := uuid.New()
tr := NewTenantResolver(&mockTenantLookup{hasAccess: false})
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("next should not be called")
})
r := httptest.NewRequest("GET", "/api/cases", nil)
r.Header.Set("X-Tenant-ID", tenantID.String())
r = r.WithContext(ContextWithUserID(r.Context(), uuid.New()))
w := httptest.NewRecorder()
tr.Resolve(next).ServeHTTP(w, r)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403, got %d", w.Code)
}
}
func TestTenantResolver_DefaultsToFirst(t *testing.T) { func TestTenantResolver_DefaultsToFirst(t *testing.T) {
tenantID := uuid.New() tenantID := uuid.New()
tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID}) tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID})

View File

@@ -13,6 +13,7 @@ type Config struct {
SupabaseServiceKey string SupabaseServiceKey string
SupabaseJWTSecret string SupabaseJWTSecret string
AnthropicAPIKey string AnthropicAPIKey string
FrontendOrigin string
} }
func Load() (*Config, error) { func Load() (*Config, error) {
@@ -24,6 +25,7 @@ func Load() (*Config, error) {
SupabaseServiceKey: os.Getenv("SUPABASE_SERVICE_KEY"), SupabaseServiceKey: os.Getenv("SUPABASE_SERVICE_KEY"),
SupabaseJWTSecret: os.Getenv("SUPABASE_JWT_SECRET"), SupabaseJWTSecret: os.Getenv("SUPABASE_JWT_SECRET"),
AnthropicAPIKey: os.Getenv("ANTHROPIC_API_KEY"), AnthropicAPIKey: os.Getenv("ANTHROPIC_API_KEY"),
FrontendOrigin: getEnv("FRONTEND_ORIGIN", "https://kanzlai.msbls.de"),
} }
if cfg.DatabaseURL == "" { if cfg.DatabaseURL == "" {

View File

@@ -5,18 +5,16 @@ import (
"io" "io"
"net/http" "net/http"
"github.com/jmoiron/sqlx" "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services" "mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
) )
type AIHandler struct { type AIHandler struct {
ai *services.AIService ai *services.AIService
db *sqlx.DB
} }
func NewAIHandler(ai *services.AIService, db *sqlx.DB) *AIHandler { func NewAIHandler(ai *services.AIService) *AIHandler {
return &AIHandler{ai: ai, db: db} return &AIHandler{ai: ai}
} }
// ExtractDeadlines handles POST /api/ai/extract-deadlines // ExtractDeadlines handles POST /api/ai/extract-deadlines
@@ -61,10 +59,14 @@ func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "provide either a PDF file or text") writeError(w, http.StatusBadRequest, "provide either a PDF file or text")
return return
} }
if len(text) > maxDescriptionLen {
writeError(w, http.StatusBadRequest, "text exceeds maximum length")
return
}
deadlines, err := h.ai.ExtractDeadlines(r.Context(), pdfData, text) deadlines, err := h.ai.ExtractDeadlines(r.Context(), pdfData, text)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "AI extraction failed: "+err.Error()) internalError(w, "AI deadline extraction failed", err)
return return
} }
@@ -77,9 +79,9 @@ func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) {
// SummarizeCase handles POST /api/ai/summarize-case // SummarizeCase handles POST /api/ai/summarize-case
// Accepts JSON {"case_id": "uuid"}. // Accepts JSON {"case_id": "uuid"}.
func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) { func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -104,7 +106,7 @@ func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) {
summary, err := h.ai.SummarizeCase(r.Context(), tenantID, caseID) summary, err := h.ai.SummarizeCase(r.Context(), tenantID, caseID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "AI summarization failed: "+err.Error()) internalError(w, "AI case summarization failed", err)
return return
} }

View File

@@ -42,7 +42,7 @@ func TestAIExtractDeadlines_InvalidJSON(t *testing.T) {
} }
} }
func TestAISummarizeCase_MissingCaseID(t *testing.T) { func TestAISummarizeCase_MissingTenant(t *testing.T) {
h := &AIHandler{} h := &AIHandler{}
body := `{"case_id":""}` body := `{"case_id":""}`
@@ -52,9 +52,9 @@ func TestAISummarizeCase_MissingCaseID(t *testing.T) {
h.SummarizeCase(w, r) h.SummarizeCase(w, r)
// Without auth context, the resolveTenant will fail first // Without tenant context, TenantFromContext returns !ok → 403
if w.Code != http.StatusUnauthorized { if w.Code != http.StatusForbidden {
t.Errorf("expected 401, got %d", w.Code) t.Errorf("expected 403, got %d", w.Code)
} }
} }
@@ -67,8 +67,8 @@ func TestAISummarizeCase_InvalidJSON(t *testing.T) {
h.SummarizeCase(w, r) h.SummarizeCase(w, r)
// Without auth context, the resolveTenant will fail first // Without tenant context, TenantFromContext returns !ok → 403
if w.Code != http.StatusUnauthorized { if w.Code != http.StatusForbidden {
t.Errorf("expected 401, got %d", w.Code) t.Errorf("expected 403, got %d", w.Code)
} }
} }

View File

@@ -121,6 +121,10 @@ func (h *AppointmentHandler) Create(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "title is required") writeError(w, http.StatusBadRequest, "title is required")
return return
} }
if msg := validateStringLength("title", req.Title, maxTitleLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
if req.StartAt.IsZero() { if req.StartAt.IsZero() {
writeError(w, http.StatusBadRequest, "start_at is required") writeError(w, http.StatusBadRequest, "start_at is required")
return return
@@ -188,6 +192,10 @@ func (h *AppointmentHandler) Update(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "title is required") writeError(w, http.StatusBadRequest, "title is required")
return return
} }
if msg := validateStringLength("title", req.Title, maxTitleLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
if req.StartAt.IsZero() { if req.StartAt.IsZero() {
writeError(w, http.StatusBadRequest, "start_at is required") writeError(w, http.StatusBadRequest, "start_at is required")
return return

View File

@@ -1,63 +0,0 @@
package handlers
import (
"net/http"
"strconv"
"github.com/google/uuid"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
)
type AuditLogHandler struct {
svc *services.AuditService
}
func NewAuditLogHandler(svc *services.AuditService) *AuditLogHandler {
return &AuditLogHandler{svc: svc}
}
func (h *AuditLogHandler) List(w http.ResponseWriter, r *http.Request) {
tenantID, ok := auth.TenantFromContext(r.Context())
if !ok {
writeError(w, http.StatusForbidden, "missing tenant")
return
}
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
limit, _ := strconv.Atoi(q.Get("limit"))
filter := services.AuditFilter{
EntityType: q.Get("entity_type"),
From: q.Get("from"),
To: q.Get("to"),
Page: page,
Limit: limit,
}
if idStr := q.Get("entity_id"); idStr != "" {
if id, err := uuid.Parse(idStr); err == nil {
filter.EntityID = &id
}
}
if idStr := q.Get("user_id"); idStr != "" {
if id, err := uuid.Parse(idStr); err == nil {
filter.UserID = &id
}
}
entries, total, err := h.svc.List(r.Context(), tenantID, filter)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to fetch audit log")
return
}
writeJSON(w, http.StatusOK, map[string]any{
"entries": entries,
"total": total,
"page": filter.Page,
"limit": filter.Limit,
})
}

View File

@@ -27,7 +27,7 @@ func (h *CalDAVHandler) TriggerSync(w http.ResponseWriter, r *http.Request) {
cfg, err := h.svc.LoadTenantConfig(tenantID) cfg, err := h.svc.LoadTenantConfig(tenantID)
if err != nil { if err != nil {
writeError(w, http.StatusBadRequest, err.Error()) writeError(w, http.StatusBadRequest, "CalDAV not configured for this tenant")
return return
} }

View File

@@ -28,18 +28,25 @@ func (h *CaseHandler) List(w http.ResponseWriter, r *http.Request) {
limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
offset, _ := strconv.Atoi(r.URL.Query().Get("offset")) offset, _ := strconv.Atoi(r.URL.Query().Get("offset"))
limit, offset = clampPagination(limit, offset)
search := r.URL.Query().Get("search")
if msg := validateStringLength("search", search, maxSearchLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
filter := services.CaseFilter{ filter := services.CaseFilter{
Status: r.URL.Query().Get("status"), Status: r.URL.Query().Get("status"),
Type: r.URL.Query().Get("type"), Type: r.URL.Query().Get("type"),
Search: r.URL.Query().Get("search"), Search: search,
Limit: limit, Limit: limit,
Offset: offset, Offset: offset,
} }
cases, total, err := h.svc.List(r.Context(), tenantID, filter) cases, total, err := h.svc.List(r.Context(), tenantID, filter)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to list cases", err)
return return
} }
@@ -66,10 +73,18 @@ func (h *CaseHandler) Create(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "case_number and title are required") writeError(w, http.StatusBadRequest, "case_number and title are required")
return return
} }
if msg := validateStringLength("case_number", input.CaseNumber, maxCaseNumberLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
if msg := validateStringLength("title", input.Title, maxTitleLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
c, err := h.svc.Create(r.Context(), tenantID, userID, input) c, err := h.svc.Create(r.Context(), tenantID, userID, input)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to create case", err)
return return
} }
@@ -91,7 +106,7 @@ func (h *CaseHandler) Get(w http.ResponseWriter, r *http.Request) {
detail, err := h.svc.GetByID(r.Context(), tenantID, caseID) detail, err := h.svc.GetByID(r.Context(), tenantID, caseID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to get case", err)
return return
} }
if detail == nil { if detail == nil {
@@ -121,10 +136,22 @@ func (h *CaseHandler) Update(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "invalid JSON body") writeError(w, http.StatusBadRequest, "invalid JSON body")
return return
} }
if input.Title != nil {
if msg := validateStringLength("title", *input.Title, maxTitleLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
}
if input.CaseNumber != nil {
if msg := validateStringLength("case_number", *input.CaseNumber, maxCaseNumberLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
}
updated, err := h.svc.Update(r.Context(), tenantID, caseID, userID, input) updated, err := h.svc.Update(r.Context(), tenantID, caseID, userID, input)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to update case", err)
return return
} }
if updated == nil { if updated == nil {

View File

@@ -24,7 +24,7 @@ func (h *DashboardHandler) Get(w http.ResponseWriter, r *http.Request) {
data, err := h.svc.Get(r.Context(), tenantID) data, err := h.svc.Get(r.Context(), tenantID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to load dashboard", err)
return return
} }

View File

@@ -4,27 +4,25 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/jmoiron/sqlx" "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services" "mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
) )
// DeadlineHandlers holds handlers for deadline CRUD endpoints // DeadlineHandlers holds handlers for deadline CRUD endpoints
type DeadlineHandlers struct { type DeadlineHandlers struct {
deadlines *services.DeadlineService deadlines *services.DeadlineService
db *sqlx.DB
} }
// NewDeadlineHandlers creates deadline handlers // NewDeadlineHandlers creates deadline handlers
func NewDeadlineHandlers(ds *services.DeadlineService, db *sqlx.DB) *DeadlineHandlers { func NewDeadlineHandlers(ds *services.DeadlineService) *DeadlineHandlers {
return &DeadlineHandlers{deadlines: ds, db: db} return &DeadlineHandlers{deadlines: ds}
} }
// Get handles GET /api/deadlines/{deadlineID} // Get handles GET /api/deadlines/{deadlineID}
func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -36,7 +34,7 @@ func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) {
deadline, err := h.deadlines.GetByID(tenantID, deadlineID) deadline, err := h.deadlines.GetByID(tenantID, deadlineID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "failed to fetch deadline") internalError(w, "failed to fetch deadline", err)
return return
} }
if deadline == nil { if deadline == nil {
@@ -49,15 +47,15 @@ func (h *DeadlineHandlers) Get(w http.ResponseWriter, r *http.Request) {
// ListAll handles GET /api/deadlines // ListAll handles GET /api/deadlines
func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
deadlines, err := h.deadlines.ListAll(tenantID) deadlines, err := h.deadlines.ListAll(tenantID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "failed to list deadlines") internalError(w, "failed to list deadlines", err)
return return
} }
@@ -66,9 +64,9 @@ func (h *DeadlineHandlers) ListAll(w http.ResponseWriter, r *http.Request) {
// ListForCase handles GET /api/cases/{caseID}/deadlines // ListForCase handles GET /api/cases/{caseID}/deadlines
func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -80,7 +78,7 @@ func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) {
deadlines, err := h.deadlines.ListForCase(tenantID, caseID) deadlines, err := h.deadlines.ListForCase(tenantID, caseID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "failed to list deadlines") internalError(w, "failed to list deadlines for case", err)
return return
} }
@@ -89,9 +87,9 @@ func (h *DeadlineHandlers) ListForCase(w http.ResponseWriter, r *http.Request) {
// Create handles POST /api/cases/{caseID}/deadlines // Create handles POST /api/cases/{caseID}/deadlines
func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -112,10 +110,14 @@ func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "title and due_date are required") writeError(w, http.StatusBadRequest, "title and due_date are required")
return return
} }
if msg := validateStringLength("title", input.Title, maxTitleLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
deadline, err := h.deadlines.Create(r.Context(), tenantID, input) deadline, err := h.deadlines.Create(tenantID, input)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "failed to create deadline") internalError(w, "failed to create deadline", err)
return return
} }
@@ -124,9 +126,9 @@ func (h *DeadlineHandlers) Create(w http.ResponseWriter, r *http.Request) {
// Update handles PUT /api/deadlines/{deadlineID} // Update handles PUT /api/deadlines/{deadlineID}
func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -142,9 +144,9 @@ func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) {
return return
} }
deadline, err := h.deadlines.Update(r.Context(), tenantID, deadlineID, input) deadline, err := h.deadlines.Update(tenantID, deadlineID, input)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "failed to update deadline") internalError(w, "failed to update deadline", err)
return return
} }
if deadline == nil { if deadline == nil {
@@ -157,9 +159,9 @@ func (h *DeadlineHandlers) Update(w http.ResponseWriter, r *http.Request) {
// Complete handles PATCH /api/deadlines/{deadlineID}/complete // Complete handles PATCH /api/deadlines/{deadlineID}/complete
func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -169,9 +171,9 @@ func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) {
return return
} }
deadline, err := h.deadlines.Complete(r.Context(), tenantID, deadlineID) deadline, err := h.deadlines.Complete(tenantID, deadlineID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, "failed to complete deadline") internalError(w, "failed to complete deadline", err)
return return
} }
if deadline == nil { if deadline == nil {
@@ -184,9 +186,9 @@ func (h *DeadlineHandlers) Complete(w http.ResponseWriter, r *http.Request) {
// Delete handles DELETE /api/deadlines/{deadlineID} // Delete handles DELETE /api/deadlines/{deadlineID}
func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) { func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) {
tenantID, err := resolveTenant(r, h.db) tenantID, ok := auth.TenantFromContext(r.Context())
if err != nil { if !ok {
handleTenantError(w, err) writeError(w, http.StatusForbidden, "missing tenant")
return return
} }
@@ -196,9 +198,8 @@ func (h *DeadlineHandlers) Delete(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.deadlines.Delete(r.Context(), tenantID, deadlineID) if err := h.deadlines.Delete(tenantID, deadlineID); err != nil {
if err != nil { writeError(w, http.StatusNotFound, "deadline not found")
writeError(w, http.StatusNotFound, err.Error())
return return
} }

View File

@@ -36,7 +36,7 @@ func (h *DocumentHandler) ListByCase(w http.ResponseWriter, r *http.Request) {
docs, err := h.svc.ListByCase(r.Context(), tenantID, caseID) docs, err := h.svc.ListByCase(r.Context(), tenantID, caseID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to list documents", err)
return return
} }
@@ -98,7 +98,7 @@ func (h *DocumentHandler) Upload(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, "case not found") writeError(w, http.StatusNotFound, "case not found")
return return
} }
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to upload document", err)
return return
} }
@@ -121,16 +121,16 @@ func (h *DocumentHandler) Download(w http.ResponseWriter, r *http.Request) {
body, contentType, title, err := h.svc.Download(r.Context(), tenantID, docID) body, contentType, title, err := h.svc.Download(r.Context(), tenantID, docID)
if err != nil { if err != nil {
if err.Error() == "document not found" || err.Error() == "document has no file" { if err.Error() == "document not found" || err.Error() == "document has no file" {
writeError(w, http.StatusNotFound, err.Error()) writeError(w, http.StatusNotFound, "document not found")
return return
} }
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to download document", err)
return return
} }
defer body.Close() defer body.Close()
w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, title)) w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, sanitizeFilename(title)))
io.Copy(w, body) io.Copy(w, body)
} }
@@ -149,7 +149,7 @@ func (h *DocumentHandler) GetMeta(w http.ResponseWriter, r *http.Request) {
doc, err := h.svc.GetByID(r.Context(), tenantID, docID) doc, err := h.svc.GetByID(r.Context(), tenantID, docID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to get document metadata", err)
return return
} }
if doc == nil { if doc == nil {

View File

@@ -2,12 +2,12 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"log/slog"
"net/http" "net/http"
"strings"
"unicode/utf8"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jmoiron/sqlx"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
) )
func writeJSON(w http.ResponseWriter, status int, v any) { func writeJSON(w http.ResponseWriter, status int, v any) {
@@ -20,62 +20,9 @@ func writeError(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg}) writeJSON(w, status, map[string]string{"error": msg})
} }
// resolveTenant gets the tenant ID for the authenticated user. // internalError logs the real error and returns a generic message to the client.
// Checks X-Tenant-ID header first, then falls back to user's first tenant. func internalError(w http.ResponseWriter, msg string, err error) {
func resolveTenant(r *http.Request, db *sqlx.DB) (uuid.UUID, error) { slog.Error(msg, "error", err)
userID, ok := auth.UserFromContext(r.Context())
if !ok {
return uuid.Nil, errUnauthorized
}
// Check header first
if headerVal := r.Header.Get("X-Tenant-ID"); headerVal != "" {
tenantID, err := uuid.Parse(headerVal)
if err != nil {
return uuid.Nil, errInvalidTenant
}
// Verify user has access to this tenant
var count int
err = db.Get(&count,
`SELECT COUNT(*) FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`,
userID, tenantID)
if err != nil || count == 0 {
return uuid.Nil, errTenantAccess
}
return tenantID, nil
}
// Fall back to user's first tenant
var tenantID uuid.UUID
err := db.Get(&tenantID,
`SELECT tenant_id FROM user_tenants WHERE user_id = $1 ORDER BY created_at LIMIT 1`,
userID)
if err != nil {
return uuid.Nil, errNoTenant
}
return tenantID, nil
}
type apiError struct {
msg string
status int
}
func (e *apiError) Error() string { return e.msg }
var (
errUnauthorized = &apiError{msg: "unauthorized", status: http.StatusUnauthorized}
errInvalidTenant = &apiError{msg: "invalid tenant ID", status: http.StatusBadRequest}
errTenantAccess = &apiError{msg: "no access to tenant", status: http.StatusForbidden}
errNoTenant = &apiError{msg: "no tenant found for user", status: http.StatusBadRequest}
)
// handleTenantError writes the appropriate error response for tenant resolution errors
func handleTenantError(w http.ResponseWriter, err error) {
if ae, ok := err.(*apiError); ok {
writeError(w, ae.status, ae.msg)
return
}
writeError(w, http.StatusInternalServerError, "internal error") writeError(w, http.StatusInternalServerError, "internal error")
} }
@@ -88,3 +35,74 @@ func parsePathUUID(r *http.Request, key string) (uuid.UUID, error) {
func parseUUID(s string) (uuid.UUID, error) { func parseUUID(s string) (uuid.UUID, error) {
return uuid.Parse(s) return uuid.Parse(s)
} }
// --- Input validation helpers ---
const (
maxTitleLen = 500
maxDescriptionLen = 10000
maxCaseNumberLen = 100
maxSearchLen = 200
maxPaginationLimit = 100
)
// validateStringLength checks if a string exceeds the given max length.
func validateStringLength(field, value string, maxLen int) string {
if utf8.RuneCountInString(value) > maxLen {
return field + " exceeds maximum length"
}
return ""
}
// clampPagination enforces sane pagination defaults and limits.
func clampPagination(limit, offset int) (int, int) {
if limit <= 0 {
limit = 20
}
if limit > maxPaginationLimit {
limit = maxPaginationLimit
}
if offset < 0 {
offset = 0
}
return limit, offset
}
// sanitizeFilename removes characters unsafe for Content-Disposition headers.
func sanitizeFilename(name string) string {
// Remove control characters, quotes, and backslashes
var b strings.Builder
for _, r := range name {
if r < 32 || r == '"' || r == '\\' || r == '/' {
b.WriteRune('_')
} else {
b.WriteRune(r)
}
}
return b.String()
}
// maskSettingsPassword masks the CalDAV password in tenant settings JSON before returning to clients.
func maskSettingsPassword(settings json.RawMessage) json.RawMessage {
if len(settings) == 0 {
return settings
}
var m map[string]json.RawMessage
if err := json.Unmarshal(settings, &m); err != nil {
return settings
}
caldavRaw, ok := m["caldav"]
if !ok {
return settings
}
var caldav map[string]json.RawMessage
if err := json.Unmarshal(caldavRaw, &caldav); err != nil {
return settings
}
if _, ok := caldav["password"]; ok {
caldav["password"], _ = json.Marshal("********")
}
m["caldav"], _ = json.Marshal(caldav)
result, _ := json.Marshal(m)
return result
}

View File

@@ -60,6 +60,10 @@ func (h *NoteHandler) Create(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "content is required") writeError(w, http.StatusBadRequest, "content is required")
return return
} }
if msg := validateStringLength("content", input.Content, maxDescriptionLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
var createdBy *uuid.UUID var createdBy *uuid.UUID
if userID != uuid.Nil { if userID != uuid.Nil {
@@ -100,6 +104,10 @@ func (h *NoteHandler) Update(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "content is required") writeError(w, http.StatusBadRequest, "content is required")
return return
} }
if msg := validateStringLength("content", req.Content, maxDescriptionLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
note, err := h.svc.Update(r.Context(), tenantID, noteID, req.Content) note, err := h.svc.Update(r.Context(), tenantID, noteID, req.Content)
if err != nil { if err != nil {

View File

@@ -34,7 +34,7 @@ func (h *PartyHandler) List(w http.ResponseWriter, r *http.Request) {
parties, err := h.svc.ListByCase(r.Context(), tenantID, caseID) parties, err := h.svc.ListByCase(r.Context(), tenantID, caseID)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to list parties", err)
return return
} }
@@ -67,13 +67,18 @@ func (h *PartyHandler) Create(w http.ResponseWriter, r *http.Request) {
return return
} }
if msg := validateStringLength("name", input.Name, maxTitleLen); msg != "" {
writeError(w, http.StatusBadRequest, msg)
return
}
party, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input) party, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
writeError(w, http.StatusNotFound, "case not found") writeError(w, http.StatusNotFound, "case not found")
return return
} }
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to create party", err)
return return
} }
@@ -101,7 +106,7 @@ func (h *PartyHandler) Update(w http.ResponseWriter, r *http.Request) {
updated, err := h.svc.Update(r.Context(), tenantID, partyID, input) updated, err := h.svc.Update(r.Context(), tenantID, partyID, input)
if err != nil { if err != nil {
writeError(w, http.StatusInternalServerError, err.Error()) internalError(w, "failed to update party", err)
return return
} }
if updated == nil { if updated == nil {

View File

@@ -2,6 +2,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"log/slog"
"net/http" "net/http"
"github.com/google/uuid" "github.com/google/uuid"
@@ -41,7 +42,8 @@ func (h *TenantHandler) CreateTenant(w http.ResponseWriter, r *http.Request) {
tenant, err := h.svc.Create(r.Context(), userID, req.Name, req.Slug) tenant, err := h.svc.Create(r.Context(), userID, req.Name, req.Slug)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to create tenant", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
@@ -58,10 +60,16 @@ func (h *TenantHandler) ListTenants(w http.ResponseWriter, r *http.Request) {
tenants, err := h.svc.ListForUser(r.Context(), userID) tenants, err := h.svc.ListForUser(r.Context(), userID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to list tenants", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
// Mask CalDAV passwords in tenant settings
for i := range tenants {
tenants[i].Settings = maskSettingsPassword(tenants[i].Settings)
}
jsonResponse(w, tenants, http.StatusOK) jsonResponse(w, tenants, http.StatusOK)
} }
@@ -82,7 +90,8 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) {
// Verify user has access to this tenant // Verify user has access to this tenant
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to get user role", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
if role == "" { if role == "" {
@@ -92,7 +101,8 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) {
tenant, err := h.svc.GetByID(r.Context(), tenantID) tenant, err := h.svc.GetByID(r.Context(), tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to get tenant", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
if tenant == nil { if tenant == nil {
@@ -100,6 +110,9 @@ func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) {
return return
} }
// Mask CalDAV password before returning
tenant.Settings = maskSettingsPassword(tenant.Settings)
jsonResponse(w, tenant, http.StatusOK) jsonResponse(w, tenant, http.StatusOK)
} }
@@ -120,7 +133,8 @@ func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
// Only owners and admins can invite // Only owners and admins can invite
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to get user role", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
if role != "owner" && role != "admin" { if role != "owner" && role != "admin" {
@@ -150,7 +164,8 @@ func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
ut, err := h.svc.InviteByEmail(r.Context(), tenantID, req.Email, req.Role) ut, err := h.svc.InviteByEmail(r.Context(), tenantID, req.Email, req.Role)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusBadRequest) // These are user-facing validation errors (user not found, already member)
jsonError(w, "failed to invite user", http.StatusBadRequest)
return return
} }
@@ -180,7 +195,8 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
// Only owners and admins can remove members (or user removing themselves) // Only owners and admins can remove members (or user removing themselves)
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to get user role", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
if role != "owner" && role != "admin" && userID != memberID { if role != "owner" && role != "admin" && userID != memberID {
@@ -189,7 +205,8 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
} }
if err := h.svc.RemoveMember(r.Context(), tenantID, memberID); err != nil { if err := h.svc.RemoveMember(r.Context(), tenantID, memberID); err != nil {
jsonError(w, err.Error(), http.StatusBadRequest) // These are user-facing validation errors (not a member, last owner, etc.)
jsonError(w, "failed to remove member", http.StatusBadRequest)
return return
} }
@@ -213,7 +230,8 @@ func (h *TenantHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) {
// Only owners and admins can update settings // Only owners and admins can update settings
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to get user role", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
if role != "owner" && role != "admin" { if role != "owner" && role != "admin" {
@@ -229,10 +247,14 @@ func (h *TenantHandler) UpdateSettings(w http.ResponseWriter, r *http.Request) {
tenant, err := h.svc.UpdateSettings(r.Context(), tenantID, settings) tenant, err := h.svc.UpdateSettings(r.Context(), tenantID, settings)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to update settings", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
// Mask CalDAV password before returning
tenant.Settings = maskSettingsPassword(tenant.Settings)
jsonResponse(w, tenant, http.StatusOK) jsonResponse(w, tenant, http.StatusOK)
} }
@@ -253,7 +275,8 @@ func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
// Verify user has access // Verify user has access
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to get user role", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }
if role == "" { if role == "" {
@@ -263,7 +286,8 @@ func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
members, err := h.svc.ListMembers(r.Context(), tenantID) members, err := h.svc.ListMembers(r.Context(), tenantID)
if err != nil { if err != nil {
jsonError(w, err.Error(), http.StatusInternalServerError) slog.Error("failed to list members", "error", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return return
} }

View File

@@ -0,0 +1,49 @@
package middleware
import (
"net/http"
"strings"
)
// SecurityHeaders adds standard security headers to all responses.
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
next.ServeHTTP(w, r)
})
}
// CORS returns middleware that restricts cross-origin requests to the given origin.
// If allowedOrigin is empty, CORS headers are not set (same-origin only).
func CORS(allowedOrigin string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if allowedOrigin != "" && origin != "" && matchOrigin(origin, allowedOrigin) {
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID")
w.Header().Set("Access-Control-Max-Age", "86400")
w.Header().Set("Vary", "Origin")
}
// Handle preflight
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
}
// matchOrigin checks if the request origin matches the allowed origin.
func matchOrigin(origin, allowed string) bool {
return strings.EqualFold(strings.TrimRight(origin, "/"), strings.TrimRight(allowed, "/"))
}

View File

@@ -1,22 +0,0 @@
package models
import (
"encoding/json"
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID int64 `db:"id" json:"id"`
TenantID uuid.UUID `db:"tenant_id" json:"tenant_id"`
UserID *uuid.UUID `db:"user_id" json:"user_id,omitempty"`
Action string `db:"action" json:"action"`
EntityType string `db:"entity_type" json:"entity_type"`
EntityID *uuid.UUID `db:"entity_id" json:"entity_id,omitempty"`
OldValues *json.RawMessage `db:"old_values" json:"old_values,omitempty"`
NewValues *json.RawMessage `db:"new_values" json:"new_values,omitempty"`
IPAddress *string `db:"ip_address" json:"ip_address,omitempty"`
UserAgent *string `db:"user_agent" json:"user_agent,omitempty"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}

View File

@@ -19,38 +19,36 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se
mux := http.NewServeMux() mux := http.NewServeMux()
// Services // Services
auditSvc := services.NewAuditService(db) tenantSvc := services.NewTenantService(db)
tenantSvc := services.NewTenantService(db, auditSvc) caseSvc := services.NewCaseService(db)
caseSvc := services.NewCaseService(db, auditSvc) partySvc := services.NewPartyService(db)
partySvc := services.NewPartyService(db, auditSvc) appointmentSvc := services.NewAppointmentService(db)
appointmentSvc := services.NewAppointmentService(db, auditSvc)
holidaySvc := services.NewHolidayService(db) holidaySvc := services.NewHolidayService(db)
deadlineSvc := services.NewDeadlineService(db, auditSvc) deadlineSvc := services.NewDeadlineService(db)
deadlineRuleSvc := services.NewDeadlineRuleService(db) deadlineRuleSvc := services.NewDeadlineRuleService(db)
calculator := services.NewDeadlineCalculator(holidaySvc) calculator := services.NewDeadlineCalculator(holidaySvc)
storageCli := services.NewStorageClient(cfg.SupabaseURL, cfg.SupabaseServiceKey) storageCli := services.NewStorageClient(cfg.SupabaseURL, cfg.SupabaseServiceKey)
documentSvc := services.NewDocumentService(db, storageCli, auditSvc) documentSvc := services.NewDocumentService(db, storageCli)
// AI service (optional — only if API key is configured) // AI service (optional — only if API key is configured)
var aiH *handlers.AIHandler var aiH *handlers.AIHandler
if cfg.AnthropicAPIKey != "" { if cfg.AnthropicAPIKey != "" {
aiSvc := services.NewAIService(cfg.AnthropicAPIKey, db) aiSvc := services.NewAIService(cfg.AnthropicAPIKey, db)
aiH = handlers.NewAIHandler(aiSvc, db) aiH = handlers.NewAIHandler(aiSvc)
} }
// Middleware // Middleware
tenantResolver := auth.NewTenantResolver(tenantSvc) tenantResolver := auth.NewTenantResolver(tenantSvc)
noteSvc := services.NewNoteService(db, auditSvc) noteSvc := services.NewNoteService(db)
dashboardSvc := services.NewDashboardService(db) dashboardSvc := services.NewDashboardService(db)
// Handlers // Handlers
auditH := handlers.NewAuditLogHandler(auditSvc)
tenantH := handlers.NewTenantHandler(tenantSvc) tenantH := handlers.NewTenantHandler(tenantSvc)
caseH := handlers.NewCaseHandler(caseSvc) caseH := handlers.NewCaseHandler(caseSvc)
partyH := handlers.NewPartyHandler(partySvc) partyH := handlers.NewPartyHandler(partySvc)
apptH := handlers.NewAppointmentHandler(appointmentSvc) apptH := handlers.NewAppointmentHandler(appointmentSvc)
deadlineH := handlers.NewDeadlineHandlers(deadlineSvc, db) deadlineH := handlers.NewDeadlineHandlers(deadlineSvc)
ruleH := handlers.NewDeadlineRuleHandlers(deadlineRuleSvc) ruleH := handlers.NewDeadlineRuleHandlers(deadlineRuleSvc)
calcH := handlers.NewCalculateHandlers(calculator, deadlineRuleSvc) calcH := handlers.NewCalculateHandlers(calculator, deadlineRuleSvc)
dashboardH := handlers.NewDashboardHandler(dashboardSvc) dashboardH := handlers.NewDashboardHandler(dashboardSvc)
@@ -125,9 +123,6 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se
// Dashboard // Dashboard
scoped.HandleFunc("GET /api/dashboard", dashboardH.Get) scoped.HandleFunc("GET /api/dashboard", dashboardH.Get)
// Audit log
scoped.HandleFunc("GET /api/audit-log", auditH.List)
// Documents // Documents
scoped.HandleFunc("GET /api/cases/{id}/documents", docH.ListByCase) scoped.HandleFunc("GET /api/cases/{id}/documents", docH.ListByCase)
scoped.HandleFunc("POST /api/cases/{id}/documents", docH.Upload) scoped.HandleFunc("POST /api/cases/{id}/documents", docH.Upload)
@@ -154,14 +149,20 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config, calDAVSvc *se
mux.Handle("/api/", authMW.RequireAuth(api)) mux.Handle("/api/", authMW.RequireAuth(api))
return requestLogger(mux) // Apply security middleware stack: CORS -> Security Headers -> Request Logger -> Routes
var handler http.Handler = mux
handler = requestLogger(handler)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(cfg.FrontendOrigin)(handler)
return handler
} }
func handleHealth(db *sqlx.DB) http.HandlerFunc { func handleHealth(db *sqlx.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "error": err.Error()}) json.NewEncoder(w).Encode(map[string]string{"status": "error"})
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -199,4 +200,3 @@ func requestLogger(next http.Handler) http.Handler {
) )
}) })
} }

View File

@@ -13,11 +13,10 @@ import (
type AppointmentService struct { type AppointmentService struct {
db *sqlx.DB db *sqlx.DB
audit *AuditService
} }
func NewAppointmentService(db *sqlx.DB, audit *AuditService) *AppointmentService { func NewAppointmentService(db *sqlx.DB) *AppointmentService {
return &AppointmentService{db: db, audit: audit} return &AppointmentService{db: db}
} }
type AppointmentFilter struct { type AppointmentFilter struct {
@@ -87,7 +86,6 @@ func (s *AppointmentService) Create(ctx context.Context, a *models.Appointment)
if err != nil { if err != nil {
return fmt.Errorf("creating appointment: %w", err) return fmt.Errorf("creating appointment: %w", err)
} }
s.audit.Log(ctx, "create", "appointment", &a.ID, nil, a)
return nil return nil
} }
@@ -118,7 +116,6 @@ func (s *AppointmentService) Update(ctx context.Context, a *models.Appointment)
if rows == 0 { if rows == 0 {
return fmt.Errorf("appointment not found") return fmt.Errorf("appointment not found")
} }
s.audit.Log(ctx, "update", "appointment", &a.ID, nil, a)
return nil return nil
} }
@@ -134,6 +131,5 @@ func (s *AppointmentService) Delete(ctx context.Context, tenantID, id uuid.UUID)
if rows == 0 { if rows == 0 {
return fmt.Errorf("appointment not found") return fmt.Errorf("appointment not found")
} }
s.audit.Log(ctx, "delete", "appointment", &id, nil, nil)
return nil return nil
} }

View File

@@ -1,141 +0,0 @@
package services
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
"mgit.msbls.de/m/KanzlAI-mGMT/internal/models"
)
type AuditService struct {
db *sqlx.DB
}
func NewAuditService(db *sqlx.DB) *AuditService {
return &AuditService{db: db}
}
// Log records an audit entry. It extracts tenant, user, IP, and user-agent from context.
// Errors are logged but not returned — audit logging must not break business operations.
func (s *AuditService) Log(ctx context.Context, action, entityType string, entityID *uuid.UUID, oldValues, newValues any) {
tenantID, ok := auth.TenantFromContext(ctx)
if !ok {
slog.Warn("audit: missing tenant_id in context", "action", action, "entity_type", entityType)
return
}
var userID *uuid.UUID
if uid, ok := auth.UserFromContext(ctx); ok {
userID = &uid
}
var oldJSON, newJSON *json.RawMessage
if oldValues != nil {
if b, err := json.Marshal(oldValues); err == nil {
raw := json.RawMessage(b)
oldJSON = &raw
}
}
if newValues != nil {
if b, err := json.Marshal(newValues); err == nil {
raw := json.RawMessage(b)
newJSON = &raw
}
}
ip := auth.IPFromContext(ctx)
ua := auth.UserAgentFromContext(ctx)
_, err := s.db.ExecContext(ctx,
`INSERT INTO audit_log (tenant_id, user_id, action, entity_type, entity_id, old_values, new_values, ip_address, user_agent)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
tenantID, userID, action, entityType, entityID, oldJSON, newJSON, ip, ua)
if err != nil {
slog.Error("audit: failed to write log entry",
"error", err,
"action", action,
"entity_type", entityType,
"entity_id", entityID,
)
}
}
// AuditFilter holds query parameters for listing audit log entries.
type AuditFilter struct {
EntityType string
EntityID *uuid.UUID
UserID *uuid.UUID
From string // RFC3339 date
To string // RFC3339 date
Page int
Limit int
}
// List returns paginated audit log entries for a tenant.
func (s *AuditService) List(ctx context.Context, tenantID uuid.UUID, filter AuditFilter) ([]models.AuditLog, int, error) {
if filter.Limit <= 0 {
filter.Limit = 50
}
if filter.Limit > 200 {
filter.Limit = 200
}
if filter.Page <= 0 {
filter.Page = 1
}
offset := (filter.Page - 1) * filter.Limit
where := "WHERE tenant_id = $1"
args := []any{tenantID}
argIdx := 2
if filter.EntityType != "" {
where += fmt.Sprintf(" AND entity_type = $%d", argIdx)
args = append(args, filter.EntityType)
argIdx++
}
if filter.EntityID != nil {
where += fmt.Sprintf(" AND entity_id = $%d", argIdx)
args = append(args, *filter.EntityID)
argIdx++
}
if filter.UserID != nil {
where += fmt.Sprintf(" AND user_id = $%d", argIdx)
args = append(args, *filter.UserID)
argIdx++
}
if filter.From != "" {
where += fmt.Sprintf(" AND created_at >= $%d", argIdx)
args = append(args, filter.From)
argIdx++
}
if filter.To != "" {
where += fmt.Sprintf(" AND created_at <= $%d", argIdx)
args = append(args, filter.To)
argIdx++
}
var total int
if err := s.db.GetContext(ctx, &total, "SELECT COUNT(*) FROM audit_log "+where, args...); err != nil {
return nil, 0, fmt.Errorf("counting audit entries: %w", err)
}
query := fmt.Sprintf("SELECT * FROM audit_log %s ORDER BY created_at DESC LIMIT $%d OFFSET $%d",
where, argIdx, argIdx+1)
args = append(args, filter.Limit, offset)
var entries []models.AuditLog
if err := s.db.SelectContext(ctx, &entries, query, args...); err != nil {
return nil, 0, fmt.Errorf("listing audit entries: %w", err)
}
if entries == nil {
entries = []models.AuditLog{}
}
return entries, total, nil
}

View File

@@ -14,11 +14,10 @@ import (
type CaseService struct { type CaseService struct {
db *sqlx.DB db *sqlx.DB
audit *AuditService
} }
func NewCaseService(db *sqlx.DB, audit *AuditService) *CaseService { func NewCaseService(db *sqlx.DB) *CaseService {
return &CaseService{db: db, audit: audit} return &CaseService{db: db}
} }
type CaseFilter struct { type CaseFilter struct {
@@ -163,9 +162,6 @@ func (s *CaseService) Create(ctx context.Context, tenantID uuid.UUID, userID uui
if err := s.db.GetContext(ctx, &c, "SELECT * FROM cases WHERE id = $1", id); err != nil { if err := s.db.GetContext(ctx, &c, "SELECT * FROM cases WHERE id = $1", id); err != nil {
return nil, fmt.Errorf("fetching created case: %w", err) return nil, fmt.Errorf("fetching created case: %w", err)
} }
s.audit.Log(ctx, "create", "case", &id, nil, c)
return &c, nil return &c, nil
} }
@@ -243,9 +239,6 @@ func (s *CaseService) Update(ctx context.Context, tenantID, caseID uuid.UUID, us
if err := s.db.GetContext(ctx, &updated, "SELECT * FROM cases WHERE id = $1", caseID); err != nil { if err := s.db.GetContext(ctx, &updated, "SELECT * FROM cases WHERE id = $1", caseID); err != nil {
return nil, fmt.Errorf("fetching updated case: %w", err) return nil, fmt.Errorf("fetching updated case: %w", err)
} }
s.audit.Log(ctx, "update", "case", &caseID, current, updated)
return &updated, nil return &updated, nil
} }
@@ -261,7 +254,6 @@ func (s *CaseService) Delete(ctx context.Context, tenantID, caseID uuid.UUID, us
return sql.ErrNoRows return sql.ErrNoRows
} }
createEvent(ctx, s.db, tenantID, caseID, userID, "case_archived", "Case archived", nil) createEvent(ctx, s.db, tenantID, caseID, userID, "case_archived", "Case archived", nil)
s.audit.Log(ctx, "delete", "case", &caseID, map[string]string{"status": "active"}, map[string]string{"status": "archived"})
return nil return nil
} }

View File

@@ -1,7 +1,6 @@
package services package services
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"time" "time"
@@ -15,12 +14,11 @@ import (
// DeadlineService handles CRUD operations for case deadlines // DeadlineService handles CRUD operations for case deadlines
type DeadlineService struct { type DeadlineService struct {
db *sqlx.DB db *sqlx.DB
audit *AuditService
} }
// NewDeadlineService creates a new deadline service // NewDeadlineService creates a new deadline service
func NewDeadlineService(db *sqlx.DB, audit *AuditService) *DeadlineService { func NewDeadlineService(db *sqlx.DB) *DeadlineService {
return &DeadlineService{db: db, audit: audit} return &DeadlineService{db: db}
} }
// ListAll returns all deadlines for a tenant, ordered by due_date // ListAll returns all deadlines for a tenant, ordered by due_date
@@ -89,7 +87,7 @@ type CreateDeadlineInput struct {
} }
// Create inserts a new deadline // Create inserts a new deadline
func (s *DeadlineService) Create(ctx context.Context, tenantID uuid.UUID, input CreateDeadlineInput) (*models.Deadline, error) { func (s *DeadlineService) Create(tenantID uuid.UUID, input CreateDeadlineInput) (*models.Deadline, error) {
id := uuid.New() id := uuid.New()
source := input.Source source := input.Source
if source == "" { if source == "" {
@@ -110,7 +108,6 @@ func (s *DeadlineService) Create(ctx context.Context, tenantID uuid.UUID, input
if err != nil { if err != nil {
return nil, fmt.Errorf("creating deadline: %w", err) return nil, fmt.Errorf("creating deadline: %w", err)
} }
s.audit.Log(ctx, "create", "deadline", &id, nil, d)
return &d, nil return &d, nil
} }
@@ -126,7 +123,7 @@ type UpdateDeadlineInput struct {
} }
// Update modifies an existing deadline // Update modifies an existing deadline
func (s *DeadlineService) Update(ctx context.Context, tenantID, deadlineID uuid.UUID, input UpdateDeadlineInput) (*models.Deadline, error) { func (s *DeadlineService) Update(tenantID, deadlineID uuid.UUID, input UpdateDeadlineInput) (*models.Deadline, error) {
// First check it exists and belongs to tenant // First check it exists and belongs to tenant
existing, err := s.GetByID(tenantID, deadlineID) existing, err := s.GetByID(tenantID, deadlineID)
if err != nil { if err != nil {
@@ -157,12 +154,11 @@ func (s *DeadlineService) Update(ctx context.Context, tenantID, deadlineID uuid.
if err != nil { if err != nil {
return nil, fmt.Errorf("updating deadline: %w", err) return nil, fmt.Errorf("updating deadline: %w", err)
} }
s.audit.Log(ctx, "update", "deadline", &deadlineID, existing, d)
return &d, nil return &d, nil
} }
// Complete marks a deadline as completed // Complete marks a deadline as completed
func (s *DeadlineService) Complete(ctx context.Context, tenantID, deadlineID uuid.UUID) (*models.Deadline, error) { func (s *DeadlineService) Complete(tenantID, deadlineID uuid.UUID) (*models.Deadline, error) {
query := `UPDATE deadlines SET query := `UPDATE deadlines SET
status = 'completed', status = 'completed',
completed_at = $1, completed_at = $1,
@@ -180,12 +176,11 @@ func (s *DeadlineService) Complete(ctx context.Context, tenantID, deadlineID uui
} }
return nil, fmt.Errorf("completing deadline: %w", err) return nil, fmt.Errorf("completing deadline: %w", err)
} }
s.audit.Log(ctx, "update", "deadline", &deadlineID, map[string]string{"status": "pending"}, map[string]string{"status": "completed"})
return &d, nil return &d, nil
} }
// Delete removes a deadline // Delete removes a deadline
func (s *DeadlineService) Delete(ctx context.Context, tenantID, deadlineID uuid.UUID) error { func (s *DeadlineService) Delete(tenantID, deadlineID uuid.UUID) error {
query := `DELETE FROM deadlines WHERE id = $1 AND tenant_id = $2` query := `DELETE FROM deadlines WHERE id = $1 AND tenant_id = $2`
result, err := s.db.Exec(query, deadlineID, tenantID) result, err := s.db.Exec(query, deadlineID, tenantID)
if err != nil { if err != nil {
@@ -198,6 +193,5 @@ func (s *DeadlineService) Delete(ctx context.Context, tenantID, deadlineID uuid.
if rows == 0 { if rows == 0 {
return fmt.Errorf("deadline not found") return fmt.Errorf("deadline not found")
} }
s.audit.Log(ctx, "delete", "deadline", &deadlineID, nil, nil)
return nil return nil
} }

View File

@@ -18,11 +18,10 @@ const documentBucket = "kanzlai-documents"
type DocumentService struct { type DocumentService struct {
db *sqlx.DB db *sqlx.DB
storage *StorageClient storage *StorageClient
audit *AuditService
} }
func NewDocumentService(db *sqlx.DB, storage *StorageClient, audit *AuditService) *DocumentService { func NewDocumentService(db *sqlx.DB, storage *StorageClient) *DocumentService {
return &DocumentService{db: db, storage: storage, audit: audit} return &DocumentService{db: db, storage: storage}
} }
type CreateDocumentInput struct { type CreateDocumentInput struct {
@@ -98,7 +97,6 @@ func (s *DocumentService) Create(ctx context.Context, tenantID, caseID, userID u
if err := s.db.GetContext(ctx, &doc, "SELECT * FROM documents WHERE id = $1", id); err != nil { if err := s.db.GetContext(ctx, &doc, "SELECT * FROM documents WHERE id = $1", id); err != nil {
return nil, fmt.Errorf("fetching created document: %w", err) return nil, fmt.Errorf("fetching created document: %w", err)
} }
s.audit.Log(ctx, "create", "document", &id, nil, doc)
return &doc, nil return &doc, nil
} }
@@ -153,7 +151,6 @@ func (s *DocumentService) Delete(ctx context.Context, tenantID, docID, userID uu
// Log case event // Log case event
createEvent(ctx, s.db, tenantID, doc.CaseID, userID, "document_deleted", createEvent(ctx, s.db, tenantID, doc.CaseID, userID, "document_deleted",
fmt.Sprintf("Document deleted: %s", doc.Title), nil) fmt.Sprintf("Document deleted: %s", doc.Title), nil)
s.audit.Log(ctx, "delete", "document", &docID, doc, nil)
return nil return nil
} }

View File

@@ -14,11 +14,10 @@ import (
type NoteService struct { type NoteService struct {
db *sqlx.DB db *sqlx.DB
audit *AuditService
} }
func NewNoteService(db *sqlx.DB, audit *AuditService) *NoteService { func NewNoteService(db *sqlx.DB) *NoteService {
return &NoteService{db: db, audit: audit} return &NoteService{db: db}
} }
// ListByParent returns all notes for a given parent entity, scoped to tenant. // ListByParent returns all notes for a given parent entity, scoped to tenant.
@@ -69,7 +68,6 @@ func (s *NoteService) Create(ctx context.Context, tenantID uuid.UUID, createdBy
if err != nil { if err != nil {
return nil, fmt.Errorf("creating note: %w", err) return nil, fmt.Errorf("creating note: %w", err)
} }
s.audit.Log(ctx, "create", "note", &id, nil, n)
return &n, nil return &n, nil
} }
@@ -87,7 +85,6 @@ func (s *NoteService) Update(ctx context.Context, tenantID, noteID uuid.UUID, co
} }
return nil, fmt.Errorf("updating note: %w", err) return nil, fmt.Errorf("updating note: %w", err)
} }
s.audit.Log(ctx, "update", "note", &noteID, nil, n)
return &n, nil return &n, nil
} }
@@ -104,7 +101,6 @@ func (s *NoteService) Delete(ctx context.Context, tenantID, noteID uuid.UUID) er
if rows == 0 { if rows == 0 {
return fmt.Errorf("note not found") return fmt.Errorf("note not found")
} }
s.audit.Log(ctx, "delete", "note", &noteID, nil, nil)
return nil return nil
} }

View File

@@ -14,11 +14,10 @@ import (
type PartyService struct { type PartyService struct {
db *sqlx.DB db *sqlx.DB
audit *AuditService
} }
func NewPartyService(db *sqlx.DB, audit *AuditService) *PartyService { func NewPartyService(db *sqlx.DB) *PartyService {
return &PartyService{db: db, audit: audit} return &PartyService{db: db}
} }
type CreatePartyInput struct { type CreatePartyInput struct {
@@ -80,7 +79,6 @@ func (s *PartyService) Create(ctx context.Context, tenantID, caseID uuid.UUID, u
if err := s.db.GetContext(ctx, &party, "SELECT * FROM parties WHERE id = $1", id); err != nil { if err := s.db.GetContext(ctx, &party, "SELECT * FROM parties WHERE id = $1", id); err != nil {
return nil, fmt.Errorf("fetching created party: %w", err) return nil, fmt.Errorf("fetching created party: %w", err)
} }
s.audit.Log(ctx, "create", "party", &id, nil, party)
return &party, nil return &party, nil
} }
@@ -137,7 +135,6 @@ func (s *PartyService) Update(ctx context.Context, tenantID, partyID uuid.UUID,
if err := s.db.GetContext(ctx, &updated, "SELECT * FROM parties WHERE id = $1", partyID); err != nil { if err := s.db.GetContext(ctx, &updated, "SELECT * FROM parties WHERE id = $1", partyID); err != nil {
return nil, fmt.Errorf("fetching updated party: %w", err) return nil, fmt.Errorf("fetching updated party: %w", err)
} }
s.audit.Log(ctx, "update", "party", &partyID, current, updated)
return &updated, nil return &updated, nil
} }
@@ -151,6 +148,5 @@ func (s *PartyService) Delete(ctx context.Context, tenantID, partyID uuid.UUID)
if rows == 0 { if rows == 0 {
return sql.ErrNoRows return sql.ErrNoRows
} }
s.audit.Log(ctx, "delete", "party", &partyID, nil, nil)
return nil return nil
} }

View File

@@ -14,11 +14,10 @@ import (
type TenantService struct { type TenantService struct {
db *sqlx.DB db *sqlx.DB
audit *AuditService
} }
func NewTenantService(db *sqlx.DB, audit *AuditService) *TenantService { func NewTenantService(db *sqlx.DB) *TenantService {
return &TenantService{db: db, audit: audit} return &TenantService{db: db}
} }
// Create creates a new tenant and assigns the creator as owner. // Create creates a new tenant and assigns the creator as owner.
@@ -50,7 +49,6 @@ func (s *TenantService) Create(ctx context.Context, userID uuid.UUID, name, slug
return nil, fmt.Errorf("commit: %w", err) return nil, fmt.Errorf("commit: %w", err)
} }
s.audit.Log(ctx, "create", "tenant", &tenant.ID, nil, tenant)
return &tenant, nil return &tenant, nil
} }
@@ -103,6 +101,19 @@ func (s *TenantService) GetUserRole(ctx context.Context, userID, tenantID uuid.U
return role, nil return role, nil
} }
// VerifyAccess checks if a user has access to a given tenant.
func (s *TenantService) VerifyAccess(ctx context.Context, userID, tenantID uuid.UUID) (bool, error) {
var exists bool
err := s.db.GetContext(ctx, &exists,
`SELECT EXISTS(SELECT 1 FROM user_tenants WHERE user_id = $1 AND tenant_id = $2)`,
userID, tenantID,
)
if err != nil {
return false, fmt.Errorf("verify tenant access: %w", err)
}
return exists, nil
}
// FirstTenantForUser returns the user's first tenant (by name), used as default. // FirstTenantForUser returns the user's first tenant (by name), used as default.
func (s *TenantService) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { func (s *TenantService) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) {
var tenantID uuid.UUID var tenantID uuid.UUID
@@ -173,7 +184,6 @@ func (s *TenantService) InviteByEmail(ctx context.Context, tenantID uuid.UUID, e
return nil, fmt.Errorf("invite user: %w", err) return nil, fmt.Errorf("invite user: %w", err)
} }
s.audit.Log(ctx, "create", "membership", &tenantID, nil, ut)
return &ut, nil return &ut, nil
} }
@@ -189,7 +199,6 @@ func (s *TenantService) UpdateSettings(ctx context.Context, tenantID uuid.UUID,
if err != nil { if err != nil {
return nil, fmt.Errorf("update settings: %w", err) return nil, fmt.Errorf("update settings: %w", err)
} }
s.audit.Log(ctx, "update", "settings", &tenantID, nil, settings)
return &tenant, nil return &tenant, nil
} }
@@ -227,6 +236,5 @@ func (s *TenantService) RemoveMember(ctx context.Context, tenantID, userID uuid.
return fmt.Errorf("remove member: %w", err) return fmt.Errorf("remove member: %w", err)
} }
s.audit.Log(ctx, "delete", "membership", &tenantID, map[string]any{"user_id": userID, "role": role}, nil)
return nil return nil
} }

View File

@@ -15,7 +15,6 @@ import {
Users, Users,
StickyNote, StickyNote,
AlertTriangle, AlertTriangle,
ScrollText,
} from "lucide-react"; } from "lucide-react";
import { format } from "date-fns"; import { format } from "date-fns";
import { de } from "date-fns/locale"; import { de } from "date-fns/locale";
@@ -45,7 +44,6 @@ const TABS = [
{ segment: "dokumente", label: "Dokumente", icon: FileText }, { segment: "dokumente", label: "Dokumente", icon: FileText },
{ segment: "parteien", label: "Parteien", icon: Users }, { segment: "parteien", label: "Parteien", icon: Users },
{ segment: "notizen", label: "Notizen", icon: StickyNote }, { segment: "notizen", label: "Notizen", icon: StickyNote },
{ segment: "protokoll", label: "Protokoll", icon: ScrollText },
] as const; ] as const;
const TAB_LABELS: Record<string, string> = { const TAB_LABELS: Record<string, string> = {
@@ -54,7 +52,6 @@ const TAB_LABELS: Record<string, string> = {
dokumente: "Dokumente", dokumente: "Dokumente",
parteien: "Parteien", parteien: "Parteien",
notizen: "Notizen", notizen: "Notizen",
protokoll: "Protokoll",
}; };
function CaseDetailSkeleton() { function CaseDetailSkeleton() {

View File

@@ -1,178 +0,0 @@
"use client";
import { useQuery } from "@tanstack/react-query";
import { useParams, useSearchParams } from "next/navigation";
import { api } from "@/lib/api";
import type { AuditLogResponse } from "@/lib/types";
import { format } from "date-fns";
import { de } from "date-fns/locale";
import { Loader2, ChevronLeft, ChevronRight } from "lucide-react";
const ACTION_LABELS: Record<string, string> = {
create: "Erstellt",
update: "Aktualisiert",
delete: "Geloescht",
};
const ACTION_COLORS: Record<string, string> = {
create: "bg-emerald-50 text-emerald-700",
update: "bg-blue-50 text-blue-700",
delete: "bg-red-50 text-red-700",
};
const ENTITY_LABELS: Record<string, string> = {
case: "Akte",
deadline: "Frist",
appointment: "Termin",
document: "Dokument",
party: "Partei",
note: "Notiz",
settings: "Einstellungen",
membership: "Mitgliedschaft",
};
function DiffPreview({
oldValues,
newValues,
}: {
oldValues?: Record<string, unknown>;
newValues?: Record<string, unknown>;
}) {
if (!oldValues && !newValues) return null;
const allKeys = new Set([
...Object.keys(oldValues ?? {}),
...Object.keys(newValues ?? {}),
]);
const changes: { key: string; from?: unknown; to?: unknown }[] = [];
for (const key of allKeys) {
const oldVal = oldValues?.[key];
const newVal = newValues?.[key];
if (JSON.stringify(oldVal) !== JSON.stringify(newVal)) {
changes.push({ key, from: oldVal, to: newVal });
}
}
if (changes.length === 0) return null;
return (
<div className="mt-2 space-y-1">
{changes.slice(0, 5).map((c) => (
<div key={c.key} className="flex items-baseline gap-2 text-xs">
<span className="font-medium text-neutral-500">{c.key}:</span>
{c.from !== undefined && (
<span className="rounded bg-red-50 px-1 text-red-600 line-through">
{String(c.from)}
</span>
)}
{c.to !== undefined && (
<span className="rounded bg-emerald-50 px-1 text-emerald-600">
{String(c.to)}
</span>
)}
</div>
))}
{changes.length > 5 && (
<span className="text-xs text-neutral-400">
+{changes.length - 5} weitere Aenderungen
</span>
)}
</div>
);
}
export default function ProtokollPage() {
const { id } = useParams<{ id: string }>();
const searchParams = useSearchParams();
const page = Number(searchParams.get("page")) || 1;
const { data, isLoading } = useQuery({
queryKey: ["audit-log", id, page],
queryFn: () =>
api.get<AuditLogResponse>(
`/audit-log?entity_id=${id}&page=${page}&limit=50`,
),
});
if (isLoading) {
return (
<div className="flex items-center justify-center py-8">
<Loader2 className="h-5 w-5 animate-spin text-neutral-400" />
</div>
);
}
const entries = data?.entries ?? [];
const total = data?.total ?? 0;
const totalPages = Math.ceil(total / 50);
if (entries.length === 0) {
return (
<div className="py-8 text-center text-sm text-neutral-400">
Keine Protokolleintraege vorhanden.
</div>
);
}
return (
<div>
<div className="space-y-3">
{entries.map((entry) => (
<div
key={entry.id}
className="rounded-md border border-neutral-100 bg-white px-4 py-3"
>
<div className="flex items-start justify-between gap-3">
<div className="flex items-center gap-2">
<span
className={`inline-block rounded-full px-2 py-0.5 text-xs font-medium ${ACTION_COLORS[entry.action] ?? "bg-neutral-100 text-neutral-600"}`}
>
{ACTION_LABELS[entry.action] ?? entry.action}
</span>
<span className="text-sm font-medium text-neutral-700">
{ENTITY_LABELS[entry.entity_type] ?? entry.entity_type}
</span>
</div>
<span className="shrink-0 text-xs text-neutral-400">
{format(new Date(entry.created_at), "d. MMM yyyy, HH:mm", {
locale: de,
})}
</span>
</div>
<DiffPreview
oldValues={entry.old_values}
newValues={entry.new_values}
/>
</div>
))}
</div>
{totalPages > 1 && (
<div className="mt-4 flex items-center justify-between">
<span className="text-xs text-neutral-400">
{total} Eintraege, Seite {page} von {totalPages}
</span>
<div className="flex gap-1">
{page > 1 && (
<a
href={`?page=${page - 1}`}
className="inline-flex items-center gap-1 rounded-md border border-neutral-200 px-2 py-1 text-xs text-neutral-600 hover:bg-neutral-50"
>
<ChevronLeft className="h-3 w-3" /> Zurueck
</a>
)}
{page < totalPages && (
<a
href={`?page=${page + 1}`}
className="inline-flex items-center gap-1 rounded-md border border-neutral-200 px-2 py-1 text-xs text-neutral-600 hover:bg-neutral-50"
>
Weiter <ChevronRight className="h-3 w-3" />
</a>
)}
</div>
</div>
)}
</div>
);
}

View File

@@ -189,27 +189,6 @@ export interface Note {
updated_at: string; updated_at: string;
} }
export interface AuditLogEntry {
id: number;
tenant_id: string;
user_id?: string;
action: string;
entity_type: string;
entity_id?: string;
old_values?: Record<string, unknown>;
new_values?: Record<string, unknown>;
ip_address?: string;
user_agent?: string;
created_at: string;
}
export interface AuditLogResponse {
entries: AuditLogEntry[];
total: number;
page: number;
limit: number;
}
export interface ApiError { export interface ApiError {
error: string; error: string;
status: number; status: number;