Compare commits
1 Commits
mai/knuth/
...
mai/ritchi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f11c411147 |
@@ -22,7 +22,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer database.Close()
|
defer database.Close()
|
||||||
|
|
||||||
authMW := auth.NewMiddleware(cfg.SupabaseJWTSecret)
|
authMW := auth.NewMiddleware(cfg.SupabaseJWTSecret, database)
|
||||||
handler := router.New(database, authMW)
|
handler := router.New(database, authMW)
|
||||||
|
|
||||||
log.Printf("Starting KanzlAI API server on :%s", cfg.Port)
|
log.Printf("Starting KanzlAI API server on :%s", cfg.Port)
|
||||||
|
|||||||
@@ -8,14 +8,16 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Middleware struct {
|
type Middleware struct {
|
||||||
jwtSecret []byte
|
jwtSecret []byte
|
||||||
|
db *sqlx.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMiddleware(jwtSecret string) *Middleware {
|
func NewMiddleware(jwtSecret string, db *sqlx.DB) *Middleware {
|
||||||
return &Middleware{jwtSecret: []byte(jwtSecret)}
|
return &Middleware{jwtSecret: []byte(jwtSecret), db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
|
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
|
||||||
@@ -33,6 +35,17 @@ func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := ContextWithUserID(r.Context(), userID)
|
ctx := ContextWithUserID(r.Context(), userID)
|
||||||
|
|
||||||
|
// Resolve tenant from user_tenants
|
||||||
|
var tenantID uuid.UUID
|
||||||
|
err = m.db.GetContext(r.Context(), &tenantID,
|
||||||
|
"SELECT tenant_id FROM user_tenants WHERE user_id = $1 LIMIT 1", userID)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "no tenant found for user", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = ContextWithTenantID(ctx, tenantID)
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TenantLookup resolves the default tenant for a user.
|
|
||||||
// Defined as an interface to avoid circular dependency with services.
|
|
||||||
type TenantLookup interface {
|
|
||||||
FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TenantResolver is middleware that resolves the tenant from X-Tenant-ID header
|
|
||||||
// or defaults to the user's first tenant.
|
|
||||||
type TenantResolver struct {
|
|
||||||
lookup TenantLookup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTenantResolver(lookup TenantLookup) *TenantResolver {
|
|
||||||
return &TenantResolver{lookup: lookup}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tr *TenantResolver) Resolve(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var tenantID uuid.UUID
|
|
||||||
|
|
||||||
if header := r.Header.Get("X-Tenant-ID"); header != "" {
|
|
||||||
parsed, err := uuid.Parse(header)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("invalid X-Tenant-ID: %v", err), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tenantID = parsed
|
|
||||||
} else {
|
|
||||||
// Default to user's first tenant
|
|
||||||
first, err := tr.lookup.FirstTenantForUser(r.Context(), userID)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("resolving tenant: %v", err), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if first == nil {
|
|
||||||
http.Error(w, "no tenant found for user", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tenantID = *first
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := ContextWithTenantID(r.Context(), tenantID)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockTenantLookup struct {
|
|
||||||
tenantID *uuid.UUID
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockTenantLookup) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) {
|
|
||||||
return m.tenantID, m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTenantResolver_FromHeader(t *testing.T) {
|
|
||||||
tenantID := uuid.New()
|
|
||||||
tr := NewTenantResolver(&mockTenantLookup{})
|
|
||||||
|
|
||||||
var gotTenantID uuid.UUID
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
id, ok := TenantFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("tenant ID not in context")
|
|
||||||
}
|
|
||||||
gotTenantID = id
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/api/cases", nil)
|
|
||||||
r.Header.Set("X-Tenant-ID", tenantID.String())
|
|
||||||
r = r.WithContext(ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
tr.Resolve(next).ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
if gotTenantID != tenantID {
|
|
||||||
t.Errorf("expected tenant %s, got %s", tenantID, gotTenantID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTenantResolver_DefaultsToFirst(t *testing.T) {
|
|
||||||
tenantID := uuid.New()
|
|
||||||
tr := NewTenantResolver(&mockTenantLookup{tenantID: &tenantID})
|
|
||||||
|
|
||||||
var gotTenantID uuid.UUID
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
id, _ := TenantFromContext(r.Context())
|
|
||||||
gotTenantID = id
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/api/cases", nil)
|
|
||||||
r = r.WithContext(ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
tr.Resolve(next).ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
if gotTenantID != tenantID {
|
|
||||||
t.Errorf("expected tenant %s, got %s", tenantID, gotTenantID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTenantResolver_NoUser(t *testing.T) {
|
|
||||||
tr := NewTenantResolver(&mockTenantLookup{})
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
t.Fatal("next should not be called")
|
|
||||||
})
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/api/cases", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
tr.Resolve(next).ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("expected 401, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTenantResolver_InvalidHeader(t *testing.T) {
|
|
||||||
tr := NewTenantResolver(&mockTenantLookup{})
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
t.Fatal("next should not be called")
|
|
||||||
})
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/api/cases", nil)
|
|
||||||
r.Header.Set("X-Tenant-ID", "not-a-uuid")
|
|
||||||
r = r.WithContext(ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
tr.Resolve(next).ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTenantResolver_NoTenantForUser(t *testing.T) {
|
|
||||||
tr := NewTenantResolver(&mockTenantLookup{tenantID: nil})
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
t.Fatal("next should not be called")
|
|
||||||
})
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/api/cases", nil)
|
|
||||||
r = r.WithContext(ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
tr.Resolve(next).ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
158
backend/internal/handlers/cases.go
Normal file
158
backend/internal/handlers/cases.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
||||||
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CaseHandler struct {
|
||||||
|
svc *services.CaseService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCaseHandler(svc *services.CaseService) *CaseHandler {
|
||||||
|
return &CaseHandler{svc: svc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CaseHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||||
|
offset, _ := strconv.Atoi(r.URL.Query().Get("offset"))
|
||||||
|
|
||||||
|
filter := services.CaseFilter{
|
||||||
|
Status: r.URL.Query().Get("status"),
|
||||||
|
Type: r.URL.Query().Get("type"),
|
||||||
|
Search: r.URL.Query().Get("search"),
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
|
||||||
|
cases, total, err := h.svc.List(r.Context(), tenantID, filter)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"cases": cases,
|
||||||
|
"total": total,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CaseHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, _ := auth.UserFromContext(r.Context())
|
||||||
|
|
||||||
|
var input services.CreateCaseInput
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid JSON body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if input.CaseNumber == "" || input.Title == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "case_number and title are required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := h.svc.Create(r.Context(), tenantID, userID, input)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusCreated, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CaseHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
detail, err := h.svc.GetByID(r.Context(), tenantID, caseID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if detail == nil {
|
||||||
|
writeError(w, http.StatusNotFound, "case not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CaseHandler) Update(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, _ := auth.UserFromContext(r.Context())
|
||||||
|
|
||||||
|
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var input services.UpdateCaseInput
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid JSON body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.svc.Update(r.Context(), tenantID, caseID, userID, input)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if updated == nil {
|
||||||
|
writeError(w, http.StatusNotFound, "case not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CaseHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, _ := auth.UserFromContext(r.Context())
|
||||||
|
|
||||||
|
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.svc.Delete(r.Context(), tenantID, caseID, userID); err != nil {
|
||||||
|
writeError(w, http.StatusNotFound, "case not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": "archived"})
|
||||||
|
}
|
||||||
16
backend/internal/handlers/helpers.go
Normal file
16
backend/internal/handlers/helpers.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(w http.ResponseWriter, status int, message string) {
|
||||||
|
writeJSON(w, status, map[string]string{"error": message})
|
||||||
|
}
|
||||||
134
backend/internal/handlers/parties.go
Normal file
134
backend/internal/handlers/parties.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
||||||
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PartyHandler struct {
|
||||||
|
svc *services.PartyService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPartyHandler(svc *services.PartyService) *PartyHandler {
|
||||||
|
return &PartyHandler{svc: svc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PartyHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parties, err := h.svc.ListByCase(r.Context(), tenantID, caseID)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"parties": parties,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PartyHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, _ := auth.UserFromContext(r.Context())
|
||||||
|
|
||||||
|
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var input services.CreatePartyInput
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid JSON body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if input.Name == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "name is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
party, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
writeError(w, http.StatusNotFound, "case not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusCreated, party)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PartyHandler) Update(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
partyID, err := uuid.Parse(r.PathValue("partyId"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid party ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var input services.UpdatePartyInput
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid JSON body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.svc.Update(r.Context(), tenantID, partyID, input)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if updated == nil {
|
||||||
|
writeError(w, http.StatusNotFound, "party not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PartyHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusForbidden, "missing tenant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
partyID, err := uuid.Parse(r.PathValue("partyId"))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid party ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.svc.Delete(r.Context(), tenantID, partyID); err != nil {
|
||||||
|
writeError(w, http.StatusNotFound, "party not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
@@ -1,243 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TenantHandler struct {
|
|
||||||
svc *services.TenantService
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTenantHandler(svc *services.TenantService) *TenantHandler {
|
|
||||||
return &TenantHandler{svc: svc}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTenant handles POST /api/tenants
|
|
||||||
func (h *TenantHandler) CreateTenant(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := auth.UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Slug string `json:"slug"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
jsonError(w, "invalid request body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.Name == "" || req.Slug == "" {
|
|
||||||
jsonError(w, "name and slug are required", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenant, err := h.svc.Create(r.Context(), userID, req.Name, req.Slug)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse(w, tenant, http.StatusCreated)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListTenants handles GET /api/tenants
|
|
||||||
func (h *TenantHandler) ListTenants(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := auth.UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenants, err := h.svc.ListForUser(r.Context(), userID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse(w, tenants, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTenant handles GET /api/tenants/{id}
|
|
||||||
func (h *TenantHandler) GetTenant(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := auth.UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenantID, err := uuid.Parse(r.PathValue("id"))
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, "invalid tenant ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify user has access to this tenant
|
|
||||||
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if role == "" {
|
|
||||||
jsonError(w, "not found", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenant, err := h.svc.GetByID(r.Context(), tenantID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tenant == nil {
|
|
||||||
jsonError(w, "not found", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse(w, tenant, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUser handles POST /api/tenants/{id}/invite
|
|
||||||
func (h *TenantHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := auth.UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenantID, err := uuid.Parse(r.PathValue("id"))
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, "invalid tenant ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only owners and admins can invite
|
|
||||||
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if role != "owner" && role != "admin" {
|
|
||||||
jsonError(w, "only owners and admins can invite users", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
jsonError(w, "invalid request body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.Email == "" {
|
|
||||||
jsonError(w, "email is required", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.Role == "" {
|
|
||||||
req.Role = "member"
|
|
||||||
}
|
|
||||||
if req.Role != "member" && req.Role != "admin" {
|
|
||||||
jsonError(w, "role must be member or admin", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ut, err := h.svc.InviteByEmail(r.Context(), tenantID, req.Email, req.Role)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse(w, ut, http.StatusCreated)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveMember handles DELETE /api/tenants/{id}/members/{uid}
|
|
||||||
func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := auth.UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenantID, err := uuid.Parse(r.PathValue("id"))
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, "invalid tenant ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
memberID, err := uuid.Parse(r.PathValue("uid"))
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, "invalid member ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only owners and admins can remove members (or user removing themselves)
|
|
||||||
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if role != "owner" && role != "admin" && userID != memberID {
|
|
||||||
jsonError(w, "insufficient permissions", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.svc.RemoveMember(r.Context(), tenantID, memberID); err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse(w, map[string]string{"status": "removed"}, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListMembers handles GET /api/tenants/{id}/members
|
|
||||||
func (h *TenantHandler) ListMembers(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID, ok := auth.UserFromContext(r.Context())
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tenantID, err := uuid.Parse(r.PathValue("id"))
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, "invalid tenant ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify user has access
|
|
||||||
role, err := h.svc.GetUserRole(r.Context(), userID, tenantID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if role == "" {
|
|
||||||
jsonError(w, "not found", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
members, err := h.svc.ListMembers(r.Context(), tenantID)
|
|
||||||
if err != nil {
|
|
||||||
jsonError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse(w, members, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func jsonResponse(w http.ResponseWriter, data interface{}, status int) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
json.NewEncoder(w).Encode(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func jsonError(w http.ResponseWriter, msg string, status int) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": msg})
|
|
||||||
}
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCreateTenant_MissingFields(t *testing.T) {
|
|
||||||
h := &TenantHandler{} // no service needed for validation
|
|
||||||
|
|
||||||
// Build request with auth context
|
|
||||||
body := `{"name":"","slug":""}`
|
|
||||||
r := httptest.NewRequest("POST", "/api/tenants", bytes.NewBufferString(body))
|
|
||||||
r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.CreateTenant(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp map[string]string
|
|
||||||
json.NewDecoder(w.Body).Decode(&resp)
|
|
||||||
if resp["error"] != "name and slug are required" {
|
|
||||||
t.Errorf("unexpected error: %s", resp["error"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateTenant_NoAuth(t *testing.T) {
|
|
||||||
h := &TenantHandler{}
|
|
||||||
r := httptest.NewRequest("POST", "/api/tenants", bytes.NewBufferString(`{}`))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.CreateTenant(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("expected 401, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTenant_InvalidID(t *testing.T) {
|
|
||||||
h := &TenantHandler{}
|
|
||||||
r := httptest.NewRequest("GET", "/api/tenants/not-a-uuid", nil)
|
|
||||||
r.SetPathValue("id", "not-a-uuid")
|
|
||||||
r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.GetTenant(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInviteUser_InvalidTenantID(t *testing.T) {
|
|
||||||
h := &TenantHandler{}
|
|
||||||
body := `{"email":"test@example.com","role":"member"}`
|
|
||||||
r := httptest.NewRequest("POST", "/api/tenants/bad/invite", bytes.NewBufferString(body))
|
|
||||||
r.SetPathValue("id", "bad")
|
|
||||||
r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.InviteUser(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInviteUser_NoAuth(t *testing.T) {
|
|
||||||
h := &TenantHandler{}
|
|
||||||
body := `{"email":"test@example.com"}`
|
|
||||||
r := httptest.NewRequest("POST", "/api/tenants/"+uuid.New().String()+"/invite", bytes.NewBufferString(body))
|
|
||||||
r.SetPathValue("id", uuid.New().String())
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.InviteUser(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusUnauthorized {
|
|
||||||
t.Errorf("expected 401, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRemoveMember_InvalidIDs(t *testing.T) {
|
|
||||||
h := &TenantHandler{}
|
|
||||||
r := httptest.NewRequest("DELETE", "/api/tenants/bad/members/bad", nil)
|
|
||||||
r.SetPathValue("id", "bad")
|
|
||||||
r.SetPathValue("uid", "bad")
|
|
||||||
r = r.WithContext(auth.ContextWithUserID(r.Context(), uuid.New()))
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.RemoveMember(w, r)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJsonResponse(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
jsonResponse(w, map[string]string{"key": "value"}, http.StatusOK)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
|
||||||
t.Errorf("expected application/json, got %s", ct)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJsonError(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
jsonError(w, "something went wrong", http.StatusBadRequest)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Errorf("expected 400, got %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp map[string]string
|
|
||||||
json.NewDecoder(w.Body).Decode(&resp)
|
|
||||||
if resp["error"] != "something went wrong" {
|
|
||||||
t.Errorf("unexpected error: %s", resp["error"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -22,9 +22,3 @@ type UserTenant struct {
|
|||||||
Role string `db:"role" json:"role"`
|
Role string `db:"role" json:"role"`
|
||||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TenantWithRole is a Tenant joined with the user's role in that tenant.
|
|
||||||
type TenantWithRole struct {
|
|
||||||
Tenant
|
|
||||||
Role string `db:"role" json:"role"`
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,24 +4,23 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
|
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/handlers"
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/handlers"
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(db *sqlx.DB, authMW *auth.Middleware) http.Handler {
|
func New(db *sqlx.DB, authMW *auth.Middleware) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// Services
|
// Services
|
||||||
tenantSvc := services.NewTenantService(db)
|
caseSvc := services.NewCaseService(db)
|
||||||
|
partySvc := services.NewPartyService(db)
|
||||||
// Middleware
|
|
||||||
tenantResolver := auth.NewTenantResolver(tenantSvc)
|
|
||||||
|
|
||||||
// Handlers
|
// Handlers
|
||||||
tenantH := handlers.NewTenantHandler(tenantSvc)
|
caseH := handlers.NewCaseHandler(caseSvc)
|
||||||
|
partyH := handlers.NewPartyHandler(partySvc)
|
||||||
|
|
||||||
// Public routes
|
// Public routes
|
||||||
mux.HandleFunc("GET /health", handleHealth(db))
|
mux.HandleFunc("GET /health", handleHealth(db))
|
||||||
@@ -29,23 +28,23 @@ func New(db *sqlx.DB, authMW *auth.Middleware) http.Handler {
|
|||||||
// Authenticated API routes
|
// Authenticated API routes
|
||||||
api := http.NewServeMux()
|
api := http.NewServeMux()
|
||||||
|
|
||||||
// Tenant management (no tenant resolver — these operate across tenants)
|
// Cases
|
||||||
api.HandleFunc("POST /api/tenants", tenantH.CreateTenant)
|
api.HandleFunc("GET /api/cases", caseH.List)
|
||||||
api.HandleFunc("GET /api/tenants", tenantH.ListTenants)
|
api.HandleFunc("POST /api/cases", caseH.Create)
|
||||||
api.HandleFunc("GET /api/tenants/{id}", tenantH.GetTenant)
|
api.HandleFunc("GET /api/cases/{id}", caseH.Get)
|
||||||
api.HandleFunc("POST /api/tenants/{id}/invite", tenantH.InviteUser)
|
api.HandleFunc("PUT /api/cases/{id}", caseH.Update)
|
||||||
api.HandleFunc("DELETE /api/tenants/{id}/members/{uid}", tenantH.RemoveMember)
|
api.HandleFunc("DELETE /api/cases/{id}", caseH.Delete)
|
||||||
api.HandleFunc("GET /api/tenants/{id}/members", tenantH.ListMembers)
|
|
||||||
|
|
||||||
// Tenant-scoped routes (require tenant context)
|
// Parties (nested under cases for creation/listing, top-level for update/delete)
|
||||||
scoped := http.NewServeMux()
|
api.HandleFunc("GET /api/cases/{id}/parties", partyH.List)
|
||||||
scoped.HandleFunc("GET /api/cases", placeholder("cases"))
|
api.HandleFunc("POST /api/cases/{id}/parties", partyH.Create)
|
||||||
scoped.HandleFunc("GET /api/deadlines", placeholder("deadlines"))
|
api.HandleFunc("PUT /api/parties/{partyId}", partyH.Update)
|
||||||
scoped.HandleFunc("GET /api/appointments", placeholder("appointments"))
|
api.HandleFunc("DELETE /api/parties/{partyId}", partyH.Delete)
|
||||||
scoped.HandleFunc("GET /api/documents", placeholder("documents"))
|
|
||||||
|
|
||||||
// Wire: auth -> tenant routes go directly, scoped routes get tenant resolver
|
// Placeholder routes for future phases
|
||||||
api.Handle("/api/", tenantResolver.Resolve(scoped))
|
api.HandleFunc("GET /api/deadlines", placeholder("deadlines"))
|
||||||
|
api.HandleFunc("GET /api/appointments", placeholder("appointments"))
|
||||||
|
api.HandleFunc("GET /api/documents", placeholder("documents"))
|
||||||
|
|
||||||
mux.Handle("/api/", authMW.RequireAuth(api))
|
mux.Handle("/api/", authMW.RequireAuth(api))
|
||||||
|
|
||||||
|
|||||||
277
backend/internal/services/case_service.go
Normal file
277
backend/internal/services/case_service.go
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/models"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CaseService struct {
|
||||||
|
db *sqlx.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCaseService(db *sqlx.DB) *CaseService {
|
||||||
|
return &CaseService{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CaseFilter struct {
|
||||||
|
Status string
|
||||||
|
Type string
|
||||||
|
Search string
|
||||||
|
Limit int
|
||||||
|
Offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
type CaseDetail struct {
|
||||||
|
models.Case
|
||||||
|
Parties []models.Party `json:"parties"`
|
||||||
|
RecentEvents []models.CaseEvent `json:"recent_events"`
|
||||||
|
DeadlinesCount int `json:"deadlines_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateCaseInput struct {
|
||||||
|
CaseNumber string `json:"case_number"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
CaseType *string `json:"case_type,omitempty"`
|
||||||
|
Court *string `json:"court,omitempty"`
|
||||||
|
CourtRef *string `json:"court_ref,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateCaseInput struct {
|
||||||
|
CaseNumber *string `json:"case_number,omitempty"`
|
||||||
|
Title *string `json:"title,omitempty"`
|
||||||
|
CaseType *string `json:"case_type,omitempty"`
|
||||||
|
Court *string `json:"court,omitempty"`
|
||||||
|
CourtRef *string `json:"court_ref,omitempty"`
|
||||||
|
Status *string `json:"status,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaseService) List(ctx context.Context, tenantID uuid.UUID, filter CaseFilter) ([]models.Case, int, error) {
|
||||||
|
if filter.Limit <= 0 {
|
||||||
|
filter.Limit = 20
|
||||||
|
}
|
||||||
|
if filter.Limit > 100 {
|
||||||
|
filter.Limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build WHERE clause
|
||||||
|
where := "WHERE tenant_id = $1"
|
||||||
|
args := []interface{}{tenantID}
|
||||||
|
argIdx := 2
|
||||||
|
|
||||||
|
if filter.Status != "" {
|
||||||
|
where += fmt.Sprintf(" AND status = $%d", argIdx)
|
||||||
|
args = append(args, filter.Status)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if filter.Type != "" {
|
||||||
|
where += fmt.Sprintf(" AND case_type = $%d", argIdx)
|
||||||
|
args = append(args, filter.Type)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if filter.Search != "" {
|
||||||
|
where += fmt.Sprintf(" AND (title ILIKE $%d OR case_number ILIKE $%d)", argIdx, argIdx)
|
||||||
|
args = append(args, "%"+filter.Search+"%")
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count total
|
||||||
|
var total int
|
||||||
|
countQuery := "SELECT COUNT(*) FROM cases " + where
|
||||||
|
if err := s.db.GetContext(ctx, &total, countQuery, args...); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("counting cases: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch page
|
||||||
|
query := fmt.Sprintf("SELECT * FROM cases %s ORDER BY updated_at DESC LIMIT $%d OFFSET $%d",
|
||||||
|
where, argIdx, argIdx+1)
|
||||||
|
args = append(args, filter.Limit, filter.Offset)
|
||||||
|
|
||||||
|
var cases []models.Case
|
||||||
|
if err := s.db.SelectContext(ctx, &cases, query, args...); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("listing cases: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cases, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaseService) GetByID(ctx context.Context, tenantID, caseID uuid.UUID) (*CaseDetail, error) {
|
||||||
|
var c models.Case
|
||||||
|
err := s.db.GetContext(ctx, &c,
|
||||||
|
"SELECT * FROM cases WHERE id = $1 AND tenant_id = $2", caseID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("getting case: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
detail := &CaseDetail{Case: c}
|
||||||
|
|
||||||
|
// Parties
|
||||||
|
if err := s.db.SelectContext(ctx, &detail.Parties,
|
||||||
|
"SELECT * FROM parties WHERE case_id = $1 AND tenant_id = $2 ORDER BY name",
|
||||||
|
caseID, tenantID); err != nil {
|
||||||
|
return nil, fmt.Errorf("getting parties: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recent events (last 20)
|
||||||
|
if err := s.db.SelectContext(ctx, &detail.RecentEvents,
|
||||||
|
"SELECT * FROM case_events WHERE case_id = $1 AND tenant_id = $2 ORDER BY created_at DESC LIMIT 20",
|
||||||
|
caseID, tenantID); err != nil {
|
||||||
|
return nil, fmt.Errorf("getting events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deadlines count
|
||||||
|
if err := s.db.GetContext(ctx, &detail.DeadlinesCount,
|
||||||
|
"SELECT COUNT(*) FROM deadlines WHERE case_id = $1 AND tenant_id = $2",
|
||||||
|
caseID, tenantID); err != nil {
|
||||||
|
return nil, fmt.Errorf("counting deadlines: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return detail, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaseService) Create(ctx context.Context, tenantID uuid.UUID, userID uuid.UUID, input CreateCaseInput) (*models.Case, error) {
|
||||||
|
if input.Status == "" {
|
||||||
|
input.Status = "active"
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uuid.New()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT INTO cases (id, tenant_id, case_number, title, case_type, court, court_ref, status, metadata, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, '{}', $9, $9)`,
|
||||||
|
id, tenantID, input.CaseNumber, input.Title, input.CaseType, input.Court, input.CourtRef, input.Status, now)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating case: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create case_created event
|
||||||
|
createEvent(ctx, s.db, tenantID, id, userID, "case_created", "Case created", nil)
|
||||||
|
|
||||||
|
var c models.Case
|
||||||
|
if err := s.db.GetContext(ctx, &c, "SELECT * FROM cases WHERE id = $1", id); err != nil {
|
||||||
|
return nil, fmt.Errorf("fetching created case: %w", err)
|
||||||
|
}
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaseService) Update(ctx context.Context, tenantID, caseID uuid.UUID, userID uuid.UUID, input UpdateCaseInput) (*models.Case, error) {
|
||||||
|
// Fetch current to detect status change
|
||||||
|
var current models.Case
|
||||||
|
err := s.db.GetContext(ctx, ¤t,
|
||||||
|
"SELECT * FROM cases WHERE id = $1 AND tenant_id = $2", caseID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("fetching case for update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build SET clause dynamically
|
||||||
|
sets := []string{}
|
||||||
|
args := []interface{}{}
|
||||||
|
argIdx := 1
|
||||||
|
|
||||||
|
if input.CaseNumber != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("case_number = $%d", argIdx))
|
||||||
|
args = append(args, *input.CaseNumber)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.Title != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("title = $%d", argIdx))
|
||||||
|
args = append(args, *input.Title)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.CaseType != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("case_type = $%d", argIdx))
|
||||||
|
args = append(args, *input.CaseType)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.Court != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("court = $%d", argIdx))
|
||||||
|
args = append(args, *input.Court)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.CourtRef != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("court_ref = $%d", argIdx))
|
||||||
|
args = append(args, *input.CourtRef)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.Status != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("status = $%d", argIdx))
|
||||||
|
args = append(args, *input.Status)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sets) == 0 {
|
||||||
|
return ¤t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sets = append(sets, fmt.Sprintf("updated_at = $%d", argIdx))
|
||||||
|
args = append(args, time.Now())
|
||||||
|
argIdx++
|
||||||
|
|
||||||
|
query := fmt.Sprintf("UPDATE cases SET %s WHERE id = $%d AND tenant_id = $%d",
|
||||||
|
joinStrings(sets, ", "), argIdx, argIdx+1)
|
||||||
|
args = append(args, caseID, tenantID)
|
||||||
|
|
||||||
|
if _, err := s.db.ExecContext(ctx, query, args...); err != nil {
|
||||||
|
return nil, fmt.Errorf("updating case: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log status change event
|
||||||
|
if input.Status != nil && *input.Status != current.Status {
|
||||||
|
desc := fmt.Sprintf("Status changed from %s to %s", current.Status, *input.Status)
|
||||||
|
createEvent(ctx, s.db, tenantID, caseID, userID, "status_changed", desc, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
var updated models.Case
|
||||||
|
if err := s.db.GetContext(ctx, &updated, "SELECT * FROM cases WHERE id = $1", caseID); err != nil {
|
||||||
|
return nil, fmt.Errorf("fetching updated case: %w", err)
|
||||||
|
}
|
||||||
|
return &updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CaseService) Delete(ctx context.Context, tenantID, caseID uuid.UUID, userID uuid.UUID) error {
|
||||||
|
result, err := s.db.ExecContext(ctx,
|
||||||
|
"UPDATE cases SET status = 'archived', updated_at = $1 WHERE id = $2 AND tenant_id = $3 AND status != 'archived'",
|
||||||
|
time.Now(), caseID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("archiving case: %w", err)
|
||||||
|
}
|
||||||
|
rows, _ := result.RowsAffected()
|
||||||
|
if rows == 0 {
|
||||||
|
return sql.ErrNoRows
|
||||||
|
}
|
||||||
|
createEvent(ctx, s.db, tenantID, caseID, userID, "case_archived", "Case archived", nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createEvent(ctx context.Context, db *sqlx.DB, tenantID, caseID uuid.UUID, userID uuid.UUID, eventType, title string, description *string) {
|
||||||
|
now := time.Now()
|
||||||
|
db.ExecContext(ctx,
|
||||||
|
`INSERT INTO case_events (id, tenant_id, case_id, event_type, title, description, event_date, created_by, metadata, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, '{}', $7, $7)`,
|
||||||
|
uuid.New(), tenantID, caseID, eventType, title, description, now, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinStrings(strs []string, sep string) string {
|
||||||
|
result := ""
|
||||||
|
for i, s := range strs {
|
||||||
|
if i > 0 {
|
||||||
|
result += sep
|
||||||
|
}
|
||||||
|
result += s
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
152
backend/internal/services/party_service.go
Normal file
152
backend/internal/services/party_service.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"mgit.msbls.de/m/KanzlAI-mGMT/internal/models"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PartyService struct {
|
||||||
|
db *sqlx.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPartyService(db *sqlx.DB) *PartyService {
|
||||||
|
return &PartyService{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreatePartyInput struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Role *string `json:"role,omitempty"`
|
||||||
|
Representative *string `json:"representative,omitempty"`
|
||||||
|
ContactInfo json.RawMessage `json:"contact_info,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdatePartyInput struct {
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
Role *string `json:"role,omitempty"`
|
||||||
|
Representative *string `json:"representative,omitempty"`
|
||||||
|
ContactInfo json.RawMessage `json:"contact_info,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PartyService) ListByCase(ctx context.Context, tenantID, caseID uuid.UUID) ([]models.Party, error) {
|
||||||
|
var parties []models.Party
|
||||||
|
err := s.db.SelectContext(ctx, &parties,
|
||||||
|
"SELECT * FROM parties WHERE case_id = $1 AND tenant_id = $2 ORDER BY name",
|
||||||
|
caseID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listing parties: %w", err)
|
||||||
|
}
|
||||||
|
return parties, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PartyService) Create(ctx context.Context, tenantID, caseID uuid.UUID, userID uuid.UUID, input CreatePartyInput) (*models.Party, error) {
|
||||||
|
// Verify case exists and belongs to tenant
|
||||||
|
var exists bool
|
||||||
|
err := s.db.GetContext(ctx, &exists,
|
||||||
|
"SELECT EXISTS(SELECT 1 FROM cases WHERE id = $1 AND tenant_id = $2)", caseID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("checking case: %w", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return nil, sql.ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uuid.New()
|
||||||
|
contactInfo := input.ContactInfo
|
||||||
|
if contactInfo == nil {
|
||||||
|
contactInfo = json.RawMessage("{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.db.ExecContext(ctx,
|
||||||
|
`INSERT INTO parties (id, tenant_id, case_id, name, role, representative, contact_info)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||||
|
id, tenantID, caseID, input.Name, input.Role, input.Representative, contactInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating party: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log event
|
||||||
|
desc := fmt.Sprintf("Party added: %s", input.Name)
|
||||||
|
createEvent(ctx, s.db, tenantID, caseID, userID, "party_added", desc, nil)
|
||||||
|
|
||||||
|
var party models.Party
|
||||||
|
if err := s.db.GetContext(ctx, &party, "SELECT * FROM parties WHERE id = $1", id); err != nil {
|
||||||
|
return nil, fmt.Errorf("fetching created party: %w", err)
|
||||||
|
}
|
||||||
|
return &party, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PartyService) Update(ctx context.Context, tenantID, partyID uuid.UUID, input UpdatePartyInput) (*models.Party, error) {
|
||||||
|
// Verify party exists and belongs to tenant
|
||||||
|
var current models.Party
|
||||||
|
err := s.db.GetContext(ctx, ¤t,
|
||||||
|
"SELECT * FROM parties WHERE id = $1 AND tenant_id = $2", partyID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("fetching party: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sets := []string{}
|
||||||
|
args := []interface{}{}
|
||||||
|
argIdx := 1
|
||||||
|
|
||||||
|
if input.Name != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("name = $%d", argIdx))
|
||||||
|
args = append(args, *input.Name)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.Role != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("role = $%d", argIdx))
|
||||||
|
args = append(args, *input.Role)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.Representative != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("representative = $%d", argIdx))
|
||||||
|
args = append(args, *input.Representative)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
if input.ContactInfo != nil {
|
||||||
|
sets = append(sets, fmt.Sprintf("contact_info = $%d", argIdx))
|
||||||
|
args = append(args, input.ContactInfo)
|
||||||
|
argIdx++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sets) == 0 {
|
||||||
|
return ¤t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf("UPDATE parties SET %s WHERE id = $%d AND tenant_id = $%d",
|
||||||
|
joinStrings(sets, ", "), argIdx, argIdx+1)
|
||||||
|
args = append(args, partyID, tenantID)
|
||||||
|
|
||||||
|
if _, err := s.db.ExecContext(ctx, query, args...); err != nil {
|
||||||
|
return nil, fmt.Errorf("updating party: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var updated models.Party
|
||||||
|
if err := s.db.GetContext(ctx, &updated, "SELECT * FROM parties WHERE id = $1", partyID); err != nil {
|
||||||
|
return nil, fmt.Errorf("fetching updated party: %w", err)
|
||||||
|
}
|
||||||
|
return &updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PartyService) Delete(ctx context.Context, tenantID, partyID uuid.UUID) error {
|
||||||
|
result, err := s.db.ExecContext(ctx,
|
||||||
|
"DELETE FROM parties WHERE id = $1 AND tenant_id = $2", partyID, tenantID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("deleting party: %w", err)
|
||||||
|
}
|
||||||
|
rows, _ := result.RowsAffected()
|
||||||
|
if rows == 0 {
|
||||||
|
return sql.ErrNoRows
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
package services
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
|
|
||||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/models"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TenantService struct {
|
|
||||||
db *sqlx.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTenantService(db *sqlx.DB) *TenantService {
|
|
||||||
return &TenantService{db: db}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create creates a new tenant and assigns the creator as owner.
|
|
||||||
func (s *TenantService) Create(ctx context.Context, userID uuid.UUID, name, slug string) (*models.Tenant, error) {
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
var tenant models.Tenant
|
|
||||||
err = tx.QueryRowxContext(ctx,
|
|
||||||
`INSERT INTO tenants (name, slug) VALUES ($1, $2) RETURNING id, name, slug, settings, created_at, updated_at`,
|
|
||||||
name, slug,
|
|
||||||
).StructScan(&tenant)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("insert tenant: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx,
|
|
||||||
`INSERT INTO user_tenants (user_id, tenant_id, role) VALUES ($1, $2, 'owner')`,
|
|
||||||
userID, tenant.ID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("assign owner: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return nil, fmt.Errorf("commit: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tenant, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListForUser returns all tenants the user belongs to.
|
|
||||||
func (s *TenantService) ListForUser(ctx context.Context, userID uuid.UUID) ([]models.TenantWithRole, error) {
|
|
||||||
var tenants []models.TenantWithRole
|
|
||||||
err := s.db.SelectContext(ctx, &tenants,
|
|
||||||
`SELECT t.id, t.name, t.slug, t.settings, t.created_at, t.updated_at, ut.role
|
|
||||||
FROM tenants t
|
|
||||||
JOIN user_tenants ut ON ut.tenant_id = t.id
|
|
||||||
WHERE ut.user_id = $1
|
|
||||||
ORDER BY t.name`,
|
|
||||||
userID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list tenants: %w", err)
|
|
||||||
}
|
|
||||||
return tenants, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByID returns a single tenant. The caller must verify the user has access.
|
|
||||||
func (s *TenantService) GetByID(ctx context.Context, tenantID uuid.UUID) (*models.Tenant, error) {
|
|
||||||
var tenant models.Tenant
|
|
||||||
err := s.db.GetContext(ctx, &tenant,
|
|
||||||
`SELECT id, name, slug, settings, created_at, updated_at FROM tenants WHERE id = $1`,
|
|
||||||
tenantID,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get tenant: %w", err)
|
|
||||||
}
|
|
||||||
return &tenant, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserRole returns the user's role in a tenant, or empty string if not a member.
|
|
||||||
func (s *TenantService) GetUserRole(ctx context.Context, userID, tenantID uuid.UUID) (string, error) {
|
|
||||||
var role string
|
|
||||||
err := s.db.GetContext(ctx, &role,
|
|
||||||
`SELECT role FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`,
|
|
||||||
userID, tenantID,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("get user role: %w", err)
|
|
||||||
}
|
|
||||||
return role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FirstTenantForUser returns the user's first tenant (by name), used as default.
|
|
||||||
func (s *TenantService) FirstTenantForUser(ctx context.Context, userID uuid.UUID) (*uuid.UUID, error) {
|
|
||||||
var tenantID uuid.UUID
|
|
||||||
err := s.db.GetContext(ctx, &tenantID,
|
|
||||||
`SELECT t.id FROM tenants t
|
|
||||||
JOIN user_tenants ut ON ut.tenant_id = t.id
|
|
||||||
WHERE ut.user_id = $1
|
|
||||||
ORDER BY t.name LIMIT 1`,
|
|
||||||
userID,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("first tenant: %w", err)
|
|
||||||
}
|
|
||||||
return &tenantID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListMembers returns all members of a tenant.
|
|
||||||
func (s *TenantService) ListMembers(ctx context.Context, tenantID uuid.UUID) ([]models.UserTenant, error) {
|
|
||||||
var members []models.UserTenant
|
|
||||||
err := s.db.SelectContext(ctx, &members,
|
|
||||||
`SELECT user_id, tenant_id, role, created_at FROM user_tenants WHERE tenant_id = $1 ORDER BY created_at`,
|
|
||||||
tenantID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list members: %w", err)
|
|
||||||
}
|
|
||||||
return members, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteByEmail looks up a user by email in auth.users and adds them to the tenant.
|
|
||||||
func (s *TenantService) InviteByEmail(ctx context.Context, tenantID uuid.UUID, email, role string) (*models.UserTenant, error) {
|
|
||||||
// Look up user in Supabase auth.users
|
|
||||||
var userID uuid.UUID
|
|
||||||
err := s.db.GetContext(ctx, &userID,
|
|
||||||
`SELECT id FROM auth.users WHERE email = $1`,
|
|
||||||
email,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, fmt.Errorf("no user found with email %s", email)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("lookup user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if already a member
|
|
||||||
var exists bool
|
|
||||||
err = s.db.GetContext(ctx, &exists,
|
|
||||||
`SELECT EXISTS(SELECT 1 FROM user_tenants WHERE user_id = $1 AND tenant_id = $2)`,
|
|
||||||
userID, tenantID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check membership: %w", err)
|
|
||||||
}
|
|
||||||
if exists {
|
|
||||||
return nil, fmt.Errorf("user is already a member of this tenant")
|
|
||||||
}
|
|
||||||
|
|
||||||
var ut models.UserTenant
|
|
||||||
err = s.db.QueryRowxContext(ctx,
|
|
||||||
`INSERT INTO user_tenants (user_id, tenant_id, role) VALUES ($1, $2, $3)
|
|
||||||
RETURNING user_id, tenant_id, role, created_at`,
|
|
||||||
userID, tenantID, role,
|
|
||||||
).StructScan(&ut)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invite user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ut, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveMember removes a user from a tenant. Cannot remove the last owner.
|
|
||||||
func (s *TenantService) RemoveMember(ctx context.Context, tenantID, userID uuid.UUID) error {
|
|
||||||
// Check if the user being removed is an owner
|
|
||||||
role, err := s.GetUserRole(ctx, userID, tenantID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("check role: %w", err)
|
|
||||||
}
|
|
||||||
if role == "" {
|
|
||||||
return fmt.Errorf("user is not a member of this tenant")
|
|
||||||
}
|
|
||||||
|
|
||||||
if role == "owner" {
|
|
||||||
// Count owners — prevent removing the last one
|
|
||||||
var ownerCount int
|
|
||||||
err := s.db.GetContext(ctx, &ownerCount,
|
|
||||||
`SELECT COUNT(*) FROM user_tenants WHERE tenant_id = $1 AND role = 'owner'`,
|
|
||||||
tenantID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("count owners: %w", err)
|
|
||||||
}
|
|
||||||
if ownerCount <= 1 {
|
|
||||||
return fmt.Errorf("cannot remove the last owner of a tenant")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = s.db.ExecContext(ctx,
|
|
||||||
`DELETE FROM user_tenants WHERE user_id = $1 AND tenant_id = $2`,
|
|
||||||
userID, tenantID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove member: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user