package middleware import ( "net/http" "strings" ) // SecurityHeaders adds standard security headers to all responses. func SecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") next.ServeHTTP(w, r) }) } // CORS returns middleware that restricts cross-origin requests to the given origin. // If allowedOrigin is empty, CORS headers are not set (same-origin only). func CORS(allowedOrigin string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if allowedOrigin != "" && origin != "" && matchOrigin(origin, allowedOrigin) { w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID") w.Header().Set("Access-Control-Max-Age", "86400") w.Header().Set("Vary", "Origin") } // Handle preflight if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } } // matchOrigin checks if the request origin matches the allowed origin. func matchOrigin(origin, allowed string) bool { return strings.EqualFold(strings.TrimRight(origin, "/"), strings.TrimRight(allowed, "/")) }