diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 6f99f13..7bffddc 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -22,7 +22,7 @@ func main() { } defer database.Close() - authMW := auth.NewMiddleware(cfg.SupabaseJWTSecret) + authMW := auth.NewMiddleware(cfg.SupabaseJWTSecret, database) handler := router.New(database, authMW) log.Printf("Starting KanzlAI API server on :%s", cfg.Port) diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index 26896e1..4f31eb6 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -8,14 +8,16 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "github.com/jmoiron/sqlx" ) type Middleware struct { jwtSecret []byte + db *sqlx.DB } -func NewMiddleware(jwtSecret string) *Middleware { - return &Middleware{jwtSecret: []byte(jwtSecret)} +func NewMiddleware(jwtSecret string, db *sqlx.DB) *Middleware { + return &Middleware{jwtSecret: []byte(jwtSecret), db: db} } func (m *Middleware) RequireAuth(next http.Handler) http.Handler { @@ -33,6 +35,17 @@ func (m *Middleware) RequireAuth(next http.Handler) http.Handler { } ctx := ContextWithUserID(r.Context(), userID) + + // Resolve tenant from user_tenants + var tenantID uuid.UUID + err = m.db.GetContext(r.Context(), &tenantID, + "SELECT tenant_id FROM user_tenants WHERE user_id = $1 LIMIT 1", userID) + if err != nil { + http.Error(w, "no tenant found for user", http.StatusForbidden) + return + } + ctx = ContextWithTenantID(ctx, tenantID) + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/backend/internal/handlers/cases.go b/backend/internal/handlers/cases.go new file mode 100644 index 0000000..a10d9d5 --- /dev/null +++ b/backend/internal/handlers/cases.go @@ -0,0 +1,158 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "strconv" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" + + "github.com/google/uuid" +) + +type CaseHandler struct { + svc *services.CaseService +} + +func NewCaseHandler(svc *services.CaseService) *CaseHandler { + return &CaseHandler{svc: svc} +} + +func (h *CaseHandler) List(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + offset, _ := strconv.Atoi(r.URL.Query().Get("offset")) + + filter := services.CaseFilter{ + Status: r.URL.Query().Get("status"), + Type: r.URL.Query().Get("type"), + Search: r.URL.Query().Get("search"), + Limit: limit, + Offset: offset, + } + + cases, total, err := h.svc.List(r.Context(), tenantID, filter) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "cases": cases, + "total": total, + }) +} + +func (h *CaseHandler) Create(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + userID, _ := auth.UserFromContext(r.Context()) + + var input services.CreateCaseInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body") + return + } + if input.CaseNumber == "" || input.Title == "" { + writeError(w, http.StatusBadRequest, "case_number and title are required") + return + } + + c, err := h.svc.Create(r.Context(), tenantID, userID, input) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + writeJSON(w, http.StatusCreated, c) +} + +func (h *CaseHandler) Get(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + + caseID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid case ID") + return + } + + detail, err := h.svc.GetByID(r.Context(), tenantID, caseID) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + if detail == nil { + writeError(w, http.StatusNotFound, "case not found") + return + } + + writeJSON(w, http.StatusOK, detail) +} + +func (h *CaseHandler) Update(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + userID, _ := auth.UserFromContext(r.Context()) + + caseID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid case ID") + return + } + + var input services.UpdateCaseInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body") + return + } + + updated, err := h.svc.Update(r.Context(), tenantID, caseID, userID, input) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + if updated == nil { + writeError(w, http.StatusNotFound, "case not found") + return + } + + writeJSON(w, http.StatusOK, updated) +} + +func (h *CaseHandler) Delete(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + userID, _ := auth.UserFromContext(r.Context()) + + caseID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid case ID") + return + } + + if err := h.svc.Delete(r.Context(), tenantID, caseID, userID); err != nil { + writeError(w, http.StatusNotFound, "case not found") + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "archived"}) +} diff --git a/backend/internal/handlers/helpers.go b/backend/internal/handlers/helpers.go new file mode 100644 index 0000000..ce23c1d --- /dev/null +++ b/backend/internal/handlers/helpers.go @@ -0,0 +1,16 @@ +package handlers + +import ( + "encoding/json" + "net/http" +) + +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, status int, message string) { + writeJSON(w, status, map[string]string{"error": message}) +} diff --git a/backend/internal/handlers/parties.go b/backend/internal/handlers/parties.go new file mode 100644 index 0000000..178e06b --- /dev/null +++ b/backend/internal/handlers/parties.go @@ -0,0 +1,134 @@ +package handlers + +import ( + "database/sql" + "encoding/json" + "net/http" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" + + "github.com/google/uuid" +) + +type PartyHandler struct { + svc *services.PartyService +} + +func NewPartyHandler(svc *services.PartyService) *PartyHandler { + return &PartyHandler{svc: svc} +} + +func (h *PartyHandler) List(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + + caseID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid case ID") + return + } + + parties, err := h.svc.ListByCase(r.Context(), tenantID, caseID) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "parties": parties, + }) +} + +func (h *PartyHandler) Create(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + userID, _ := auth.UserFromContext(r.Context()) + + caseID, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid case ID") + return + } + + var input services.CreatePartyInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body") + return + } + if input.Name == "" { + writeError(w, http.StatusBadRequest, "name is required") + return + } + + party, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input) + if err != nil { + if err == sql.ErrNoRows { + writeError(w, http.StatusNotFound, "case not found") + return + } + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + writeJSON(w, http.StatusCreated, party) +} + +func (h *PartyHandler) Update(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + + partyID, err := uuid.Parse(r.PathValue("partyId")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid party ID") + return + } + + var input services.UpdatePartyInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body") + return + } + + updated, err := h.svc.Update(r.Context(), tenantID, partyID, input) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + if updated == nil { + writeError(w, http.StatusNotFound, "party not found") + return + } + + writeJSON(w, http.StatusOK, updated) +} + +func (h *PartyHandler) Delete(w http.ResponseWriter, r *http.Request) { + tenantID, ok := auth.TenantFromContext(r.Context()) + if !ok { + writeError(w, http.StatusForbidden, "missing tenant") + return + } + + partyID, err := uuid.Parse(r.PathValue("partyId")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid party ID") + return + } + + if err := h.svc.Delete(r.Context(), tenantID, partyID); err != nil { + writeError(w, http.StatusNotFound, "party not found") + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go index 530bc76..4a5c8d6 100644 --- a/backend/internal/router/router.go +++ b/backend/internal/router/router.go @@ -5,6 +5,8 @@ import ( "net/http" "mgit.msbls.de/m/KanzlAI-mGMT/internal/auth" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/handlers" + "mgit.msbls.de/m/KanzlAI-mGMT/internal/services" "github.com/jmoiron/sqlx" ) @@ -12,12 +14,34 @@ import ( func New(db *sqlx.DB, authMW *auth.Middleware) http.Handler { mux := http.NewServeMux() + // Services + caseSvc := services.NewCaseService(db) + partySvc := services.NewPartyService(db) + + // Handlers + caseH := handlers.NewCaseHandler(caseSvc) + partyH := handlers.NewPartyHandler(partySvc) + // Public routes mux.HandleFunc("GET /health", handleHealth(db)) // Authenticated API routes api := http.NewServeMux() - api.HandleFunc("GET /api/cases", placeholder("cases")) + + // Cases + api.HandleFunc("GET /api/cases", caseH.List) + api.HandleFunc("POST /api/cases", caseH.Create) + api.HandleFunc("GET /api/cases/{id}", caseH.Get) + api.HandleFunc("PUT /api/cases/{id}", caseH.Update) + api.HandleFunc("DELETE /api/cases/{id}", caseH.Delete) + + // Parties (nested under cases for creation/listing, top-level for update/delete) + api.HandleFunc("GET /api/cases/{id}/parties", partyH.List) + api.HandleFunc("POST /api/cases/{id}/parties", partyH.Create) + api.HandleFunc("PUT /api/parties/{partyId}", partyH.Update) + api.HandleFunc("DELETE /api/parties/{partyId}", partyH.Delete) + + // Placeholder routes for future phases api.HandleFunc("GET /api/deadlines", placeholder("deadlines")) api.HandleFunc("GET /api/appointments", placeholder("appointments")) api.HandleFunc("GET /api/documents", placeholder("documents")) diff --git a/backend/internal/services/case_service.go b/backend/internal/services/case_service.go new file mode 100644 index 0000000..dbed424 --- /dev/null +++ b/backend/internal/services/case_service.go @@ -0,0 +1,277 @@ +package services + +import ( + "context" + "database/sql" + "fmt" + "time" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/models" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" +) + +type CaseService struct { + db *sqlx.DB +} + +func NewCaseService(db *sqlx.DB) *CaseService { + return &CaseService{db: db} +} + +type CaseFilter struct { + Status string + Type string + Search string + Limit int + Offset int +} + +type CaseDetail struct { + models.Case + Parties []models.Party `json:"parties"` + RecentEvents []models.CaseEvent `json:"recent_events"` + DeadlinesCount int `json:"deadlines_count"` +} + +type CreateCaseInput struct { + CaseNumber string `json:"case_number"` + Title string `json:"title"` + CaseType *string `json:"case_type,omitempty"` + Court *string `json:"court,omitempty"` + CourtRef *string `json:"court_ref,omitempty"` + Status string `json:"status"` +} + +type UpdateCaseInput struct { + CaseNumber *string `json:"case_number,omitempty"` + Title *string `json:"title,omitempty"` + CaseType *string `json:"case_type,omitempty"` + Court *string `json:"court,omitempty"` + CourtRef *string `json:"court_ref,omitempty"` + Status *string `json:"status,omitempty"` +} + +func (s *CaseService) List(ctx context.Context, tenantID uuid.UUID, filter CaseFilter) ([]models.Case, int, error) { + if filter.Limit <= 0 { + filter.Limit = 20 + } + if filter.Limit > 100 { + filter.Limit = 100 + } + + // Build WHERE clause + where := "WHERE tenant_id = $1" + args := []interface{}{tenantID} + argIdx := 2 + + if filter.Status != "" { + where += fmt.Sprintf(" AND status = $%d", argIdx) + args = append(args, filter.Status) + argIdx++ + } + if filter.Type != "" { + where += fmt.Sprintf(" AND case_type = $%d", argIdx) + args = append(args, filter.Type) + argIdx++ + } + if filter.Search != "" { + where += fmt.Sprintf(" AND (title ILIKE $%d OR case_number ILIKE $%d)", argIdx, argIdx) + args = append(args, "%"+filter.Search+"%") + argIdx++ + } + + // Count total + var total int + countQuery := "SELECT COUNT(*) FROM cases " + where + if err := s.db.GetContext(ctx, &total, countQuery, args...); err != nil { + return nil, 0, fmt.Errorf("counting cases: %w", err) + } + + // Fetch page + query := fmt.Sprintf("SELECT * FROM cases %s ORDER BY updated_at DESC LIMIT $%d OFFSET $%d", + where, argIdx, argIdx+1) + args = append(args, filter.Limit, filter.Offset) + + var cases []models.Case + if err := s.db.SelectContext(ctx, &cases, query, args...); err != nil { + return nil, 0, fmt.Errorf("listing cases: %w", err) + } + + return cases, total, nil +} + +func (s *CaseService) GetByID(ctx context.Context, tenantID, caseID uuid.UUID) (*CaseDetail, error) { + var c models.Case + err := s.db.GetContext(ctx, &c, + "SELECT * FROM cases WHERE id = $1 AND tenant_id = $2", caseID, tenantID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("getting case: %w", err) + } + + detail := &CaseDetail{Case: c} + + // Parties + if err := s.db.SelectContext(ctx, &detail.Parties, + "SELECT * FROM parties WHERE case_id = $1 AND tenant_id = $2 ORDER BY name", + caseID, tenantID); err != nil { + return nil, fmt.Errorf("getting parties: %w", err) + } + + // Recent events (last 20) + if err := s.db.SelectContext(ctx, &detail.RecentEvents, + "SELECT * FROM case_events WHERE case_id = $1 AND tenant_id = $2 ORDER BY created_at DESC LIMIT 20", + caseID, tenantID); err != nil { + return nil, fmt.Errorf("getting events: %w", err) + } + + // Deadlines count + if err := s.db.GetContext(ctx, &detail.DeadlinesCount, + "SELECT COUNT(*) FROM deadlines WHERE case_id = $1 AND tenant_id = $2", + caseID, tenantID); err != nil { + return nil, fmt.Errorf("counting deadlines: %w", err) + } + + return detail, nil +} + +func (s *CaseService) Create(ctx context.Context, tenantID uuid.UUID, userID uuid.UUID, input CreateCaseInput) (*models.Case, error) { + if input.Status == "" { + input.Status = "active" + } + + id := uuid.New() + now := time.Now() + + _, err := s.db.ExecContext(ctx, + `INSERT INTO cases (id, tenant_id, case_number, title, case_type, court, court_ref, status, metadata, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, '{}', $9, $9)`, + id, tenantID, input.CaseNumber, input.Title, input.CaseType, input.Court, input.CourtRef, input.Status, now) + if err != nil { + return nil, fmt.Errorf("creating case: %w", err) + } + + // Create case_created event + createEvent(ctx, s.db, tenantID, id, userID, "case_created", "Case created", nil) + + var c models.Case + if err := s.db.GetContext(ctx, &c, "SELECT * FROM cases WHERE id = $1", id); err != nil { + return nil, fmt.Errorf("fetching created case: %w", err) + } + return &c, nil +} + +func (s *CaseService) Update(ctx context.Context, tenantID, caseID uuid.UUID, userID uuid.UUID, input UpdateCaseInput) (*models.Case, error) { + // Fetch current to detect status change + var current models.Case + err := s.db.GetContext(ctx, ¤t, + "SELECT * FROM cases WHERE id = $1 AND tenant_id = $2", caseID, tenantID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("fetching case for update: %w", err) + } + + // Build SET clause dynamically + sets := []string{} + args := []interface{}{} + argIdx := 1 + + if input.CaseNumber != nil { + sets = append(sets, fmt.Sprintf("case_number = $%d", argIdx)) + args = append(args, *input.CaseNumber) + argIdx++ + } + if input.Title != nil { + sets = append(sets, fmt.Sprintf("title = $%d", argIdx)) + args = append(args, *input.Title) + argIdx++ + } + if input.CaseType != nil { + sets = append(sets, fmt.Sprintf("case_type = $%d", argIdx)) + args = append(args, *input.CaseType) + argIdx++ + } + if input.Court != nil { + sets = append(sets, fmt.Sprintf("court = $%d", argIdx)) + args = append(args, *input.Court) + argIdx++ + } + if input.CourtRef != nil { + sets = append(sets, fmt.Sprintf("court_ref = $%d", argIdx)) + args = append(args, *input.CourtRef) + argIdx++ + } + if input.Status != nil { + sets = append(sets, fmt.Sprintf("status = $%d", argIdx)) + args = append(args, *input.Status) + argIdx++ + } + + if len(sets) == 0 { + return ¤t, nil + } + + sets = append(sets, fmt.Sprintf("updated_at = $%d", argIdx)) + args = append(args, time.Now()) + argIdx++ + + query := fmt.Sprintf("UPDATE cases SET %s WHERE id = $%d AND tenant_id = $%d", + joinStrings(sets, ", "), argIdx, argIdx+1) + args = append(args, caseID, tenantID) + + if _, err := s.db.ExecContext(ctx, query, args...); err != nil { + return nil, fmt.Errorf("updating case: %w", err) + } + + // Log status change event + if input.Status != nil && *input.Status != current.Status { + desc := fmt.Sprintf("Status changed from %s to %s", current.Status, *input.Status) + createEvent(ctx, s.db, tenantID, caseID, userID, "status_changed", desc, nil) + } + + var updated models.Case + if err := s.db.GetContext(ctx, &updated, "SELECT * FROM cases WHERE id = $1", caseID); err != nil { + return nil, fmt.Errorf("fetching updated case: %w", err) + } + return &updated, nil +} + +func (s *CaseService) Delete(ctx context.Context, tenantID, caseID uuid.UUID, userID uuid.UUID) error { + result, err := s.db.ExecContext(ctx, + "UPDATE cases SET status = 'archived', updated_at = $1 WHERE id = $2 AND tenant_id = $3 AND status != 'archived'", + time.Now(), caseID, tenantID) + if err != nil { + return fmt.Errorf("archiving case: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return sql.ErrNoRows + } + createEvent(ctx, s.db, tenantID, caseID, userID, "case_archived", "Case archived", nil) + return nil +} + +func createEvent(ctx context.Context, db *sqlx.DB, tenantID, caseID uuid.UUID, userID uuid.UUID, eventType, title string, description *string) { + now := time.Now() + db.ExecContext(ctx, + `INSERT INTO case_events (id, tenant_id, case_id, event_type, title, description, event_date, created_by, metadata, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, '{}', $7, $7)`, + uuid.New(), tenantID, caseID, eventType, title, description, now, userID) +} + +func joinStrings(strs []string, sep string) string { + result := "" + for i, s := range strs { + if i > 0 { + result += sep + } + result += s + } + return result +} diff --git a/backend/internal/services/party_service.go b/backend/internal/services/party_service.go new file mode 100644 index 0000000..8aeb3a3 --- /dev/null +++ b/backend/internal/services/party_service.go @@ -0,0 +1,152 @@ +package services + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "mgit.msbls.de/m/KanzlAI-mGMT/internal/models" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" +) + +type PartyService struct { + db *sqlx.DB +} + +func NewPartyService(db *sqlx.DB) *PartyService { + return &PartyService{db: db} +} + +type CreatePartyInput struct { + Name string `json:"name"` + Role *string `json:"role,omitempty"` + Representative *string `json:"representative,omitempty"` + ContactInfo json.RawMessage `json:"contact_info,omitempty"` +} + +type UpdatePartyInput struct { + Name *string `json:"name,omitempty"` + Role *string `json:"role,omitempty"` + Representative *string `json:"representative,omitempty"` + ContactInfo json.RawMessage `json:"contact_info,omitempty"` +} + +func (s *PartyService) ListByCase(ctx context.Context, tenantID, caseID uuid.UUID) ([]models.Party, error) { + var parties []models.Party + err := s.db.SelectContext(ctx, &parties, + "SELECT * FROM parties WHERE case_id = $1 AND tenant_id = $2 ORDER BY name", + caseID, tenantID) + if err != nil { + return nil, fmt.Errorf("listing parties: %w", err) + } + return parties, nil +} + +func (s *PartyService) Create(ctx context.Context, tenantID, caseID uuid.UUID, userID uuid.UUID, input CreatePartyInput) (*models.Party, error) { + // Verify case exists and belongs to tenant + var exists bool + err := s.db.GetContext(ctx, &exists, + "SELECT EXISTS(SELECT 1 FROM cases WHERE id = $1 AND tenant_id = $2)", caseID, tenantID) + if err != nil { + return nil, fmt.Errorf("checking case: %w", err) + } + if !exists { + return nil, sql.ErrNoRows + } + + id := uuid.New() + contactInfo := input.ContactInfo + if contactInfo == nil { + contactInfo = json.RawMessage("{}") + } + + _, err = s.db.ExecContext(ctx, + `INSERT INTO parties (id, tenant_id, case_id, name, role, representative, contact_info) + VALUES ($1, $2, $3, $4, $5, $6, $7)`, + id, tenantID, caseID, input.Name, input.Role, input.Representative, contactInfo) + if err != nil { + return nil, fmt.Errorf("creating party: %w", err) + } + + // Log event + desc := fmt.Sprintf("Party added: %s", input.Name) + createEvent(ctx, s.db, tenantID, caseID, userID, "party_added", desc, nil) + + var party models.Party + if err := s.db.GetContext(ctx, &party, "SELECT * FROM parties WHERE id = $1", id); err != nil { + return nil, fmt.Errorf("fetching created party: %w", err) + } + return &party, nil +} + +func (s *PartyService) Update(ctx context.Context, tenantID, partyID uuid.UUID, input UpdatePartyInput) (*models.Party, error) { + // Verify party exists and belongs to tenant + var current models.Party + err := s.db.GetContext(ctx, ¤t, + "SELECT * FROM parties WHERE id = $1 AND tenant_id = $2", partyID, tenantID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("fetching party: %w", err) + } + + sets := []string{} + args := []interface{}{} + argIdx := 1 + + if input.Name != nil { + sets = append(sets, fmt.Sprintf("name = $%d", argIdx)) + args = append(args, *input.Name) + argIdx++ + } + if input.Role != nil { + sets = append(sets, fmt.Sprintf("role = $%d", argIdx)) + args = append(args, *input.Role) + argIdx++ + } + if input.Representative != nil { + sets = append(sets, fmt.Sprintf("representative = $%d", argIdx)) + args = append(args, *input.Representative) + argIdx++ + } + if input.ContactInfo != nil { + sets = append(sets, fmt.Sprintf("contact_info = $%d", argIdx)) + args = append(args, input.ContactInfo) + argIdx++ + } + + if len(sets) == 0 { + return ¤t, nil + } + + query := fmt.Sprintf("UPDATE parties SET %s WHERE id = $%d AND tenant_id = $%d", + joinStrings(sets, ", "), argIdx, argIdx+1) + args = append(args, partyID, tenantID) + + if _, err := s.db.ExecContext(ctx, query, args...); err != nil { + return nil, fmt.Errorf("updating party: %w", err) + } + + var updated models.Party + if err := s.db.GetContext(ctx, &updated, "SELECT * FROM parties WHERE id = $1", partyID); err != nil { + return nil, fmt.Errorf("fetching updated party: %w", err) + } + return &updated, nil +} + +func (s *PartyService) Delete(ctx context.Context, tenantID, partyID uuid.UUID) error { + result, err := s.db.ExecContext(ctx, + "DELETE FROM parties WHERE id = $1 AND tenant_id = $2", partyID, tenantID) + if err != nil { + return fmt.Errorf("deleting party: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return sql.ErrNoRows + } + return nil +}