package auth import ( "fmt" "net/http" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/jmoiron/sqlx" ) type Middleware struct { jwtSecret []byte db *sqlx.DB } func NewMiddleware(jwtSecret string, db *sqlx.DB) *Middleware { return &Middleware{jwtSecret: []byte(jwtSecret), db: db} } func (m *Middleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := extractBearerToken(r) if token == "" { http.Error(w, `{"error":"missing authorization token"}`, http.StatusUnauthorized) return } userID, err := m.verifyJWT(token) if err != nil { http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized) return } ctx := ContextWithUserID(r.Context(), userID) // Capture IP and user-agent for audit logging ip := r.Header.Get("X-Forwarded-For") if ip == "" { ip = r.RemoteAddr } ctx = ContextWithRequestInfo(ctx, ip, r.UserAgent()) // Tenant and role resolution handled by TenantResolver middleware for scoped routes. next.ServeHTTP(w, r.WithContext(ctx)) }) } func (m *Middleware) verifyJWT(tokenStr string) (uuid.UUID, error) { parsedToken, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } return m.jwtSecret, nil }) if err != nil { return uuid.Nil, fmt.Errorf("parsing JWT: %w", err) } if !parsedToken.Valid { return uuid.Nil, fmt.Errorf("invalid JWT token") } claims, ok := parsedToken.Claims.(jwt.MapClaims) if !ok { return uuid.Nil, fmt.Errorf("extracting JWT claims") } if exp, ok := claims["exp"].(float64); ok { if time.Now().Unix() > int64(exp) { return uuid.Nil, fmt.Errorf("JWT token has expired") } } sub, ok := claims["sub"].(string) if !ok { return uuid.Nil, fmt.Errorf("missing sub claim in JWT") } userID, err := uuid.Parse(sub) if err != nil { return uuid.Nil, fmt.Errorf("invalid user ID format: %w", err) } return userID, nil } func extractBearerToken(r *http.Request) string { auth := r.Header.Get("Authorization") if auth == "" { return "" } parts := strings.SplitN(auth, " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") { return "" } return parts[1] }