From 0b6bab8512c98de86a32bd3f681338501758ab26 Mon Sep 17 00:00:00 2001 From: m Date: Wed, 25 Mar 2026 13:27:39 +0100 Subject: [PATCH] feat: add tenant + auth backend endpoints (Phase 1A) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tenant management: - POST /api/tenants — create tenant (creator becomes owner) - GET /api/tenants — list tenants for authenticated user - GET /api/tenants/:id — tenant details with access check - POST /api/tenants/:id/invite — invite user by email (owner/admin) - DELETE /api/tenants/:id/members/:uid — remove member - GET /api/tenants/:id/members — list members New packages: - internal/services/tenant_service.go — CRUD on tenants + user_tenants - internal/handlers/tenant_handler.go — HTTP handlers with auth checks - internal/auth/tenant_resolver.go — X-Tenant-ID header middleware, defaults to user's first tenant for scoped routes Authorization: owners/admins can invite and remove members. Cannot remove the last owner. Users can remove themselves. TenantResolver applies to resource routes (cases, deadlines, etc.) but not tenant management routes. --- backend/internal/auth/tenant_resolver.go | 61 +++++ backend/internal/auth/tenant_resolver_test.go | 124 +++++++++ backend/internal/handlers/tenant_handler.go | 243 ++++++++++++++++++ .../internal/handlers/tenant_handler_test.go | 132 ++++++++++ backend/internal/models/tenant.go | 6 + backend/internal/router/router.go | 37 ++- backend/internal/services/tenant_service.go | 211 +++++++++++++++ 7 files changed, 808 insertions(+), 6 deletions(-) create mode 100644 backend/internal/auth/tenant_resolver.go create mode 100644 backend/internal/auth/tenant_resolver_test.go create mode 100644 backend/internal/handlers/tenant_handler.go create mode 100644 backend/internal/handlers/tenant_handler_test.go create mode 100644 backend/internal/services/tenant_service.go diff --git a/backend/internal/auth/tenant_resolver.go b/backend/internal/auth/tenant_resolver.go new file mode 100644 index 0000000..6358d4d --- /dev/null +++ b/backend/internal/auth/tenant_resolver.go @@ -0,0 +1,61 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + + "github.com/google/uuid" +) + +// TenantLookup resolves the default tenant for a user. +// Defined as an interface to avoid circular dependency with services. +type TenantLookup interface { + FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) +} + +// TenantResolver is middleware that resolves the tenant from X-Tenant-ID header +// or defaults to the user's first tenant. +type TenantResolver struct { + lookup TenantLookup +} + +func NewTenantResolver(lookup TenantLookup) *TenantResolver { + return &TenantResolver{lookup: lookup} +} + +func (tr *TenantResolver) Resolve(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userID, ok := UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var tenantID uuid.UUID + + if header := r.Header.Get("X-Tenant-ID"); header != "" { + parsed, err := uuid.Parse(header) + if err != nil { + http.Error(w, fmt.Sprintf("invalid X-Tenant-ID: %v", err), http.StatusBadRequest) + return + } + tenantID = parsed + } else { + // Default to user's first tenant + first, err := tr.lookup.FirstTenantForUser(r.Context(), userID) + if err != nil { + http.Error(w, fmt.Sprintf("resolving tenant: %v", err), http.StatusInternalServerError) + return + } + if first == nil { + http.Error(w, "no tenant found for user", http.StatusBadRequest) + return + } + tenantID = *first + } + + ctx := ContextWithTenantID(r.Context(), tenantID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/backend/internal/auth/tenant_resolver_test.go b/backend/internal/auth/tenant_resolver_test.go new file mode 100644 index 0000000..dfb8e2d --- /dev/null +++ b/backend/internal/auth/tenant_resolver_test.go @@ -0,0 +1,124 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" +) + +type mockTenantLookup struct { + tenantID *uuid.UUID + err error +} + +func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { + return m.tenantID, m.err +} + +func TestTenantResolver_FromHeader(t *testing.T) { + tenantID := uuid.New() + tr := NewTenantResolver(&mockTenantLookup{}) + + var gotTenantID uuid.UUID + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id, ok := TenantFromContext(r.Context()) + if !ok { + t.Fatal("tenant ID not in context") + } + gotTenantID = id + w.WriteHeader(http.StatusOK) + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + r.Header.Set("X-Tenant-ID", tenantID.String()) + r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if gotTenantID != tenantID { + t.Errorf("expected tenant %s, got %s", tenantID, gotTenantID) + } +} + +func TestTenantResolver_DefaultsToFirst(t *testing.T) { + tenantID := uuid.New() + tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID}) + + var gotTenantID uuid.UUID + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id, _ := TenantFromContext(r.Context()) + gotTenantID = id + w.WriteHeader(http.StatusOK) + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if gotTenantID != tenantID { + t.Errorf("expected tenant %s, got %s", tenantID, gotTenantID) + } +} + +func TestTenantResolver_NoUser(t *testing.T) { + tr := NewTenantResolver(&mockTenantLookup{}) + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next should not be called") + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestTenantResolver_InvalidHeader(t *testing.T) { + tr := NewTenantResolver(&mockTenantLookup{}) + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next should not be called") + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + r.Header.Set("X-Tenant-ID", "not-a-uuid") + r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestTenantResolver_NoTenantForUser(t *testing.T) { + tr := NewTenantResolver(&mockTenantLookup{tenantID: nil}) + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next should not be called") + }) + + r := httptest.NewRequest("GET", "/api/cases", nil) + r = r.WithContext(ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + tr.Resolve(next).ServeHTTP(w, r) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} diff --git a/backend/internal/handlers/tenant_handler.go b/backend/internal/handlers/tenant_handler.go new file mode 100644 index 0000000..3351d36 --- /dev/null +++ b/backend/internal/handlers/tenant_handler.go @@ -0,0 +1,243 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/google/uuid" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" +) + +type TenantHandler struct { + svc *services.TenantService +} + +func NewTenantHandler(svc *services.TenantService) *TenantHandler { + return &TenantHandler{svc: svc} +} + +// CreateTenant handles POST /api/tenants +func (h *TenantHandler) CreateTenant(w http.ResponseWriter, r *http.Request) { + userID, ok := auth.UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var req struct { + Name string `json:"name"` + Slug string `json:"slug"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonError(w, "invalid request body", http.StatusBadRequest) + return + } + if req.Name == "" || req.Slug == "" { + jsonError(w, "name and slug are required", http.StatusBadRequest) + return + } + + tenant, err := h.svc.Create(r.Context(), userID, req.Name, req.Slug) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + + jsonResponse(w, tenant, http.StatusCreated) +} + +// ListTenants handles GET /api/tenants +func (h *TenantHandler) ListTenants(w http.ResponseWriter, r *http.Request) { + userID, ok := auth.UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + tenants, err := h.svc.ListForUser(r.Context(), userID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + + jsonResponse(w, tenants, http.StatusOK) +} + +// GetTenant handles GET /api/tenants/{id} +func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) { + userID, ok := auth.UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + tenantID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + jsonError(w, "invalid tenant ID", http.StatusBadRequest) + return + } + + // Verify user has access to this tenant + role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + if role == "" { + jsonError(w, "not found", http.StatusNotFound) + return + } + + tenant, err := h.svc.GetByID(r.Context(), tenantID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + if tenant == nil { + jsonError(w, "not found", http.StatusNotFound) + return + } + + jsonResponse(w, tenant, http.StatusOK) +} + +// InviteUser handles POST /api/tenants/{id}/invite +func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) { + userID, ok := auth.UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + tenantID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + jsonError(w, "invalid tenant ID", http.StatusBadRequest) + return + } + + // Only owners and admins can invite + role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + if role != "owner" && role != "admin" { + jsonError(w, "only owners and admins can invite users", http.StatusForbidden) + return + } + + var req struct { + Email string `json:"email"` + Role string `json:"role"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonError(w, "invalid request body", http.StatusBadRequest) + return + } + if req.Email == "" { + jsonError(w, "email is required", http.StatusBadRequest) + return + } + if req.Role == "" { + req.Role = "member" + } + if req.Role != "member" && req.Role != "admin" { + jsonError(w, "role must be member or admin", http.StatusBadRequest) + return + } + + ut, err := h.svc.InviteByEmail(r.Context(), tenantID, req.Email, req.Role) + if err != nil { + jsonError(w, err.Error(), http.StatusBadRequest) + return + } + + jsonResponse(w, ut, http.StatusCreated) +} + +// RemoveMember handles DELETE /api/tenants/{id}/members/{uid} +func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { + userID, ok := auth.UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + tenantID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + jsonError(w, "invalid tenant ID", http.StatusBadRequest) + return + } + + memberID, err := uuid.Parse(r.PathValue("uid")) + if err != nil { + jsonError(w, "invalid member ID", http.StatusBadRequest) + return + } + + // Only owners and admins can remove members (or user removing themselves) + role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + if role != "owner" && role != "admin" && userID != memberID { + jsonError(w, "insufficient permissions", http.StatusForbidden) + return + } + + if err := h.svc.RemoveMember(r.Context(), tenantID, memberID); err != nil { + jsonError(w, err.Error(), http.StatusBadRequest) + return + } + + jsonResponse(w, map[string]string{"status": "removed"}, http.StatusOK) +} + +// ListMembers handles GET /api/tenants/{id}/members +func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) { + userID, ok := auth.UserFromContext(r.Context()) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + tenantID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + jsonError(w, "invalid tenant ID", http.StatusBadRequest) + return + } + + // Verify user has access + role, err := h.svc.GetUserRole(r.Context(), userID, tenantID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + if role == "" { + jsonError(w, "not found", http.StatusNotFound) + return + } + + members, err := h.svc.ListMembers(r.Context(), tenantID) + if err != nil { + jsonError(w, err.Error(), http.StatusInternalServerError) + return + } + + jsonResponse(w, members, http.StatusOK) +} + +func jsonResponse(w http.ResponseWriter, data interface{}, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +func jsonError(w http.ResponseWriter, msg string, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} diff --git a/backend/internal/handlers/tenant_handler_test.go b/backend/internal/handlers/tenant_handler_test.go new file mode 100644 index 0000000..7461bea --- /dev/null +++ b/backend/internal/handlers/tenant_handler_test.go @@ -0,0 +1,132 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" +) + +func TestCreateTenant_MissingFields(t *testing.T) { + h := &TenantHandler{} // no service needed for validation + + // Build request with auth context + body := `{"name":"","slug":""}` + r := httptest.NewRequest("POST", "/api/tenants", bytes.NewBufferString(body)) + r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + h.CreateTenant(w, r) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "name and slug are required" { + t.Errorf("unexpected error: %s", resp["error"]) + } +} + +func TestCreateTenant_NoAuth(t *testing.T) { + h := &TenantHandler{} + r := httptest.NewRequest("POST", "/api/tenants", bytes.NewBufferString(`{}`)) + w := httptest.NewRecorder() + + h.CreateTenant(w, r) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestGetTenant_InvalidID(t *testing.T) { + h := &TenantHandler{} + r := httptest.NewRequest("GET", "/api/tenants/not-a-uuid", nil) + r.SetPathValue("id", "not-a-uuid") + r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + h.GetTenant(w, r) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestInviteUser_InvalidTenantID(t *testing.T) { + h := &TenantHandler{} + body := `{"email":"test@example.com","role":"member"}` + r := httptest.NewRequest("POST", "/api/tenants/bad/invite", bytes.NewBufferString(body)) + r.SetPathValue("id", "bad") + r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + h.InviteUser(w, r) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestInviteUser_NoAuth(t *testing.T) { + h := &TenantHandler{} + body := `{"email":"test@example.com"}` + r := httptest.NewRequest("POST", "/api/tenants/"+uuid.New().String()+"/invite", bytes.NewBufferString(body)) + r.SetPathValue("id", uuid.New().String()) + w := httptest.NewRecorder() + + h.InviteUser(w, r) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestRemoveMember_InvalidIDs(t *testing.T) { + h := &TenantHandler{} + r := httptest.NewRequest("DELETE", "/api/tenants/bad/members/bad", nil) + r.SetPathValue("id", "bad") + r.SetPathValue("uid", "bad") + r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New())) + w := httptest.NewRecorder() + + h.RemoveMember(w, r) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestJsonResponse(t *testing.T) { + w := httptest.NewRecorder() + jsonResponse(w, map[string]string{"key": "value"}, http.StatusOK) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } +} + +func TestJsonError(t *testing.T) { + w := httptest.NewRecorder() + jsonError(w, "something went wrong", http.StatusBadRequest) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + + var resp map[string]string + json.NewDecoder(w.Body).Decode(&resp) + if resp["error"] != "something went wrong" { + t.Errorf("unexpected error: %s", resp["error"]) + } +} diff --git a/backend/internal/models/tenant.go b/backend/internal/models/tenant.go index 5b928a4..adc720d 100644 --- a/backend/internal/models/tenant.go +++ b/backend/internal/models/tenant.go @@ -22,3 +22,9 @@ type UserTenant struct { Role string `db:"role" json:"role"` CreatedAt time.Time `db:"created_at" json:"created_at"` } + +// TenantWithRole is a Tenant joined with the user's role in that tenant. +type TenantWithRole struct { + Tenant + Role string `db:"role" json:"role"` +} diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go index 530bc76..296dccf 100644 --- a/backend/internal/router/router.go +++ b/backend/internal/router/router.go @@ -4,23 +4,48 @@ import ( "encoding/json" "net/http" - "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" - "github.com/jmoiron/sqlx" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/handlers" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" ) func New(db *sqlx.DB, authMW *auth.Middleware) http.Handler { mux := http.NewServeMux() + // Services + tenantSvc := services.NewTenantService(db) + + // Middleware + tenantResolver := auth.NewTenantResolver(tenantSvc) + + // Handlers + tenantH := handlers.NewTenantHandler(tenantSvc) + // Public routes mux.HandleFunc("GET /health", handleHealth(db)) // Authenticated API routes api := http.NewServeMux() - api.HandleFunc("GET /api/cases", placeholder("cases")) - api.HandleFunc("GET /api/deadlines", placeholder("deadlines")) - api.HandleFunc("GET /api/appointments", placeholder("appointments")) - api.HandleFunc("GET /api/documents", placeholder("documents")) + + // Tenant management (no tenant resolver — these operate across tenants) + api.HandleFunc("POST /api/tenants", tenantH.CreateTenant) + api.HandleFunc("GET /api/tenants", tenantH.ListTenants) + api.HandleFunc("GET /api/tenants/{id}", tenantH.GetTenant) + api.HandleFunc("POST /api/tenants/{id}/invite", tenantH.InviteUser) + api.HandleFunc("DELETE /api/tenants/{id}/members/{uid}", tenantH.RemoveMember) + api.HandleFunc("GET /api/tenants/{id}/members", tenantH.ListMembers) + + // Tenant-scoped routes (require tenant context) + scoped := http.NewServeMux() + scoped.HandleFunc("GET /api/cases", placeholder("cases")) + scoped.HandleFunc("GET /api/deadlines", placeholder("deadlines")) + scoped.HandleFunc("GET /api/appointments", placeholder("appointments")) + scoped.HandleFunc("GET /api/documents", placeholder("documents")) + + // Wire: auth -> tenant routes go directly, scoped routes get tenant resolver + api.Handle("/api/", tenantResolver.Resolve(scoped)) mux.Handle("/api/", authMW.RequireAuth(api)) diff --git a/backend/internal/services/tenant_service.go b/backend/internal/services/tenant_service.go new file mode 100644 index 0000000..5085831 --- /dev/null +++ b/backend/internal/services/tenant_service.go @@ -0,0 +1,211 @@ +package services + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/models" +) + +type TenantService struct { + db *sqlx.DB +} + +func NewTenantService(db *sqlx.DB) *TenantService { + return &TenantService{db: db} +} + +// Create creates a new tenant and assigns the creator as owner. +func (s *TenantService) Create(ctx context.Context, userID uuid.UUID, name, slug string) (*models.Tenant, error) { + tx, err := s.db.BeginTxx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + var tenant models.Tenant + err = tx.QueryRowxContext(ctx, + `INSERT INTO tenants (name, slug) VALUES ($1, $2) RETURNING id, name, slug, settings, created_at, updated_at`, + name, slug, + ).StructScan(&tenant) + if err != nil { + return nil, fmt.Errorf("insert tenant: %w", err) + } + + _, err = tx.ExecContext(ctx, + `INSERT INTO user_tenants (user_id, tenant_id, role) VALUES ($1, $2, 'owner')`, + userID, tenant.ID, + ) + if err != nil { + return nil, fmt.Errorf("assign owner: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit: %w", err) + } + + return &tenant, nil +} + +// ListForUser returns all tenants the user belongs to. +func (s *TenantService) ListForUser(ctx context.Context, userID uuid.UUID) ([]models.TenantWithRole, error) { + var tenants []models.TenantWithRole + err := s.db.SelectContext(ctx, &tenants, + `SELECT t.id, t.name, t.slug, t.settings, t.created_at, t.updated_at, ut.role + FROM tenants t + JOIN user_tenants ut ON ut.tenant_id = t.id + WHERE ut.user_id = $1 + ORDER BY t.name`, + userID, + ) + if err != nil { + return nil, fmt.Errorf("list tenants: %w", err) + } + return tenants, nil +} + +// GetByID returns a single tenant. The caller must verify the user has access. +func (s *TenantService) GetByID(ctx context.Context, tenantID uuid.UUID) (*models.Tenant, error) { + var tenant models.Tenant + err := s.db.GetContext(ctx, &tenant, + `SELECT id, name, slug, settings, created_at, updated_at FROM tenants WHERE id = $1`, + tenantID, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get tenant: %w", err) + } + return &tenant, nil +} + +// GetUserRole returns the user's role in a tenant, or empty string if not a member. +func (s *TenantService) GetUserRole(ctx context.Context, userID, tenantID uuid.UUID) (string, error) { + var role string + err := s.db.GetContext(ctx, &role, + `SELECT role FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`, + userID, tenantID, + ) + if err == sql.ErrNoRows { + return "", nil + } + if err != nil { + return "", fmt.Errorf("get user role: %w", err) + } + return role, nil +} + +// FirstTenantForUser returns the user's first tenant (by name), used as default. +func (s *TenantService) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) { + var tenantID uuid.UUID + err := s.db.GetContext(ctx, &tenantID, + `SELECT t.id FROM tenants t + JOIN user_tenants ut ON ut.tenant_id = t.id + WHERE ut.user_id = $1 + ORDER BY t.name LIMIT 1`, + userID, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("first tenant: %w", err) + } + return &tenantID, nil +} + +// ListMembers returns all members of a tenant. +func (s *TenantService) ListMembers(ctx context.Context, tenantID uuid.UUID) ([]models.UserTenant, error) { + var members []models.UserTenant + err := s.db.SelectContext(ctx, &members, + `SELECT user_id, tenant_id, role, created_at FROM user_tenants WHERE tenant_id = $1 ORDER BY created_at`, + tenantID, + ) + if err != nil { + return nil, fmt.Errorf("list members: %w", err) + } + return members, nil +} + +// InviteByEmail looks up a user by email in auth.users and adds them to the tenant. +func (s *TenantService) InviteByEmail(ctx context.Context, tenantID uuid.UUID, email, role string) (*models.UserTenant, error) { + // Look up user in Supabase auth.users + var userID uuid.UUID + err := s.db.GetContext(ctx, &userID, + `SELECT id FROM auth.users WHERE email = $1`, + email, + ) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("no user found with email %s", email) + } + if err != nil { + return nil, fmt.Errorf("lookup user: %w", err) + } + + // Check if already a member + var exists bool + err = s.db.GetContext(ctx, &exists, + `SELECT EXISTS(SELECT 1 FROM user_tenants WHERE user_id = $1 AND tenant_id = $2)`, + userID, tenantID, + ) + if err != nil { + return nil, fmt.Errorf("check membership: %w", err) + } + if exists { + return nil, fmt.Errorf("user is already a member of this tenant") + } + + var ut models.UserTenant + err = s.db.QueryRowxContext(ctx, + `INSERT INTO user_tenants (user_id, tenant_id, role) VALUES ($1, $2, $3) + RETURNING user_id, tenant_id, role, created_at`, + userID, tenantID, role, + ).StructScan(&ut) + if err != nil { + return nil, fmt.Errorf("invite user: %w", err) + } + + return &ut, nil +} + +// RemoveMember removes a user from a tenant. Cannot remove the last owner. +func (s *TenantService) RemoveMember(ctx context.Context, tenantID, userID uuid.UUID) error { + // Check if the user being removed is an owner + role, err := s.GetUserRole(ctx, userID, tenantID) + if err != nil { + return fmt.Errorf("check role: %w", err) + } + if role == "" { + return fmt.Errorf("user is not a member of this tenant") + } + + if role == "owner" { + // Count owners — prevent removing the last one + var ownerCount int + err := s.db.GetContext(ctx, &ownerCount, + `SELECT COUNT(*) FROM user_tenants WHERE tenant_id = $1 AND role = 'owner'`, + tenantID, + ) + if err != nil { + return fmt.Errorf("count owners: %w", err) + } + if ownerCount <= 1 { + return fmt.Errorf("cannot remove the last owner of a tenant") + } + } + + _, err = s.db.ExecContext(ctx, + `DELETE FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`, + userID, tenantID, + ) + if err != nil { + return fmt.Errorf("remove member: %w", err) + } + + return nil +}