package handlers import ( "encoding/json" "net/http" "github.com/google/uuid" "github.com/jmoiron/sqlx" "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" ) func writeJSON(w http.ResponseWriter, status int, v any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(v) } func writeError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } // resolveTenant gets the tenant ID for the authenticated user. // Checks X-Tenant-ID header first, then falls back to user's first tenant. func resolveTenant(r *http.Request, db *sqlx.DB) (uuid.UUID, error) { userID, ok := auth.UserFromContext(r.Context()) if !ok { return uuid.Nil, errUnauthorized } // Check header first if headerVal := r.Header.Get("X-Tenant-ID"); headerVal != "" { tenantID, err := uuid.Parse(headerVal) if err != nil { return uuid.Nil, errInvalidTenant } // Verify user has access to this tenant var count int err = db.Get(&count, `SELECT COUNT(*) FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`, userID, tenantID) if err != nil || count == 0 { return uuid.Nil, errTenantAccess } return tenantID, nil } // Fall back to user's first tenant var tenantID uuid.UUID err := db.Get(&tenantID, `SELECT tenant_id FROM user_tenants WHERE user_id = $1 ORDER BY created_at LIMIT 1`, userID) if err != nil { return uuid.Nil, errNoTenant } return tenantID, nil } type apiError struct { msg string status int } func (e *apiError) Error() string { return e.msg } var ( errUnauthorized = &apiError{msg: "unauthorized", status: http.StatusUnauthorized} errInvalidTenant = &apiError{msg: "invalid tenant ID", status: http.StatusBadRequest} errTenantAccess = &apiError{msg: "no access to tenant", status: http.StatusForbidden} errNoTenant = &apiError{msg: "no tenant found for user", status: http.StatusBadRequest} ) // handleTenantError writes the appropriate error response for tenant resolution errors func handleTenantError(w http.ResponseWriter, err error) { if ae, ok := err.(*apiError); ok { writeError(w, ae.status, ae.msg) return } writeError(w, http.StatusInternalServerError, "internal error") } // parsePathUUID extracts a UUID from the URL path using PathValue func parsePathUUID(r *http.Request, key string) (uuid.UUID, error) { return uuid.Parse(r.PathValue(key)) } // parseUUID parses a UUID string func parseUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) }