Compare commits
1 Commits
mai/linus/
...
mai/pike/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5758e2c37f |
@@ -7,7 +7,6 @@ PORT=8080
|
||||
# Supabase (required for database access)
|
||||
SUPABASE_URL=
|
||||
SUPABASE_ANON_KEY=
|
||||
SUPABASE_SERVICE_KEY=
|
||||
|
||||
# Claude API (required for AI features)
|
||||
ANTHROPIC_API_KEY=
|
||||
|
||||
@@ -23,7 +23,7 @@ func main() {
|
||||
defer database.Close()
|
||||
|
||||
authMW := auth.NewMiddleware(cfg.SupabaseJWTSecret, database)
|
||||
handler := router.New(database, authMW, cfg)
|
||||
handler := router.New(database, authMW, cfg.AnthropicAPIKey)
|
||||
|
||||
log.Printf("Starting KanzlAI API server on :%s", cfg.Port)
|
||||
if err := http.ListenAndServe(":"+cfg.Port, handler); err != nil {
|
||||
|
||||
@@ -3,8 +3,14 @@ module mgit.msbls.de/m/KanzlAI-mGMT
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/anthropics/anthropic-sdk-go v1.27.1 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/jmoiron/sqlx v1.4.0 // indirect
|
||||
github.com/lib/pq v1.12.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/anthropics/anthropic-sdk-go v1.27.1 h1:7DgMZ2Ng3C2mPzJGHA30NXQTZolcF07mHd0tGaLwfzk=
|
||||
github.com/anthropics/anthropic-sdk-go v1.27.1/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
@@ -10,3 +12,15 @@ github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.12.0 h1:mC1zeiNamwKBecjHarAr26c/+d8V5w/u4J0I/yASbJo=
|
||||
github.com/lib/pq v1.12.0/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
|
||||
@@ -6,13 +6,12 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
DatabaseURL string
|
||||
SupabaseURL string
|
||||
SupabaseAnonKey string
|
||||
SupabaseServiceKey string
|
||||
Port string
|
||||
DatabaseURL string
|
||||
SupabaseURL string
|
||||
SupabaseAnonKey string
|
||||
SupabaseJWTSecret string
|
||||
AnthropicAPIKey string
|
||||
AnthropicAPIKey string
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -20,8 +19,7 @@ func Load() (*Config, error) {
|
||||
Port: getEnv("PORT", "8080"),
|
||||
DatabaseURL: os.Getenv("DATABASE_URL"),
|
||||
SupabaseURL: os.Getenv("SUPABASE_URL"),
|
||||
SupabaseAnonKey: os.Getenv("SUPABASE_ANON_KEY"),
|
||||
SupabaseServiceKey: os.Getenv("SUPABASE_SERVICE_KEY"),
|
||||
SupabaseAnonKey: os.Getenv("SUPABASE_ANON_KEY"),
|
||||
SupabaseJWTSecret: os.Getenv("SUPABASE_JWT_SECRET"),
|
||||
AnthropicAPIKey: os.Getenv("ANTHROPIC_API_KEY"),
|
||||
}
|
||||
|
||||
115
backend/internal/handlers/ai.go
Normal file
115
backend/internal/handlers/ai.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
||||
)
|
||||
|
||||
type AIHandler struct {
|
||||
ai *services.AIService
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewAIHandler(ai *services.AIService, db *sqlx.DB) *AIHandler {
|
||||
return &AIHandler{ai: ai, db: db}
|
||||
}
|
||||
|
||||
// ExtractDeadlines handles POST /api/ai/extract-deadlines
|
||||
// Accepts either multipart/form-data with a "file" PDF field, or JSON {"text": "..."}.
|
||||
func (h *AIHandler) ExtractDeadlines(w http.ResponseWriter, r *http.Request) {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
var pdfData []byte
|
||||
var text string
|
||||
|
||||
// Check if multipart (PDF upload)
|
||||
if len(contentType) >= 9 && contentType[:9] == "multipart" {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max
|
||||
writeError(w, http.StatusBadRequest, "failed to parse multipart form")
|
||||
return
|
||||
}
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "missing 'file' field in multipart form")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
pdfData, err = io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Assume JSON body
|
||||
var body struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
text = body.Text
|
||||
}
|
||||
|
||||
if len(pdfData) == 0 && text == "" {
|
||||
writeError(w, http.StatusBadRequest, "provide either a PDF file or text")
|
||||
return
|
||||
}
|
||||
|
||||
deadlines, err := h.ai.ExtractDeadlines(r.Context(), pdfData, text)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "AI extraction failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"deadlines": deadlines,
|
||||
"count": len(deadlines),
|
||||
})
|
||||
}
|
||||
|
||||
// SummarizeCase handles POST /api/ai/summarize-case
|
||||
// Accepts JSON {"case_id": "uuid"}.
|
||||
func (h *AIHandler) SummarizeCase(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, err := resolveTenant(r, h.db)
|
||||
if err != nil {
|
||||
handleTenantError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
CaseID string `json:"case_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if body.CaseID == "" {
|
||||
writeError(w, http.StatusBadRequest, "case_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
caseID, err := parseUUID(body.CaseID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid case_id")
|
||||
return
|
||||
}
|
||||
|
||||
summary, err := h.ai.SummarizeCase(r.Context(), tenantID, caseID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "AI summarization failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{
|
||||
"case_id": caseID.String(),
|
||||
"summary": summary,
|
||||
})
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const maxUploadSize = 50 << 20 // 50 MB
|
||||
|
||||
type DocumentHandler struct {
|
||||
svc *services.DocumentService
|
||||
}
|
||||
|
||||
func NewDocumentHandler(svc *services.DocumentService) *DocumentHandler {
|
||||
return &DocumentHandler{svc: svc}
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) ListByCase(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusForbidden, "missing tenant")
|
||||
return
|
||||
}
|
||||
|
||||
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||
return
|
||||
}
|
||||
|
||||
docs, err := h.svc.ListByCase(r.Context(), tenantID, caseID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"documents": docs,
|
||||
"total": len(docs),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusForbidden, "missing tenant")
|
||||
return
|
||||
}
|
||||
userID, _ := auth.UserFromContext(r.Context())
|
||||
|
||||
caseID, err := uuid.Parse(r.PathValue("id"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid case ID")
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize)
|
||||
if err := r.ParseMultipartForm(maxUploadSize); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "file too large or invalid multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "missing file field")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
title := r.FormValue("title")
|
||||
if title == "" {
|
||||
title = header.Filename
|
||||
}
|
||||
|
||||
contentType := header.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
input := services.CreateDocumentInput{
|
||||
Title: title,
|
||||
DocType: r.FormValue("doc_type"),
|
||||
Filename: header.Filename,
|
||||
ContentType: contentType,
|
||||
Size: int(header.Size),
|
||||
Data: file,
|
||||
}
|
||||
|
||||
doc, err := h.svc.Create(r.Context(), tenantID, caseID, userID, input)
|
||||
if err != nil {
|
||||
if err.Error() == "case not found" {
|
||||
writeError(w, http.StatusNotFound, "case not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, doc)
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusForbidden, "missing tenant")
|
||||
return
|
||||
}
|
||||
|
||||
docID, err := uuid.Parse(r.PathValue("docId"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid document ID")
|
||||
return
|
||||
}
|
||||
|
||||
body, contentType, title, err := h.svc.Download(r.Context(), tenantID, docID)
|
||||
if err != nil {
|
||||
if err.Error() == "document not found" || err.Error() == "document has no file" {
|
||||
writeError(w, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, title))
|
||||
io.Copy(w, body)
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) GetMeta(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusForbidden, "missing tenant")
|
||||
return
|
||||
}
|
||||
|
||||
docID, err := uuid.Parse(r.PathValue("docId"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid document ID")
|
||||
return
|
||||
}
|
||||
|
||||
doc, err := h.svc.GetByID(r.Context(), tenantID, docID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if doc == nil {
|
||||
writeError(w, http.StatusNotFound, "document not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, doc)
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
tenantID, ok := auth.TenantFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusForbidden, "missing tenant")
|
||||
return
|
||||
}
|
||||
userID, _ := auth.UserFromContext(r.Context())
|
||||
|
||||
docID, err := uuid.Parse(r.PathValue("docId"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid document ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Delete(r.Context(), tenantID, docID, userID); err != nil {
|
||||
writeError(w, http.StatusNotFound, "document not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
|
||||
}
|
||||
@@ -83,3 +83,8 @@ func handleTenantError(w http.ResponseWriter, err error) {
|
||||
func parsePathUUID(r *http.Request, key string) (uuid.UUID, error) {
|
||||
return uuid.Parse(r.PathValue(key))
|
||||
}
|
||||
|
||||
// parseUUID parses a UUID string
|
||||
func parseUUID(s string) (uuid.UUID, error) {
|
||||
return uuid.Parse(s)
|
||||
}
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/auth"
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/config"
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/handlers"
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/services"
|
||||
)
|
||||
|
||||
func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config) http.Handler {
|
||||
func New(db *sqlx.DB, authMW *auth.Middleware, anthropicAPIKey string) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Services
|
||||
@@ -24,8 +23,13 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config) http.Handler
|
||||
deadlineSvc := services.NewDeadlineService(db)
|
||||
deadlineRuleSvc := services.NewDeadlineRuleService(db)
|
||||
calculator := services.NewDeadlineCalculator(holidaySvc)
|
||||
storageCli := services.NewStorageClient(cfg.SupabaseURL, cfg.SupabaseServiceKey)
|
||||
documentSvc := services.NewDocumentService(db, storageCli)
|
||||
|
||||
// AI service (optional — only if API key is configured)
|
||||
var aiH *handlers.AIHandler
|
||||
if anthropicAPIKey != "" {
|
||||
aiSvc := services.NewAIService(anthropicAPIKey, db)
|
||||
aiH = handlers.NewAIHandler(aiSvc, db)
|
||||
}
|
||||
|
||||
// Middleware
|
||||
tenantResolver := auth.NewTenantResolver(tenantSvc)
|
||||
@@ -38,7 +42,6 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config) http.Handler
|
||||
deadlineH := handlers.NewDeadlineHandlers(deadlineSvc, db)
|
||||
ruleH := handlers.NewDeadlineRuleHandlers(deadlineRuleSvc)
|
||||
calcH := handlers.NewCalculateHandlers(calculator, deadlineRuleSvc)
|
||||
docH := handlers.NewDocumentHandler(documentSvc)
|
||||
|
||||
// Public routes
|
||||
mux.HandleFunc("GET /health", handleHealth(db))
|
||||
@@ -90,12 +93,14 @@ func New(db *sqlx.DB, authMW *auth.Middleware, cfg *config.Config) http.Handler
|
||||
scoped.HandleFunc("PUT /api/appointments/{id}", apptH.Update)
|
||||
scoped.HandleFunc("DELETE /api/appointments/{id}", apptH.Delete)
|
||||
|
||||
// Documents
|
||||
scoped.HandleFunc("GET /api/cases/{id}/documents", docH.ListByCase)
|
||||
scoped.HandleFunc("POST /api/cases/{id}/documents", docH.Upload)
|
||||
scoped.HandleFunc("GET /api/documents/{docId}", docH.Download)
|
||||
scoped.HandleFunc("GET /api/documents/{docId}/meta", docH.GetMeta)
|
||||
scoped.HandleFunc("DELETE /api/documents/{docId}", docH.Delete)
|
||||
// AI endpoints
|
||||
if aiH != nil {
|
||||
scoped.HandleFunc("POST /api/ai/extract-deadlines", aiH.ExtractDeadlines)
|
||||
scoped.HandleFunc("POST /api/ai/summarize-case", aiH.SummarizeCase)
|
||||
}
|
||||
|
||||
// Placeholder routes for future phases
|
||||
scoped.HandleFunc("GET /api/documents", placeholder("documents"))
|
||||
|
||||
// Wire: auth -> tenant routes go directly, scoped routes get tenant resolver
|
||||
api.Handle("/api/", tenantResolver.Resolve(scoped))
|
||||
@@ -117,3 +122,12 @@ func handleHealth(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func placeholder(resource string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "not_implemented",
|
||||
"resource": resource,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
283
backend/internal/services/ai_service.go
Normal file
283
backend/internal/services/ai_service.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/models"
|
||||
)
|
||||
|
||||
type AIService struct {
|
||||
client anthropic.Client
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewAIService(apiKey string, db *sqlx.DB) *AIService {
|
||||
client := anthropic.NewClient(option.WithAPIKey(apiKey))
|
||||
return &AIService{client: client, db: db}
|
||||
}
|
||||
|
||||
// ExtractedDeadline represents a deadline extracted by AI from a document.
|
||||
type ExtractedDeadline struct {
|
||||
Title string `json:"title"`
|
||||
DueDate *string `json:"due_date"`
|
||||
DurationValue int `json:"duration_value"`
|
||||
DurationUnit string `json:"duration_unit"`
|
||||
Timing string `json:"timing"`
|
||||
TriggerEvent string `json:"trigger_event"`
|
||||
RuleReference string `json:"rule_reference"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
SourceQuote string `json:"source_quote"`
|
||||
}
|
||||
|
||||
type extractDeadlinesToolInput struct {
|
||||
Deadlines []ExtractedDeadline `json:"deadlines"`
|
||||
}
|
||||
|
||||
var deadlineExtractionTool = anthropic.ToolParam{
|
||||
Name: "extract_deadlines",
|
||||
Description: anthropic.String("Extract all legal deadlines found in the document. Return each deadline with its details."),
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: map[string]any{
|
||||
"deadlines": map[string]any{
|
||||
"type": "array",
|
||||
"description": "List of extracted deadlines",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"title": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Short title describing the deadline (e.g. 'Statement of Defence', 'Reply to Counterclaim')",
|
||||
},
|
||||
"due_date": map[string]any{
|
||||
"type": []string{"string", "null"},
|
||||
"description": "Absolute due date in YYYY-MM-DD format if determinable, null otherwise",
|
||||
},
|
||||
"duration_value": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Numeric duration value (e.g. 3 for '3 months')",
|
||||
},
|
||||
"duration_unit": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"days", "weeks", "months"},
|
||||
"description": "Unit of the duration period",
|
||||
},
|
||||
"timing": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"after", "before"},
|
||||
"description": "Whether the deadline is before or after the trigger event",
|
||||
},
|
||||
"trigger_event": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The event that triggers this deadline (e.g. 'service of the Statement of Claim')",
|
||||
},
|
||||
"rule_reference": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Legal rule reference (e.g. 'Rule 23 RoP', 'Rule 222 RoP', '§ 276 ZPO')",
|
||||
},
|
||||
"confidence": map[string]any{
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"description": "Confidence score from 0.0 to 1.0",
|
||||
},
|
||||
"source_quote": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The exact quote from the document where this deadline was found",
|
||||
},
|
||||
},
|
||||
"required": []string{"title", "duration_value", "duration_unit", "timing", "trigger_event", "rule_reference", "confidence", "source_quote"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"deadlines"},
|
||||
},
|
||||
}
|
||||
|
||||
const extractionSystemPrompt = `You are a legal deadline extraction assistant for German and UPC (Unified Patent Court) patent litigation.
|
||||
|
||||
Your task is to extract all legal deadlines, time limits, and procedural time periods from the provided document.
|
||||
|
||||
For each deadline found, extract:
|
||||
- A clear title describing the deadline
|
||||
- The absolute due date if it can be determined from the document
|
||||
- The duration (value + unit: days/weeks/months)
|
||||
- Whether it runs before or after a trigger event
|
||||
- The trigger event that starts the deadline
|
||||
- The legal rule reference (e.g. Rule 23 RoP, § 276 ZPO)
|
||||
- Your confidence level (0.0-1.0) in the extraction
|
||||
- The exact source quote from the document
|
||||
|
||||
Be thorough: extract every deadline mentioned, including conditional ones. If a deadline references another deadline (e.g. "within 2 months of the defence"), capture that relationship in the trigger_event field.
|
||||
|
||||
If the document contains no deadlines, return an empty list.`
|
||||
|
||||
// ExtractDeadlines sends a document (PDF or text) to Claude for deadline extraction.
|
||||
func (s *AIService) ExtractDeadlines(ctx context.Context, pdfData []byte, text string) ([]ExtractedDeadline, error) {
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
|
||||
if len(pdfData) > 0 {
|
||||
encoded := base64.StdEncoding.EncodeToString(pdfData)
|
||||
contentBlocks = append(contentBlocks, anthropic.ContentBlockParamUnion{
|
||||
OfDocument: &anthropic.DocumentBlockParam{
|
||||
Source: anthropic.DocumentBlockParamSourceUnion{
|
||||
OfBase64: &anthropic.Base64PDFSourceParam{
|
||||
Data: encoded,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
contentBlocks = append(contentBlocks, anthropic.NewTextBlock("Extract all legal deadlines from this document."))
|
||||
} else if text != "" {
|
||||
contentBlocks = append(contentBlocks, anthropic.NewTextBlock("Extract all legal deadlines from the following text:\n\n"+text))
|
||||
} else {
|
||||
return nil, fmt.Errorf("either pdf_data or text must be provided")
|
||||
}
|
||||
|
||||
msg, err := s.client.Messages.New(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.ModelClaudeSonnet4_5,
|
||||
MaxTokens: 4096,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{Text: extractionSystemPrompt},
|
||||
},
|
||||
Messages: []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(contentBlocks...),
|
||||
},
|
||||
Tools: []anthropic.ToolUnionParam{
|
||||
{OfTool: &deadlineExtractionTool},
|
||||
},
|
||||
ToolChoice: anthropic.ToolChoiceParamOfTool("extract_deadlines"),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
// Find the tool_use block in the response
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "tool_use" && block.Name == "extract_deadlines" {
|
||||
var input extractDeadlinesToolInput
|
||||
if err := json.Unmarshal(block.Input, &input); err != nil {
|
||||
return nil, fmt.Errorf("parsing tool output: %w", err)
|
||||
}
|
||||
return input.Deadlines, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no tool_use block in response")
|
||||
}
|
||||
|
||||
const summarizeSystemPrompt = `You are a legal case summary assistant for German and UPC patent litigation case management.
|
||||
|
||||
Given a case's details, recent events, and deadlines, produce a concise 2-3 sentence summary of what matters right now. Focus on:
|
||||
- The most urgent upcoming deadline
|
||||
- Recent significant events
|
||||
- The current procedural stage
|
||||
|
||||
Write in clear, professional language suitable for a lawyer reviewing their case list. Be specific about dates and deadlines.`
|
||||
|
||||
// SummarizeCase generates an AI summary for a case and caches it in the database.
|
||||
func (s *AIService) SummarizeCase(ctx context.Context, tenantID, caseID uuid.UUID) (string, error) {
|
||||
// Load case
|
||||
var c models.Case
|
||||
err := s.db.GetContext(ctx, &c,
|
||||
"SELECT * FROM cases WHERE id = $1 AND tenant_id = $2", caseID, tenantID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("loading case: %w", err)
|
||||
}
|
||||
|
||||
// Load recent events
|
||||
var events []models.CaseEvent
|
||||
if err := s.db.SelectContext(ctx, &events,
|
||||
"SELECT * FROM case_events WHERE case_id = $1 AND tenant_id = $2 ORDER BY created_at DESC LIMIT 10",
|
||||
caseID, tenantID); err != nil {
|
||||
return "", fmt.Errorf("loading events: %w", err)
|
||||
}
|
||||
|
||||
// Load active deadlines
|
||||
var deadlines []models.Deadline
|
||||
if err := s.db.SelectContext(ctx, &deadlines,
|
||||
"SELECT * FROM deadlines WHERE case_id = $1 AND tenant_id = $2 AND status = 'active' ORDER BY due_date ASC LIMIT 10",
|
||||
caseID, tenantID); err != nil {
|
||||
return "", fmt.Errorf("loading deadlines: %w", err)
|
||||
}
|
||||
|
||||
// Build context text
|
||||
caseInfo := fmt.Sprintf("Case: %s — %s\nStatus: %s", c.CaseNumber, c.Title, c.Status)
|
||||
if c.Court != nil {
|
||||
caseInfo += fmt.Sprintf("\nCourt: %s", *c.Court)
|
||||
}
|
||||
if c.CourtRef != nil {
|
||||
caseInfo += fmt.Sprintf("\nCourt Reference: %s", *c.CourtRef)
|
||||
}
|
||||
if c.CaseType != nil {
|
||||
caseInfo += fmt.Sprintf("\nType: %s", *c.CaseType)
|
||||
}
|
||||
|
||||
eventText := "\n\nRecent Events:"
|
||||
if len(events) == 0 {
|
||||
eventText += "\nNo events recorded."
|
||||
}
|
||||
for _, e := range events {
|
||||
eventText += fmt.Sprintf("\n- [%s] %s", e.CreatedAt.Format("2006-01-02"), e.Title)
|
||||
if e.Description != nil {
|
||||
eventText += fmt.Sprintf(": %s", *e.Description)
|
||||
}
|
||||
}
|
||||
|
||||
deadlineText := "\n\nUpcoming Deadlines:"
|
||||
if len(deadlines) == 0 {
|
||||
deadlineText += "\nNo active deadlines."
|
||||
}
|
||||
for _, d := range deadlines {
|
||||
deadlineText += fmt.Sprintf("\n- %s: due %s (status: %s)", d.Title, d.DueDate, d.Status)
|
||||
if d.Description != nil {
|
||||
deadlineText += fmt.Sprintf(" — %s", *d.Description)
|
||||
}
|
||||
}
|
||||
|
||||
prompt := caseInfo + eventText + deadlineText
|
||||
|
||||
msg, err := s.client.Messages.New(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.ModelClaudeSonnet4_5,
|
||||
MaxTokens: 512,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{Text: summarizeSystemPrompt},
|
||||
},
|
||||
Messages: []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("Summarize the current state of this case:\n\n" + prompt)),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
// Extract text from response
|
||||
var summary string
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "text" {
|
||||
summary += block.Text
|
||||
}
|
||||
}
|
||||
|
||||
if summary == "" {
|
||||
return "", fmt.Errorf("empty response from Claude")
|
||||
}
|
||||
|
||||
// Cache summary in database
|
||||
_, err = s.db.ExecContext(ctx,
|
||||
"UPDATE cases SET ai_summary = $1, updated_at = $2 WHERE id = $3 AND tenant_id = $4",
|
||||
summary, time.Now(), caseID, tenantID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("caching summary: %w", err)
|
||||
}
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
109
backend/internal/services/ai_service_test.go
Normal file
109
backend/internal/services/ai_service_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDeadlineExtractionToolSchema(t *testing.T) {
|
||||
// Verify the tool schema serializes correctly
|
||||
data, err := json.Marshal(deadlineExtractionTool)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal tool: %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("failed to unmarshal tool JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsed["name"] != "extract_deadlines" {
|
||||
t.Errorf("expected name 'extract_deadlines', got %v", parsed["name"])
|
||||
}
|
||||
|
||||
schema, ok := parsed["input_schema"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("input_schema is not a map")
|
||||
}
|
||||
|
||||
if schema["type"] != "object" {
|
||||
t.Errorf("expected schema type 'object', got %v", schema["type"])
|
||||
}
|
||||
|
||||
props, ok := schema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("properties is not a map")
|
||||
}
|
||||
|
||||
deadlines, ok := props["deadlines"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("deadlines property is not a map")
|
||||
}
|
||||
|
||||
if deadlines["type"] != "array" {
|
||||
t.Errorf("expected deadlines type 'array', got %v", deadlines["type"])
|
||||
}
|
||||
|
||||
items, ok := deadlines["items"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("items is not a map")
|
||||
}
|
||||
|
||||
itemProps, ok := items["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("item properties is not a map")
|
||||
}
|
||||
|
||||
expectedFields := []string{"title", "due_date", "duration_value", "duration_unit", "timing", "trigger_event", "rule_reference", "confidence", "source_quote"}
|
||||
for _, field := range expectedFields {
|
||||
if _, ok := itemProps[field]; !ok {
|
||||
t.Errorf("missing expected field %q in item properties", field)
|
||||
}
|
||||
}
|
||||
|
||||
required, ok := items["required"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("required is not a list")
|
||||
}
|
||||
if len(required) != 8 {
|
||||
t.Errorf("expected 8 required fields, got %d", len(required))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractedDeadlineJSON(t *testing.T) {
|
||||
dueDate := "2026-04-15"
|
||||
d := ExtractedDeadline{
|
||||
Title: "Statement of Defence",
|
||||
DueDate: &dueDate,
|
||||
DurationValue: 3,
|
||||
DurationUnit: "months",
|
||||
Timing: "after",
|
||||
TriggerEvent: "service of the Statement of Claim",
|
||||
RuleReference: "Rule 23 RoP",
|
||||
Confidence: 0.95,
|
||||
SourceQuote: "The defendant shall file a defence within 3 months",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var parsed ExtractedDeadline
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if parsed.Title != d.Title {
|
||||
t.Errorf("title mismatch: %q != %q", parsed.Title, d.Title)
|
||||
}
|
||||
if *parsed.DueDate != *d.DueDate {
|
||||
t.Errorf("due_date mismatch: %q != %q", *parsed.DueDate, *d.DueDate)
|
||||
}
|
||||
if parsed.DurationValue != d.DurationValue {
|
||||
t.Errorf("duration_value mismatch: %d != %d", parsed.DurationValue, d.DurationValue)
|
||||
}
|
||||
if parsed.Confidence != d.Confidence {
|
||||
t.Errorf("confidence mismatch: %f != %f", parsed.Confidence, d.Confidence)
|
||||
}
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/KanzlAI-mGMT/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const documentBucket = "kanzlai-documents"
|
||||
|
||||
type DocumentService struct {
|
||||
db *sqlx.DB
|
||||
storage *StorageClient
|
||||
}
|
||||
|
||||
func NewDocumentService(db *sqlx.DB, storage *StorageClient) *DocumentService {
|
||||
return &DocumentService{db: db, storage: storage}
|
||||
}
|
||||
|
||||
type CreateDocumentInput struct {
|
||||
Title string `json:"title"`
|
||||
DocType string `json:"doc_type"`
|
||||
Filename string
|
||||
ContentType string
|
||||
Size int
|
||||
Data io.Reader
|
||||
}
|
||||
|
||||
func (s *DocumentService) ListByCase(ctx context.Context, tenantID, caseID uuid.UUID) ([]models.Document, error) {
|
||||
var docs []models.Document
|
||||
err := s.db.SelectContext(ctx, &docs,
|
||||
"SELECT * FROM documents WHERE tenant_id = $1 AND case_id = $2 ORDER BY created_at DESC",
|
||||
tenantID, caseID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing documents: %w", err)
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) GetByID(ctx context.Context, tenantID, docID uuid.UUID) (*models.Document, error) {
|
||||
var doc models.Document
|
||||
err := s.db.GetContext(ctx, &doc,
|
||||
"SELECT * FROM documents WHERE id = $1 AND tenant_id = $2", docID, tenantID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("getting document: %w", err)
|
||||
}
|
||||
return &doc, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) Create(ctx context.Context, tenantID, caseID, userID uuid.UUID, input CreateDocumentInput) (*models.Document, error) {
|
||||
// Verify case belongs to tenant
|
||||
var caseExists int
|
||||
if err := s.db.GetContext(ctx, &caseExists,
|
||||
"SELECT COUNT(*) FROM cases WHERE id = $1 AND tenant_id = $2",
|
||||
caseID, tenantID); err != nil {
|
||||
return nil, fmt.Errorf("verifying case: %w", err)
|
||||
}
|
||||
if caseExists == 0 {
|
||||
return nil, fmt.Errorf("case not found")
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
storagePath := fmt.Sprintf("%s/%s/%s_%s", tenantID, caseID, id, input.Filename)
|
||||
|
||||
// Upload to Supabase Storage
|
||||
if err := s.storage.Upload(ctx, documentBucket, storagePath, input.ContentType, input.Data); err != nil {
|
||||
return nil, fmt.Errorf("uploading file: %w", err)
|
||||
}
|
||||
|
||||
// Insert metadata record
|
||||
now := time.Now()
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO documents (id, tenant_id, case_id, title, doc_type, file_path, file_size, mime_type, uploaded_by, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $10)`,
|
||||
id, tenantID, caseID, input.Title, nilIfEmpty(input.DocType), storagePath, input.Size, input.ContentType, userID, now)
|
||||
if err != nil {
|
||||
// Best effort: clean up uploaded file
|
||||
_ = s.storage.Delete(ctx, documentBucket, []string{storagePath})
|
||||
return nil, fmt.Errorf("inserting document record: %w", err)
|
||||
}
|
||||
|
||||
// Log case event
|
||||
createEvent(ctx, s.db, tenantID, caseID, userID, "document_uploaded",
|
||||
fmt.Sprintf("Document uploaded: %s", input.Title), nil)
|
||||
|
||||
var doc models.Document
|
||||
if err := s.db.GetContext(ctx, &doc, "SELECT * FROM documents WHERE id = $1", id); err != nil {
|
||||
return nil, fmt.Errorf("fetching created document: %w", err)
|
||||
}
|
||||
return &doc, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) Download(ctx context.Context, tenantID, docID uuid.UUID) (io.ReadCloser, string, string, error) {
|
||||
doc, err := s.GetByID(ctx, tenantID, docID)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
if doc == nil {
|
||||
return nil, "", "", fmt.Errorf("document not found")
|
||||
}
|
||||
if doc.FilePath == nil {
|
||||
return nil, "", "", fmt.Errorf("document has no file")
|
||||
}
|
||||
|
||||
body, contentType, err := s.storage.Download(ctx, documentBucket, *doc.FilePath)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("downloading file: %w", err)
|
||||
}
|
||||
|
||||
// Use stored mime_type if available, fall back to storage response
|
||||
if doc.MimeType != nil && *doc.MimeType != "" {
|
||||
contentType = *doc.MimeType
|
||||
}
|
||||
|
||||
return body, contentType, doc.Title, nil
|
||||
}
|
||||
|
||||
func (s *DocumentService) Delete(ctx context.Context, tenantID, docID, userID uuid.UUID) error {
|
||||
doc, err := s.GetByID(ctx, tenantID, docID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if doc == nil {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
// Delete from storage
|
||||
if doc.FilePath != nil {
|
||||
if err := s.storage.Delete(ctx, documentBucket, []string{*doc.FilePath}); err != nil {
|
||||
return fmt.Errorf("deleting file from storage: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete database record
|
||||
_, err = s.db.ExecContext(ctx,
|
||||
"DELETE FROM documents WHERE id = $1 AND tenant_id = $2", docID, tenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting document record: %w", err)
|
||||
}
|
||||
|
||||
// Log case event
|
||||
createEvent(ctx, s.db, tenantID, doc.CaseID, userID, "document_deleted",
|
||||
fmt.Sprintf("Document deleted: %s", doc.Title), nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func nilIfEmpty(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// StorageClient interacts with Supabase Storage via REST API.
|
||||
type StorageClient struct {
|
||||
baseURL string
|
||||
serviceKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewStorageClient(supabaseURL, serviceKey string) *StorageClient {
|
||||
return &StorageClient{
|
||||
baseURL: supabaseURL,
|
||||
serviceKey: serviceKey,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Upload stores a file in the given bucket at the specified path.
|
||||
func (s *StorageClient) Upload(ctx context.Context, bucket, path, contentType string, data io.Reader) error {
|
||||
url := fmt.Sprintf("%s/storage/v1/object/%s/%s", s.baseURL, bucket, path)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating upload request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+s.serviceKey)
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
req.Header.Set("x-upsert", "true")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("uploading to storage: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("storage upload failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Download retrieves a file from storage. Caller must close the returned ReadCloser.
|
||||
func (s *StorageClient) Download(ctx context.Context, bucket, path string) (io.ReadCloser, string, error) {
|
||||
url := fmt.Sprintf("%s/storage/v1/object/%s/%s", s.baseURL, bucket, path)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating download request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+s.serviceKey)
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("downloading from storage: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, "", fmt.Errorf("file not found in storage")
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, "", fmt.Errorf("storage download failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
return resp.Body, ct, nil
|
||||
}
|
||||
|
||||
// Delete removes files from storage by their paths.
|
||||
func (s *StorageClient) Delete(ctx context.Context, bucket string, paths []string) error {
|
||||
url := fmt.Sprintf("%s/storage/v1/object/%s", s.baseURL, bucket)
|
||||
|
||||
body, err := json.Marshal(map[string][]string{"prefixes": paths})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling delete request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "DELETE", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+s.serviceKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting from storage: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("storage delete failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user