Files
KanzlAI-mGMT/backend/internal/auth/middleware.go
m f11c411147 feat: add case + party CRUD with case events (Phase 1B)
- CaseService: list (paginated, filterable), get detail (with parties,
  events, deadline count), create, update, soft-delete (archive)
- PartyService: list by case, create, update, delete
- Auto-create case_events on case creation, status change, party add,
  and case archive
- Auth middleware now resolves tenant_id from user_tenants table
- All operations scoped to tenant_id from auth context
2026-03-25 13:26:50 +01:00

103 lines
2.4 KiB
Go

package auth
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
)
type Middleware struct {
jwtSecret []byte
db *sqlx.DB
}
func NewMiddleware(jwtSecret string, db *sqlx.DB) *Middleware {
return &Middleware{jwtSecret: []byte(jwtSecret), db: db}
}
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := extractBearerToken(r)
if token == "" {
http.Error(w, "missing authorization token", http.StatusUnauthorized)
return
}
userID, err := m.verifyJWT(token)
if err != nil {
http.Error(w, fmt.Sprintf("invalid token: %v", err), http.StatusUnauthorized)
return
}
ctx := ContextWithUserID(r.Context(), userID)
// Resolve tenant from user_tenants
var tenantID uuid.UUID
err = m.db.GetContext(r.Context(), &tenantID,
"SELECT tenant_id FROM user_tenants WHERE user_id = $1 LIMIT 1", userID)
if err != nil {
http.Error(w, "no tenant found for user", http.StatusForbidden)
return
}
ctx = ContextWithTenantID(ctx, tenantID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (m *Middleware) verifyJWT(tokenStr string) (uuid.UUID, error) {
parsedToken, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return m.jwtSecret, nil
})
if err != nil {
return uuid.Nil, fmt.Errorf("parsing JWT: %w", err)
}
if !parsedToken.Valid {
return uuid.Nil, fmt.Errorf("invalid JWT token")
}
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return uuid.Nil, fmt.Errorf("extracting JWT claims")
}
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return uuid.Nil, fmt.Errorf("JWT token has expired")
}
}
sub, ok := claims["sub"].(string)
if !ok {
return uuid.Nil, fmt.Errorf("missing sub claim in JWT")
}
userID, err := uuid.Parse(sub)
if err != nil {
return uuid.Nil, fmt.Errorf("invalid user ID format: %w", err)
}
return userID, nil
}
func extractBearerToken(r *http.Request) string {
auth := r.Header.Get("Authorization")
if auth == "" {
return ""
}
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
return ""
}
return parts[1]
}