Compare commits
15 Commits
mai/hermes
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| c2b6f8bf97 | |||
| f8dd5e0736 | |||
| 7caf975335 | |||
| 8435817ce1 | |||
| 623dd290c5 | |||
| 64120c27d7 | |||
| dbe1704f42 | |||
| 2758c5a500 | |||
| cb6656c436 | |||
| e22f286024 | |||
| 2d5896e27d | |||
| b282325663 | |||
| a1d0165445 | |||
| 2a8bd4313b | |||
| 4183d4c55a |
@@ -10,13 +10,17 @@ and lifecycle of its own block in `~/.config/imagen.yaml`.
|
||||
## Architecture
|
||||
|
||||
```
|
||||
cmd/imagen/ CLI shell — generate, backends, config, serve
|
||||
cmd/imagen/ CLI shell — generate, worker, backends, config, serve
|
||||
internal/backend/ Backend interface + Registry + Mock reference impl
|
||||
internal/prompt/ Style preset registry (embedded styles.yaml)
|
||||
internal/output/ Filename templating, image writer, JSON sidecar
|
||||
internal/config/ YAML loader, validation, sample generator
|
||||
internal/cloud/ Supabase Storage + imagen.images writer
|
||||
internal/usage/ mai.imagen_usage cost-tracking sink
|
||||
internal/worker/ imagen.jobs queue consumer (DB-agnostic via Queue interface)
|
||||
internal/server/ HTTP stub (not implemented yet — follow-up issue)
|
||||
docs/ architecture.md, usage.md
|
||||
scripts/ imagen-worker.service + env template, ComfyUI scripts
|
||||
docs/ architecture.md, usage.md, setup-worker-mriver.md
|
||||
```
|
||||
|
||||
Data flow for `imagen generate`:
|
||||
|
||||
@@ -10,6 +10,27 @@ import (
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
)
|
||||
|
||||
// instanceStatus checks adapter-specific preconditions (e.g. the
|
||||
// Replicate API token env var being set) and returns a short
|
||||
// user-facing status string.
|
||||
func instanceStatus(spec config.BackendSpec) string {
|
||||
if !backend.Default.Has(spec.Type) {
|
||||
return fmt.Sprintf("type %q not compiled in", spec.Type)
|
||||
}
|
||||
switch spec.Type {
|
||||
case backend.ReplicateType:
|
||||
envName, _ := spec.Raw["api_token_env"].(string)
|
||||
if envName == "" {
|
||||
envName = "REPLICATE_API_TOKEN"
|
||||
}
|
||||
if os.Getenv(envName) == "" {
|
||||
return fmt.Sprintf("not configured (set %s)", envName)
|
||||
}
|
||||
return "ok"
|
||||
}
|
||||
return "registered"
|
||||
}
|
||||
|
||||
func runBackends(args []string) error {
|
||||
fs := flag.NewFlagSet("backends", flag.ContinueOnError)
|
||||
var configPath string
|
||||
@@ -27,10 +48,7 @@ func runBackends(args []string) error {
|
||||
fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS")
|
||||
if cfg != nil {
|
||||
for name, spec := range cfg.Backends {
|
||||
status := "registered"
|
||||
if !backend.Default.Has(spec.Type) {
|
||||
status = fmt.Sprintf("type %q not compiled in", spec.Type)
|
||||
}
|
||||
status := instanceStatus(spec)
|
||||
marker := ""
|
||||
if name == cfg.DefaultBackend {
|
||||
marker = " (default)"
|
||||
|
||||
386
cmd/imagen/compare.go
Normal file
386
cmd/imagen/compare.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/image/font"
|
||||
"golang.org/x/image/font/basicfont"
|
||||
"golang.org/x/image/math/fixed"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/output"
|
||||
"mgit.msbls.de/m/ImaGen/internal/prompt"
|
||||
)
|
||||
|
||||
// runCompare implements `imagen compare "<prompt>" --models a,b,c --output <dir>`.
|
||||
//
|
||||
// Each backend in --models runs sequentially against the same prompt (mRock
|
||||
// has a single GPU; parallelising would just OOM). Each generation lands as
|
||||
// a backend-suffixed file in the output dir; a contact sheet stitches them
|
||||
// together into one PNG with the backend name overlaid on each cell. A
|
||||
// sidecar JSON next to the contact sheet lists every generation with its
|
||||
// per-model metadata (latency, seed, model file, VRAM peak).
|
||||
func runCompare(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("compare", flag.ContinueOnError)
|
||||
var (
|
||||
modelsCSV string
|
||||
size string
|
||||
outDir string
|
||||
style string
|
||||
negative string
|
||||
seed int64
|
||||
steps int
|
||||
configPath string
|
||||
noContact bool
|
||||
)
|
||||
fs.StringVar(&modelsCSV, "models", "", "comma-separated backend instance names (required)")
|
||||
fs.StringVar(&size, "size", "1024x1024", "WxH for every backend")
|
||||
fs.StringVar(&outDir, "output", "", "directory to write the images + contact sheet (default: ~/Pictures/imagen/compare)")
|
||||
fs.StringVar(&style, "style", "", "style preset applied to the prompt before dispatching to each backend")
|
||||
fs.StringVar(&negative, "negative", "", "negative prompt (forwarded to every backend that supports it)")
|
||||
fs.Int64Var(&seed, "seed", 0, "deterministic seed for every backend (0 = each backend rolls its own)")
|
||||
fs.IntVar(&steps, "steps", 0, "diffusion steps (0 = each backend's default)")
|
||||
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
|
||||
fs.BoolVar(&noContact, "no-contact-sheet", false, "skip the composite PNG; only write per-backend images + sidecar")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), `Usage: imagen compare "<prompt>" --models a,b,c [flags]`)
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
leadingPositional, flagArgs := splitLeadingPositional(args)
|
||||
if err := fs.Parse(flagArgs); err != nil {
|
||||
return err
|
||||
}
|
||||
positional := append(leadingPositional, fs.Args()...)
|
||||
if len(positional) == 0 {
|
||||
fs.Usage()
|
||||
return userErr("missing prompt")
|
||||
}
|
||||
rawPrompt := strings.Join(positional, " ")
|
||||
modelNames := splitCSV(modelsCSV)
|
||||
if len(modelNames) == 0 {
|
||||
return userErr("--models is required (comma-separated backend instance names)")
|
||||
}
|
||||
|
||||
w, h, err := parseSize(size)
|
||||
if err != nil {
|
||||
return userErr("bad --size: %v", err)
|
||||
}
|
||||
|
||||
cfg, cfgErr := config.Load(configPath)
|
||||
if cfgErr != nil && !os.IsNotExist(cfgErr) {
|
||||
return cfgErr
|
||||
}
|
||||
|
||||
if outDir == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
outDir = filepath.Join(home, "Pictures", "imagen", "compare")
|
||||
}
|
||||
outDir = config.ExpandPath(outDir)
|
||||
|
||||
finalPrompt, err := prompt.Apply(rawPrompt, style)
|
||||
if err != nil {
|
||||
return userErr("%v", err)
|
||||
}
|
||||
|
||||
runID := time.Now().Format("20060102-150405")
|
||||
runDir := filepath.Join(outDir, runID+"-"+output.Slug(rawPrompt))
|
||||
if err := os.MkdirAll(runDir, 0o755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", runDir, err)
|
||||
}
|
||||
|
||||
results := make([]compareResult, 0, len(modelNames))
|
||||
for i, name := range modelNames {
|
||||
fmt.Fprintf(os.Stderr, "[%d/%d] %s ...\n", i+1, len(modelNames), name)
|
||||
res, err := generateOne(ctx, cfg, name, finalPrompt, negative, w, h, seed, steps, runDir, rawPrompt)
|
||||
if err != nil {
|
||||
// Don't abort the whole run on a single backend failure — record
|
||||
// the error and continue. flexsiebels-style consumers want to
|
||||
// see N-1 results rather than zero when one model is offline.
|
||||
fmt.Fprintf(os.Stderr, " failed: %v\n", err)
|
||||
results = append(results, compareResult{Backend: name, Error: err.Error()})
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %s (%d ms)\n", res.ImagePath, res.LatencyMs)
|
||||
results = append(results, res)
|
||||
}
|
||||
|
||||
// Sidecar JSON beside the run dir captures every attempt.
|
||||
sidecar := filepath.Join(runDir, "compare.json")
|
||||
if err := writeCompareSidecar(sidecar, rawPrompt, style, negative, w, h, seed, steps, results); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "sidecar:", sidecar)
|
||||
|
||||
// Contact sheet stitches the successful results together. If every
|
||||
// backend failed there's nothing to draw, so skip silently.
|
||||
if !noContact {
|
||||
successes := successfulResults(results)
|
||||
if len(successes) > 0 {
|
||||
sheet := filepath.Join(runDir, "contact-sheet.png")
|
||||
if err := writeContactSheet(sheet, rawPrompt, successes); err != nil {
|
||||
return fmt.Errorf("contact sheet: %w", err)
|
||||
}
|
||||
fmt.Println(sheet)
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, "imagen compare: all backends failed; no contact sheet written")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// compareResult is one backend's output in a comparison run. Error is set
|
||||
// when Generate failed for this backend; ImagePath + Metadata are empty in
|
||||
// that case.
|
||||
type compareResult struct {
|
||||
Backend string `json:"backend"`
|
||||
ImagePath string `json:"image_path,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
VRAMUsedMiB int64 `json:"vram_used_mib,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func generateOne(ctx context.Context, cfg *config.Config, name, finalPrompt, negative string, w, h int, seed int64, steps int, runDir, rawPrompt string) (compareResult, error) {
|
||||
be, err := buildBackend(cfg, name)
|
||||
if err != nil {
|
||||
return compareResult{Backend: name}, err
|
||||
}
|
||||
attachUsageSink(be)
|
||||
|
||||
req := backend.Request{
|
||||
Prompt: finalPrompt,
|
||||
NegativePrompt: negative,
|
||||
Width: w,
|
||||
Height: h,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
}
|
||||
res, err := be.Generate(ctx, req)
|
||||
if err != nil {
|
||||
return compareResult{Backend: name}, err
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
|
||||
imgBytes, err := io.ReadAll(res.ImageReader)
|
||||
if err != nil {
|
||||
return compareResult{Backend: name}, fmt.Errorf("read image: %w", err)
|
||||
}
|
||||
|
||||
imgPath := filepath.Join(runDir, output.Slug(rawPrompt)+"--"+output.Slug(name)+"."+extFromMime(res.MimeType))
|
||||
if err := os.WriteFile(imgPath, imgBytes, 0o644); err != nil {
|
||||
return compareResult{Backend: name}, fmt.Errorf("write %s: %w", imgPath, err)
|
||||
}
|
||||
|
||||
cr := compareResult{
|
||||
Backend: name,
|
||||
ImagePath: imgPath,
|
||||
Seed: seedFromMetadata(res.Metadata, seed),
|
||||
LatencyMs: metaInt64(res.Metadata, "latency_ms"),
|
||||
Model: metaString(res.Metadata, "model"),
|
||||
Metadata: res.Metadata,
|
||||
}
|
||||
if v, ok := res.Metadata["vram_used_mib"].(int64); ok {
|
||||
cr.VRAMUsedMiB = v
|
||||
}
|
||||
return cr, nil
|
||||
}
|
||||
|
||||
func successfulResults(rs []compareResult) []compareResult {
|
||||
out := make([]compareResult, 0, len(rs))
|
||||
for _, r := range rs {
|
||||
if r.Error == "" && r.ImagePath != "" {
|
||||
out = append(out, r)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeCompareSidecar(path, rawPrompt, style, negative string, w, h int, seed int64, steps int, results []compareResult) error {
|
||||
body := map[string]any{
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
"prompt": rawPrompt,
|
||||
"style": style,
|
||||
"negative": negative,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"results": results,
|
||||
"backends": backendNames(results),
|
||||
"successful": len(successfulResults(results)),
|
||||
"total": len(results),
|
||||
}
|
||||
data, err := json.MarshalIndent(body, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sidecar: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, append(data, '\n'), 0o644)
|
||||
}
|
||||
|
||||
func backendNames(rs []compareResult) []string {
|
||||
out := make([]string, len(rs))
|
||||
for i, r := range rs {
|
||||
out[i] = r.Backend
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// writeContactSheet stitches a grid of (image, label) cells into one PNG.
|
||||
// Cells are sized to fit in a target width of ~2400px while keeping each
|
||||
// individual image full-resolution (no downscale) up to the column limit;
|
||||
// past that, images sit at their native size and we just lay them out.
|
||||
//
|
||||
// The grid is a simple horizontal row when N <= 4; otherwise N/2 rows of 2.
|
||||
// This is a contact sheet, not a fancy gallery — readability for side-by-
|
||||
// side eyeballing is the goal.
|
||||
func writeContactSheet(path, prompt string, results []compareResult) error {
|
||||
if len(results) == 0 {
|
||||
return fmt.Errorf("no successful results to lay out")
|
||||
}
|
||||
cells := make([]contactCell, 0, len(results))
|
||||
for _, r := range results {
|
||||
img, err := readPNG(r.ImagePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", r.ImagePath, err)
|
||||
}
|
||||
cells = append(cells, contactCell{
|
||||
Image: img,
|
||||
Label: r.Backend,
|
||||
SubLabel: fmt.Sprintf("%dms · seed %d", r.LatencyMs, r.Seed),
|
||||
})
|
||||
}
|
||||
|
||||
cols := len(cells)
|
||||
if cols > 4 {
|
||||
cols = 2
|
||||
}
|
||||
rows := (len(cells) + cols - 1) / cols
|
||||
|
||||
const labelH = 64
|
||||
const pad = 16
|
||||
|
||||
cellW := cells[0].Image.Bounds().Dx()
|
||||
cellH := cells[0].Image.Bounds().Dy()
|
||||
for _, c := range cells {
|
||||
if w := c.Image.Bounds().Dx(); w > cellW {
|
||||
cellW = w
|
||||
}
|
||||
if h := c.Image.Bounds().Dy(); h > cellH {
|
||||
cellH = h
|
||||
}
|
||||
}
|
||||
|
||||
totalW := cols*cellW + (cols+1)*pad
|
||||
totalH := rows*(cellH+labelH) + (rows+1)*pad + 48 // header band
|
||||
|
||||
canvas := image.NewRGBA(image.Rect(0, 0, totalW, totalH))
|
||||
draw.Draw(canvas, canvas.Bounds(), &image.Uniform{C: color.RGBA{R: 30, G: 30, B: 35, A: 255}}, image.Point{}, draw.Src)
|
||||
|
||||
// Header: show the truncated prompt.
|
||||
headerText := "imagen compare — " + truncate(prompt, 100)
|
||||
drawText(canvas, headerText, pad, 30, color.RGBA{R: 240, G: 240, B: 245, A: 255})
|
||||
|
||||
for i, c := range cells {
|
||||
col := i % cols
|
||||
row := i / cols
|
||||
x0 := pad + col*(cellW+pad)
|
||||
y0 := 48 + pad + row*(cellH+labelH+pad)
|
||||
// Center the image inside the cell when smaller than the max cell size.
|
||||
iw := c.Image.Bounds().Dx()
|
||||
ih := c.Image.Bounds().Dy()
|
||||
offX := (cellW - iw) / 2
|
||||
offY := (cellH - ih) / 2
|
||||
dstRect := image.Rect(x0+offX, y0+offY, x0+offX+iw, y0+offY+ih)
|
||||
draw.Draw(canvas, dstRect, c.Image, c.Image.Bounds().Min, draw.Src)
|
||||
|
||||
// Label band underneath.
|
||||
labelY := y0 + cellH + 20
|
||||
drawText(canvas, c.Label, x0+8, labelY, color.RGBA{R: 250, G: 250, B: 250, A: 255})
|
||||
drawText(canvas, c.SubLabel, x0+8, labelY+22, color.RGBA{R: 180, G: 180, B: 190, A: 255})
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
return png.Encode(f, canvas)
|
||||
}
|
||||
|
||||
type contactCell struct {
|
||||
Image image.Image
|
||||
Label string
|
||||
SubLabel string
|
||||
}
|
||||
|
||||
func readPNG(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
img, _, err := image.Decode(f)
|
||||
return img, err
|
||||
}
|
||||
|
||||
func drawText(dst *image.RGBA, s string, x, y int, c color.Color) {
|
||||
drawer := &font.Drawer{
|
||||
Dst: dst,
|
||||
Src: &image.Uniform{C: c},
|
||||
Face: basicfont.Face7x13,
|
||||
Dot: fixed.Point26_6{X: fixed.I(x), Y: fixed.I(y)},
|
||||
}
|
||||
drawer.DrawString(s)
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-1] + "…"
|
||||
}
|
||||
|
||||
func splitCSV(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func metaInt64(m map[string]any, key string) int64 {
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return n
|
||||
case int:
|
||||
return int64(n)
|
||||
case float64:
|
||||
return int64(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
203
cmd/imagen/compare_test.go
Normal file
203
cmd/imagen/compare_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// runCompareWithEnv runs the compare subcommand in a writable tmpdir, with
|
||||
// XDG_CONFIG_HOME pointing somewhere empty so no host imagen.yaml leaks in.
|
||||
func runCompareWithEnv(t *testing.T, args []string) (stderr, stdout *bytes.Buffer, runDir string, err error) {
|
||||
t.Helper()
|
||||
tmp := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", filepath.Join(tmp, "no-config"))
|
||||
t.Setenv("HOME", tmp)
|
||||
|
||||
out := filepath.Join(tmp, "compare")
|
||||
// stdlib flag parsing requires flags after the leading positional. Append
|
||||
// --output at the end so any caller-supplied flags still parse cleanly.
|
||||
args = append(args, "--output", out)
|
||||
|
||||
// Capture stdout/stderr via os pipes since runCompare writes directly.
|
||||
oldStdout := os.Stdout
|
||||
oldStderr := os.Stderr
|
||||
rOut, wOut, _ := os.Pipe()
|
||||
rErr, wErr, _ := os.Pipe()
|
||||
os.Stdout = wOut
|
||||
os.Stderr = wErr
|
||||
defer func() {
|
||||
os.Stdout = oldStdout
|
||||
os.Stderr = oldStderr
|
||||
}()
|
||||
|
||||
cmdErr := runCompare(context.Background(), args)
|
||||
|
||||
_ = wOut.Close()
|
||||
_ = wErr.Close()
|
||||
stdout = &bytes.Buffer{}
|
||||
stderr = &bytes.Buffer{}
|
||||
_, _ = stdout.ReadFrom(rOut)
|
||||
_, _ = stderr.ReadFrom(rErr)
|
||||
|
||||
entries, _ := os.ReadDir(out)
|
||||
if len(entries) == 1 {
|
||||
runDir = filepath.Join(out, entries[0].Name())
|
||||
}
|
||||
return stderr, stdout, runDir, cmdErr
|
||||
}
|
||||
|
||||
func TestCompareHappyPathWithMockBackends(t *testing.T) {
|
||||
// Two mock instances stand in for two different backends. mock ignores
|
||||
// cfg so we can reuse the registered type as the instance name and skip
|
||||
// writing imagen.yaml entirely.
|
||||
stderr, stdout, runDir, err := runCompareWithEnv(t, []string{
|
||||
"a cat in a fishbowl",
|
||||
"--models", "mock,mock",
|
||||
"--size", "64x64",
|
||||
"--seed", "42",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runCompare: %v\nstderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
if runDir == "" {
|
||||
t.Fatal("expected a run directory under --output")
|
||||
}
|
||||
// Sidecar JSON
|
||||
sidecar := filepath.Join(runDir, "compare.json")
|
||||
data, err := os.ReadFile(sidecar)
|
||||
if err != nil {
|
||||
t.Fatalf("read sidecar: %v", err)
|
||||
}
|
||||
var body struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Successful int `json:"successful"`
|
||||
Total int `json:"total"`
|
||||
Results []struct {
|
||||
Backend string `json:"backend"`
|
||||
ImagePath string `json:"image_path"`
|
||||
Error string `json:"error"`
|
||||
} `json:"results"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &body); err != nil {
|
||||
t.Fatalf("parse sidecar: %v\n%s", err, data)
|
||||
}
|
||||
if body.Prompt != "a cat in a fishbowl" {
|
||||
t.Errorf("prompt = %q", body.Prompt)
|
||||
}
|
||||
if body.Total != 2 || body.Successful != 2 {
|
||||
t.Errorf("counts = %d successful / %d total", body.Successful, body.Total)
|
||||
}
|
||||
for _, r := range body.Results {
|
||||
if r.Error != "" {
|
||||
t.Errorf("backend %s errored: %s", r.Backend, r.Error)
|
||||
}
|
||||
if _, err := os.Stat(r.ImagePath); err != nil {
|
||||
t.Errorf("image not on disk for %s: %v", r.Backend, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Contact sheet path was printed on stdout.
|
||||
sheet := strings.TrimSpace(stdout.String())
|
||||
if sheet == "" {
|
||||
t.Fatal("expected contact sheet path on stdout")
|
||||
}
|
||||
f, err := os.Open(sheet)
|
||||
if err != nil {
|
||||
t.Fatalf("open contact sheet: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
img, err := png.Decode(f)
|
||||
if err != nil {
|
||||
t.Fatalf("decode contact sheet PNG: %v", err)
|
||||
}
|
||||
if w := img.Bounds().Dx(); w < 100 {
|
||||
t.Errorf("contact sheet looks empty (width %d)", w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareSkipContactSheet(t *testing.T) {
|
||||
stderr, stdout, runDir, err := runCompareWithEnv(t, []string{
|
||||
"x",
|
||||
"--models", "mock",
|
||||
"--size", "32x32",
|
||||
"--seed", "1",
|
||||
"--no-contact-sheet",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runCompare: %v\nstderr: %s", err, stderr.String())
|
||||
}
|
||||
if got := strings.TrimSpace(stdout.String()); got != "" {
|
||||
t.Errorf("expected no stdout output (no contact sheet), got %q", got)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(runDir, "contact-sheet.png")); err == nil {
|
||||
t.Errorf("contact-sheet.png should not exist with --no-contact-sheet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareRecordsBackendErrors(t *testing.T) {
|
||||
// One real (mock) + one unknown. Unknown should fail but not abort the
|
||||
// run — sidecar records both, contact sheet built from successes only.
|
||||
stderr, _, runDir, err := runCompareWithEnv(t, []string{
|
||||
"y",
|
||||
"--models", "mock,this-instance-does-not-exist",
|
||||
"--size", "32x32",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runCompare: %v\nstderr: %s", err, stderr.String())
|
||||
}
|
||||
sidecar := filepath.Join(runDir, "compare.json")
|
||||
data, _ := os.ReadFile(sidecar)
|
||||
var body struct {
|
||||
Successful int `json:"successful"`
|
||||
Total int `json:"total"`
|
||||
Results []struct {
|
||||
Backend string `json:"backend"`
|
||||
Error string `json:"error"`
|
||||
} `json:"results"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &body); err != nil {
|
||||
t.Fatalf("parse sidecar: %v", err)
|
||||
}
|
||||
if body.Total != 2 {
|
||||
t.Errorf("expected 2 results, got %d", body.Total)
|
||||
}
|
||||
if body.Successful != 1 {
|
||||
t.Errorf("expected 1 success, got %d", body.Successful)
|
||||
}
|
||||
var sawError bool
|
||||
for _, r := range body.Results {
|
||||
if r.Backend == "this-instance-does-not-exist" && r.Error != "" {
|
||||
sawError = true
|
||||
}
|
||||
}
|
||||
if !sawError {
|
||||
t.Errorf("expected an error recorded for the unknown backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareNoModelsFails(t *testing.T) {
|
||||
_, _, _, err := runCompareWithEnv(t, []string{"x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when --models is empty")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--models") {
|
||||
t.Errorf("error should mention --models, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareNoPromptFails(t *testing.T) {
|
||||
_, _, _, err := runCompareWithEnv(t, []string{"--models", "mock"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when prompt is missing")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing prompt") {
|
||||
t.Errorf("error should mention missing prompt, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,18 @@ func runConfig(args []string) error {
|
||||
}
|
||||
fmt.Fprintf(os.Stdout, "OK — %d backend(s) defined, default=%q\n",
|
||||
len(cfg.Backends), cfg.DefaultBackend)
|
||||
// Soft warnings — surfaced on stderr so they're visible but don't
|
||||
// fail the validate exit code.
|
||||
cloudMode := cfg.Output.CloudSync
|
||||
if cloudMode == "" {
|
||||
cloudMode = "auto"
|
||||
}
|
||||
if cloudMode != "off" && cfg.OwnerUserID == "" {
|
||||
fmt.Fprintln(os.Stderr,
|
||||
"warning: cloud_sync is "+cloudMode+" but owner_user_id is empty — DB inserts will be skipped.")
|
||||
fmt.Fprintln(os.Stderr,
|
||||
" look it up: SELECT id FROM auth.users WHERE email = '<your-supabase-email>';")
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return userErr("unknown config subcommand %q (init|validate|path)", args[0])
|
||||
|
||||
@@ -2,30 +2,39 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
"mgit.msbls.de/m/ImaGen/internal/cloud"
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/output"
|
||||
"mgit.msbls.de/m/ImaGen/internal/preview"
|
||||
"mgit.msbls.de/m/ImaGen/internal/prompt"
|
||||
"mgit.msbls.de/m/ImaGen/internal/usage"
|
||||
)
|
||||
|
||||
func runGenerate(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("generate", flag.ContinueOnError)
|
||||
var (
|
||||
backendName string
|
||||
size string
|
||||
outPath string
|
||||
seed int64
|
||||
steps int
|
||||
style string
|
||||
negative string
|
||||
configPath string
|
||||
noSidecar bool
|
||||
backendName string
|
||||
size string
|
||||
outPath string
|
||||
seed int64
|
||||
steps int
|
||||
style string
|
||||
negative string
|
||||
configPath string
|
||||
noSidecar bool
|
||||
previewOn bool
|
||||
previewOff bool
|
||||
noCloud bool
|
||||
)
|
||||
fs.StringVar(&backendName, "backend", "", "backend instance name (default: config.default_backend)")
|
||||
fs.StringVar(&size, "size", "1024x1024", "WxH, e.g. 1024x1024")
|
||||
@@ -36,6 +45,9 @@ func runGenerate(ctx context.Context, args []string) error {
|
||||
fs.StringVar(&negative, "negative", "", "negative prompt (ignored by backends that don't support it)")
|
||||
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
|
||||
fs.BoolVar(&noSidecar, "no-sidecar", false, "skip the JSON sidecar even if config enables it")
|
||||
fs.BoolVar(&previewOn, "preview", false, "force tmux preview window on (errors outside $TMUX)")
|
||||
fs.BoolVar(&previewOff, "no-preview", false, "skip the tmux preview window")
|
||||
fs.BoolVar(&noCloud, "no-cloud", false, "skip Supabase upload + imagen.images insert for this generation")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), `Usage: imagen generate "<prompt>" [flags]`)
|
||||
fs.PrintDefaults()
|
||||
@@ -76,6 +88,7 @@ func runGenerate(ctx context.Context, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
attachUsageSink(be)
|
||||
|
||||
finalPrompt, err := prompt.Apply(rawPrompt, style)
|
||||
if err != nil {
|
||||
@@ -118,9 +131,249 @@ func runGenerate(ctx context.Context, args []string) error {
|
||||
if paths.SidecarPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
|
||||
}
|
||||
|
||||
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", "", paths, in, res, w, h); err != nil {
|
||||
// cloud-sync failures are warnings — the image already wrote.
|
||||
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
|
||||
} else if result != nil && result.ImageID != "" {
|
||||
fmt.Fprintf(os.Stderr, "cloud: imagen.images.id=%s storage_path=%s\n", result.ImageID, result.StoragePath)
|
||||
}
|
||||
|
||||
if err := maybePreview(cfg, previewOn, previewOff, paths.ImagePath, rawPrompt); err != nil {
|
||||
// preview failures are warnings — the image already wrote.
|
||||
fmt.Fprintln(os.Stderr, "imagen: preview:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveCloudSyncMode applies the precedence chain config -> env -> flag.
|
||||
// Flags win, env beats config, config beats the implicit auto default.
|
||||
// Mirrors resolvePreviewMode shape.
|
||||
func resolveCloudSyncMode(cfg *config.Config, noCloudFlag bool, env string) (string, error) {
|
||||
mode := "auto"
|
||||
if cfg != nil && cfg.Output.CloudSync != "" {
|
||||
mode = cfg.Output.CloudSync
|
||||
}
|
||||
if env != "" {
|
||||
switch env {
|
||||
case "auto", "on", "off":
|
||||
mode = env
|
||||
default:
|
||||
return "", fmt.Errorf("$IMAGEN_CLOUD_SYNC = %q (must be auto|on|off)", env)
|
||||
}
|
||||
}
|
||||
if noCloudFlag {
|
||||
mode = "off"
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
// maybeCloudSync resolves the effective mode and, if it says yes, uploads
|
||||
// the PNG and inserts the row. Returns the SyncResult on success so callers
|
||||
// that need the imagen.images.id (e.g. the worker linking a job row) can pick
|
||||
// it up. ownerOverride, when non-empty, wins over config + env — the worker
|
||||
// passes the job row's owner_user_id so each job is attributed correctly.
|
||||
// seriesID, when non-empty, lands on imagen.images.series_id so the
|
||||
// list-page query (`WHERE series_id IS NULL`) hides series members from
|
||||
// the flat grid; empty means solo run.
|
||||
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride, seriesID string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
|
||||
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if mode == "off" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sink, ok := cloud.NewFromEnv()
|
||||
if !ok {
|
||||
if mode == "on" {
|
||||
return nil, fmt.Errorf("cloud_sync=on but SUPABASE_URL / SUPABASE_SERVICE_KEY not set in env")
|
||||
}
|
||||
// auto + missing env = silent skip.
|
||||
return nil, nil
|
||||
}
|
||||
switch {
|
||||
case ownerOverride != "":
|
||||
sink.OwnerUserID = ownerOverride
|
||||
case cfg != nil && cfg.OwnerUserID != "":
|
||||
// Config-supplied owner_user_id takes precedence over $IMAGEN_OWNER_USER_ID.
|
||||
sink.OwnerUserID = cfg.OwnerUserID
|
||||
}
|
||||
if sink.OwnerUserID == "" {
|
||||
if mode == "on" {
|
||||
return nil, fmt.Errorf("cloud_sync=on but owner_user_id not set in config and $IMAGEN_OWNER_USER_ID is empty")
|
||||
}
|
||||
// auto + missing UUID = silent skip.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pngBytes, readErr := os.ReadFile(paths.ImagePath)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("read local image: %w", readErr)
|
||||
}
|
||||
|
||||
// Reuse the writer's date/slug/seed so storage_path mirrors the local
|
||||
// filename's prefix exactly — viewers can join `imagen.images` on
|
||||
// either side without timezone drift.
|
||||
date := paths.Date
|
||||
slug := paths.Slug
|
||||
if date == "" || slug == "" {
|
||||
now := time.Now()
|
||||
date = now.Format("2006-01-02")
|
||||
slug = output.Slug(in.Prompt)
|
||||
}
|
||||
ext := in.Ext
|
||||
if ext == "" {
|
||||
ext = strings.TrimPrefix(filepath.Ext(paths.ImagePath), ".")
|
||||
}
|
||||
if ext == "" {
|
||||
ext = "png"
|
||||
}
|
||||
|
||||
// Snapshot the sidecar (if it exists) so the row carries the same
|
||||
// metadata view a downstream viewer would see on disk.
|
||||
var sidecar map[string]any
|
||||
if paths.SidecarPath != "" {
|
||||
if scBytes, err := os.ReadFile(paths.SidecarPath); err == nil {
|
||||
_ = json.Unmarshal(scBytes, &sidecar)
|
||||
}
|
||||
}
|
||||
|
||||
model := metaString(res.Metadata, "model")
|
||||
steps := metaInt(res.Metadata, "steps")
|
||||
cost := metaFloatPtr(res.Metadata, "cost_usd_estimate")
|
||||
latency := metaInt(res.Metadata, "latency_ms")
|
||||
|
||||
seed := paths.Seed
|
||||
if seed == 0 {
|
||||
seed = in.Seed
|
||||
}
|
||||
syncReq := cloud.SyncRequest{
|
||||
Date: date,
|
||||
Slug: slug,
|
||||
Seed: seed,
|
||||
Ext: ext,
|
||||
PNG: pngBytes,
|
||||
MimeType: res.MimeType,
|
||||
Prompt: in.Prompt,
|
||||
Backend: in.Backend,
|
||||
Model: model,
|
||||
Steps: steps,
|
||||
Width: width,
|
||||
Height: height,
|
||||
LatencyMs: latency,
|
||||
CostUSDEstimate: cost,
|
||||
Sidecar: sidecar,
|
||||
SeriesID: seriesID,
|
||||
}
|
||||
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
return sink.Sync(syncCtx, syncReq)
|
||||
}
|
||||
|
||||
func metaString(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func metaInt(m map[string]any, key string) int {
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func metaFloatPtr(m map[string]any, key string) *float64 {
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return &n
|
||||
case float32:
|
||||
f := float64(n)
|
||||
return &f
|
||||
case int:
|
||||
f := float64(n)
|
||||
return &f
|
||||
case int64:
|
||||
f := float64(n)
|
||||
return &f
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePreviewMode applies the precedence chain config -> env -> flag.
|
||||
// Flags win, env beats config, config beats the implicit auto default.
|
||||
func resolvePreviewMode(cfg *config.Config, flagOn, flagOff bool, env string) (preview.Mode, error) {
|
||||
mode := preview.ModeAuto
|
||||
if cfg != nil && cfg.Output.Preview != "" {
|
||||
m, err := preview.ParseMode(cfg.Output.Preview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("config output.preview: %w", err)
|
||||
}
|
||||
mode = m
|
||||
}
|
||||
if env != "" {
|
||||
m, err := preview.ParseMode(env)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("$IMAGEN_PREVIEW: %w", err)
|
||||
}
|
||||
mode = m
|
||||
}
|
||||
if flagOn && flagOff {
|
||||
return "", userErr("--preview and --no-preview are mutually exclusive")
|
||||
}
|
||||
if flagOn {
|
||||
mode = preview.ModeOn
|
||||
}
|
||||
if flagOff {
|
||||
mode = preview.ModeOff
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
// maybePreview resolves the effective preview mode and, if it says yes,
|
||||
// spawns a tmux window via tmux-img. Always non-fatal.
|
||||
func maybePreview(cfg *config.Config, flagOn, flagOff bool, imagePath, rawPrompt string) error {
|
||||
mode, err := resolvePreviewMode(cfg, flagOn, flagOff, os.Getenv("IMAGEN_PREVIEW"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decision, err := preview.Resolve(mode, os.Getenv("TMUX") != "", stdoutIsTTY())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !decision.ShouldPreview {
|
||||
return nil
|
||||
}
|
||||
spawner := &preview.Spawner{}
|
||||
return spawner.Spawn(imagePath, output.Slug(rawPrompt))
|
||||
}
|
||||
|
||||
func stdoutIsTTY() bool {
|
||||
fi, err := os.Stdout.Stat()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return fi.Mode()&os.ModeCharDevice != 0
|
||||
}
|
||||
|
||||
// splitLeadingPositional separates the positional args at the start of args
|
||||
// from the rest (which begins with the first flag). A literal "--" terminator
|
||||
// pushes everything after it into the positional list and out of flag parsing.
|
||||
@@ -153,6 +406,21 @@ func parseSize(s string) (int, int, error) {
|
||||
return w, h, nil
|
||||
}
|
||||
|
||||
// attachUsageSink wires a Supabase cost-tracking sink into the backend
|
||||
// when it accepts one and the env is configured. Adapters that record
|
||||
// usage expose a public Sink field of type backend.UsageSink.
|
||||
func attachUsageSink(be backend.Backend) {
|
||||
r, ok := be.(*backend.Replicate)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sink, ok := usage.NewSupabaseSinkFromEnv()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
r.Sink = sink
|
||||
}
|
||||
|
||||
func buildBackend(cfg *config.Config, name string) (backend.Backend, error) {
|
||||
if cfg != nil {
|
||||
spec, ok := cfg.Backends[name]
|
||||
|
||||
87
cmd/imagen/generate_test.go
Normal file
87
cmd/imagen/generate_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/preview"
|
||||
)
|
||||
|
||||
func TestResolvePreviewMode(t *testing.T) {
|
||||
type tc struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
flagOn bool
|
||||
flagOff bool
|
||||
env string
|
||||
want preview.Mode
|
||||
wantError bool
|
||||
}
|
||||
cases := []tc{
|
||||
{name: "all-empty-defaults-to-auto", want: preview.ModeAuto},
|
||||
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, want: preview.ModeOn},
|
||||
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{Preview: "off"}}, want: preview.ModeOff},
|
||||
{name: "config-auto-explicit", cfg: &config.Config{Output: config.OutputConfig{Preview: "auto"}}, want: preview.ModeAuto},
|
||||
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, env: "off", want: preview.ModeOff},
|
||||
{name: "flag-on-overrides-env-off", env: "off", flagOn: true, want: preview.ModeOn},
|
||||
{name: "flag-off-overrides-env-on", env: "on", flagOff: true, want: preview.ModeOff},
|
||||
{name: "flag-off-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, flagOff: true, want: preview.ModeOff},
|
||||
{name: "both-flags-error", flagOn: true, flagOff: true, wantError: true},
|
||||
{name: "bad-env-errors", env: "yes", wantError: true},
|
||||
{name: "bad-config-errors", cfg: &config.Config{Output: config.OutputConfig{Preview: "yes"}}, wantError: true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got, err := resolvePreviewMode(c.cfg, c.flagOn, c.flagOff, c.env)
|
||||
if c.wantError {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got mode %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != c.want {
|
||||
t.Errorf("mode = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudSyncMode(t *testing.T) {
|
||||
type tc struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
noCloud bool
|
||||
env string
|
||||
want string
|
||||
wantError bool
|
||||
}
|
||||
cases := []tc{
|
||||
{name: "all-empty-defaults-to-auto", want: "auto"},
|
||||
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, want: "on"},
|
||||
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "off"}}, want: "off"},
|
||||
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "off", want: "off"},
|
||||
{name: "flag-overrides-env-and-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "on", noCloud: true, want: "off"},
|
||||
{name: "flag-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, noCloud: true, want: "off"},
|
||||
{name: "bad-env-errors", env: "yes", wantError: true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
got, err := resolveCloudSyncMode(c.cfg, c.noCloud, c.env)
|
||||
if c.wantError {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got mode %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != c.want {
|
||||
t.Errorf("mode = %q, want %q", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,14 +14,18 @@ import (
|
||||
_ "mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
)
|
||||
|
||||
const usage = `imagen — model-agnostic image generation
|
||||
const helpText = `imagen — model-agnostic image generation
|
||||
|
||||
Usage:
|
||||
imagen generate <prompt> [flags] generate one image
|
||||
imagen compare <prompt> --models a,b,c [flags]
|
||||
run one prompt across N backends + contact sheet
|
||||
imagen worker [flags] consume the imagen.jobs queue (daemon)
|
||||
imagen backends list registered backend types
|
||||
imagen config init print a sample imagen.yaml on stdout
|
||||
imagen config validate validate the active config
|
||||
imagen serve [--addr :8080] (stub) start the HTTP server
|
||||
imagen usage [--since DATE] show cost-tracking rows
|
||||
imagen version print version
|
||||
imagen help show this help
|
||||
|
||||
@@ -33,7 +37,7 @@ var Version = "dev"
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprint(os.Stderr, usage)
|
||||
fmt.Fprint(os.Stderr, helpText)
|
||||
os.Exit(2)
|
||||
}
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
@@ -44,18 +48,24 @@ func main() {
|
||||
switch os.Args[1] {
|
||||
case "generate":
|
||||
err = runGenerate(ctx, args)
|
||||
case "compare":
|
||||
err = runCompare(ctx, args)
|
||||
case "worker":
|
||||
err = runWorker(ctx, args)
|
||||
case "backends":
|
||||
err = runBackends(args)
|
||||
case "config":
|
||||
err = runConfig(args)
|
||||
case "serve":
|
||||
err = runServe(args)
|
||||
case "usage":
|
||||
err = runUsage(ctx, args)
|
||||
case "version", "-v", "--version":
|
||||
fmt.Println(Version)
|
||||
case "help", "-h", "--help":
|
||||
fmt.Print(usage)
|
||||
fmt.Print(helpText)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage)
|
||||
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText)
|
||||
os.Exit(2)
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
189
cmd/imagen/usage.go
Normal file
189
cmd/imagen/usage.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/usage"
|
||||
)
|
||||
|
||||
// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage
|
||||
// via Supabase REST and prints a tab-aligned table grouped by week +
|
||||
// backend + model + caller, with totals at the bottom.
|
||||
func runUsage(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("usage", flag.ContinueOnError)
|
||||
var (
|
||||
since string
|
||||
raw bool
|
||||
)
|
||||
fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date")
|
||||
fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sinceT time.Time
|
||||
if since != "" {
|
||||
t, err := time.Parse("2006-01-02", since)
|
||||
if err != nil {
|
||||
return userErr("--since must be YYYY-MM-DD: %v", err)
|
||||
}
|
||||
sinceT = t
|
||||
}
|
||||
|
||||
sink, ok := usage.NewSupabaseSinkFromEnv()
|
||||
if !ok {
|
||||
return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage")
|
||||
}
|
||||
rows, err := sink.Query(ctx, sinceT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if raw {
|
||||
printRawRows(rows)
|
||||
return nil
|
||||
}
|
||||
printGroupedRows(rows)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRawRows(rows []usage.Row) {
|
||||
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD")
|
||||
var totalCost float64
|
||||
for _, r := range rows {
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
r.CreatedAt.Local().Format("2006-01-02 15:04"),
|
||||
r.Backend,
|
||||
r.Model,
|
||||
derefString(r.Caller),
|
||||
intOrDash(r.LatencyMs),
|
||||
costOrDash(r.CostUSDEstimate),
|
||||
)
|
||||
if r.CostUSDEstimate != nil {
|
||||
totalCost += *r.CostUSDEstimate
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost)
|
||||
_ = tw.Flush()
|
||||
}
|
||||
|
||||
type group struct {
|
||||
week string
|
||||
backend string
|
||||
model string
|
||||
caller string
|
||||
count int
|
||||
cost float64
|
||||
costSet bool
|
||||
}
|
||||
|
||||
type groupKey struct {
|
||||
week, backend, model, caller string
|
||||
}
|
||||
|
||||
func printGroupedRows(rows []usage.Row) {
|
||||
groups := map[groupKey]*group{}
|
||||
for _, r := range rows {
|
||||
caller := derefString(r.Caller)
|
||||
k := groupKey{
|
||||
week: weekStart(r.CreatedAt).Format("2006-01-02"),
|
||||
backend: r.Backend,
|
||||
model: r.Model,
|
||||
caller: caller,
|
||||
}
|
||||
g, ok := groups[k]
|
||||
if !ok {
|
||||
g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller}
|
||||
groups[k] = g
|
||||
}
|
||||
g.count++
|
||||
if r.CostUSDEstimate != nil {
|
||||
g.cost += *r.CostUSDEstimate
|
||||
g.costSet = true
|
||||
}
|
||||
}
|
||||
|
||||
keys := make([]groupKey, 0, len(groups))
|
||||
for k := range groups {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
if keys[i].week != keys[j].week {
|
||||
return keys[i].week > keys[j].week // newest first
|
||||
}
|
||||
if keys[i].backend != keys[j].backend {
|
||||
return keys[i].backend < keys[j].backend
|
||||
}
|
||||
if keys[i].model != keys[j].model {
|
||||
return keys[i].model < keys[j].model
|
||||
}
|
||||
return keys[i].caller < keys[j].caller
|
||||
})
|
||||
|
||||
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD")
|
||||
var totalCount int
|
||||
var totalCost float64
|
||||
for _, k := range keys {
|
||||
g := groups[k]
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n",
|
||||
g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet),
|
||||
)
|
||||
totalCount += g.count
|
||||
totalCost += g.cost
|
||||
}
|
||||
fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost)
|
||||
_ = tw.Flush()
|
||||
}
|
||||
|
||||
// weekStart returns the Monday of the week containing t (UTC).
|
||||
func weekStart(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
wd := int(t.Weekday())
|
||||
if wd == 0 {
|
||||
wd = 7 // shift Sunday to end-of-week
|
||||
}
|
||||
delta := time.Duration(wd-1) * -24 * time.Hour
|
||||
d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
return d.Add(delta)
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
func intOrDash(p *int) string {
|
||||
if p == nil {
|
||||
return "-"
|
||||
}
|
||||
return fmt.Sprintf("%d", *p)
|
||||
}
|
||||
|
||||
func costOrDash(p *float64) string {
|
||||
if p == nil {
|
||||
return "-"
|
||||
}
|
||||
return fmt.Sprintf("%.4f", *p)
|
||||
}
|
||||
|
||||
func costStr(v float64, set bool) string {
|
||||
if !set {
|
||||
return "-"
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%.4f", v))
|
||||
}
|
||||
292
cmd/imagen/worker.go
Normal file
292
cmd/imagen/worker.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/output"
|
||||
"mgit.msbls.de/m/ImaGen/internal/prompt"
|
||||
"mgit.msbls.de/m/ImaGen/internal/worker"
|
||||
)
|
||||
|
||||
// runWorker is the `imagen worker` subcommand: a long-running daemon that
|
||||
// consumes the imagen.jobs queue and writes results into imagen.images via
|
||||
// the same cloud-sync path generate uses.
|
||||
func runWorker(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("worker", flag.ContinueOnError)
|
||||
var (
|
||||
configPath string
|
||||
pollInterval time.Duration
|
||||
jobTimeout time.Duration
|
||||
)
|
||||
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
|
||||
fs.DurationVar(&pollInterval, "poll-interval", 5*time.Second, "safety-poll cadence between LISTEN wakeups")
|
||||
fs.DurationVar(&jobTimeout, "job-timeout", 5*time.Minute, "max wall-time per job before the worker marks it failed")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), `Usage: imagen worker [flags]
|
||||
|
||||
Long-running daemon. LISTENs on the Postgres 'imagen_jobs' channel and polls
|
||||
imagen.jobs every --poll-interval as a safety net, claims pending rows, runs
|
||||
the generation pipeline, then updates the row with status + image_id.
|
||||
|
||||
Env:
|
||||
IMAGEN_WORKER_DATABASE_URL Postgres DSN for direct LISTEN + UPDATE.
|
||||
Required (PostgREST cannot LISTEN).
|
||||
SUPABASE_URL, SUPABASE_SERVICE_KEY, IMAGEN_OWNER_USER_ID
|
||||
Reused from generate's cloud-sync path; the
|
||||
worker writes imagen.images rows through the
|
||||
same code path. Per-job owner_user_id from the
|
||||
job row overrides IMAGEN_OWNER_USER_ID.`)
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, cfgErr := config.Load(configPath)
|
||||
if cfgErr != nil && !os.IsNotExist(cfgErr) {
|
||||
return cfgErr
|
||||
}
|
||||
|
||||
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
return userErr("IMAGEN_WORKER_DATABASE_URL not set; the worker needs a direct Postgres DSN for LISTEN/NOTIFY")
|
||||
}
|
||||
|
||||
q, err := dialQueue(ctx, dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("queue: %w", err)
|
||||
}
|
||||
defer q.Close()
|
||||
|
||||
p := &workerPipeline{cfg: cfg}
|
||||
w := worker.New(q, p, worker.Config{
|
||||
PollInterval: pollInterval,
|
||||
JobTimeout: jobTimeout,
|
||||
Logger: func(format string, a ...any) { fmt.Fprintf(os.Stderr, format+"\n", a...) },
|
||||
})
|
||||
fmt.Fprintln(os.Stderr, "imagen worker: ready (poll-interval", pollInterval, "job-timeout", jobTimeout, ")")
|
||||
return w.Run(ctx)
|
||||
}
|
||||
|
||||
// pgxQueue is the production Queue. It opens one dedicated connection used
|
||||
// for both LISTEN (long-lived) and UPDATE operations. A second connection
|
||||
// would split state needlessly — a single worker process processes one job
|
||||
// at a time so the connection is never contended.
|
||||
type pgxQueue struct {
|
||||
conn *pgx.Conn
|
||||
}
|
||||
|
||||
func dialQueue(ctx context.Context, dsn string) (*pgxQueue, error) {
|
||||
conn, err := pgx.Connect(ctx, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgx.Connect: %w", err)
|
||||
}
|
||||
if _, err := conn.Exec(ctx, "LISTEN imagen_jobs"); err != nil {
|
||||
conn.Close(ctx)
|
||||
return nil, fmt.Errorf("LISTEN imagen_jobs: %w", err)
|
||||
}
|
||||
return &pgxQueue{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (q *pgxQueue) Close() {
|
||||
if q == nil || q.conn == nil {
|
||||
return
|
||||
}
|
||||
// Best-effort: a 5s budget is enough to send a polite TerminateMessage.
|
||||
shutdown, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = q.conn.Close(shutdown)
|
||||
}
|
||||
|
||||
// ClaimNextPending atomically marks the oldest pending row 'running' and
|
||||
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
|
||||
// process — out of scope for v1 but cheap insurance.
|
||||
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
|
||||
// series_id is nullable on imagen.jobs (solo run when NULL); cast to text
|
||||
// with COALESCE so pgx scans into a plain Go string. Empty string =
|
||||
// solo run; the pipeline skips series propagation in that case.
|
||||
const stmt = `
|
||||
UPDATE imagen.jobs
|
||||
SET status='running', started_at=now()
|
||||
WHERE id = (
|
||||
SELECT id FROM imagen.jobs
|
||||
WHERE status='pending'
|
||||
ORDER BY created_at
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
RETURNING id, owner_user_id, prompt, backend,
|
||||
COALESCE(model,''),
|
||||
COALESCE(width, 0), COALESCE(height, 0),
|
||||
COALESCE(steps, 0), COALESCE(seed, 0),
|
||||
COALESCE(style,''),
|
||||
COALESCE(series_id::text, '')`
|
||||
var j worker.Job
|
||||
err := q.conn.QueryRow(ctx, stmt).Scan(
|
||||
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
|
||||
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
|
||||
&j.SeriesID,
|
||||
)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &j, nil
|
||||
}
|
||||
|
||||
func (q *pgxQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
|
||||
_, err := q.conn.Exec(ctx,
|
||||
`UPDATE imagen.jobs SET status='done', image_id=$2, completed_at=now() WHERE id=$1`,
|
||||
jobID, imageID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *pgxQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
|
||||
// Trim outrageously long error text so a 10MB stack-trace doesn't end up
|
||||
// in the row (callers see a summary, full text goes to stderr / logs).
|
||||
const maxLen = 2000
|
||||
if len(msg) > maxLen {
|
||||
msg = msg[:maxLen] + "... [truncated]"
|
||||
}
|
||||
_, err := q.conn.Exec(ctx,
|
||||
`UPDATE imagen.jobs SET status='failed', error=$2, completed_at=now() WHERE id=$1`,
|
||||
jobID, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
// WaitForJob blocks until a NOTIFY arrives on imagen_jobs, the timeout fires,
|
||||
// or ctx is cancelled. Notifications during a previous processJob are queued
|
||||
// by pgx and delivered on the next call — we don't lose wake-ups even when
|
||||
// processing took longer than poll-interval.
|
||||
func (q *pgxQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
_, err := q.conn.WaitForNotification(waitCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil // poll cadence fired
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return context.Canceled
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetStaleRunning bumps any rows stuck in 'running' back to 'pending' so
|
||||
// they get re-claimed. Called once at startup. A row stuck in 'running' came
|
||||
// from a previous worker crash; without this, flexsiebels would poll
|
||||
// forever on a job nobody is processing.
|
||||
func (q *pgxQueue) ResetStaleRunning(ctx context.Context) error {
|
||||
_, err := q.conn.Exec(ctx,
|
||||
`UPDATE imagen.jobs SET status='pending', started_at=NULL WHERE status='running'`)
|
||||
return err
|
||||
}
|
||||
|
||||
// workerPipeline is the Pipeline implementation that drives a single job
|
||||
// through buildBackend → prompt enrichment → generate → write disk →
|
||||
// cloud-sync, then returns the imagen.images.id back to the worker so it
|
||||
// can link the row.
|
||||
type workerPipeline struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func (p *workerPipeline) Run(ctx context.Context, job worker.Job) worker.Outcome {
|
||||
if job.OwnerUserID == "" {
|
||||
return worker.Outcome{Err: fmt.Errorf("job %s: missing owner_user_id", job.ID)}
|
||||
}
|
||||
if job.Prompt == "" {
|
||||
return worker.Outcome{Err: fmt.Errorf("job %s: empty prompt", job.ID)}
|
||||
}
|
||||
if job.Backend == "" {
|
||||
return worker.Outcome{Err: fmt.Errorf("job %s: missing backend", job.ID)}
|
||||
}
|
||||
|
||||
be, err := buildBackend(p.cfg, job.Backend)
|
||||
if err != nil {
|
||||
return worker.Outcome{Err: fmt.Errorf("backend %q: %w", job.Backend, err)}
|
||||
}
|
||||
attachUsageSink(be)
|
||||
|
||||
finalPrompt, err := prompt.Apply(job.Prompt, job.Style)
|
||||
if err != nil {
|
||||
return worker.Outcome{Err: fmt.Errorf("style: %w", err)}
|
||||
}
|
||||
|
||||
req := backend.Request{
|
||||
Prompt: finalPrompt,
|
||||
Width: job.Width,
|
||||
Height: job.Height,
|
||||
Steps: job.Steps,
|
||||
Seed: job.Seed,
|
||||
Style: job.Style,
|
||||
}
|
||||
res, err := be.Generate(ctx, req)
|
||||
if err != nil {
|
||||
return worker.Outcome{Err: fmt.Errorf("generate: %w", err)}
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
|
||||
writer := buildWriter(p.cfg, false)
|
||||
in := output.Inputs{
|
||||
Prompt: job.Prompt,
|
||||
Backend: be.Name(),
|
||||
Seed: seedFromMetadata(res.Metadata, job.Seed),
|
||||
Ext: extFromMime(res.MimeType),
|
||||
Metadata: res.Metadata,
|
||||
}
|
||||
paths, err := writer.Write(res.ImageReader, in)
|
||||
if err != nil {
|
||||
return worker.Outcome{Err: fmt.Errorf("write disk: %w", err)}
|
||||
}
|
||||
|
||||
// Worker is queue-driven: cloud-sync is mandatory because flexsiebels
|
||||
// needs imagen.images.id to render the result. Pass cloud_sync=on via
|
||||
// the override path (third arg = ownerUserID); we set the mode by
|
||||
// disallowing the 'off' branch through the cfg later if the user
|
||||
// explicitly turned it off in config.
|
||||
if cloudModeOff(p.cfg) {
|
||||
// We refuse to silently drop a queued job. If cloud sync is off in
|
||||
// config, the worker can't serve flexsiebels at all.
|
||||
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
|
||||
}
|
||||
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, job.SeriesID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
|
||||
if syncErr != nil {
|
||||
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
|
||||
}
|
||||
if syncRes == nil || syncRes.ImageID == "" {
|
||||
return worker.Outcome{Err: fmt.Errorf("cloud sync returned no imagen.images id (check SUPABASE_URL + SUPABASE_SERVICE_KEY)")}
|
||||
}
|
||||
return worker.Outcome{ImageID: syncRes.ImageID}
|
||||
}
|
||||
|
||||
func cloudModeOff(cfg *config.Config) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(cfg.Output.CloudSync, "off")
|
||||
}
|
||||
|
||||
// dimOrFallback returns job.<dim> when the job specified one, otherwise the
|
||||
// dimension reported by the backend's metadata. Some backends (Replicate
|
||||
// when given an aspect ratio) round the requested size to their nearest
|
||||
// supported value; this keeps the row honest about what was actually generated.
|
||||
func dimOrFallback(jobDim int, res *backend.Result, key string) int {
|
||||
if jobDim > 0 {
|
||||
return jobDim
|
||||
}
|
||||
return metaInt(res.Metadata, key)
|
||||
}
|
||||
129
cmd/imagen/worker_integration_test.go
Normal file
129
cmd/imagen/worker_integration_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/worker"
|
||||
)
|
||||
|
||||
// TestWorker_Integration_EndToEnd runs the full pipeline against a real
|
||||
// msupabase instance: insert a row into imagen.jobs, let the worker claim
|
||||
// it, generate via the mock backend (no Replicate spend, no ComfyUI
|
||||
// dependency), write to Supabase Storage + imagen.images, then flip the job
|
||||
// to 'done' with the linked image_id.
|
||||
//
|
||||
// Guarded by IMAGEN_WORKER_INTEGRATION=1. Required env beyond that:
|
||||
//
|
||||
// IMAGEN_WORKER_DATABASE_URL postgres DSN (direct, not PostgREST)
|
||||
// SUPABASE_URL e.g. https://supa.flexsiebels.de
|
||||
// SUPABASE_SERVICE_KEY service-role JWT
|
||||
// IMAGEN_OWNER_USER_ID UUID of an auth.users row (RLS fallback)
|
||||
//
|
||||
// The test creates and later deletes its own job row so repeated runs don't
|
||||
// leave debris.
|
||||
func TestWorker_Integration_EndToEnd(t *testing.T) {
|
||||
if os.Getenv("IMAGEN_WORKER_INTEGRATION") != "1" {
|
||||
t.Skip("set IMAGEN_WORKER_INTEGRATION=1 to run the integration test")
|
||||
}
|
||||
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Fatal("IMAGEN_WORKER_DATABASE_URL must be set for the integration test")
|
||||
}
|
||||
if os.Getenv("SUPABASE_URL") == "" || os.Getenv("SUPABASE_SERVICE_KEY") == "" {
|
||||
t.Fatal("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set for the integration test")
|
||||
}
|
||||
owner := os.Getenv("IMAGEN_OWNER_USER_ID")
|
||||
if owner == "" {
|
||||
t.Fatal("IMAGEN_OWNER_USER_ID must be set for the integration test")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
q, err := dialQueue(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("dialQueue: %v", err)
|
||||
}
|
||||
defer q.Close()
|
||||
|
||||
// Insert the test job on a separate connection (the worker's conn is
|
||||
// busy LISTENing). Mock backend = no external dependency.
|
||||
insertConn, err := pgx.Connect(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("insert conn: %v", err)
|
||||
}
|
||||
defer insertConn.Close(ctx)
|
||||
|
||||
var jobID string
|
||||
prompt := fmt.Sprintf("imagen integration test %d", time.Now().UnixNano())
|
||||
err = insertConn.QueryRow(ctx, `
|
||||
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
|
||||
VALUES ($1, $2, 'mock', 64, 64)
|
||||
RETURNING id`,
|
||||
owner, prompt).Scan(&jobID)
|
||||
if err != nil {
|
||||
t.Fatalf("insert job: %v", err)
|
||||
}
|
||||
t.Logf("inserted imagen.jobs id=%s", jobID)
|
||||
// Tidy up at the end of the test so a re-run starts clean.
|
||||
defer func() {
|
||||
cleanup, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _ = insertConn.Exec(cleanup, `DELETE FROM imagen.jobs WHERE id=$1`, jobID)
|
||||
}()
|
||||
|
||||
// Use a per-test temp dir so the generated PNG doesn't litter the repo.
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{Output: config.OutputConfig{Directory: tmpDir}}
|
||||
p := &workerPipeline{cfg: cfg}
|
||||
w := worker.New(q, p, worker.Config{
|
||||
PollInterval: 1 * time.Second,
|
||||
JobTimeout: 30 * time.Second,
|
||||
Logger: func(format string, a ...any) { t.Logf("worker: "+format, a...) },
|
||||
})
|
||||
|
||||
// Run the worker until it processes one job (the one we just inserted)
|
||||
// or the test context times out.
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = w.Run(runCtx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll for completion.
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
var status, imageID string
|
||||
for time.Now().Before(deadline) {
|
||||
err = insertConn.QueryRow(ctx,
|
||||
`SELECT status, COALESCE(image_id::text,'') FROM imagen.jobs WHERE id=$1`,
|
||||
jobID).Scan(&status, &imageID)
|
||||
if err != nil {
|
||||
t.Fatalf("poll: %v", err)
|
||||
}
|
||||
if status == "done" || status == "failed" {
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
runCancel()
|
||||
<-done
|
||||
|
||||
if status != "done" {
|
||||
var errText string
|
||||
_ = insertConn.QueryRow(ctx,
|
||||
`SELECT COALESCE(error,'') FROM imagen.jobs WHERE id=$1`, jobID).Scan(&errText)
|
||||
t.Fatalf("job not done within timeout: status=%q error=%q", status, errText)
|
||||
}
|
||||
if imageID == "" {
|
||||
t.Fatalf("job done but image_id is empty")
|
||||
}
|
||||
t.Logf("job done: image_id=%s", imageID)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
|
||||
|
||||
```
|
||||
┌───────────────────────┐
|
||||
│ cmd/imagen │ CLI dispatch
|
||||
│ cmd/imagen │ CLI dispatch (generate / worker / …)
|
||||
│ (or HTTP server) │
|
||||
└──────────┬────────────┘
|
||||
│
|
||||
@@ -15,6 +15,10 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
|
||||
│ internal/prompt │ style preset → prompt suffix
|
||||
│ internal/output │ filename templating, sidecar
|
||||
│ internal/config │ YAML loader, validation
|
||||
│ internal/preview │ tmux-img window spawner
|
||||
│ internal/cloud │ Supabase Storage + imagen.images
|
||||
│ internal/usage │ mai.imagen_usage cost-tracking
|
||||
│ internal/worker │ imagen.jobs queue consumer
|
||||
└──────────┬────────────┘
|
||||
│
|
||||
┌──────────▼────────────┐
|
||||
@@ -102,9 +106,37 @@ contains the prompt, backend instance name, seed, ISO timestamp, and the
|
||||
- Network errors during `Generate` — wrap and return; no retry policy yet
|
||||
(decide per-adapter, or move to a shared retry helper if a pattern emerges).
|
||||
|
||||
## Async write path: `imagen worker` + `imagen.jobs`
|
||||
|
||||
`imagen generate` is the synchronous CLI. For web callers (flexsiebels'
|
||||
owner-mode UI) `cmd/imagen worker` runs as a daemon that consumes the
|
||||
`imagen.jobs` table.
|
||||
|
||||
```
|
||||
flexsiebels POST imagen worker (mRiver, systemd)
|
||||
→ INSERT INTO LISTEN imagen_jobs ◄── pg_notify trigger
|
||||
imagen.jobs(pending) claim row (UPDATE … RETURNING)
|
||||
dispatch through internal/backend
|
||||
write disk + cloud-sync via internal/cloud
|
||||
UPDATE imagen.jobs SET status='done', image_id=…
|
||||
```
|
||||
|
||||
The queue table lives next to `imagen.images` in the same `imagen` schema.
|
||||
Owner-scoped RLS lets the flexsiebels user INSERT + read their own rows;
|
||||
the worker writes (status updates + image_id link) via service-role which
|
||||
bypasses RLS. A 5-second safety poll fires on every wake-up to cover
|
||||
dropped NOTIFY events and worker cold starts with a non-empty queue. See
|
||||
`docs/setup-worker-mriver.md` for the systemd installation.
|
||||
|
||||
The worker reuses `internal/backend`, `internal/output`, and
|
||||
`internal/cloud` unchanged — it is purely an orchestration layer around
|
||||
the same pipeline `imagen generate` drives.
|
||||
|
||||
## Out of scope (today)
|
||||
|
||||
- Image post-processing (cropping, watermarking).
|
||||
- Cost-tracking (lands with the Replicate adapter, since only API backends bill).
|
||||
- Multi-image `n>1` per request — backends that support it can expose it via
|
||||
`BackendOpts`; the framework doesn't have a first-class field yet.
|
||||
- Job cancellation / kill switch — separate follow-up issue.
|
||||
- Concurrent workers / multi-host scale-out — `FOR UPDATE SKIP LOCKED` in
|
||||
the claim query makes it cheap to add, but a single worker is the v1 setup.
|
||||
|
||||
310
docs/backends.md
Normal file
310
docs/backends.md
Normal file
@@ -0,0 +1,310 @@
|
||||
# ImaGen backends
|
||||
|
||||
This document covers the local-ComfyUI backend plug-in story: how adapters
|
||||
are layered, how to add a new model without touching Go, and the per-model
|
||||
setup steps for the bundled templates.
|
||||
|
||||
For the host-side ComfyUI install (mRock — venv, weights for the default
|
||||
FLUX.1-schnell, systemd, VRAM coexistence with Ollama, smoke test against
|
||||
the raw HTTP API), see [`setup-comfyui-mrock.md`](setup-comfyui-mrock.md).
|
||||
|
||||
## Architecture: Path 1 — workflow-template adapter
|
||||
|
||||
`imagen generate` and `imagen compare` dispatch through the `comfyui`
|
||||
adapter, which holds the HTTP plumbing (`/prompt`, `/history/{id}`, `/view`,
|
||||
`/system_stats`) and treats the workflow itself as data. Each backend
|
||||
instance in `imagen.yaml` picks a workflow JSON via the `workflow:` key.
|
||||
Adding a new model is yaml + JSON, never Go:
|
||||
|
||||
```
|
||||
internal/backend/
|
||||
comfyui.go # one adapter, all ComfyUI models
|
||||
workflow_template.go # loader + token-substitution
|
||||
workflows/
|
||||
flux1-schnell.json # bundled templates (embedded with //go:embed)
|
||||
flux2-klein.json
|
||||
sd35-medium.json
|
||||
```
|
||||
|
||||
### Why Path 1 over per-family adapters (`comfyui-flux.go`, `comfyui-sd3.go`…)
|
||||
|
||||
- **Workflow JSON is the natural exchange format**. ComfyUI users export
|
||||
workflows from its GUI as JSON. Anything else means rebuilding the graph
|
||||
by hand in Go for every new model.
|
||||
- **Adding a model is a config change, not a build change**. With Path 2,
|
||||
every new family is a Go file, a new test file, a registry entry, a new
|
||||
worker binary, a redeploy. Path 1 lets us land a new model with one yaml
|
||||
block + one JSON file + one section in this doc.
|
||||
- **The HTTP plumbing is identical across families**. `/prompt`,
|
||||
`/history`, `/view`, the retry policy, the "value not in list" hint, VRAM
|
||||
reporting — none of it depends on the workflow shape. Path 2 would
|
||||
duplicate that across files.
|
||||
- **Failure isolation stays clean**. The workflow loader fails at adapter
|
||||
construction (`imagen backends` surfaces the error), the HTTP layer
|
||||
fails at `Generate`, and ComfyUI's own validation surfaces missing-model
|
||||
hints. Each layer's error message points at the right config knob.
|
||||
|
||||
Path 2's argument was "each family owns its quirks (samplers, schedulers,
|
||||
dual-stage etc.)". That argument doesn't survive contact with the
|
||||
substitution-map design: per-family knobs are just key/value fields in the
|
||||
yaml block and `${shift}`/`${guidance}`/`${cfg}` placeholders in the
|
||||
template. No code duplication, no inheritance to debug.
|
||||
|
||||
### Token substitution
|
||||
|
||||
`workflow_template.SubstituteWorkflow` walks the parsed JSON and replaces
|
||||
every whole-value string of the form `"${key}"` with the typed value from
|
||||
the substitution map. Numbers stay numbers, strings stay strings — no
|
||||
round-tripping through `strings.Replace`.
|
||||
|
||||
The substitution map is built per call from:
|
||||
|
||||
1. **Request fields** (always present): `${prompt}`, `${negative}`,
|
||||
`${width}`, `${height}`, `${seed}`, `${steps}`, `${sampler}`,
|
||||
`${scheduler}`, `${cfg}`.
|
||||
2. **Every scalar field from the yaml block** (string / int / int64 /
|
||||
float64 / bool), minus framework keys (`type`, `base_url`, `workflow`,
|
||||
`default_*`). So `${vae}`, `${clip}`, `${clip_l}`, `${clip_t5}`,
|
||||
`${dtype}`, `${shift}`, `${guidance}` all become substitutable just by
|
||||
being in yaml.
|
||||
3. **Sensible defaults** for the common optional knobs above, so a
|
||||
workflow that references `${dtype}` without the user setting one in
|
||||
yaml still substitutes cleanly (`fp8_e4m3fn` for FLUX, `3.0` for SD3
|
||||
shift, etc.). Extra defaults are ignored by workflows that don't
|
||||
reference them.
|
||||
|
||||
Partial matches (e.g. `"prefix ${prompt} suffix"`) are deliberately **not**
|
||||
substituted — the placeholder must be the entire value so we can preserve
|
||||
its JSON type. This prevents a prompt containing literal `${seed}` text
|
||||
from corrupting the workflow.
|
||||
|
||||
Unknown placeholders (referenced in JSON but missing from the substitution
|
||||
map) error out before the workflow leaves the binary.
|
||||
|
||||
### Back-compat
|
||||
|
||||
The `workflow:` field defaults to `flux1-schnell` if omitted. Existing
|
||||
yaml blocks like the pre-#10 FLUX.1-schnell instance:
|
||||
|
||||
```yaml
|
||||
flux-schnell-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
model: flux1-schnell.safetensors
|
||||
```
|
||||
|
||||
still work unchanged — they implicitly pick up the migrated
|
||||
`flux1-schnell.json` template, which keeps the same node IDs (6, 8, 9, 10,
|
||||
11, 12, 13, 27, 30, 31) as the historical hardcoded workflow.
|
||||
|
||||
## Bundled workflows
|
||||
|
||||
### FLUX.1-schnell — the back-compat default
|
||||
|
||||
| Field | Default | Notes |
|
||||
|---|---|---|
|
||||
| `model` | `flux1-schnell.safetensors` | drop in `models/unet/` |
|
||||
| `vae` | `ae.safetensors` | `models/vae/` |
|
||||
| `clip_l` | `clip_l.safetensors` | `models/clip/` |
|
||||
| `clip_t5` | `t5xxl_fp8_e4m3fn.safetensors` | `models/clip/` |
|
||||
| `dtype` | `fp8_e4m3fn` | weight dtype for the UNet loader |
|
||||
| `default_steps` / `default_cfg` | 4 / 1.0 | schnell is distilled to ~4 steps |
|
||||
|
||||
VRAM peak ~10–12 GB at 1024×1024. Install path:
|
||||
[`setup-comfyui-mrock.md`](setup-comfyui-mrock.md). Already shipping.
|
||||
|
||||
### FLUX.2 [klein] 4B — direct upgrade
|
||||
|
||||
Released by Black Forest Labs late 2025 / early 2026, BFL non-commercial
|
||||
license. The distilled 4B "klein" variant lands sub-second on the RTX
|
||||
4070 Ti SUPER and shares the new Qwen-based text encoder + a re-trained
|
||||
VAE with the larger family.
|
||||
|
||||
```yaml
|
||||
flux2-klein-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
workflow: flux2-klein
|
||||
model: flux-2-klein-base-4b-fp8.safetensors # models/unet/
|
||||
vae: flux2-vae.safetensors # models/vae/
|
||||
clip: qwen_3_4b.safetensors # models/text_encoders/
|
||||
dtype: fp8_e4m3fn
|
||||
default_steps: 4
|
||||
default_cfg: 1.0
|
||||
guidance: 4.0
|
||||
```
|
||||
|
||||
**Model downloads** (on mRock, ungated mirrors when available):
|
||||
|
||||
```bash
|
||||
cd ~/dev/comfyui/models
|
||||
curl -L -o unet/flux-2-klein-base-4b-fp8.safetensors \
|
||||
https://huggingface.co/black-forest-labs/FLUX.2-klein/resolve/main/flux-2-klein-base-4b-fp8.safetensors
|
||||
curl -L -o vae/flux2-vae.safetensors \
|
||||
https://huggingface.co/black-forest-labs/FLUX.2-klein/resolve/main/flux2-vae.safetensors
|
||||
mkdir -p text_encoders
|
||||
curl -L -o text_encoders/qwen_3_4b.safetensors \
|
||||
https://huggingface.co/black-forest-labs/FLUX.2-klein/resolve/main/qwen_3_4b.safetensors
|
||||
```
|
||||
|
||||
BFL's primary repo is gated; if `curl` returns 401, configure an HF token
|
||||
in `~/.cache/huggingface/token` or use one of the community mirrors
|
||||
(check the official model card for the current list). The filenames the
|
||||
template references match BFL's canonical names — rename downloads to
|
||||
match if a mirror uses different ones.
|
||||
|
||||
VRAM peak: ~8.5 GB (4B fp8). With Ollama parked at ~8 GB this still fits;
|
||||
unlike FLUX.1-schnell, klein doesn't require stopping Ollama on mRock.
|
||||
|
||||
### SD3.5-medium — single-checkpoint variant
|
||||
|
||||
Stability AI's 2.5B mid-size model with bundled text encoders. The
|
||||
`incl_clips_t5xxlfp8scaled` variant ships clip_g + clip_l + t5xxl_fp8 all
|
||||
in one `.safetensors`, so the workflow uses `CheckpointLoaderSimple`
|
||||
instead of separate UNet/VAE/CLIP loaders.
|
||||
|
||||
```yaml
|
||||
sd35-medium-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
workflow: sd35-medium
|
||||
model: sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors # models/checkpoints/
|
||||
default_steps: 28
|
||||
default_sampler: dpmpp_2m
|
||||
default_scheduler: sgm_uniform
|
||||
default_cfg: 4.5
|
||||
shift: 3.0
|
||||
```
|
||||
|
||||
**Model download** (on mRock):
|
||||
|
||||
```bash
|
||||
cd ~/dev/comfyui/models
|
||||
curl -L -o checkpoints/sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors \
|
||||
https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/resolve/main/sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors
|
||||
```
|
||||
|
||||
VRAM peak: ~9.9 GB at 1024×1024. Same envelope as FLUX.1-schnell — stop
|
||||
Ollama before generating, restart after.
|
||||
|
||||
## Adding a new bundled workflow
|
||||
|
||||
1. **Export from ComfyUI**: load the model in the ComfyUI GUI, build a
|
||||
text-to-image workflow that produces what you want, "Save (API
|
||||
Format)" — the file you get is the right shape.
|
||||
2. **Sprinkle placeholders**: open the JSON and replace per-call values
|
||||
with `${name}` tokens. Whole-value substitution only:
|
||||
|
||||
```json
|
||||
"inputs": {
|
||||
"text": "${prompt}", // was "a cat sitting on a chair"
|
||||
"seed": "${seed}", // was 1234567
|
||||
"steps": "${steps}", // was 28
|
||||
"cfg": "${cfg}",
|
||||
"sampler_name": "${sampler}",
|
||||
"scheduler": "${scheduler}",
|
||||
"width": "${width}",
|
||||
"height": "${height}"
|
||||
}
|
||||
```
|
||||
|
||||
Use `${model}` for the checkpoint / unet filename and any per-template
|
||||
knobs (`${vae}`, `${shift}`, `${guidance}`, `${clip}` …).
|
||||
3. **Drop it into `internal/backend/workflows/<name>.json`**. The
|
||||
`//go:embed workflows/*.json` directive in `workflow_template.go`
|
||||
picks it up at build time — no registry entry needed.
|
||||
4. **Add a yaml instance** in `internal/config/config.go`'s `Sample` block
|
||||
for `imagen config init` (and `~/.config/imagen.yaml`) so users
|
||||
discover the new backend.
|
||||
5. **Document the model files + HF download URLs** in this doc.
|
||||
6. **Smoke test**: `imagen generate "test" --backend <new-instance>
|
||||
--size 1024x1024` should produce an image.
|
||||
|
||||
Per-call overrides for sampler/scheduler/cfg go via `--steps`, `--seed`,
|
||||
and (programmatic) `backend.Request.BackendOpts["sampler"]` /
|
||||
`["scheduler"]` / `["cfg"]`. The compare harness forwards the
|
||||
constant-across-backends knobs verbatim.
|
||||
|
||||
## Loading a workflow from disk (one-off)
|
||||
|
||||
Pass an absolute filesystem path as `workflow:` and the adapter reads it
|
||||
from disk instead of the embedded FS. Handy for prototyping a new model
|
||||
before committing it:
|
||||
|
||||
```yaml
|
||||
my-experimental:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
workflow: /home/m/dev/comfyui/workflows/my-test.json
|
||||
model: my-test-model.safetensors
|
||||
```
|
||||
|
||||
The fallback chain is: filesystem path (if the string looks like a path
|
||||
or ends in `.json`), then bundled lookup by name, then bundled lookup
|
||||
with `.json` appended.
|
||||
|
||||
## `imagen compare`: cross-backend evaluation
|
||||
|
||||
```bash
|
||||
imagen compare "a wizard casting a spell" \
|
||||
--models flux-schnell-local,flux2-klein-local,sd35-medium-local \
|
||||
--size 1024x1024 \
|
||||
--output ~/Pictures/imagen/compare
|
||||
```
|
||||
|
||||
Per run, `compare`:
|
||||
|
||||
- creates `<output>/<YYYYMMDD-HHMMSS>-<prompt-slug>/`
|
||||
- dispatches each named backend sequentially (mRock has one GPU; parallel
|
||||
would OOM) — one backend's failure doesn't abort the run
|
||||
- writes per-backend PNGs as `<prompt-slug>--<backend-slug>.png`
|
||||
- writes `compare.json` listing every attempt (success + failure) with
|
||||
per-model `seed`, `latency_ms`, `model`, `vram_used_mib`, full
|
||||
`metadata` map, and the error string for any failure
|
||||
- composites a `contact-sheet.png` with the prompt as header and each
|
||||
cell labelled `<backend>` / `<latency>ms · seed <n>`
|
||||
|
||||
Flags mirror `generate`: `--seed`, `--steps`, `--style`, `--negative`,
|
||||
`--size` are shared across all backends. `--no-contact-sheet` skips the
|
||||
composite when only the per-image PNGs and sidecar matter (e.g. for a
|
||||
worker script that builds its own diff view).
|
||||
|
||||
## Diagnostics
|
||||
|
||||
`imagen backends` shows every instance with its registration state. For
|
||||
local ComfyUI, the status is currently just `registered` (we don't probe
|
||||
the upstream HTTP endpoint at startup — the boot-helper hint kicks in on
|
||||
first generation if mRock is asleep).
|
||||
|
||||
Per-backend errors emit at most three kinds:
|
||||
|
||||
1. **Adapter construction failure** (e.g. workflow JSON not found,
|
||||
missing required yaml field). Caught at `buildBackend` time:
|
||||
`imagen: backend "<name>": <err>`.
|
||||
2. **HTTP / runtime failure during Generate**. Wrapped with the boot
|
||||
helper for `connection refused`/`no such host`/timeouts pointing at
|
||||
`boot-whitetower mrock` so a sleeping mRock has an obvious next step.
|
||||
3. **ComfyUI workflow-validation failure** (200-with-node_errors or 400).
|
||||
Surfaces with a model-not-found hint (matching `value_not_in_list` +
|
||||
`unet_name`/`ckpt_name`) when applicable, pointing back at this doc.
|
||||
|
||||
## Worker daemon notes
|
||||
|
||||
`imagen worker` (the `imagen.jobs` queue consumer) uses the same adapter
|
||||
+ workflow lookup as the synchronous CLI — flexsiebels' `/imagine` UI
|
||||
INSERTs a `backend = <instance>` row, the worker claims it, and the
|
||||
underlying ComfyUI HTTP calls are identical to what `generate` makes. No
|
||||
worker-specific changes are required when a new backend lands; the
|
||||
config + workflow are the only state that has to be present on the
|
||||
worker host.
|
||||
|
||||
After merging a new template or yaml block:
|
||||
|
||||
```bash
|
||||
# On the worker host (mRiver today):
|
||||
systemctl --user restart imagen-worker
|
||||
```
|
||||
|
||||
The daemon-rebuild trap from issue #9 still applies: if you build the
|
||||
imagen binary on the dev machine and `scp` it over, restart the unit so
|
||||
systemd picks up the new ELF.
|
||||
97
docs/setup-worker-mriver.md
Normal file
97
docs/setup-worker-mriver.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# `imagen worker` on mRiver
|
||||
|
||||
The worker is a long-running daemon that consumes the `imagen.jobs` queue
|
||||
(written by flexsiebels' owner-mode UI) and writes the resulting image to
|
||||
Supabase Storage + `imagen.images` via the same cloud-sync path the CLI
|
||||
`imagen generate` uses.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
flexsiebels (owner UI)
|
||||
|
|
||||
v INSERT INTO imagen.jobs (...)
|
||||
|
|
||||
msupabase Postgres
|
||||
|
|
||||
| AFTER INSERT trigger:
|
||||
| pg_notify('imagen_jobs', NEW.id)
|
||||
v
|
||||
imagen worker (mRiver) ── LISTEN imagen_jobs
|
||||
|
|
||||
| 1. claim oldest 'pending' row (status='running')
|
||||
| 2. dispatch to backend (FLUX schnell local / FLUX dev replicate / …)
|
||||
| 3. write PNG to disk
|
||||
| 4. upload to Storage + INSERT into imagen.images
|
||||
| 5. UPDATE imagen.jobs SET status='done', image_id=...
|
||||
v
|
||||
flexsiebels polls GET .../jobs/<id> → renders the rendered card
|
||||
```
|
||||
|
||||
A 5-second safety poll covers dropped NOTIFY events and worker cold starts
|
||||
with a non-empty queue.
|
||||
|
||||
## One-time setup
|
||||
|
||||
```bash
|
||||
# 1. Build the binary (or `task build`).
|
||||
cd ~/dev/ImaGen
|
||||
go build -o bin/imagen ./cmd/imagen
|
||||
|
||||
# 2. Write the environment file.
|
||||
cp scripts/imagen-worker.env.example ~/.dotfiles/.env.imagen-worker
|
||||
chmod 600 ~/.dotfiles/.env.imagen-worker
|
||||
$EDITOR ~/.dotfiles/.env.imagen-worker # fill in real DSN, service key
|
||||
|
||||
# 3. Install the user systemd unit.
|
||||
mkdir -p ~/.config/systemd/user
|
||||
cp scripts/imagen-worker.service ~/.config/systemd/user/imagen-worker.service
|
||||
systemctl --user daemon-reload
|
||||
systemctl --user enable --now imagen-worker.service
|
||||
|
||||
# 4. Tail the logs.
|
||||
journalctl --user -u imagen-worker -f
|
||||
```
|
||||
|
||||
## Required env vars
|
||||
|
||||
See `scripts/imagen-worker.env.example` for the canonical list. Required:
|
||||
|
||||
- `IMAGEN_WORKER_DATABASE_URL` — direct Postgres DSN. PostgREST cannot LISTEN.
|
||||
- `SUPABASE_URL`, `SUPABASE_SERVICE_KEY` — same pair `imagen generate`
|
||||
reads for the cloud-sync writer.
|
||||
- `IMAGEN_OWNER_USER_ID` — fallback owner UUID; per-job row's
|
||||
`owner_user_id` overrides this.
|
||||
|
||||
Optional, depending on enabled backends:
|
||||
|
||||
- `REPLICATE_API_TOKEN` if any job will request a Replicate-typed backend.
|
||||
|
||||
## Operating
|
||||
|
||||
```bash
|
||||
systemctl --user status imagen-worker # health
|
||||
systemctl --user restart imagen-worker # pick up a new binary
|
||||
journalctl --user -u imagen-worker -n 200 # recent log lines
|
||||
```
|
||||
|
||||
On startup the worker calls `ResetStaleRunning` once, flipping any rows
|
||||
left in `'running'` from a previous crash back to `'pending'` so they get
|
||||
re-claimed by the 5-second poll.
|
||||
|
||||
## Smoke test
|
||||
|
||||
With the worker running, INSERT a test job:
|
||||
|
||||
```sql
|
||||
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
|
||||
VALUES (
|
||||
'ac6c9501-3757-4a6d-8b97-2cff4288382b',
|
||||
'a tiny owl wearing wire-rim glasses, photo',
|
||||
'flux-schnell-local', 1024, 1024
|
||||
);
|
||||
```
|
||||
|
||||
Within ~10 seconds the row should show `status='done'`, a populated
|
||||
`image_id` linking to a real `imagen.images` row, and a Storage object at
|
||||
`<YYYY-MM-DD>/<slug>-<seed>.png` in the `imagen-generated` bucket.
|
||||
@@ -4,14 +4,21 @@
|
||||
|
||||
```
|
||||
imagen generate <prompt> [flags] generate one image
|
||||
imagen compare <prompt> --models a,b,c [flags]
|
||||
run one prompt across N backends + contact sheet
|
||||
imagen worker [flags] consume the imagen.jobs queue (daemon)
|
||||
imagen backends list configured + registered backends
|
||||
imagen config init print a sample imagen.yaml on stdout
|
||||
imagen config validate parse + validate the active config
|
||||
imagen config path print the resolved config path
|
||||
imagen serve [--addr :8080] (stub) start the HTTP server
|
||||
imagen usage [--since DATE] show cost-tracking rows
|
||||
imagen version print version
|
||||
```
|
||||
|
||||
For the per-backend setup (FLUX.1, FLUX.2 [klein], SD3.5 medium, …) and
|
||||
the architecture rationale, see [`backends.md`](backends.md).
|
||||
|
||||
## `generate` flags
|
||||
|
||||
| Flag | Default | Notes |
|
||||
@@ -24,8 +31,29 @@ imagen version print version
|
||||
| `--negative` | empty | Negative prompt (ignored by some adapters) |
|
||||
| `--output` | empty (= use naming template) | Explicit path |
|
||||
| `--no-sidecar` | `false` | Skip the JSON sidecar even if config enables it |
|
||||
| `--preview` | (auto) | Force open a tmux preview window via `tmux-img` |
|
||||
| `--no-preview` | (auto) | Suppress the preview window (use for batch / CI callers) |
|
||||
| `--no-cloud` | `false` | Skip Supabase upload + `imagen.images` insert for this call |
|
||||
| `--config` | `~/.config/imagen.yaml` | Override config path |
|
||||
|
||||
### Preview window
|
||||
|
||||
After a successful generate, imagen optionally opens a sibling tmux window
|
||||
named `img:<slug>` running `tmux-img --hold <path>`. The new window is
|
||||
spawned in the background (`tmux new-window -d`) so the generating pane
|
||||
keeps focus and its terminal output.
|
||||
|
||||
Resolution order is **config → `$IMAGEN_PREVIEW` → flag** (later wins):
|
||||
|
||||
- `output.preview` in `imagen.yaml`: `auto` (default) | `on` | `off`
|
||||
- `IMAGEN_PREVIEW=auto|on|off` overrides config
|
||||
- `--preview` / `--no-preview` override env
|
||||
|
||||
`auto` previews iff stdout is a TTY *and* `$TMUX` is set. `on` previews
|
||||
unconditionally and errors outside a tmux session. `off` never previews.
|
||||
|
||||
Preview failures are non-fatal — the image already wrote.
|
||||
|
||||
## Examples
|
||||
|
||||
```sh
|
||||
@@ -71,3 +99,53 @@ API-backed adapters read tokens from env vars referenced by the config
|
||||
export REPLICATE_API_TOKEN=...
|
||||
imagen generate "a cat" --backend flux-dev-replicate
|
||||
```
|
||||
|
||||
## Cost-tracking (Replicate)
|
||||
|
||||
Successful generations through the Replicate adapter write one row to
|
||||
`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost
|
||||
estimate, prompt sha256 hash (never the prompt itself), and the caller
|
||||
identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`).
|
||||
|
||||
The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are
|
||||
unset, or the database write fails, the image still lands and the CLI
|
||||
prints a warning to stderr.
|
||||
|
||||
Inspect spend:
|
||||
|
||||
```sh
|
||||
imagen usage # all rows, grouped by week + backend + model + caller
|
||||
imagen usage --since 2026-05-01 # only rows on/after a UTC date
|
||||
imagen usage --since 2026-05-01 --raw
|
||||
```
|
||||
|
||||
Per-model rates live in `internal/backend/replicate_pricing.go` — they
|
||||
are snapshotted from <https://replicate.com/pricing> and refreshed on a
|
||||
quarterly cadence.
|
||||
|
||||
## Cloud-sync (Supabase)
|
||||
|
||||
Successful generations also upload the PNG to the private Supabase
|
||||
Storage bucket `imagen-generated` (path: `<YYYY-MM-DD>/<slug>-<seed>.png`)
|
||||
and insert a row into `imagen.images`. The row carries the prompt,
|
||||
sha256-hashed prompt, backend, model, seed/steps/width/height, latency,
|
||||
cost estimate, the full local sidecar JSON, and an empty `tags` array
|
||||
ready for the flexsiebels viewer to fill in.
|
||||
|
||||
Configuration:
|
||||
|
||||
- `owner_user_id` in `imagen.yaml` — m's `auth.users.id`. Empty disables
|
||||
inserts (the column is `NOT NULL`).
|
||||
- `output.cloud_sync` in `imagen.yaml`: `auto` (default — on iff
|
||||
SUPABASE creds + `owner_user_id` are set), `on` (errors if either is
|
||||
missing), `off`.
|
||||
- `IMAGEN_CLOUD_SYNC=auto|on|off` overrides config.
|
||||
- `--no-cloud` overrides everything for one call.
|
||||
|
||||
Reuses the same Supabase env (`SUPABASE_URL` + `SUPABASE_SERVICE_KEY` or
|
||||
`MAI_SUPABASE_KEY`) as cost-tracking. Service-role bypasses RLS for
|
||||
inserts; the `owner_user_id = auth.uid()` policy on the table gates the
|
||||
read path the flexsiebels viewer hits.
|
||||
|
||||
Failures (Storage 5xx, DB unreachable) emit `imagen: cloud sync: <err>`
|
||||
to stderr and the local PNG + sidecar stay put. Exit code is unchanged.
|
||||
|
||||
16
go.mod
16
go.mod
@@ -1,5 +1,17 @@
|
||||
module mgit.msbls.de/m/ImaGen
|
||||
|
||||
go 1.24
|
||||
go 1.25.0
|
||||
|
||||
require gopkg.in/yaml.v3 v3.0.1
|
||||
require (
|
||||
github.com/jackc/pgx/v5 v5.9.2
|
||||
golang.org/x/image v0.40.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
)
|
||||
|
||||
35
go.sum
35
go.sum
@@ -1,4 +1,37 @@
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/image v0.40.0 h1:Tw4GyDXMo+daZN1znreBRC3VayR1aLFUyUEOLUdW1a8=
|
||||
golang.org/x/image v0.40.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -20,24 +20,29 @@ import (
|
||||
const ComfyType = "comfyui"
|
||||
|
||||
// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history`
|
||||
// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from
|
||||
// the values in Request.
|
||||
// + `/view` HTTP API and submits a workflow built by substituting Request
|
||||
// values into a JSON template (bundled under internal/backend/workflows/ or
|
||||
// loaded from a filesystem path).
|
||||
//
|
||||
// Concurrency: a single Comfy is safe to share across goroutines as long as
|
||||
// the underlying http.Client is. Generate does not hold long-lived state.
|
||||
type Comfy struct {
|
||||
instance string
|
||||
|
||||
base string
|
||||
model string
|
||||
vae string
|
||||
clipL string
|
||||
clipT5 string
|
||||
dtype string
|
||||
base string
|
||||
workflow string
|
||||
|
||||
// rawCfg keeps the original yaml block (minus framework keys) so we can
|
||||
// expose every user-defined string/number as a workflow substitution
|
||||
// without enumerating each per-model knob in Go. Empty values still get
|
||||
// a substitution entry so a template can reference ${negative} when the
|
||||
// request didn't pass one.
|
||||
rawCfg map[string]any
|
||||
|
||||
defaultSteps int
|
||||
defaultSampler string
|
||||
defaultScheduler string
|
||||
defaultCFG float64
|
||||
|
||||
httpClient *http.Client
|
||||
pollInterval time.Duration
|
||||
@@ -49,12 +54,20 @@ type Comfy struct {
|
||||
}
|
||||
|
||||
// NewComfy is the registry constructor. cfg is the adapter's slice of
|
||||
// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX
|
||||
// schnell defaults.
|
||||
// imagen.yaml.
|
||||
//
|
||||
// Required keys: base_url, model.
|
||||
// Optional keys: workflow (defaults to "flux1-schnell" for back-compat with
|
||||
// existing configs), default_steps, default_sampler, default_scheduler,
|
||||
// default_cfg, plus any template-specific knobs (vae, clip, clip_l,
|
||||
// clip_t5, dtype, shift, guidance, …) the chosen workflow references.
|
||||
func NewComfy(name string, cfg map[string]any) (Backend, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("comfyui: empty instance name")
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{}
|
||||
}
|
||||
base := strings.TrimRight(getString(cfg, "base_url", ""), "/")
|
||||
if base == "" {
|
||||
return nil, fmt.Errorf("comfyui[%s]: base_url is required", name)
|
||||
@@ -67,23 +80,27 @@ func NewComfy(name string, cfg map[string]any) (Backend, error) {
|
||||
return nil, fmt.Errorf("comfyui[%s]: model is required", name)
|
||||
}
|
||||
|
||||
workflow := getString(cfg, "workflow", "flux1-schnell")
|
||||
// Fail fast on a bad workflow ref so users see the error at startup,
|
||||
// not on first /prompt submission.
|
||||
if _, err := LoadWorkflowTemplate(workflow); err != nil {
|
||||
return nil, fmt.Errorf("comfyui[%s]: %w", name, err)
|
||||
}
|
||||
|
||||
c := &Comfy{
|
||||
instance: name,
|
||||
base: base,
|
||||
model: model,
|
||||
|
||||
vae: getString(cfg, "vae", "ae.safetensors"),
|
||||
clipL: getString(cfg, "clip_l", "clip_l.safetensors"),
|
||||
clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"),
|
||||
dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"),
|
||||
workflow: workflow,
|
||||
rawCfg: cfg,
|
||||
|
||||
defaultSteps: getInt(cfg, "default_steps", 4),
|
||||
defaultSampler: getString(cfg, "default_sampler", "euler"),
|
||||
defaultScheduler: getString(cfg, "default_scheduler", "simple"),
|
||||
defaultCFG: getFloat(cfg, "default_cfg", 1.0),
|
||||
|
||||
httpClient: &http.Client{Timeout: 60 * time.Second},
|
||||
pollInterval: 250 * time.Millisecond,
|
||||
pollTimeout: 120 * time.Second,
|
||||
pollTimeout: 300 * time.Second,
|
||||
|
||||
randSeed: cryptoSeed,
|
||||
clientIDFn: randClientID,
|
||||
@@ -103,19 +120,26 @@ func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
|
||||
|
||||
sampler := c.defaultSampler
|
||||
scheduler := c.defaultScheduler
|
||||
cfg := c.defaultCFG
|
||||
if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" {
|
||||
sampler = v
|
||||
}
|
||||
if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" {
|
||||
scheduler = v
|
||||
}
|
||||
if v, ok := req.BackendOpts["cfg"].(float64); ok && v > 0 {
|
||||
cfg = v
|
||||
}
|
||||
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = c.randSeed()
|
||||
}
|
||||
|
||||
workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler)
|
||||
workflow, err := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("comfyui[%s]: build workflow: %w", c.instance, err)
|
||||
}
|
||||
clientID := c.clientIDFn()
|
||||
|
||||
start := time.Now()
|
||||
@@ -133,14 +157,17 @@ func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
|
||||
}
|
||||
latencyMs := time.Since(start).Milliseconds()
|
||||
|
||||
model := getString(c.rawCfg, "model", "")
|
||||
meta := map[string]any{
|
||||
"backend": c.instance,
|
||||
"backend_type": ComfyType,
|
||||
"model": c.model,
|
||||
"workflow": c.workflow,
|
||||
"model": model,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"sampler": sampler,
|
||||
"scheduler": scheduler,
|
||||
"cfg": cfg,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"latency_ms": latencyMs,
|
||||
@@ -173,6 +200,7 @@ func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clien
|
||||
return "", fmt.Errorf("comfyui: marshal workflow: %w", err)
|
||||
}
|
||||
|
||||
model := getString(c.rawCfg, "model", "")
|
||||
var lastErr error
|
||||
for attempt := range 2 {
|
||||
if attempt > 0 {
|
||||
@@ -196,7 +224,7 @@ func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clien
|
||||
_ = resp.Body.Close()
|
||||
switch {
|
||||
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
||||
return parsePromptID(respBody, c.model)
|
||||
return parsePromptID(respBody, model)
|
||||
case resp.StatusCode >= 500:
|
||||
lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody))
|
||||
continue
|
||||
@@ -333,98 +361,74 @@ func (c *Comfy) connError(err error) error {
|
||||
// workflow-validation failures and put the diagnostics in node_errors; older
|
||||
// builds use 200 + node_errors. This handles the 4xx flavour.
|
||||
func (c *Comfy) classifyBadRequest(status int, body []byte) error {
|
||||
if hint, ok := missingModelHint(body, c.model); ok {
|
||||
return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint)
|
||||
model := getString(c.rawCfg, "model", "")
|
||||
if hint, ok := missingModelHint(body, model); ok {
|
||||
return fmt.Errorf("comfyui /prompt %d: %s — see docs/backends.md", status, hint)
|
||||
}
|
||||
return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body))
|
||||
}
|
||||
|
||||
// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow,
|
||||
// node-IDs matching the upstream "flux-schnell" template so anyone debugging
|
||||
// in the ComfyUI UI sees a familiar shape.
|
||||
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any {
|
||||
return map[string]any{
|
||||
"6": map[string]any{
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": map[string]any{
|
||||
"text": prompt,
|
||||
"clip": []any{"11", 0},
|
||||
},
|
||||
},
|
||||
"8": map[string]any{
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": map[string]any{
|
||||
"samples": []any{"31", 0},
|
||||
"vae": []any{"10", 0},
|
||||
},
|
||||
},
|
||||
"9": map[string]any{
|
||||
"class_type": "SaveImage",
|
||||
"inputs": map[string]any{
|
||||
"filename_prefix": "imagen",
|
||||
"images": []any{"8", 0},
|
||||
},
|
||||
},
|
||||
"10": map[string]any{
|
||||
"class_type": "VAELoader",
|
||||
"inputs": map[string]any{"vae_name": c.vae},
|
||||
},
|
||||
"11": map[string]any{
|
||||
"class_type": "DualCLIPLoader",
|
||||
"inputs": map[string]any{
|
||||
"clip_name1": c.clipT5,
|
||||
"clip_name2": c.clipL,
|
||||
"type": "flux",
|
||||
},
|
||||
},
|
||||
"12": map[string]any{
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": map[string]any{
|
||||
"unet_name": c.model,
|
||||
"weight_dtype": c.dtype,
|
||||
},
|
||||
},
|
||||
"13": map[string]any{
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": map[string]any{
|
||||
"text": negative,
|
||||
"clip": []any{"11", 0},
|
||||
},
|
||||
},
|
||||
"27": map[string]any{
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": map[string]any{
|
||||
"width": w,
|
||||
"height": h,
|
||||
"batch_size": 1,
|
||||
},
|
||||
},
|
||||
"30": map[string]any{
|
||||
"class_type": "ModelSamplingFlux",
|
||||
"inputs": map[string]any{
|
||||
"model": []any{"12", 0},
|
||||
"max_shift": 1.15,
|
||||
"base_shift": 0.5,
|
||||
"width": w,
|
||||
"height": h,
|
||||
},
|
||||
},
|
||||
"31": map[string]any{
|
||||
"class_type": "KSampler",
|
||||
"inputs": map[string]any{
|
||||
"model": []any{"30", 0},
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"cfg": 1.0,
|
||||
"sampler_name": sampler,
|
||||
"scheduler": scheduler,
|
||||
"denoise": 1.0,
|
||||
"positive": []any{"6", 0},
|
||||
"negative": []any{"13", 0},
|
||||
"latent_image": []any{"27", 0},
|
||||
},
|
||||
},
|
||||
// buildWorkflow loads the configured workflow template and substitutes the
|
||||
// per-call placeholders (prompt, seed, sampler, …) plus any string/number
|
||||
// fields the user defined in the yaml block. The set of placeholder keys
|
||||
// that aren't in `subs` produces an error from SubstituteWorkflow.
|
||||
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string, cfg float64) (map[string]any, error) {
|
||||
wf, err := LoadWorkflowTemplate(c.workflow)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subs := map[string]any{
|
||||
"prompt": prompt,
|
||||
"negative": negative,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"sampler": sampler,
|
||||
"scheduler": scheduler,
|
||||
"cfg": cfg,
|
||||
}
|
||||
// Surface every scalar field from the yaml block so per-template knobs
|
||||
// (vae, clip, clip_l, clip_t5, dtype, shift, guidance, …) work without
|
||||
// adapter-code changes. Framework keys are excluded.
|
||||
for k, v := range c.rawCfg {
|
||||
switch k {
|
||||
case "type", "base_url", "workflow",
|
||||
"default_steps", "default_sampler", "default_scheduler", "default_cfg":
|
||||
continue
|
||||
}
|
||||
if _, alreadySet := subs[k]; alreadySet {
|
||||
// A per-call var (e.g. ${prompt}) beats anything yaml put under
|
||||
// the same key — yaml can't shadow request-derived values.
|
||||
continue
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case string, int, int64, float64, bool:
|
||||
subs[k] = v
|
||||
}
|
||||
}
|
||||
// Provide sensible defaults for common optional knobs so a workflow that
|
||||
// references one of these doesn't fail substitution when the user
|
||||
// didn't override it in yaml. Extra keys are ignored if the workflow
|
||||
// doesn't reference them, so it's safe to always set the lot.
|
||||
defaults := map[string]any{
|
||||
"vae": "ae.safetensors",
|
||||
"clip_l": "clip_l.safetensors",
|
||||
"clip_t5": "t5xxl_fp8_e4m3fn.safetensors",
|
||||
"clip": "qwen_3_4b.safetensors",
|
||||
"dtype": "fp8_e4m3fn",
|
||||
"guidance": 4.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
for k, v := range defaults {
|
||||
if _, ok := subs[k]; !ok {
|
||||
subs[k] = v
|
||||
}
|
||||
}
|
||||
if _, err := SubstituteWorkflow(wf, subs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wf, nil
|
||||
}
|
||||
|
||||
// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a
|
||||
@@ -432,8 +436,8 @@ func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, ste
|
||||
// turns that into the same user-facing error as a 4xx with the same body.
|
||||
func parsePromptID(body []byte, model string) (string, error) {
|
||||
var resp struct {
|
||||
PromptID string `json:"prompt_id"`
|
||||
NodeErrors map[string]any `json:"node_errors"`
|
||||
PromptID string `json:"prompt_id"`
|
||||
NodeErrors map[string]any `json:"node_errors"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
@@ -441,7 +445,7 @@ func parsePromptID(body []byte, model string) (string, error) {
|
||||
}
|
||||
if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 {
|
||||
if hint, ok := missingModelHint(body, model); ok {
|
||||
return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint)
|
||||
return "", fmt.Errorf("comfyui /prompt: %s — see docs/backends.md", hint)
|
||||
}
|
||||
return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body))
|
||||
}
|
||||
@@ -489,15 +493,21 @@ func parseHistory(body []byte, promptID string) (string, bool, error) {
|
||||
}
|
||||
|
||||
// missingModelHint returns a user-actionable message when the response body
|
||||
// indicates the configured unet model isn't loaded on the server. ComfyUI
|
||||
// uses both the human-readable "Value not in list" message and the enum
|
||||
// "value_not_in_list" type — match either.
|
||||
// indicates the configured unet/checkpoint model isn't loaded on the server.
|
||||
// ComfyUI uses both the human-readable "Value not in list" message and the
|
||||
// enum "value_not_in_list" type — match either.
|
||||
func missingModelHint(body []byte, model string) (string, bool) {
|
||||
s := string(body)
|
||||
hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list")
|
||||
if hasMarker && strings.Contains(s, "unet_name") {
|
||||
if !hasMarker {
|
||||
return "", false
|
||||
}
|
||||
if strings.Contains(s, "unet_name") {
|
||||
return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true
|
||||
}
|
||||
if strings.Contains(s, "ckpt_name") {
|
||||
return fmt.Sprintf("checkpoint %q not present in the ComfyUI server's models/checkpoints/", model), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -536,6 +546,22 @@ func getInt(m map[string]any, k string, def int) int {
|
||||
return def
|
||||
}
|
||||
|
||||
func getFloat(m map[string]any, k string, def float64) float64 {
|
||||
if v, ok := m[k]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return n
|
||||
case float32:
|
||||
return float64(n)
|
||||
case int:
|
||||
return float64(n)
|
||||
case int64:
|
||||
return float64(n)
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func orDefaultInt(v, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
|
||||
@@ -312,7 +312,7 @@ func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") {
|
||||
if !strings.Contains(msg, "docs/backends.md") {
|
||||
t.Errorf("error should point at the setup doc, got %v", err)
|
||||
}
|
||||
if !strings.Contains(msg, "flux1-schnell.safetensors") {
|
||||
@@ -331,7 +331,7 @@ func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("expected error for node_errors on 200")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") {
|
||||
if !strings.Contains(err.Error(), "docs/backends.md") {
|
||||
t.Errorf("error should point at the setup doc, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
567
internal/backend/replicate.go
Normal file
567
internal/backend/replicate.go
Normal file
@@ -0,0 +1,567 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReplicateType is the type-name adapters register under for Replicate
|
||||
// instances.
|
||||
const ReplicateType = "replicate"
|
||||
|
||||
// Replicate is the Replicate API adapter. It speaks the public REST API
|
||||
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
|
||||
// to submit, then polls /v1/predictions/{id} until the prediction
|
||||
// settles, then downloads the produced image.
|
||||
//
|
||||
// Concurrency: a single Replicate is safe to share across goroutines.
|
||||
type Replicate struct {
|
||||
instance string
|
||||
|
||||
apiBase string
|
||||
apiToken string
|
||||
tokenEnv string
|
||||
model string // "owner/name" or "owner/name:version-hash"
|
||||
owner string
|
||||
name string
|
||||
version string // optional; empty means "use the model-based predictions endpoint"
|
||||
defaultSteps int
|
||||
defaultAspect string
|
||||
|
||||
httpClient *http.Client
|
||||
pollInterval time.Duration
|
||||
pollTimeout time.Duration
|
||||
|
||||
// Hooks for tests; production paths use the package-level defaults.
|
||||
randSeed func() int64
|
||||
initialBackoff time.Duration
|
||||
|
||||
// Sink is where successful generations are recorded for cost-tracking.
|
||||
// nil means "do not record". The framework wires this up in the CLI;
|
||||
// adapter tests inject a fake.
|
||||
Sink UsageSink
|
||||
}
|
||||
|
||||
// UsageSink writes one row per successful generation. Implementations
|
||||
// should treat write failures as warnings, not errors — the image has
|
||||
// already landed on disk; failing the call would lose the artefact.
|
||||
type UsageSink interface {
|
||||
Record(ctx context.Context, row UsageRow) error
|
||||
}
|
||||
|
||||
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
|
||||
// prompt itself is intentionally NOT included — only the sha256 hash.
|
||||
type UsageRow struct {
|
||||
Backend string
|
||||
Model string
|
||||
Seed *int64
|
||||
PromptHash string
|
||||
LatencyMs int
|
||||
CostUSDEstimate *float64
|
||||
Caller string
|
||||
}
|
||||
|
||||
// NewReplicate is the registry constructor. cfg is the adapter's slice
|
||||
// of imagen.yaml.
|
||||
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("replicate: empty instance name")
|
||||
}
|
||||
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
|
||||
model := getString(cfg, "model", "")
|
||||
if model == "" {
|
||||
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
|
||||
}
|
||||
owner, modelName, version, err := parseModelRef(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
|
||||
}
|
||||
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
|
||||
pollTimeout := timeoutForModel(modelName)
|
||||
r := &Replicate{
|
||||
instance: name,
|
||||
apiBase: apiBase,
|
||||
tokenEnv: tokenEnv,
|
||||
apiToken: os.Getenv(tokenEnv),
|
||||
model: model,
|
||||
owner: owner,
|
||||
name: modelName,
|
||||
version: version,
|
||||
defaultSteps: getInt(cfg, "default_steps", 0),
|
||||
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
pollInterval: 500 * time.Millisecond,
|
||||
pollTimeout: pollTimeout,
|
||||
randSeed: cryptoSeed,
|
||||
initialBackoff: time.Second,
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Name returns the instance name from imagen.yaml.
|
||||
func (r *Replicate) Name() string { return r.instance }
|
||||
|
||||
// Generate submits one prediction and returns the resulting PNG.
|
||||
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
|
||||
if r.apiToken == "" {
|
||||
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
|
||||
}
|
||||
|
||||
width := orDefaultInt(req.Width, 1024)
|
||||
height := orDefaultInt(req.Height, 1024)
|
||||
aspect := computeAspectRatio(width, height, r.defaultAspect)
|
||||
|
||||
steps := orDefaultInt(req.Steps, r.defaultSteps)
|
||||
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = r.randSeed()
|
||||
}
|
||||
|
||||
input := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"aspect_ratio": aspect,
|
||||
"num_outputs": 1,
|
||||
"output_format": "png",
|
||||
"seed": seed,
|
||||
}
|
||||
if steps > 0 {
|
||||
input["num_inference_steps"] = steps
|
||||
}
|
||||
if req.NegativePrompt != "" {
|
||||
input["negative_prompt"] = req.NegativePrompt
|
||||
}
|
||||
maps.Copy(input, req.BackendOpts)
|
||||
|
||||
start := time.Now()
|
||||
pred, err := r.submitPrediction(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final, err := r.waitForCompletion(ctx, pred.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
|
||||
}
|
||||
imgURL, err := pickFirstOutputURL(final.Output)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
|
||||
}
|
||||
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
latencyMs := int(time.Since(start).Milliseconds())
|
||||
|
||||
predictTime := final.Metrics.PredictTime
|
||||
costEst, costKnown := replicatePerImageUSD(r.model)
|
||||
|
||||
meta := map[string]any{
|
||||
"backend": r.instance,
|
||||
"backend_type": ReplicateType,
|
||||
"model": r.model,
|
||||
"model_version": final.Version,
|
||||
"prediction_id": pred.ID,
|
||||
"seed": seed,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"aspect_ratio": aspect,
|
||||
"latency_ms": int64(latencyMs),
|
||||
}
|
||||
if predictTime > 0 {
|
||||
meta["predict_time_seconds"] = predictTime
|
||||
}
|
||||
if costKnown {
|
||||
meta["cost_usd_estimate"] = costEst
|
||||
}
|
||||
|
||||
if r.Sink != nil {
|
||||
row := UsageRow{
|
||||
Backend: r.instance,
|
||||
Model: r.model,
|
||||
PromptHash: hashPrompt(req.Prompt),
|
||||
LatencyMs: latencyMs,
|
||||
Caller: ResolveCaller(),
|
||||
}
|
||||
row.Seed = new(int64)
|
||||
*row.Seed = seed
|
||||
if costKnown {
|
||||
c := costEst
|
||||
row.CostUSDEstimate = &c
|
||||
}
|
||||
if err := r.Sink.Record(ctx, row); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Result{
|
||||
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
|
||||
MimeType: mime,
|
||||
Metadata: meta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// replicatePrediction is what the REST API returns under /v1/predictions
|
||||
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
|
||||
// declared.
|
||||
type replicatePrediction struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Version string `json:"version"`
|
||||
Error json.RawMessage `json:"error"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Metrics struct {
|
||||
PredictTime float64 `json:"predict_time"`
|
||||
} `json:"metrics"`
|
||||
}
|
||||
|
||||
// submitPrediction creates a prediction and returns it. Picks the
|
||||
// model-based endpoint when no version was given (recommended for
|
||||
// Replicate's official models), and the legacy /v1/predictions otherwise.
|
||||
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
|
||||
var (
|
||||
reqURL string
|
||||
body []byte
|
||||
err error
|
||||
)
|
||||
if r.version == "" {
|
||||
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
|
||||
body, err = json.Marshal(map[string]any{"input": input})
|
||||
} else {
|
||||
reqURL = r.apiBase + "/v1/predictions"
|
||||
body, err = json.Marshal(map[string]any{
|
||||
"version": r.version,
|
||||
"input": input,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pred replicatePrediction
|
||||
if err := json.Unmarshal(respBody, &pred); err != nil {
|
||||
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
|
||||
}
|
||||
if pred.ID == "" {
|
||||
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
|
||||
}
|
||||
return &pred, nil
|
||||
}
|
||||
|
||||
// waitForCompletion polls /v1/predictions/{id} until the status is a
|
||||
// terminal value (succeeded, failed, canceled) or the timeout fires.
|
||||
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
|
||||
deadline := time.Now().Add(r.pollTimeout)
|
||||
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
|
||||
start := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
|
||||
}
|
||||
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pred replicatePrediction
|
||||
if err := json.Unmarshal(body, &pred); err != nil {
|
||||
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
|
||||
}
|
||||
switch pred.Status {
|
||||
case "succeeded":
|
||||
return &pred, nil
|
||||
case "failed":
|
||||
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
|
||||
case "canceled":
|
||||
return nil, fmt.Errorf("status=canceled")
|
||||
case "starting", "processing", "":
|
||||
default:
|
||||
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(r.pollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry executes a Replicate API request and applies the resilience
|
||||
// policy: 401 surfaces a clean message naming the env var; 429 retries
|
||||
// with exponential backoff up to three times; 5xx retries once; other
|
||||
// errors surface unchanged.
|
||||
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
|
||||
const max429Retries = 3
|
||||
backoff := r.initialBackoff
|
||||
if backoff <= 0 {
|
||||
backoff = time.Second
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+r.apiToken)
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
if attempt >= 1 {
|
||||
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
|
||||
}
|
||||
lastErr = err
|
||||
if !sleepCtx(ctx, backoff) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
switch {
|
||||
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
||||
return respBody, nil
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
|
||||
case resp.StatusCode == http.StatusTooManyRequests:
|
||||
if attempt >= max429Retries {
|
||||
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
|
||||
}
|
||||
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
|
||||
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
|
||||
if !sleepCtx(ctx, wait) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
backoff *= 2
|
||||
continue
|
||||
case resp.StatusCode >= 500:
|
||||
if attempt >= 1 {
|
||||
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
|
||||
}
|
||||
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
|
||||
if !sleepCtx(ctx, backoff) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
default:
|
||||
_ = lastErr
|
||||
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchImage downloads the rendered image from the Replicate-provided
|
||||
// CDN URL. One retry on a generic network error.
|
||||
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
|
||||
var lastErr error
|
||||
for attempt := range 2 {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, "", err
|
||||
}
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
lastErr = readErr
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
if resp.StatusCode >= 500 && attempt == 0 {
|
||||
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
|
||||
continue
|
||||
}
|
||||
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
mime := resp.Header.Get("Content-Type")
|
||||
if mime == "" {
|
||||
mime = "image/png"
|
||||
}
|
||||
return body, mime, nil
|
||||
}
|
||||
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
|
||||
}
|
||||
|
||||
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
|
||||
func parseModelRef(ref string) (owner, name, version string, err error) {
|
||||
rest := ref
|
||||
if i := strings.IndexByte(rest, ':'); i >= 0 {
|
||||
version = rest[i+1:]
|
||||
rest = rest[:i]
|
||||
}
|
||||
parts := strings.SplitN(rest, "/", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
|
||||
}
|
||||
return parts[0], parts[1], version, nil
|
||||
}
|
||||
|
||||
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
|
||||
// than schnell; everything else gets the dev timeout for safety.
|
||||
func timeoutForModel(name string) time.Duration {
|
||||
switch strings.ToLower(name) {
|
||||
case "flux-schnell":
|
||||
return 60 * time.Second
|
||||
default:
|
||||
return 120 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// pickFirstOutputURL extracts the first image URL from the output field.
|
||||
// Replicate returns either a single string or an array of strings for image
|
||||
// models — we accept both.
|
||||
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", fmt.Errorf("output is empty")
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
|
||||
return s, nil
|
||||
}
|
||||
var arr []string
|
||||
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
|
||||
return arr[0], nil
|
||||
}
|
||||
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
|
||||
}
|
||||
|
||||
// computeAspectRatio reduces width:height to a Replicate-supported aspect
|
||||
// ratio when the reduction lands on one of the canonical values; otherwise
|
||||
// returns fallback.
|
||||
func computeAspectRatio(w, h int, fallback string) string {
|
||||
if w <= 0 || h <= 0 {
|
||||
return fallback
|
||||
}
|
||||
g := gcd(w, h)
|
||||
a, b := w/g, h/g
|
||||
s := fmt.Sprintf("%d:%d", a, b)
|
||||
if isReplicateAspectRatio(s) {
|
||||
return s
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func isReplicateAspectRatio(s string) bool {
|
||||
switch s {
|
||||
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func gcd(a, b int) int {
|
||||
for b != 0 {
|
||||
a, b = b, a%b
|
||||
}
|
||||
if a < 0 {
|
||||
return -a
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
|
||||
// is intentionally never written to the cost-tracking table.
|
||||
func hashPrompt(p string) string {
|
||||
sum := sha256.Sum256([]byte(p))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// ResolveCaller returns the agent identity the cost-tracking row is
|
||||
// attributed to. Order of resolution mirrors the maimcp identity logic:
|
||||
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
|
||||
// "unknown".
|
||||
func ResolveCaller() string {
|
||||
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
|
||||
return v
|
||||
}
|
||||
if pane := os.Getenv("TMUX_PANE"); pane != "" {
|
||||
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
|
||||
if err == nil {
|
||||
if name := strings.TrimSpace(string(out)); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// backoffFor429 honours a Retry-After header (in seconds) when present
|
||||
// and within reason, otherwise falls back to the caller's backoff.
|
||||
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
|
||||
if retryAfter == "" {
|
||||
return fallback
|
||||
}
|
||||
d, err := time.ParseDuration(retryAfter + "s")
|
||||
if err != nil || d <= 0 {
|
||||
return fallback
|
||||
}
|
||||
if d > 30*time.Second {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func sleepCtx(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(d):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func bytesReader(b []byte) io.Reader {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
return bytes.NewReader(b)
|
||||
}
|
||||
|
||||
// shortPath strips the host so error messages don't leak the API base.
|
||||
func shortPath(u string) string {
|
||||
if i := strings.Index(u, "/v1/"); i >= 0 {
|
||||
return u[i:]
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(ReplicateType, NewReplicate)
|
||||
}
|
||||
42
internal/backend/replicate_pricing.go
Normal file
42
internal/backend/replicate_pricing.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
import "strings"
|
||||
|
||||
// Replicate pricing snapshot.
|
||||
//
|
||||
// Source: https://replicate.com/pricing and the per-model "Run" tab on
|
||||
// each model page. Replicate bills per second of GPU time, but the
|
||||
// black-forest-labs FLUX models also publish a flat per-image price for
|
||||
// the typical settings — that flat number is what we hardcode here.
|
||||
//
|
||||
// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the
|
||||
// rates drift more than ~10%, update the table and bump snapshotDate.
|
||||
const replicatePricingSnapshotDate = "2026-05-08"
|
||||
|
||||
// replicatePerImageUSD is the per-image cost estimate keyed by Replicate
|
||||
// model identifier ("owner/name", with any ":version" trimmed). Returns
|
||||
// the rate and true if the model is known, 0 and false otherwise — an
|
||||
// unknown model writes a row with NULL cost rather than a wrong number.
|
||||
func replicatePerImageUSD(model string) (float64, bool) {
|
||||
key := normalisePricingKey(model)
|
||||
switch key {
|
||||
case "black-forest-labs/flux-schnell":
|
||||
return 0.003, true
|
||||
case "black-forest-labs/flux-dev":
|
||||
return 0.025, true
|
||||
case "black-forest-labs/flux-pro":
|
||||
return 0.055, true
|
||||
case "black-forest-labs/flux-1.1-pro":
|
||||
return 0.040, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// normalisePricingKey strips the optional ":version" suffix and lowercases
|
||||
// the owner/name pair. "Owner/Name:hash" → "owner/name".
|
||||
func normalisePricingKey(model string) string {
|
||||
if i := strings.IndexByte(model, ':'); i >= 0 {
|
||||
model = model[:i]
|
||||
}
|
||||
return strings.ToLower(model)
|
||||
}
|
||||
675
internal/backend/replicate_test.go
Normal file
675
internal/backend/replicate_test.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeReplicate is a programmable mock of the Replicate REST API.
|
||||
type fakeReplicate struct {
|
||||
t *testing.T
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
// Responses for the predictions submission. If the request matches the
|
||||
// model-based path, modelEndpointHits increments; for /v1/predictions,
|
||||
// versionEndpointHits increments.
|
||||
createStatus int
|
||||
createBody []byte
|
||||
createCalls atomic.Int32
|
||||
|
||||
// Auth-fail policy: if true, every request to the prediction endpoints
|
||||
// returns 401 with a stock body.
|
||||
auth401 bool
|
||||
|
||||
// 429-then-OK policy: first N create calls return 429.
|
||||
create429Until int32
|
||||
retryAfter string
|
||||
|
||||
// Sequence of /predictions/{id} responses, walked in order; once
|
||||
// exhausted, the last entry is returned indefinitely.
|
||||
pollResponses []string
|
||||
pollIdx atomic.Int32
|
||||
pollCalls atomic.Int32
|
||||
|
||||
// Image download policy.
|
||||
imageStatus int
|
||||
imageBody []byte
|
||||
image5xxFirst int32
|
||||
imageCalls atomic.Int32
|
||||
|
||||
server *httptest.Server
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
|
||||
f.handleCreate(w, r)
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
|
||||
f.handleCreate(w, r)
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
|
||||
f.handlePoll(w, r)
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/img":
|
||||
f.handleImage(w, r)
|
||||
default:
|
||||
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
|
||||
n := f.createCalls.Add(1)
|
||||
if f.auth401 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
|
||||
return
|
||||
}
|
||||
if n <= f.create429Until {
|
||||
if f.retryAfter != "" {
|
||||
w.Header().Set("Retry-After", f.retryAfter)
|
||||
}
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(f.createStatus)
|
||||
_, _ = w.Write(f.createBody)
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
|
||||
f.pollCalls.Add(1)
|
||||
idx := int(f.pollIdx.Add(1)) - 1
|
||||
if idx >= len(f.pollResponses) {
|
||||
idx = len(f.pollResponses) - 1
|
||||
}
|
||||
if idx < 0 {
|
||||
http.Error(w, "no poll response configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(f.pollResponses[idx]))
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
|
||||
n := f.imageCalls.Add(1)
|
||||
if n <= f.image5xxFirst {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = w.Write([]byte("upstream unavailable"))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.WriteHeader(f.imageStatus)
|
||||
_, _ = w.Write(f.imageBody)
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) start() {
|
||||
f.server = httptest.NewServer(f.handler())
|
||||
f.t.Cleanup(f.server.Close)
|
||||
}
|
||||
|
||||
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
|
||||
|
||||
func newFakeReplicate(t *testing.T) *fakeReplicate {
|
||||
t.Helper()
|
||||
f := &fakeReplicate{
|
||||
t: t,
|
||||
createStatus: http.StatusCreated,
|
||||
imageStatus: http.StatusOK,
|
||||
imageBody: mustPNG(t, 8, 8),
|
||||
}
|
||||
f.start()
|
||||
// Default happy-path responses now that the server URL is known.
|
||||
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
|
||||
f.pollResponses = []string{
|
||||
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
|
||||
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
|
||||
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
|
||||
t.Helper()
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"api_token_env": "TEST_REPLICATE_TOKEN",
|
||||
"model": model,
|
||||
"api_base": f.server.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake-token"
|
||||
r.pollInterval = time.Millisecond
|
||||
r.pollTimeout = 5 * time.Second
|
||||
r.initialBackoff = time.Millisecond
|
||||
r.randSeed = func() int64 { return 42 }
|
||||
return r
|
||||
}
|
||||
|
||||
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
|
||||
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
|
||||
t.Errorf("expected error for empty instance name")
|
||||
}
|
||||
if _, err := NewReplicate("x", map[string]any{}); err == nil {
|
||||
t.Errorf("expected error for missing model")
|
||||
}
|
||||
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
|
||||
t.Errorf("expected error for malformed model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.apiToken = "" // simulate the env var being unset
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when token is missing")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
|
||||
t.Errorf("error should name the env var: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
sink := &recordingSink{}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.Sink = sink
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{
|
||||
Prompt: "a tiny dragon",
|
||||
Width: 1024, Height: 1024,
|
||||
Seed: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
body, _ := io.ReadAll(res.ImageReader)
|
||||
if !bytes.Equal(body, f.imageBody) {
|
||||
t.Errorf("image body did not round-trip")
|
||||
}
|
||||
|
||||
if mime := res.MimeType; mime != "image/png" {
|
||||
t.Errorf("mime = %q", mime)
|
||||
}
|
||||
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
|
||||
t.Errorf("metadata model = %v", got)
|
||||
}
|
||||
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
|
||||
t.Errorf("metadata model_version = %v", got)
|
||||
}
|
||||
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
|
||||
t.Errorf("metadata predict_time_seconds = %v", got)
|
||||
}
|
||||
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
|
||||
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
|
||||
}
|
||||
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
|
||||
t.Errorf("aspect_ratio = %q", got)
|
||||
}
|
||||
|
||||
if got := f.pollCalls.Load(); got < 3 {
|
||||
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
|
||||
}
|
||||
|
||||
if len(sink.rows) != 1 {
|
||||
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
|
||||
}
|
||||
row := sink.rows[0]
|
||||
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
|
||||
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
|
||||
}
|
||||
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
|
||||
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
|
||||
}
|
||||
if len(row.PromptHash) != 64 {
|
||||
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
|
||||
}
|
||||
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
|
||||
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
|
||||
}
|
||||
if row.LatencyMs <= 0 {
|
||||
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
|
||||
var sentBody []byte
|
||||
var sentPath string
|
||||
mu := sync.Mutex{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
|
||||
mu.Lock()
|
||||
sentBody, _ = io.ReadAll(r.Body)
|
||||
sentPath = r.URL.Path
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
|
||||
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
|
||||
default:
|
||||
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"model": "black-forest-labs/flux-dev:abc123",
|
||||
"api_base": srv.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake"
|
||||
r.pollInterval = time.Millisecond
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if sentPath != "/v1/predictions" {
|
||||
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(sentBody, &body); err != nil {
|
||||
t.Fatalf("unmarshal sent body: %v", err)
|
||||
}
|
||||
if body["version"] != "abc123" {
|
||||
t.Errorf("expected version=abc123 in body, got %v", body["version"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicate401SurfacesEnvHint(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.auth401 = true
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected 401 to surface as error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
|
||||
t.Errorf("error should name env var: %v", err)
|
||||
}
|
||||
if !strings.Contains(msg, "401") {
|
||||
t.Errorf("error should mention 401: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.create429Until = 2 // first two calls 429 then succeed
|
||||
f.retryAfter = "" // force the adapter's exp-backoff path
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.pollInterval = time.Millisecond
|
||||
|
||||
// Squash the backoff so the test is fast.
|
||||
r.httpClient = &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
res, err := r.Generate(ctx, Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate after 429s: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
if got := f.createCalls.Load(); got != 3 {
|
||||
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.create429Until = 99 // every call 429
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
_, err := r.Generate(ctx, Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after sustained 429s")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "429") {
|
||||
t.Errorf("expected 429 in error: %v", err)
|
||||
}
|
||||
// max429Retries=3 → 1 initial + 3 retries = 4 total
|
||||
if got := f.createCalls.Load(); got != 4 {
|
||||
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.pollResponses = []string{
|
||||
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
|
||||
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
|
||||
}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for failed prediction")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed") {
|
||||
t.Errorf("error should mention failure: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "NSFW") {
|
||||
t.Errorf("error should include the API error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
// Always return processing → adapter times out.
|
||||
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.pollInterval = 5 * time.Millisecond
|
||||
r.pollTimeout = 30 * time.Millisecond
|
||||
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "did not complete") {
|
||||
t.Errorf("expected 'did not complete' in error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "waited") {
|
||||
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.image5xxFirst = 1 // first download 502, second OK
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (download retry): %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
if got := f.imageCalls.Load(); got != 2 {
|
||||
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.image5xxFirst = 99 // every download fails
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
|
||||
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after sustained image-download 5xx")
|
||||
}
|
||||
if got := f.imageCalls.Load(); got != 2 {
|
||||
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateContextCancelStopsPolling(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.pollInterval = 5 * time.Millisecond
|
||||
r.pollTimeout = 5 * time.Second
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err := r.Generate(ctx, Request{Prompt: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected ctx error")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
t.Errorf("expected deadline exceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
|
||||
var captured map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var top struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &top)
|
||||
captured = top.Input
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"model": "black-forest-labs/flux-schnell",
|
||||
"api_base": srv.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake"
|
||||
r.pollInterval = time.Millisecond
|
||||
|
||||
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
|
||||
// captured input should still have been recorded.
|
||||
_, _ = r.Generate(context.Background(), Request{
|
||||
Prompt: "p",
|
||||
Width: 1024, Height: 1024,
|
||||
BackendOpts: map[string]any{
|
||||
"output_quality": 90,
|
||||
"go_fast": true,
|
||||
},
|
||||
})
|
||||
if captured == nil {
|
||||
t.Fatal("create endpoint not hit")
|
||||
}
|
||||
if captured["output_quality"] != float64(90) {
|
||||
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
|
||||
}
|
||||
if captured["go_fast"] != true {
|
||||
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
|
||||
}
|
||||
if captured["prompt"] != "p" {
|
||||
t.Errorf("prompt not threaded: %v", captured["prompt"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateOutputArrayAccepted(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
f.pollResponses = []string{
|
||||
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
|
||||
}
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (output as array): %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
}
|
||||
|
||||
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
r := newReplicate(t, f, "stability-ai/sdxl")
|
||||
sink := &recordingSink{}
|
||||
r.Sink = sink
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate (unknown model): %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
if _, present := res.Metadata["cost_usd_estimate"]; present {
|
||||
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
|
||||
}
|
||||
if len(sink.rows) != 1 {
|
||||
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
|
||||
}
|
||||
if sink.rows[0].CostUSDEstimate != nil {
|
||||
t.Errorf("expected nil cost in sink row for unknown model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
|
||||
f := newFakeReplicate(t)
|
||||
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
|
||||
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
|
||||
|
||||
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if err != nil {
|
||||
t.Fatalf("sink failure should not fail Generate: %v", err)
|
||||
}
|
||||
res.ImageReader.Close()
|
||||
}
|
||||
|
||||
func TestReplicateDefaultStepsApplied(t *testing.T) {
|
||||
var captured map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var top struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &top)
|
||||
captured = top.Input
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
be, err := NewReplicate("flux-test", map[string]any{
|
||||
"model": "black-forest-labs/flux-dev",
|
||||
"api_base": srv.URL,
|
||||
"default_steps": 28,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewReplicate: %v", err)
|
||||
}
|
||||
r := be.(*Replicate)
|
||||
r.apiToken = "fake"
|
||||
r.pollInterval = time.Millisecond
|
||||
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
|
||||
if captured == nil {
|
||||
t.Fatal("create endpoint not hit")
|
||||
}
|
||||
if captured["num_inference_steps"] != float64(28) {
|
||||
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeAspectRatio(t *testing.T) {
|
||||
cases := []struct {
|
||||
w, h int
|
||||
fallback string
|
||||
want string
|
||||
}{
|
||||
{1024, 1024, "1:1", "1:1"},
|
||||
{1920, 1080, "1:1", "16:9"},
|
||||
{2560, 1440, "1:1", "16:9"},
|
||||
{1024, 768, "1:1", "4:3"},
|
||||
{1024, 1280, "1:1", "4:5"},
|
||||
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
|
||||
{0, 1024, "1:1", "1:1"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := computeAspectRatio(c.w, c.h, c.fallback)
|
||||
if got != c.want {
|
||||
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef(t *testing.T) {
|
||||
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
|
||||
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
|
||||
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
|
||||
}
|
||||
owner, name, ver, err = parseModelRef("owner/name:hash123")
|
||||
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
|
||||
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
|
||||
}
|
||||
if _, _, _, err := parseModelRef("noslash"); err == nil {
|
||||
t.Errorf("expected error for malformed ref")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPromptStable(t *testing.T) {
|
||||
a := hashPrompt("hello")
|
||||
b := hashPrompt("hello")
|
||||
c := hashPrompt("hello!")
|
||||
if a != b {
|
||||
t.Errorf("hashPrompt should be deterministic")
|
||||
}
|
||||
if a == c {
|
||||
t.Errorf("different prompts should hash differently")
|
||||
}
|
||||
if len(a) != 64 {
|
||||
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicatePricingKnownModels(t *testing.T) {
|
||||
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
|
||||
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
|
||||
}
|
||||
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
|
||||
t.Errorf("dev rate = %v (ok=%v)", v, ok)
|
||||
}
|
||||
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
|
||||
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
|
||||
}
|
||||
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
|
||||
t.Errorf("unknown model should report ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateTypeIsRegistered(t *testing.T) {
|
||||
if !Default.Has(ReplicateType) {
|
||||
t.Errorf("replicate type not registered in Default")
|
||||
}
|
||||
}
|
||||
|
||||
// recordingSink captures rows for assertion.
|
||||
type recordingSink struct {
|
||||
mu sync.Mutex
|
||||
rows []UsageRow
|
||||
}
|
||||
|
||||
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.rows = append(s.rows, row)
|
||||
return nil
|
||||
}
|
||||
|
||||
type sinkFunc func(context.Context, UsageRow) error
|
||||
|
||||
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }
|
||||
156
internal/backend/workflow_template.go
Normal file
156
internal/backend/workflow_template.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed workflows/*.json
|
||||
var bundledWorkflows embed.FS
|
||||
|
||||
// placeholderRE matches a single-token placeholder like "${prompt}" — the
|
||||
// whole string value must be the placeholder, leading/trailing whitespace
|
||||
// allowed. This lets us preserve types (a numeric substitution becomes a
|
||||
// JSON number, not a stringified one) instead of round-tripping through
|
||||
// strings.Replace which would force everything into a string.
|
||||
var placeholderRE = regexp.MustCompile(`^\s*\$\{([a-zA-Z][a-zA-Z0-9_]*)\}\s*$`)
|
||||
|
||||
// LoadWorkflowTemplate returns the parsed JSON for a workflow template.
|
||||
// `name` is resolved in this order:
|
||||
//
|
||||
// 1. exact filesystem path that exists on disk (absolute or relative);
|
||||
// 2. one of the bundled templates under internal/backend/workflows/
|
||||
// (with or without the .json suffix).
|
||||
//
|
||||
// The returned map is a fresh deep copy of the template; callers can mutate
|
||||
// it freely.
|
||||
func LoadWorkflowTemplate(name string) (map[string]any, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("workflow template name is empty")
|
||||
}
|
||||
raw, err := readWorkflowBytes(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var wf map[string]any
|
||||
if err := json.Unmarshal(raw, &wf); err != nil {
|
||||
return nil, fmt.Errorf("workflow %s: parse: %w", name, err)
|
||||
}
|
||||
return wf, nil
|
||||
}
|
||||
|
||||
// BundledWorkflowNames returns the names of templates compiled into the
|
||||
// binary, sorted. Each name is the basename without the .json suffix.
|
||||
func BundledWorkflowNames() []string {
|
||||
entries, err := fs.ReadDir(bundledWorkflows, "workflows")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
n := e.Name()
|
||||
if !strings.HasSuffix(n, ".json") {
|
||||
continue
|
||||
}
|
||||
out = append(out, strings.TrimSuffix(n, ".json"))
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func readWorkflowBytes(name string) ([]byte, error) {
|
||||
// Filesystem path wins if it points at a real file. Lets a user override
|
||||
// a bundled template by passing an absolute path in yaml.
|
||||
if strings.ContainsRune(name, os.PathSeparator) || strings.HasSuffix(name, ".json") {
|
||||
if b, err := os.ReadFile(name); err == nil {
|
||||
return b, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("workflow %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
// Bundled lookup. Try the literal name as a file inside workflows/, then
|
||||
// with the .json suffix appended.
|
||||
candidates := []string{
|
||||
filepath.Join("workflows", name),
|
||||
filepath.Join("workflows", name+".json"),
|
||||
}
|
||||
for _, c := range candidates {
|
||||
if b, err := bundledWorkflows.ReadFile(c); err == nil {
|
||||
return b, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("workflow %q not found (bundled templates: %v)", name, BundledWorkflowNames())
|
||||
}
|
||||
|
||||
// SubstituteWorkflow walks wf and replaces every "${key}" string with the
|
||||
// matching value from subs, preserving JSON types. Returns the set of
|
||||
// placeholder keys it actually touched, so the caller can detect missing
|
||||
// substitutions even when a key is defined in subs but never referenced in
|
||||
// the workflow (typical when a yaml block sets a knob a different template
|
||||
// would consume).
|
||||
//
|
||||
// Unknown placeholders (referenced in the workflow but absent from subs)
|
||||
// produce an error so we never submit a workflow with raw "${foo}" tokens.
|
||||
func SubstituteWorkflow(wf map[string]any, subs map[string]any) (used map[string]struct{}, err error) {
|
||||
used = make(map[string]struct{})
|
||||
walked, err := substituteValue(wf, subs, used)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// substituteValue returns the replacement for the top-level value, which
|
||||
// should still be the same map (just with mutated children).
|
||||
if m, ok := walked.(map[string]any); ok {
|
||||
// Copy back into wf so the caller's reference reflects the result.
|
||||
for k := range wf {
|
||||
delete(wf, k)
|
||||
}
|
||||
maps.Copy(wf, m)
|
||||
}
|
||||
return used, nil
|
||||
}
|
||||
|
||||
func substituteValue(v any, subs map[string]any, used map[string]struct{}) (any, error) {
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(x))
|
||||
for k, child := range x {
|
||||
replaced, err := substituteValue(child, subs, used)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[k] = replaced
|
||||
}
|
||||
return out, nil
|
||||
case []any:
|
||||
out := make([]any, len(x))
|
||||
for i, child := range x {
|
||||
replaced, err := substituteValue(child, subs, used)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[i] = replaced
|
||||
}
|
||||
return out, nil
|
||||
case string:
|
||||
if m := placeholderRE.FindStringSubmatch(x); m != nil {
|
||||
key := m[1]
|
||||
val, ok := subs[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("workflow placeholder ${%s} has no substitution", key)
|
||||
}
|
||||
used[key] = struct{}{}
|
||||
return val, nil
|
||||
}
|
||||
return x, nil
|
||||
default:
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
153
internal/backend/workflow_template_test.go
Normal file
153
internal/backend/workflow_template_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBundledWorkflowsParseable(t *testing.T) {
|
||||
names := BundledWorkflowNames()
|
||||
if len(names) == 0 {
|
||||
t.Fatal("expected at least one bundled workflow")
|
||||
}
|
||||
mustHave := []string{"flux1-schnell", "flux2-klein", "sd35-medium"}
|
||||
for _, want := range mustHave {
|
||||
if !slices.Contains(names, want) {
|
||||
t.Errorf("bundled workflows missing %q (have: %v)", want, names)
|
||||
}
|
||||
}
|
||||
// Every bundled template must parse and contain at least one node.
|
||||
for _, n := range names {
|
||||
wf, err := LoadWorkflowTemplate(n)
|
||||
if err != nil {
|
||||
t.Errorf("LoadWorkflowTemplate(%q): %v", n, err)
|
||||
continue
|
||||
}
|
||||
if len(wf) == 0 {
|
||||
t.Errorf("workflow %q has zero nodes", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadWorkflowFromFilesystem(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "custom.json")
|
||||
body := `{"1":{"class_type":"X","inputs":{"v":"${prompt}"}}}`
|
||||
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
|
||||
t.Fatalf("write tmp workflow: %v", err)
|
||||
}
|
||||
wf, err := LoadWorkflowTemplate(path)
|
||||
if err != nil {
|
||||
t.Fatalf("load from path: %v", err)
|
||||
}
|
||||
if _, ok := wf["1"]; !ok {
|
||||
t.Errorf("custom workflow missing node 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadWorkflowUnknownNameErrors(t *testing.T) {
|
||||
_, err := LoadWorkflowTemplate("definitely-not-a-real-workflow")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown workflow name")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("error should say not found, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteWorkflowPreservesTypes(t *testing.T) {
|
||||
wf := map[string]any{
|
||||
"31": map[string]any{
|
||||
"class_type": "KSampler",
|
||||
"inputs": map[string]any{
|
||||
"seed": "${seed}",
|
||||
"steps": "${steps}",
|
||||
"text": "${prompt}",
|
||||
"cfg": "${cfg}",
|
||||
},
|
||||
},
|
||||
}
|
||||
subs := map[string]any{
|
||||
"seed": int64(42),
|
||||
"steps": 11,
|
||||
"prompt": "a cat",
|
||||
"cfg": 4.5,
|
||||
}
|
||||
used, err := SubstituteWorkflow(wf, subs)
|
||||
if err != nil {
|
||||
t.Fatalf("Substitute: %v", err)
|
||||
}
|
||||
if len(used) != 4 {
|
||||
t.Errorf("used = %v, want all four", used)
|
||||
}
|
||||
inputs := wf["31"].(map[string]any)["inputs"].(map[string]any)
|
||||
if seed, ok := inputs["seed"].(int64); !ok || seed != 42 {
|
||||
t.Errorf("seed = %T %v, want int64 42", inputs["seed"], inputs["seed"])
|
||||
}
|
||||
if steps, ok := inputs["steps"].(int); !ok || steps != 11 {
|
||||
t.Errorf("steps = %T %v, want int 11", inputs["steps"], inputs["steps"])
|
||||
}
|
||||
if text, ok := inputs["text"].(string); !ok || text != "a cat" {
|
||||
t.Errorf("text = %T %v, want string", inputs["text"], inputs["text"])
|
||||
}
|
||||
if cfg, ok := inputs["cfg"].(float64); !ok || cfg != 4.5 {
|
||||
t.Errorf("cfg = %T %v, want float64 4.5", inputs["cfg"], inputs["cfg"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteWorkflowMissingPlaceholderErrors(t *testing.T) {
|
||||
wf := map[string]any{
|
||||
"1": map[string]any{"inputs": map[string]any{"v": "${missing}"}},
|
||||
}
|
||||
_, err := SubstituteWorkflow(wf, map[string]any{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing placeholder")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "${missing}") {
|
||||
t.Errorf("error should name the placeholder, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubstituteWorkflowOnlyWholeTokens(t *testing.T) {
|
||||
// Partial-match strings ("prefix ${prompt} suffix") are NOT substituted —
|
||||
// the placeholder must be the whole value so we can preserve types.
|
||||
wf := map[string]any{
|
||||
"1": map[string]any{"inputs": map[string]any{
|
||||
"keep_string": "stuff with ${prompt} inside",
|
||||
"replace_full": "${prompt}",
|
||||
}},
|
||||
}
|
||||
used, err := SubstituteWorkflow(wf, map[string]any{"prompt": "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("Substitute: %v", err)
|
||||
}
|
||||
inputs := wf["1"].(map[string]any)["inputs"].(map[string]any)
|
||||
if inputs["keep_string"].(string) != "stuff with ${prompt} inside" {
|
||||
t.Errorf("partial match should be left alone, got %q", inputs["keep_string"])
|
||||
}
|
||||
if inputs["replace_full"].(string) != "x" {
|
||||
t.Errorf("full-value match should substitute, got %q", inputs["replace_full"])
|
||||
}
|
||||
if _, ok := used["prompt"]; !ok {
|
||||
t.Errorf("used should track keys that fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlux1SchnellTemplateMatchesLegacyShape(t *testing.T) {
|
||||
// Regression guard against the historical hardcoded workflow: every
|
||||
// node ID the old Comfy.buildWorkflow used must still exist in the
|
||||
// migrated template.
|
||||
wf, err := LoadWorkflowTemplate("flux1-schnell")
|
||||
if err != nil {
|
||||
t.Fatalf("load flux1-schnell: %v", err)
|
||||
}
|
||||
legacyNodes := []string{"6", "8", "9", "10", "11", "12", "13", "27", "30", "31"}
|
||||
for _, id := range legacyNodes {
|
||||
if _, ok := wf[id]; !ok {
|
||||
t.Errorf("flux1-schnell template missing node %q (legacy parity)", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
84
internal/backend/workflows/flux1-schnell.json
Normal file
84
internal/backend/workflows/flux1-schnell.json
Normal file
@@ -0,0 +1,84 @@
|
||||
{
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "${prompt}",
|
||||
"clip": ["11", 0]
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {
|
||||
"samples": ["31", 0],
|
||||
"vae": ["10", 0]
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {
|
||||
"filename_prefix": "imagen",
|
||||
"images": ["8", 0]
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"class_type": "VAELoader",
|
||||
"inputs": {
|
||||
"vae_name": "${vae}"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"class_type": "DualCLIPLoader",
|
||||
"inputs": {
|
||||
"clip_name1": "${clip_t5}",
|
||||
"clip_name2": "${clip_l}",
|
||||
"type": "flux"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": {
|
||||
"unet_name": "${model}",
|
||||
"weight_dtype": "${dtype}"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "${negative}",
|
||||
"clip": ["11", 0]
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {
|
||||
"width": "${width}",
|
||||
"height": "${height}",
|
||||
"batch_size": 1
|
||||
}
|
||||
},
|
||||
"30": {
|
||||
"class_type": "ModelSamplingFlux",
|
||||
"inputs": {
|
||||
"model": ["12", 0],
|
||||
"max_shift": 1.15,
|
||||
"base_shift": 0.5,
|
||||
"width": "${width}",
|
||||
"height": "${height}"
|
||||
}
|
||||
},
|
||||
"31": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["30", 0],
|
||||
"seed": "${seed}",
|
||||
"steps": "${steps}",
|
||||
"cfg": "${cfg}",
|
||||
"sampler_name": "${sampler}",
|
||||
"scheduler": "${scheduler}",
|
||||
"denoise": 1.0,
|
||||
"positive": ["6", 0],
|
||||
"negative": ["13", 0],
|
||||
"latent_image": ["27", 0]
|
||||
}
|
||||
}
|
||||
}
|
||||
79
internal/backend/workflows/flux2-klein.json
Normal file
79
internal/backend/workflows/flux2-klein.json
Normal file
@@ -0,0 +1,79 @@
|
||||
{
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "${prompt}",
|
||||
"clip": ["11", 0]
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {
|
||||
"samples": ["31", 0],
|
||||
"vae": ["10", 0]
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {
|
||||
"filename_prefix": "imagen",
|
||||
"images": ["8", 0]
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"class_type": "VAELoader",
|
||||
"inputs": {
|
||||
"vae_name": "${vae}"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"class_type": "CLIPLoader",
|
||||
"inputs": {
|
||||
"clip_name": "${clip}",
|
||||
"type": "flux2"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": {
|
||||
"unet_name": "${model}",
|
||||
"weight_dtype": "${dtype}"
|
||||
}
|
||||
},
|
||||
"14": {
|
||||
"class_type": "FluxGuidance",
|
||||
"inputs": {
|
||||
"conditioning": ["6", 0],
|
||||
"guidance": "${guidance}"
|
||||
}
|
||||
},
|
||||
"15": {
|
||||
"class_type": "ConditioningZeroOut",
|
||||
"inputs": {
|
||||
"conditioning": ["6", 0]
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"class_type": "EmptyFlux2LatentImage",
|
||||
"inputs": {
|
||||
"width": "${width}",
|
||||
"height": "${height}",
|
||||
"batch_size": 1
|
||||
}
|
||||
},
|
||||
"31": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["12", 0],
|
||||
"seed": "${seed}",
|
||||
"steps": "${steps}",
|
||||
"cfg": "${cfg}",
|
||||
"sampler_name": "${sampler}",
|
||||
"scheduler": "${scheduler}",
|
||||
"denoise": 1.0,
|
||||
"positive": ["14", 0],
|
||||
"negative": ["15", 0],
|
||||
"latent_image": ["27", 0]
|
||||
}
|
||||
}
|
||||
}
|
||||
66
internal/backend/workflows/sd35-medium.json
Normal file
66
internal/backend/workflows/sd35-medium.json
Normal file
@@ -0,0 +1,66 @@
|
||||
{
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "${model}"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "${prompt}",
|
||||
"clip": ["4", 1]
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "${negative}",
|
||||
"clip": ["4", 1]
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {
|
||||
"samples": ["31", 0],
|
||||
"vae": ["4", 2]
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {
|
||||
"filename_prefix": "imagen",
|
||||
"images": ["8", 0]
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"inputs": {
|
||||
"model": ["4", 0],
|
||||
"shift": "${shift}"
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {
|
||||
"width": "${width}",
|
||||
"height": "${height}",
|
||||
"batch_size": 1
|
||||
}
|
||||
},
|
||||
"31": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["13", 0],
|
||||
"seed": "${seed}",
|
||||
"steps": "${steps}",
|
||||
"cfg": "${cfg}",
|
||||
"sampler_name": "${sampler}",
|
||||
"scheduler": "${scheduler}",
|
||||
"denoise": 1.0,
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["27", 0]
|
||||
}
|
||||
}
|
||||
}
|
||||
365
internal/cloud/cloud.go
Normal file
365
internal/cloud/cloud.go
Normal file
@@ -0,0 +1,365 @@
|
||||
// Package cloud syncs a generated image to Supabase Storage and inserts
|
||||
// a row into imagen.images. Both steps are best-effort: callers log the
|
||||
// returned error and proceed, because the local PNG + sidecar are already
|
||||
// on disk by the time Sync runs and a cloud blip should not lose the
|
||||
// artefact.
|
||||
//
|
||||
// The single source of truth for the row schema is the imagen_schema_init
|
||||
// migration — see internal docs in the issue body for #7.
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// supabaseSchema is the PostgREST profile header value the imagen schema
|
||||
// is exposed under (see ALTER ROLE authenticator SET pgrst.db_schemas).
|
||||
const supabaseSchema = "imagen"
|
||||
|
||||
// bucketName is the Supabase Storage bucket all generated images land in.
|
||||
const bucketName = "imagen-generated"
|
||||
|
||||
// Sink writes one PNG + one row per generation. It is safe to share
|
||||
// across goroutines.
|
||||
type Sink struct {
|
||||
// URL is SUPABASE_URL — e.g. https://supa.flexsiebels.de.
|
||||
URL string
|
||||
// APIKey is the service-role key (SUPABASE_SERVICE_KEY). Storage uploads
|
||||
// and DB inserts both bypass RLS with this key — the policies on the
|
||||
// table + bucket are the contract for the read side.
|
||||
APIKey string
|
||||
// OwnerUserID is m's auth.users.id. It populates owner_user_id on every
|
||||
// row. Empty means the sink refuses to insert (the column is NOT NULL
|
||||
// and the user-mode reader needs it for the RLS policy).
|
||||
OwnerUserID string
|
||||
// HTTP is the http client; tests inject one pointing at httptest.
|
||||
HTTP *http.Client
|
||||
// MaxRetries is the number of additional attempts after the first
|
||||
// failure for retryable (5xx) responses. Zero means single-shot.
|
||||
MaxRetries int
|
||||
// InitialBackoff is the wait before the first retry; doubles per attempt.
|
||||
// Set very small in tests.
|
||||
InitialBackoff time.Duration
|
||||
}
|
||||
|
||||
// NewFromEnv returns a sink populated from SUPABASE_URL +
|
||||
// SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) + IMAGEN_OWNER_USER_ID.
|
||||
// Returns ok=false if the URL or key are missing — the caller treats that
|
||||
// as "cloud-sync disabled by environment".
|
||||
func NewFromEnv() (*Sink, bool) {
|
||||
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
|
||||
if u == "" {
|
||||
return nil, false
|
||||
}
|
||||
key := os.Getenv("SUPABASE_SERVICE_KEY")
|
||||
if key == "" {
|
||||
key = os.Getenv("MAI_SUPABASE_KEY")
|
||||
}
|
||||
if key == "" {
|
||||
return nil, false
|
||||
}
|
||||
return &Sink{
|
||||
URL: u,
|
||||
APIKey: key,
|
||||
OwnerUserID: os.Getenv("IMAGEN_OWNER_USER_ID"),
|
||||
HTTP: &http.Client{Timeout: 30 * time.Second},
|
||||
MaxRetries: 2,
|
||||
InitialBackoff: time.Second,
|
||||
}, true
|
||||
}
|
||||
|
||||
// SyncRequest is the cross-backend ingredient set Sync needs. Date is
|
||||
// formatted as YYYY-MM-DD; Slug + Seed are reused from the local
|
||||
// filename so storage_path mirrors disk layout.
|
||||
type SyncRequest struct {
|
||||
Date string
|
||||
Slug string
|
||||
Seed int64
|
||||
Ext string // "png", "jpg", "webp" — no leading dot
|
||||
PNG []byte
|
||||
MimeType string
|
||||
|
||||
Prompt string
|
||||
Backend string
|
||||
Model string
|
||||
Steps int
|
||||
Width int
|
||||
Height int
|
||||
LatencyMs int
|
||||
CostUSDEstimate *float64
|
||||
Sidecar map[string]any
|
||||
|
||||
// SeriesID is the parent imagen.series row when this image is one of
|
||||
// N tries in a batch. Empty means a solo run — the column stays NULL,
|
||||
// which keeps the row visible on the main list-page query
|
||||
// (`WHERE series_id IS NULL`).
|
||||
SeriesID string
|
||||
}
|
||||
|
||||
// SyncResult tells the caller what landed where.
|
||||
type SyncResult struct {
|
||||
StoragePath string // e.g. "2026-05-11/lighthouse-42.png"
|
||||
ImageID string // imagen.images.id (UUID)
|
||||
}
|
||||
|
||||
// Sync uploads the bytes and inserts the metadata row. Returns the row's
|
||||
// id and storage_path on success; any non-nil error is what the caller
|
||||
// surfaces as "imagen: cloud sync: <err>" and otherwise ignores.
|
||||
func (s *Sink) Sync(ctx context.Context, req SyncRequest) (*SyncResult, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("cloud sink not configured")
|
||||
}
|
||||
if s.OwnerUserID == "" {
|
||||
return nil, fmt.Errorf("owner_user_id not set (config or $IMAGEN_OWNER_USER_ID); refusing to insert NULL into imagen.images")
|
||||
}
|
||||
if req.Date == "" || req.Slug == "" {
|
||||
return nil, fmt.Errorf("date and slug are required for storage_path")
|
||||
}
|
||||
ext := req.Ext
|
||||
if ext == "" {
|
||||
ext = "png"
|
||||
}
|
||||
storagePath := fmt.Sprintf("%s/%s-%d.%s", req.Date, req.Slug, req.Seed, ext)
|
||||
|
||||
if err := s.upload(ctx, storagePath, req.PNG, req.MimeType); err != nil {
|
||||
return nil, fmt.Errorf("storage upload: %w", err)
|
||||
}
|
||||
|
||||
id, err := s.insertRow(ctx, storagePath, req)
|
||||
if err != nil {
|
||||
return &SyncResult{StoragePath: storagePath}, fmt.Errorf("db insert: %w", err)
|
||||
}
|
||||
return &SyncResult{StoragePath: storagePath, ImageID: id}, nil
|
||||
}
|
||||
|
||||
// upload PUTs the PNG into the imagen-generated bucket. We use
|
||||
// Content-Type so signed URLs render in the browser without a download
|
||||
// prompt. POST would error on second-write; PUT (with x-upsert: true) is
|
||||
// idempotent for re-runs of the same date+slug+seed.
|
||||
func (s *Sink) upload(ctx context.Context, storagePath string, body []byte, mime string) error {
|
||||
if mime == "" {
|
||||
mime = "image/png"
|
||||
}
|
||||
endpoint := fmt.Sprintf("%s/storage/v1/object/%s/%s", s.URL, bucketName, pathEscape(storagePath))
|
||||
return s.doRetry(ctx, func(ctx context.Context) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Content-Type", mime)
|
||||
req.Header.Set("x-upsert", "true")
|
||||
return s.HTTP.Do(req)
|
||||
})
|
||||
}
|
||||
|
||||
// insertRow POSTs to PostgREST against the imagen schema. Prefer:
|
||||
// return=representation gives us the inserted id back without a second
|
||||
// round-trip.
|
||||
func (s *Sink) insertRow(ctx context.Context, storagePath string, req SyncRequest) (string, error) {
|
||||
row := map[string]any{
|
||||
"owner_user_id": s.OwnerUserID,
|
||||
"prompt": req.Prompt,
|
||||
"prompt_hash": hashPrompt(req.Prompt),
|
||||
"backend": req.Backend,
|
||||
"storage_path": storagePath,
|
||||
}
|
||||
if req.Model != "" {
|
||||
row["model"] = req.Model
|
||||
}
|
||||
if req.Seed != 0 {
|
||||
row["seed"] = req.Seed
|
||||
}
|
||||
if req.Steps != 0 {
|
||||
row["steps"] = req.Steps
|
||||
}
|
||||
if req.Width != 0 {
|
||||
row["width"] = req.Width
|
||||
}
|
||||
if req.Height != 0 {
|
||||
row["height"] = req.Height
|
||||
}
|
||||
if req.LatencyMs != 0 {
|
||||
row["latency_ms"] = req.LatencyMs
|
||||
}
|
||||
if req.CostUSDEstimate != nil {
|
||||
row["cost_usd_estimate"] = *req.CostUSDEstimate
|
||||
}
|
||||
if len(req.Sidecar) > 0 {
|
||||
row["sidecar"] = req.Sidecar
|
||||
}
|
||||
if req.SeriesID != "" {
|
||||
row["series_id"] = req.SeriesID
|
||||
}
|
||||
|
||||
body, err := json.Marshal(row)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal row: %w", err)
|
||||
}
|
||||
|
||||
endpoint := s.URL + "/rest/v1/images"
|
||||
|
||||
respBody, err := s.doRetryRead(ctx, func(ctx context.Context) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept-Profile", supabaseSchema)
|
||||
req.Header.Set("Content-Profile", supabaseSchema)
|
||||
req.Header.Set("Prefer", "return=representation")
|
||||
return s.HTTP.Do(req)
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var rows []struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &rows); err != nil {
|
||||
return "", fmt.Errorf("parse insert response: %w (body: %s)", err, snip(respBody))
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return "", fmt.Errorf("insert returned 0 rows (body: %s)", snip(respBody))
|
||||
}
|
||||
return rows[0].ID, nil
|
||||
}
|
||||
|
||||
// SignedURL asks the Storage API for a time-limited URL. ttlSeconds is
|
||||
// the validity window. Returned URL is host-qualified and ready to hand
|
||||
// to a browser.
|
||||
func (s *Sink) SignedURL(ctx context.Context, storagePath string, ttlSeconds int) (string, error) {
|
||||
if s == nil {
|
||||
return "", fmt.Errorf("cloud sink not configured")
|
||||
}
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = 3600
|
||||
}
|
||||
endpoint := fmt.Sprintf("%s/storage/v1/object/sign/%s/%s", s.URL, bucketName, pathEscape(storagePath))
|
||||
body, err := json.Marshal(map[string]any{"expiresIn": ttlSeconds})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := s.HTTP.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("sign %d: %s", resp.StatusCode, snip(respBody))
|
||||
}
|
||||
var parsed struct {
|
||||
SignedURL string `json:"signedURL"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return "", fmt.Errorf("parse sign response: %w (body: %s)", err, snip(respBody))
|
||||
}
|
||||
if parsed.SignedURL == "" {
|
||||
return "", fmt.Errorf("empty signedURL in response: %s", snip(respBody))
|
||||
}
|
||||
full := parsed.SignedURL
|
||||
if strings.HasPrefix(full, "/") {
|
||||
full = s.URL + full
|
||||
}
|
||||
return full, nil
|
||||
}
|
||||
|
||||
// doRetry runs op up to MaxRetries+1 times. 5xx and transport errors are
|
||||
// retried with exponential backoff; 4xx surfaces immediately as a
|
||||
// permanent error (caller's bug in the row, not a network blip).
|
||||
func (s *Sink) doRetry(ctx context.Context, op func(context.Context) (*http.Response, error)) error {
|
||||
_, err := s.doRetryRead(ctx, op)
|
||||
return err
|
||||
}
|
||||
|
||||
// doRetryRead is the read-the-body variant. Returns the 2xx response
|
||||
// body bytes; non-2xx is wrapped in an error. Same retry semantics as
|
||||
// doRetry: 5xx/transport retries with exponential backoff, 4xx is fatal.
|
||||
func (s *Sink) doRetryRead(ctx context.Context, op func(context.Context) (*http.Response, error)) ([]byte, error) {
|
||||
backoff := s.InitialBackoff
|
||||
if backoff == 0 {
|
||||
backoff = time.Second
|
||||
}
|
||||
attempts := s.MaxRetries + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
var lastErr error
|
||||
for i := 0; i < attempts; i++ {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff *= 2
|
||||
}
|
||||
resp, err := op(ctx)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if readErr != nil {
|
||||
lastErr = fmt.Errorf("read body: %w", readErr)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return body, nil
|
||||
}
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func hashPrompt(p string) string {
|
||||
sum := sha256.Sum256([]byte(p))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// pathEscape encodes each path segment but keeps the slashes — the
|
||||
// Storage API treats the part after the bucket name as a virtual file
|
||||
// path with directory separators.
|
||||
func pathEscape(p string) string {
|
||||
parts := strings.Split(p, "/")
|
||||
for i, seg := range parts {
|
||||
parts[i] = url.PathEscape(seg)
|
||||
}
|
||||
return path.Join(parts...)
|
||||
}
|
||||
|
||||
func snip(b []byte) string {
|
||||
const max = 500
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > max {
|
||||
s = s[:max] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
374
internal/cloud/cloud_test.go
Normal file
374
internal/cloud/cloud_test.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeSupabase is a tiny stand-in for Supabase Storage + PostgREST. It
|
||||
// records what came in and returns canned responses based on path.
|
||||
type fakeSupabase struct {
|
||||
t *testing.T
|
||||
mux *http.ServeMux
|
||||
server *httptest.Server
|
||||
uploadCalls int32
|
||||
insertCalls int32
|
||||
uploadBytes []byte
|
||||
uploadHdr http.Header
|
||||
insertBody []byte
|
||||
insertHdr http.Header
|
||||
}
|
||||
|
||||
func newFakeSupabase(t *testing.T, opts ...func(*fakeSupabase)) *fakeSupabase {
|
||||
f := &fakeSupabase{t: t}
|
||||
f.mux = http.NewServeMux()
|
||||
// Storage upload — anything under /storage/v1/object/<bucket>/...
|
||||
f.mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&f.uploadCalls, 1)
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
f.uploadBytes = body
|
||||
f.uploadHdr = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"Key":"imagen-generated/somepath"}`))
|
||||
})
|
||||
// Storage sign URL
|
||||
f.mux.HandleFunc("/storage/v1/object/sign/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"signedURL":"/storage/v1/object/sign/imagen-generated/some.png?token=abc"}`))
|
||||
})
|
||||
// PostgREST insert
|
||||
f.mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&f.insertCalls, 1)
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
f.insertBody = body
|
||||
f.insertHdr = r.Header.Clone()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte(`[{"id":"00000000-0000-0000-0000-000000000abc"}]`))
|
||||
})
|
||||
for _, opt := range opts {
|
||||
opt(f)
|
||||
}
|
||||
f.server = httptest.NewServer(f.mux)
|
||||
t.Cleanup(f.server.Close)
|
||||
return f
|
||||
}
|
||||
|
||||
func newSink(server *httptest.Server) *Sink {
|
||||
return &Sink{
|
||||
URL: server.URL,
|
||||
APIKey: "fake-service-key",
|
||||
OwnerUserID: "00000000-0000-0000-0000-000000000001",
|
||||
HTTP: server.Client(),
|
||||
MaxRetries: 2,
|
||||
InitialBackoff: time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncHappyPath(t *testing.T) {
|
||||
f := newFakeSupabase(t)
|
||||
s := newSink(f.server)
|
||||
|
||||
cost := 0.003
|
||||
res, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11",
|
||||
Slug: "lighthouse",
|
||||
Seed: 42,
|
||||
Ext: "png",
|
||||
PNG: []byte("PNGbytes"),
|
||||
MimeType: "image/png",
|
||||
Prompt: "a tiny lighthouse on a stormy cliff",
|
||||
Backend: "flux-schnell-local",
|
||||
Model: "flux1-schnell",
|
||||
Steps: 4,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
LatencyMs: 1500,
|
||||
CostUSDEstimate: &cost,
|
||||
Sidecar: map[string]any{
|
||||
"timestamp": "2026-05-11T01:30:00Z",
|
||||
"backend": "flux-schnell-local",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Sync: %v", err)
|
||||
}
|
||||
if res.StoragePath != "2026-05-11/lighthouse-42.png" {
|
||||
t.Errorf("storage_path = %q", res.StoragePath)
|
||||
}
|
||||
if res.ImageID != "00000000-0000-0000-0000-000000000abc" {
|
||||
t.Errorf("image_id = %q", res.ImageID)
|
||||
}
|
||||
if got := atomic.LoadInt32(&f.uploadCalls); got != 1 {
|
||||
t.Errorf("upload calls = %d, want 1", got)
|
||||
}
|
||||
if got := atomic.LoadInt32(&f.insertCalls); got != 1 {
|
||||
t.Errorf("insert calls = %d, want 1", got)
|
||||
}
|
||||
if !bytes.Equal(f.uploadBytes, []byte("PNGbytes")) {
|
||||
t.Errorf("uploaded bytes = %q", f.uploadBytes)
|
||||
}
|
||||
|
||||
// Verify the row payload carries the prompt + computed hash + non-zero
|
||||
// metadata. Empty fields should be omitted from the JSON body so RLS
|
||||
// won't see surprise keys.
|
||||
var row map[string]any
|
||||
if err := json.Unmarshal(f.insertBody, &row); err != nil {
|
||||
t.Fatalf("insert body parse: %v\n%s", err, f.insertBody)
|
||||
}
|
||||
if row["prompt"] != "a tiny lighthouse on a stormy cliff" {
|
||||
t.Errorf("row.prompt = %v", row["prompt"])
|
||||
}
|
||||
if row["owner_user_id"] != "00000000-0000-0000-0000-000000000001" {
|
||||
t.Errorf("row.owner_user_id = %v", row["owner_user_id"])
|
||||
}
|
||||
if row["storage_path"] != "2026-05-11/lighthouse-42.png" {
|
||||
t.Errorf("row.storage_path = %v", row["storage_path"])
|
||||
}
|
||||
hash, _ := row["prompt_hash"].(string)
|
||||
if len(hash) != 64 {
|
||||
t.Errorf("prompt_hash should be 64-char sha256 hex, got %q", hash)
|
||||
}
|
||||
if row["backend"] != "flux-schnell-local" {
|
||||
t.Errorf("row.backend = %v", row["backend"])
|
||||
}
|
||||
if row["seed"].(float64) != 42 {
|
||||
t.Errorf("row.seed = %v", row["seed"])
|
||||
}
|
||||
if row["latency_ms"].(float64) != 1500 {
|
||||
t.Errorf("row.latency_ms = %v", row["latency_ms"])
|
||||
}
|
||||
if row["cost_usd_estimate"].(float64) != 0.003 {
|
||||
t.Errorf("row.cost = %v", row["cost_usd_estimate"])
|
||||
}
|
||||
if row["sidecar"] == nil {
|
||||
t.Errorf("row.sidecar missing")
|
||||
}
|
||||
|
||||
// PostgREST schema headers — hardcoded to "imagen".
|
||||
if got := f.insertHdr.Get("Accept-Profile"); got != "imagen" {
|
||||
t.Errorf("Accept-Profile = %q", got)
|
||||
}
|
||||
if got := f.insertHdr.Get("Content-Profile"); got != "imagen" {
|
||||
t.Errorf("Content-Profile = %q", got)
|
||||
}
|
||||
if got := f.insertHdr.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") {
|
||||
t.Errorf("Authorization = %q", got)
|
||||
}
|
||||
|
||||
// Storage upsert should be set so re-runs of the same date+slug+seed
|
||||
// don't fail with 409.
|
||||
if got := f.uploadHdr.Get("x-upsert"); got != "true" {
|
||||
t.Errorf("x-upsert = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRetryOn5xx(t *testing.T) {
|
||||
var uploadAttempts int32
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
|
||||
n := atomic.AddInt32(&uploadAttempts, 1)
|
||||
// Two 503s, then OK.
|
||||
if n < 3 {
|
||||
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte(`[{"id":"row-id"}]`))
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
s := newSink(srv)
|
||||
|
||||
res, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Sync (with retry): %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&uploadAttempts); got != 3 {
|
||||
t.Errorf("upload attempts = %d, want 3", got)
|
||||
}
|
||||
if res.ImageID != "row-id" {
|
||||
t.Errorf("image_id = %q", res.ImageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncNoRetryOn4xx(t *testing.T) {
|
||||
var uploadAttempts int32
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&uploadAttempts, 1)
|
||||
http.Error(w, `{"message":"bad request"}`, http.StatusBadRequest)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
s := newSink(srv)
|
||||
|
||||
_, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 400")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "400") {
|
||||
t.Errorf("error should mention 400 status: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&uploadAttempts); got != 1 {
|
||||
t.Errorf("upload attempts = %d, want 1 (no retry on 4xx)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMissingOwnerUserID(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NewServeMux())
|
||||
defer srv.Close()
|
||||
s := &Sink{
|
||||
URL: srv.URL,
|
||||
APIKey: "k",
|
||||
// OwnerUserID intentionally empty.
|
||||
HTTP: srv.Client(),
|
||||
InitialBackoff: time.Millisecond,
|
||||
}
|
||||
_, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when owner_user_id unset")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "owner_user_id") {
|
||||
t.Errorf("error should mention owner_user_id: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRequiresDateAndSlug(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NewServeMux())
|
||||
defer srv.Close()
|
||||
s := newSink(srv)
|
||||
_, err := s.Sync(context.Background(), SyncRequest{
|
||||
Slug: "x", Seed: 1, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing date")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignedURL(t *testing.T) {
|
||||
f := newFakeSupabase(t)
|
||||
s := newSink(f.server)
|
||||
got, err := s.SignedURL(context.Background(), "2026-05-11/x.png", 60)
|
||||
if err != nil {
|
||||
t.Fatalf("SignedURL: %v", err)
|
||||
}
|
||||
want := f.server.URL + "/storage/v1/object/sign/imagen-generated/some.png?token=abc"
|
||||
if got != want {
|
||||
t.Errorf("signed URL = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncDBFailureSurfacesPathOnError(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "schema cache miss", http.StatusInternalServerError)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
s := newSink(srv)
|
||||
res, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11", Slug: "x", Seed: 9, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from DB insert failure")
|
||||
}
|
||||
// Storage upload succeeded — caller can still see the upload landed.
|
||||
if res == nil || res.StoragePath != "2026-05-11/x-9.png" {
|
||||
t.Errorf("expected storage_path on partial success, got %+v", res)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSyncWritesSeriesID is the second half of the ImaGen#9 propagation
|
||||
// contract: when SeriesID is non-empty, the POST body to imagen.images
|
||||
// carries `series_id`. When empty, the key is omitted entirely so the
|
||||
// row's series_id stays NULL (solo-run path, list-page query
|
||||
// `WHERE series_id IS NULL` keeps showing it).
|
||||
func TestSyncWritesSeriesID(t *testing.T) {
|
||||
const seriesID = "22222222-2222-2222-2222-222222222222"
|
||||
f := newFakeSupabase(t)
|
||||
s := newSink(f.server)
|
||||
|
||||
_, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
SeriesID: seriesID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Sync: %v", err)
|
||||
}
|
||||
var row map[string]any
|
||||
if err := json.Unmarshal(f.insertBody, &row); err != nil {
|
||||
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
|
||||
}
|
||||
if row["series_id"] != seriesID {
|
||||
t.Fatalf("row.series_id = %v want %q", row["series_id"], seriesID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncOmitsSeriesIDWhenEmpty(t *testing.T) {
|
||||
f := newFakeSupabase(t)
|
||||
s := newSink(f.server)
|
||||
|
||||
_, err := s.Sync(context.Background(), SyncRequest{
|
||||
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
|
||||
PNG: []byte("p"), Prompt: "p", Backend: "b",
|
||||
// SeriesID intentionally empty.
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Sync: %v", err)
|
||||
}
|
||||
var row map[string]any
|
||||
if err := json.Unmarshal(f.insertBody, &row); err != nil {
|
||||
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
|
||||
}
|
||||
if _, present := row["series_id"]; present {
|
||||
t.Fatalf("solo run should omit series_id from POST body, got %v", row["series_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathEscape(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",
|
||||
"2026-05-11/two words.png": "2026-05-11/two%20words.png",
|
||||
"with#hash/and?query.png": "with%23hash/and%3Fquery.png",
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := pathEscape(in)
|
||||
if got != want {
|
||||
t.Errorf("pathEscape(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
// Sanity: every part should round-trip via url.PathUnescape.
|
||||
for _, seg := range strings.Split(got, "/") {
|
||||
if _, err := url.PathUnescape(seg); err != nil {
|
||||
t.Errorf("segment %q failed unescape: %v", seg, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,15 +15,29 @@ import (
|
||||
// Config is the top-level shape of imagen.yaml.
|
||||
type Config struct {
|
||||
DefaultBackend string `yaml:"default_backend"`
|
||||
// OwnerUserID is m's auth.users.id on msupabase. The cloud-sync writer
|
||||
// uses it to populate imagen.images.owner_user_id (NOT NULL, owns RLS).
|
||||
// Empty disables DB inserts even when cloud_sync is on.
|
||||
OwnerUserID string `yaml:"owner_user_id"`
|
||||
Output OutputConfig `yaml:"output"`
|
||||
Backends map[string]BackendSpec `yaml:"backends"`
|
||||
}
|
||||
|
||||
// OutputConfig controls where generated images and metadata sidecars land.
|
||||
// OutputConfig controls where generated images and metadata sidecars land,
|
||||
// and whether `imagen generate` opens a tmux preview window.
|
||||
type OutputConfig struct {
|
||||
Directory string `yaml:"directory"`
|
||||
Naming string `yaml:"naming"`
|
||||
WriteMetadataJSON bool `yaml:"write_metadata_json"`
|
||||
// Preview is the tri-state preview mode: "auto" (default), "on", "off".
|
||||
// Empty / unset is treated as "auto". $IMAGEN_PREVIEW and the
|
||||
// --preview/--no-preview flags override this in turn.
|
||||
Preview string `yaml:"preview"`
|
||||
// CloudSync controls whether successful generations also upload to
|
||||
// Supabase Storage and insert into imagen.images. Tri-state mirroring
|
||||
// Preview: "auto" (default — on when SUPABASE_URL + SUPABASE_SERVICE_KEY
|
||||
// are set), "on" (errors if env unset), "off". --no-cloud overrides.
|
||||
CloudSync string `yaml:"cloud_sync"`
|
||||
}
|
||||
|
||||
// BackendSpec is one entry under `backends:`. Type identifies the adapter;
|
||||
@@ -78,6 +92,16 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("default_backend %q is not defined under backends:", c.DefaultBackend)
|
||||
}
|
||||
}
|
||||
switch c.Output.Preview {
|
||||
case "", "auto", "on", "off":
|
||||
default:
|
||||
return fmt.Errorf("output.preview = %q (must be auto|on|off)", c.Output.Preview)
|
||||
}
|
||||
switch c.Output.CloudSync {
|
||||
case "", "auto", "on", "off":
|
||||
default:
|
||||
return fmt.Errorf("output.cloud_sync = %q (must be auto|on|off)", c.Output.CloudSync)
|
||||
}
|
||||
for name, spec := range c.Backends {
|
||||
if name == "" {
|
||||
return errors.New("empty backend name")
|
||||
@@ -97,30 +121,94 @@ const Sample = `# imagen.yaml — config for the imagen CLI.
|
||||
|
||||
default_backend: flux-schnell-local
|
||||
|
||||
# Owner UUID for the cloud-sync row in imagen.images. Look up via:
|
||||
# SELECT id FROM auth.users WHERE email = '<your-supabase-email>';
|
||||
# Empty disables imagen.images inserts even when cloud_sync is on.
|
||||
owner_user_id: ""
|
||||
|
||||
output:
|
||||
directory: ~/Pictures/imagen
|
||||
naming: "{date}-{slug}-{seed}.png"
|
||||
write_metadata_json: true
|
||||
# Open a tmux window with tmux-img after a successful generation.
|
||||
# auto (default): preview iff stdout is a TTY and $TMUX is set.
|
||||
# on: always preview (errors outside a tmux session).
|
||||
# off: never preview (use this for batch / CI callers).
|
||||
preview: auto
|
||||
# Sync the PNG to Supabase Storage (bucket: imagen-generated) and insert
|
||||
# a row into imagen.images. Reads SUPABASE_URL + SUPABASE_SERVICE_KEY
|
||||
# from env (same as mai.imagen_usage cost-tracking).
|
||||
# auto (default): on iff env is configured AND owner_user_id is set.
|
||||
# on: always upload (errors if env or owner_user_id is missing).
|
||||
# off: never upload. --no-cloud also forces off per-call.
|
||||
cloud_sync: auto
|
||||
|
||||
backends:
|
||||
# FLUX.1-schnell on the local ComfyUI server. The "workflow" key picks the
|
||||
# bundled template under internal/backend/workflows/; omit it for back-compat
|
||||
# (defaults to flux1-schnell). See docs/backends.md for the per-model setup.
|
||||
flux-schnell-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
workflow: flux1-schnell
|
||||
# Filename of the unet checkpoint inside the ComfyUI server's
|
||||
# models/unet/ directory. See docs/setup-comfyui-mrock.md.
|
||||
# models/unet/ directory.
|
||||
model: flux1-schnell.safetensors
|
||||
vae: ae.safetensors
|
||||
clip_l: clip_l.safetensors
|
||||
clip_t5: t5xxl_fp8_e4m3fn.safetensors
|
||||
dtype: fp8_e4m3fn
|
||||
default_steps: 4
|
||||
default_sampler: euler
|
||||
default_scheduler: simple
|
||||
default_cfg: 1.0
|
||||
|
||||
# FLUX.2 [klein] 4B distilled — sub-second on RTX 4070 Ti SUPER.
|
||||
# Weights: BFL non-commercial; flux-2-klein-base-4b-fp8 in models/unet/,
|
||||
# qwen_3_4b in models/text_encoders/, flux2-vae in models/vae/.
|
||||
flux2-klein-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
workflow: flux2-klein
|
||||
model: flux-2-klein-base-4b-fp8.safetensors
|
||||
vae: flux2-vae.safetensors
|
||||
clip: qwen_3_4b.safetensors
|
||||
dtype: fp8_e4m3fn
|
||||
default_steps: 4
|
||||
default_sampler: euler
|
||||
default_scheduler: simple
|
||||
default_cfg: 1.0
|
||||
guidance: 4.0
|
||||
|
||||
# SD3.5 medium — single-checkpoint variant that bundles the three text
|
||||
# encoders inside the .safetensors. Drop into models/checkpoints/.
|
||||
sd35-medium-local:
|
||||
type: comfyui
|
||||
base_url: http://mrock:8188
|
||||
workflow: sd35-medium
|
||||
model: sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors
|
||||
default_steps: 28
|
||||
default_sampler: dpmpp_2m
|
||||
default_scheduler: sgm_uniform
|
||||
default_cfg: 4.5
|
||||
shift: 3.0
|
||||
|
||||
mock:
|
||||
type: mock
|
||||
|
||||
flux-schnell-replicate:
|
||||
type: replicate
|
||||
api_token_env: REPLICATE_API_TOKEN
|
||||
model: black-forest-labs/flux-schnell
|
||||
default_steps: 4
|
||||
default_aspect_ratio: "1:1"
|
||||
|
||||
flux-dev-replicate:
|
||||
type: replicate
|
||||
api_token_env: REPLICATE_API_TOKEN
|
||||
model: black-forest-labs/flux-dev
|
||||
default_steps: 28
|
||||
default_aspect_ratio: "1:1"
|
||||
|
||||
dalle3:
|
||||
type: openai
|
||||
|
||||
@@ -60,6 +60,67 @@ func TestValidateRejectsMissingType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePreviewMode(t *testing.T) {
|
||||
for _, mode := range []string{"", "auto", "on", "off"} {
|
||||
c := &Config{Output: OutputConfig{Preview: mode}}
|
||||
if err := c.Validate(); err != nil {
|
||||
t.Errorf("preview=%q: unexpected error %v", mode, err)
|
||||
}
|
||||
}
|
||||
bad := &Config{Output: OutputConfig{Preview: "yes"}}
|
||||
if err := bad.Validate(); err == nil {
|
||||
t.Errorf("expected error for invalid preview value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCloudSyncMode(t *testing.T) {
|
||||
for _, mode := range []string{"", "auto", "on", "off"} {
|
||||
c := &Config{Output: OutputConfig{CloudSync: mode}}
|
||||
if err := c.Validate(); err != nil {
|
||||
t.Errorf("cloud_sync=%q: unexpected error %v", mode, err)
|
||||
}
|
||||
}
|
||||
bad := &Config{Output: OutputConfig{CloudSync: "yes"}}
|
||||
if err := bad.Validate(); err == nil {
|
||||
t.Errorf("expected error for invalid cloud_sync value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleParsesCloudSyncAuto(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "imagen.yaml")
|
||||
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
|
||||
t.Fatalf("write sample: %v", err)
|
||||
}
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if cfg.Output.CloudSync != "auto" {
|
||||
t.Errorf("Output.CloudSync = %q, want auto", cfg.Output.CloudSync)
|
||||
}
|
||||
// owner_user_id is intentionally empty in the sample — operators fill
|
||||
// it in after looking up their auth.users.id.
|
||||
if cfg.OwnerUserID != "" {
|
||||
t.Errorf("Sample OwnerUserID should be empty, got %q", cfg.OwnerUserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleParsesPreviewAuto(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "imagen.yaml")
|
||||
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
|
||||
t.Fatalf("write sample: %v", err)
|
||||
}
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if cfg.Output.Preview != "auto" {
|
||||
t.Errorf("Output.Preview = %q, want auto", cfg.Output.Preview)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandPath(t *testing.T) {
|
||||
home, _ := os.UserHomeDir()
|
||||
cases := map[string]string{
|
||||
|
||||
@@ -35,6 +35,13 @@ type Inputs struct {
|
||||
type Outputs struct {
|
||||
ImagePath string
|
||||
SidecarPath string
|
||||
// Date is the YYYY-MM-DD the writer used for the filename. Cloud sync
|
||||
// reuses this so storage_path matches the local filename's date.
|
||||
Date string
|
||||
// Slug is the filename-safe prompt fragment the writer used.
|
||||
Slug string
|
||||
// Seed is the seed value baked into the filename.
|
||||
Seed int64
|
||||
}
|
||||
|
||||
// Write streams img to disk and, if enabled, writes a sidecar. The image
|
||||
@@ -50,10 +57,12 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
|
||||
if tmpl == "" {
|
||||
tmpl = "{date}-{slug}-{seed}.{ext}"
|
||||
}
|
||||
date := now.Format("2006-01-02")
|
||||
slug := Slug(in.Prompt)
|
||||
name := renderTemplate(tmpl, map[string]string{
|
||||
"date": now.Format("2006-01-02"),
|
||||
"date": date,
|
||||
"time": now.Format("150405"),
|
||||
"slug": Slug(in.Prompt),
|
||||
"slug": slug,
|
||||
"seed": fmt.Sprintf("%d", in.Seed),
|
||||
"backend": in.Backend,
|
||||
"ext": strings.TrimPrefix(ext, "."),
|
||||
@@ -80,7 +89,7 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
|
||||
return nil, fmt.Errorf("close %s: %w", imagePath, err)
|
||||
}
|
||||
|
||||
out := &Outputs{ImagePath: imagePath}
|
||||
out := &Outputs{ImagePath: imagePath, Date: date, Slug: slug, Seed: in.Seed}
|
||||
|
||||
if w.WriteSidecar {
|
||||
sidecar := imagePath + ".json"
|
||||
@@ -122,7 +131,7 @@ func (w *Writer) WriteToPath(img io.Reader, path string, in Inputs) (*Outputs, e
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, fmt.Errorf("close %s: %w", path, err)
|
||||
}
|
||||
out := &Outputs{ImagePath: path}
|
||||
out := &Outputs{ImagePath: path, Date: now.Format("2006-01-02"), Slug: Slug(in.Prompt), Seed: in.Seed}
|
||||
if w.WriteSidecar {
|
||||
sidecar := path + ".json"
|
||||
body := map[string]any{
|
||||
|
||||
119
internal/preview/tmux.go
Normal file
119
internal/preview/tmux.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Package preview opens a tmux window showing a generated image via tmux-img.
|
||||
// Mode resolution and the actual spawn are kept separate so the CLI can
|
||||
// decide-then-act and tests can drive each half independently.
|
||||
package preview
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Mode is the tri-state preview setting: auto (default), on (force), off.
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeAuto Mode = "auto"
|
||||
ModeOn Mode = "on"
|
||||
ModeOff Mode = "off"
|
||||
)
|
||||
|
||||
// ParseMode normalises a string into a Mode. Empty parses to ModeAuto so
|
||||
// callers can pass through unset config / env values.
|
||||
func ParseMode(s string) (Mode, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "auto":
|
||||
return ModeAuto, nil
|
||||
case "on":
|
||||
return ModeOn, nil
|
||||
case "off":
|
||||
return ModeOff, nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid preview mode %q (auto|on|off)", s)
|
||||
}
|
||||
|
||||
// Decision is the answer to "should we preview, and why".
|
||||
type Decision struct {
|
||||
ShouldPreview bool
|
||||
Reason string
|
||||
}
|
||||
|
||||
// Resolve maps (mode, runtime context) to a Decision.
|
||||
//
|
||||
// - off -> never preview
|
||||
// - on -> preview, but error if not in tmux (forced on outside tmux)
|
||||
// - auto -> preview iff inTmux && stdoutTTY
|
||||
func Resolve(mode Mode, inTmux, stdoutTTY bool) (Decision, error) {
|
||||
switch mode {
|
||||
case ModeOff:
|
||||
return Decision{ShouldPreview: false, Reason: "preview=off"}, nil
|
||||
case ModeOn:
|
||||
if !inTmux {
|
||||
return Decision{}, ErrNoTmuxForced
|
||||
}
|
||||
return Decision{ShouldPreview: true, Reason: "preview=on"}, nil
|
||||
case ModeAuto, "":
|
||||
if !inTmux {
|
||||
return Decision{ShouldPreview: false, Reason: "auto: $TMUX unset"}, nil
|
||||
}
|
||||
if !stdoutTTY {
|
||||
return Decision{ShouldPreview: false, Reason: "auto: stdout not a tty"}, nil
|
||||
}
|
||||
return Decision{ShouldPreview: true, Reason: "auto"}, nil
|
||||
}
|
||||
return Decision{}, fmt.Errorf("invalid preview mode %q", mode)
|
||||
}
|
||||
|
||||
// Errors returned by Spawn and Resolve. Each names the missing piece and,
|
||||
// where relevant, where to install it.
|
||||
var (
|
||||
ErrTmuxMissing = errors.New("tmux: binary not found on $PATH (required for image preview)")
|
||||
ErrTmuxImgMissing = errors.New("tmux-img: binary not found on $PATH (install at ~/.local/bin/tmux-img)")
|
||||
ErrNoTmuxForced = errors.New("--preview requires $TMUX (are you in a tmux session?)")
|
||||
)
|
||||
|
||||
// Spawner spawns the tmux preview window. The exec.LookPath / cmd.Run hooks
|
||||
// exist so tests can inject fakes without touching $PATH.
|
||||
type Spawner struct {
|
||||
LookPath func(string) (string, error)
|
||||
Run func(*exec.Cmd) error
|
||||
}
|
||||
|
||||
// Spawn opens a new tmux window named img:<slug> running tmux-img --hold
|
||||
// <imagePath>. -d keeps focus in the current pane. Caller is expected to
|
||||
// have already verified that we are inside a tmux session.
|
||||
func (s *Spawner) Spawn(imagePath, slug string) error {
|
||||
look := s.LookPath
|
||||
if look == nil {
|
||||
look = exec.LookPath
|
||||
}
|
||||
run := s.Run
|
||||
if run == nil {
|
||||
run = func(c *exec.Cmd) error { return c.Run() }
|
||||
}
|
||||
|
||||
tmuxBin, err := look("tmux")
|
||||
if err != nil {
|
||||
return ErrTmuxMissing
|
||||
}
|
||||
tmuxImgBin, err := look("tmux-img")
|
||||
if err != nil {
|
||||
return ErrTmuxImgMissing
|
||||
}
|
||||
|
||||
name := "img:" + slug
|
||||
shellCmd := fmt.Sprintf("%s --hold %s",
|
||||
shellQuote(tmuxImgBin), shellQuote(imagePath))
|
||||
cmd := exec.Command(tmuxBin, "new-window", "-d", "-n", name, shellCmd)
|
||||
if err := run(cmd); err != nil {
|
||||
return fmt.Errorf("tmux new-window: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shellQuote single-quotes s for /bin/sh — tmux passes the trailing arg of
|
||||
// new-window through a shell.
|
||||
func shellQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
|
||||
}
|
||||
170
internal/preview/tmux_test.go
Normal file
170
internal/preview/tmux_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package preview
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseMode(t *testing.T) {
|
||||
cases := map[string]Mode{
|
||||
"": ModeAuto,
|
||||
"auto": ModeAuto,
|
||||
"AUTO": ModeAuto,
|
||||
"on": ModeOn,
|
||||
" on ": ModeOn,
|
||||
"off": ModeOff,
|
||||
}
|
||||
for in, want := range cases {
|
||||
got, err := ParseMode(in)
|
||||
if err != nil {
|
||||
t.Errorf("ParseMode(%q) err = %v", in, err)
|
||||
continue
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("ParseMode(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
if _, err := ParseMode("nope"); err == nil {
|
||||
t.Errorf("ParseMode(nope) should have errored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve(t *testing.T) {
|
||||
type tc struct {
|
||||
mode Mode
|
||||
inTmux bool
|
||||
stdoutTTY bool
|
||||
want bool
|
||||
wantErr error
|
||||
}
|
||||
cases := map[string]tc{
|
||||
"off-anywhere": {ModeOff, false, false, false, nil},
|
||||
"off-in-tmux-tty": {ModeOff, true, true, false, nil},
|
||||
"on-in-tmux": {ModeOn, true, false, true, nil},
|
||||
"on-outside-tmux-errs": {ModeOn, false, true, false, ErrNoTmuxForced},
|
||||
"auto-no-tmux": {ModeAuto, false, true, false, nil},
|
||||
"auto-tmux-no-tty": {ModeAuto, true, false, false, nil},
|
||||
"auto-tmux-and-tty": {ModeAuto, true, true, true, nil},
|
||||
}
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d, err := Resolve(c.mode, c.inTmux, c.stdoutTTY)
|
||||
if c.wantErr != nil {
|
||||
if !errors.Is(err, c.wantErr) {
|
||||
t.Fatalf("err = %v, want %v", err, c.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
if d.ShouldPreview != c.want {
|
||||
t.Errorf("ShouldPreview = %v, want %v (reason: %s)", d.ShouldPreview, c.want, d.Reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_BuildsCorrectCommand(t *testing.T) {
|
||||
var captured *exec.Cmd
|
||||
s := &Spawner{
|
||||
LookPath: func(name string) (string, error) {
|
||||
switch name {
|
||||
case "tmux":
|
||||
return "/usr/bin/tmux", nil
|
||||
case "tmux-img":
|
||||
return "/home/m/.local/bin/tmux-img", nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
},
|
||||
Run: func(c *exec.Cmd) error {
|
||||
captured = c
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := s.Spawn("/tmp/imagen/cat.png", "cat-in-a-fishbowl"); err != nil {
|
||||
t.Fatalf("Spawn: %v", err)
|
||||
}
|
||||
if captured == nil {
|
||||
t.Fatal("Run was not called")
|
||||
}
|
||||
if captured.Path != "/usr/bin/tmux" {
|
||||
t.Errorf("Path = %q, want /usr/bin/tmux", captured.Path)
|
||||
}
|
||||
args := captured.Args
|
||||
if len(args) < 6 {
|
||||
t.Fatalf("args = %v (need at least 6)", args)
|
||||
}
|
||||
// tmux new-window -d -n img:<slug> '<shell-cmd>'
|
||||
if args[1] != "new-window" {
|
||||
t.Errorf("args[1] = %q, want new-window", args[1])
|
||||
}
|
||||
if args[2] != "-d" {
|
||||
t.Errorf("args[2] = %q, want -d", args[2])
|
||||
}
|
||||
if args[3] != "-n" {
|
||||
t.Errorf("args[3] = %q, want -n", args[3])
|
||||
}
|
||||
if args[4] != "img:cat-in-a-fishbowl" {
|
||||
t.Errorf("args[4] = %q, want img:cat-in-a-fishbowl", args[4])
|
||||
}
|
||||
shellCmd := args[5]
|
||||
if !strings.Contains(shellCmd, "tmux-img") || !strings.Contains(shellCmd, "--hold") || !strings.Contains(shellCmd, "/tmp/imagen/cat.png") {
|
||||
t.Errorf("shell cmd %q missing expected pieces", shellCmd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_PathWithSpacesAndQuotes(t *testing.T) {
|
||||
var captured *exec.Cmd
|
||||
s := &Spawner{
|
||||
LookPath: func(name string) (string, error) {
|
||||
if name == "tmux" {
|
||||
return "/usr/bin/tmux", nil
|
||||
}
|
||||
if name == "tmux-img" {
|
||||
return "/usr/local/bin/tmux-img", nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
},
|
||||
Run: func(c *exec.Cmd) error { captured = c; return nil },
|
||||
}
|
||||
weird := "/tmp/imagen/o'malley's cat.png"
|
||||
if err := s.Spawn(weird, "slug"); err != nil {
|
||||
t.Fatalf("Spawn: %v", err)
|
||||
}
|
||||
shellCmd := captured.Args[5]
|
||||
// Single-quoted with the embedded apostrophe escaped via the
|
||||
// '\'' shell idiom — confirm we did not just splice the raw path.
|
||||
if strings.Contains(shellCmd, "o'malley's") {
|
||||
t.Errorf("shell cmd %q contains unescaped apostrophes", shellCmd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_MissingTmux(t *testing.T) {
|
||||
s := &Spawner{
|
||||
LookPath: func(string) (string, error) { return "", exec.ErrNotFound },
|
||||
Run: func(*exec.Cmd) error { return nil },
|
||||
}
|
||||
err := s.Spawn("/x.png", "s")
|
||||
if !errors.Is(err, ErrTmuxMissing) {
|
||||
t.Errorf("err = %v, want ErrTmuxMissing", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawn_MissingTmuxImg(t *testing.T) {
|
||||
s := &Spawner{
|
||||
LookPath: func(name string) (string, error) {
|
||||
if name == "tmux" {
|
||||
return "/usr/bin/tmux", nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
},
|
||||
Run: func(*exec.Cmd) error { return nil },
|
||||
}
|
||||
err := s.Spawn("/x.png", "s")
|
||||
if !errors.Is(err, ErrTmuxImgMissing) {
|
||||
t.Errorf("err = %v, want ErrTmuxImgMissing", err)
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func TestApplyToEmptyPromptUsesPresetOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStylesContainsAllExpected(t *testing.T) {
|
||||
want := []string{"blog-header", "diagram", "illustration", "photo", "sketch"}
|
||||
want := []string{"3d-render", "anime", "blog-header", "cinematic", "diagram", "illustration", "isometric", "line-art", "photo", "sketch", "watercolor"}
|
||||
got := Styles()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("Styles() = %v, want %v", got, want)
|
||||
|
||||
@@ -4,3 +4,9 @@ styles:
|
||||
diagram: "minimal technical diagram, isometric, white background, line-art"
|
||||
sketch: "rough pencil sketch, hand-drawn, monochrome"
|
||||
blog-header: "wide aspect, conceptual, soft palette, editorial illustration"
|
||||
cinematic: "cinematic still, 35mm film, shallow depth of field, dramatic lighting, color graded"
|
||||
watercolor: "watercolor painting, soft washes, paper texture, loose brushwork"
|
||||
anime: "anime illustration, cel-shaded, expressive linework, vibrant flat colors"
|
||||
3d-render: "3d render, octane, soft global illumination, subtle ambient occlusion, physically based materials"
|
||||
line-art: "clean line art, black ink on white, no shading, even stroke weight"
|
||||
isometric: "isometric illustration, 30 degree projection, flat colors, crisp geometric shapes"
|
||||
|
||||
160
internal/usage/usage.go
Normal file
160
internal/usage/usage.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Package usage records per-call cost-tracking rows for the imagen CLI
|
||||
// to mai.imagen_usage on Supabase. The writer is best-effort by design —
|
||||
// the calling adapter logs failures and proceeds, because the image
|
||||
// itself has already landed on disk by the time we record.
|
||||
package usage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
)
|
||||
|
||||
// Default REST schema is the mai schema where mai.imagen_usage lives.
|
||||
const supabaseSchema = "mai"
|
||||
|
||||
// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/
|
||||
// Content-Profile headers to target the mai schema instead of public.
|
||||
type SupabaseSink struct {
|
||||
URL string // SUPABASE_URL — e.g. https://msup.msbls.de
|
||||
APIKey string // SUPABASE_SERVICE_KEY
|
||||
HTTP *http.Client
|
||||
}
|
||||
|
||||
// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY
|
||||
// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use.
|
||||
// Returns nil + ok=false if the env vars are not configured — the CLI
|
||||
// uses that to skip cost-tracking gracefully.
|
||||
func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) {
|
||||
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
|
||||
if u == "" {
|
||||
return nil, false
|
||||
}
|
||||
key := os.Getenv("SUPABASE_SERVICE_KEY")
|
||||
if key == "" {
|
||||
key = os.Getenv("MAI_SUPABASE_KEY")
|
||||
}
|
||||
if key == "" {
|
||||
return nil, false
|
||||
}
|
||||
return &SupabaseSink{
|
||||
URL: u,
|
||||
APIKey: key,
|
||||
HTTP: &http.Client{Timeout: 10 * time.Second},
|
||||
}, true
|
||||
}
|
||||
|
||||
type supabaseRow struct {
|
||||
Backend string `json:"backend"`
|
||||
Model string `json:"model"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
PromptHash string `json:"prompt_hash"`
|
||||
LatencyMs int `json:"latency_ms"`
|
||||
CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
}
|
||||
|
||||
// Record inserts one row into mai.imagen_usage.
|
||||
func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error {
|
||||
body, err := json.Marshal(supabaseRow{
|
||||
Backend: row.Backend,
|
||||
Model: row.Model,
|
||||
Seed: row.Seed,
|
||||
PromptHash: row.PromptHash,
|
||||
LatencyMs: row.LatencyMs,
|
||||
CostUSDEstimate: row.CostUSDEstimate,
|
||||
Caller: row.Caller,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("usage: marshal: %w", err)
|
||||
}
|
||||
|
||||
endpoint := s.URL + "/rest/v1/imagen_usage"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept-Profile", supabaseSchema)
|
||||
req.Header.Set("Content-Profile", supabaseSchema)
|
||||
req.Header.Set("Prefer", "return=minimal")
|
||||
|
||||
resp, err := s.HTTP.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("usage: POST: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Row is the read-side row shape (only the fields the CLI needs).
|
||||
type Row struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Backend string `json:"backend"`
|
||||
Model string `json:"model"`
|
||||
Seed *int64 `json:"seed"`
|
||||
PromptHash string `json:"prompt_hash"`
|
||||
LatencyMs *int `json:"latency_ms"`
|
||||
CostUSDEstimate *float64 `json:"cost_usd_estimate"`
|
||||
Caller *string `json:"caller"`
|
||||
}
|
||||
|
||||
// Query returns rows from mai.imagen_usage filtered by created_at >= since.
|
||||
// Pass zero time to fetch the full table (capped server-side by PostgREST
|
||||
// — we set a hard 5000-row limit here too).
|
||||
func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) {
|
||||
q := url.Values{}
|
||||
q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller")
|
||||
q.Set("order", "created_at.desc")
|
||||
q.Set("limit", "5000")
|
||||
if !since.IsZero() {
|
||||
q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("apikey", s.APIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+s.APIKey)
|
||||
req.Header.Set("Accept-Profile", supabaseSchema)
|
||||
|
||||
resp, err := s.HTTP.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("usage: GET: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body))
|
||||
}
|
||||
var rows []Row
|
||||
if err := json.Unmarshal(body, &rows); err != nil {
|
||||
return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body))
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func snip(b []byte) string {
|
||||
const max = 500
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > max {
|
||||
s = s[:max] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
217
internal/worker/worker.go
Normal file
217
internal/worker/worker.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Package worker consumes the imagen.jobs queue. It claims pending rows via
|
||||
// an UPDATE-returning lock (single source of truth, no double-claim window),
|
||||
// runs the supplied generation pipeline, then writes status + image_id back.
|
||||
//
|
||||
// The package is DB-agnostic: it talks to two small interfaces (Queue +
|
||||
// Pipeline) so unit tests can drive the claim/transition logic with no real
|
||||
// Postgres connection. cmd/imagen wires the pgx implementation.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Job is the slice of an imagen.jobs row the worker needs to drive a
|
||||
// generation. Null columns from the DB are represented as zero values; the
|
||||
// pipeline treats zero values as "use backend default" (same convention as
|
||||
// backend.Request).
|
||||
type Job struct {
|
||||
ID string
|
||||
OwnerUserID string
|
||||
Prompt string
|
||||
Backend string
|
||||
Model string
|
||||
Width int
|
||||
Height int
|
||||
Steps int
|
||||
Seed int64
|
||||
Style string
|
||||
// SeriesID is the parent imagen.series row when this job is one of N
|
||||
// tries in a batch. Empty means a solo run — the pipeline must not
|
||||
// propagate a series_id onto the resulting imagen.images row.
|
||||
SeriesID string
|
||||
}
|
||||
|
||||
// Outcome is what the pipeline reports back per job. ImageID is the
|
||||
// imagen.images.id the cloud-sync produced. Empty ImageID with nil Err means
|
||||
// the cloud-sync was skipped (config off) — we treat that as a failure for
|
||||
// the worker since flexsiebels needs the image_id to render the result.
|
||||
type Outcome struct {
|
||||
ImageID string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Queue is the persistence layer for the imagen.jobs table. Implementations
|
||||
// must be safe for serialised single-worker use (concurrent claim across
|
||||
// multiple worker processes is out of scope for v1 — the FOR UPDATE SKIP
|
||||
// LOCKED clause in the pgx claim query covers it cheaply anyway).
|
||||
type Queue interface {
|
||||
// ClaimNextPending atomically marks the oldest pending row 'running' and
|
||||
// returns it. Returns (nil, nil) when the queue is empty.
|
||||
ClaimNextPending(ctx context.Context) (*Job, error)
|
||||
// MarkDone records success: status='done', image_id, completed_at=now().
|
||||
MarkDone(ctx context.Context, jobID, imageID string) error
|
||||
// MarkFailed records failure: status='failed', error=msg, completed_at=now().
|
||||
MarkFailed(ctx context.Context, jobID, errMsg string) error
|
||||
// WaitForJob blocks until either a NOTIFY arrives on imagen_jobs, the
|
||||
// timeout expires, or ctx is cancelled. Returns nil on notification or
|
||||
// timeout; returns ctx.Err() on cancellation. Transient connection errors
|
||||
// are returned so the caller can decide to reconnect.
|
||||
WaitForJob(ctx context.Context, timeout time.Duration) error
|
||||
// ResetStaleRunning marks any rows stuck in 'running' (e.g. left over
|
||||
// from a crash before this process started) back to 'pending'. Called
|
||||
// once at worker startup so the cold-start safety poll can pick them up.
|
||||
ResetStaleRunning(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Pipeline runs one generation and reports back the imagen.images.id (or an
|
||||
// error). The implementation owns backend dispatch, prompt enrichment, disk
|
||||
// write, and cloud-sync; the worker only orchestrates queue state.
|
||||
type Pipeline interface {
|
||||
Run(ctx context.Context, job Job) Outcome
|
||||
}
|
||||
|
||||
// Config is the runtime knob set for the worker loop.
|
||||
type Config struct {
|
||||
// PollInterval is the safety-poll cadence between LISTEN wakeups. Picking
|
||||
// this too low wastes DB roundtrips; too high lets a dropped NOTIFY
|
||||
// stall the queue. 5s is the spec'd default.
|
||||
PollInterval time.Duration
|
||||
// JobTimeout caps any single Pipeline.Run. A backend hang shouldn't
|
||||
// freeze the queue forever.
|
||||
JobTimeout time.Duration
|
||||
// Logger receives one-line status events. nil means silent.
|
||||
Logger func(format string, args ...any)
|
||||
}
|
||||
|
||||
// Worker is the orchestration loop. It is not reusable across Run calls.
|
||||
type Worker struct {
|
||||
q Queue
|
||||
p Pipeline
|
||||
cfg Config
|
||||
|
||||
// processingMu guards the in-flight job so SIGTERM-triggered shutdown
|
||||
// waits for it to complete before returning.
|
||||
processingMu sync.Mutex
|
||||
}
|
||||
|
||||
// New constructs a Worker.
|
||||
func New(q Queue, p Pipeline, cfg Config) *Worker {
|
||||
if cfg.PollInterval <= 0 {
|
||||
cfg.PollInterval = 5 * time.Second
|
||||
}
|
||||
if cfg.JobTimeout <= 0 {
|
||||
cfg.JobTimeout = 5 * time.Minute
|
||||
}
|
||||
return &Worker{q: q, p: p, cfg: cfg}
|
||||
}
|
||||
|
||||
// Run drives the consume loop until ctx is cancelled or a fatal queue error
|
||||
// (e.g. unrecoverable DB drop) is returned. A LISTEN wait can fail with a
|
||||
// transient transport error; the worker logs and continues so a temporary
|
||||
// network blip doesn't take it down.
|
||||
func (w *Worker) Run(ctx context.Context) error {
|
||||
if err := w.q.ResetStaleRunning(ctx); err != nil {
|
||||
w.log("worker: reset stale running rows: %v", err)
|
||||
// Don't return — a stale row will eventually be visible to the poll
|
||||
// path once flexsiebels gives up and resubmits, and we'd rather keep
|
||||
// serving fresh jobs than crash here.
|
||||
}
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil
|
||||
}
|
||||
// Drain the queue: claim and process until empty.
|
||||
if err := w.drain(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
w.log("worker: drain: %v", err)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil
|
||||
}
|
||||
// Wait for the next wake. WaitForJob covers both LISTEN and the
|
||||
// timeout-based poll fallback; either returns nil and we loop.
|
||||
if err := w.q.WaitForJob(ctx, w.cfg.PollInterval); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
w.log("worker: wait: %v (continuing)", err)
|
||||
// Pace the retries so a totally-broken DB doesn't busy-spin.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(w.cfg.PollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// drain claims and processes every currently-pending job. The job-scoped
|
||||
// context is derived from context.Background() so that a SIGTERM mid-job
|
||||
// still lets the pipeline finish — that's the "no half-state on shutdown"
|
||||
// guarantee the issue calls for.
|
||||
func (w *Worker) drain(ctx context.Context) error {
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
job, err := w.q.ClaimNextPending(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("claim: %w", err)
|
||||
}
|
||||
if job == nil {
|
||||
return nil
|
||||
}
|
||||
w.processOne(*job)
|
||||
}
|
||||
}
|
||||
|
||||
// processOne runs the pipeline for one already-claimed job and writes the
|
||||
// outcome back to the queue. The job context is independent of the outer
|
||||
// ctx so an in-flight job can finish even after SIGTERM.
|
||||
func (w *Worker) processOne(job Job) {
|
||||
w.processingMu.Lock()
|
||||
defer w.processingMu.Unlock()
|
||||
|
||||
w.log("worker: processing job %s backend=%s", job.ID, job.Backend)
|
||||
jobCtx, cancel := context.WithTimeout(context.Background(), w.cfg.JobTimeout)
|
||||
defer cancel()
|
||||
out := w.p.Run(jobCtx, job)
|
||||
|
||||
// Status-update uses Background ctx with a short timeout — we must
|
||||
// always be able to record the outcome, otherwise the row sits in
|
||||
// 'running' forever.
|
||||
updCtx, updCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer updCancel()
|
||||
if out.Err != nil {
|
||||
w.log("worker: job %s failed: %v", job.ID, out.Err)
|
||||
if err := w.q.MarkFailed(updCtx, job.ID, out.Err.Error()); err != nil {
|
||||
w.log("worker: mark failed for %s: %v", job.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if out.ImageID == "" {
|
||||
// Pipeline reported success but no imagen.images row — treat as
|
||||
// failure because flexsiebels has nothing to link.
|
||||
const msg = "pipeline did not return an imagen.images id (cloud sync misconfigured?)"
|
||||
w.log("worker: job %s: %s", job.ID, msg)
|
||||
if err := w.q.MarkFailed(updCtx, job.ID, msg); err != nil {
|
||||
w.log("worker: mark failed for %s: %v", job.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := w.q.MarkDone(updCtx, job.ID, out.ImageID); err != nil {
|
||||
w.log("worker: mark done for %s: %v", job.ID, err)
|
||||
return
|
||||
}
|
||||
w.log("worker: job %s done image_id=%s", job.ID, out.ImageID)
|
||||
}
|
||||
|
||||
func (w *Worker) log(format string, args ...any) {
|
||||
if w.cfg.Logger != nil {
|
||||
w.cfg.Logger(format, args...)
|
||||
}
|
||||
}
|
||||
376
internal/worker/worker_test.go
Normal file
376
internal/worker/worker_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeQueue is a hand-rolled in-memory queue that mirrors the contract of a
|
||||
// real Postgres-backed implementation: ClaimNextPending atomically takes one
|
||||
// pending row and flips its status to "running", MarkDone/MarkFailed are
|
||||
// idempotent terminal transitions, WaitForJob blocks until notified or until
|
||||
// the timeout elapses.
|
||||
type fakeQueue struct {
|
||||
mu sync.Mutex
|
||||
pending []Job
|
||||
state map[string]string // jobID -> status
|
||||
last map[string]string // jobID -> error msg or image_id
|
||||
notify chan struct{}
|
||||
|
||||
claimErr error
|
||||
doneErr error
|
||||
failErr error
|
||||
resetErr error
|
||||
|
||||
claimed int
|
||||
done int
|
||||
failed int
|
||||
resets int
|
||||
}
|
||||
|
||||
func newFakeQueue(jobs ...Job) *fakeQueue {
|
||||
q := &fakeQueue{
|
||||
state: make(map[string]string),
|
||||
last: make(map[string]string),
|
||||
notify: make(chan struct{}, 16),
|
||||
}
|
||||
for _, j := range jobs {
|
||||
q.pending = append(q.pending, j)
|
||||
q.state[j.ID] = "pending"
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *fakeQueue) ClaimNextPending(ctx context.Context) (*Job, error) {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
if q.claimErr != nil {
|
||||
return nil, q.claimErr
|
||||
}
|
||||
if len(q.pending) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
j := q.pending[0]
|
||||
q.pending = q.pending[1:]
|
||||
q.state[j.ID] = "running"
|
||||
q.claimed++
|
||||
return &j, nil
|
||||
}
|
||||
|
||||
func (q *fakeQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
if q.doneErr != nil {
|
||||
return q.doneErr
|
||||
}
|
||||
q.state[jobID] = "done"
|
||||
q.last[jobID] = imageID
|
||||
q.done++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
if q.failErr != nil {
|
||||
return q.failErr
|
||||
}
|
||||
q.state[jobID] = "failed"
|
||||
q.last[jobID] = msg
|
||||
q.failed++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-q.notify:
|
||||
return nil
|
||||
case <-time.After(timeout):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (q *fakeQueue) ResetStaleRunning(ctx context.Context) error {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.resets++
|
||||
return q.resetErr
|
||||
}
|
||||
|
||||
// pingNotify simulates an INSERT-trigger NOTIFY by waking WaitForJob.
|
||||
func (q *fakeQueue) pingNotify() {
|
||||
select {
|
||||
case q.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// stub pipeline.
|
||||
type fakePipeline struct {
|
||||
mu sync.Mutex
|
||||
results map[string]Outcome // by job.ID; "" key = default outcome
|
||||
calls int
|
||||
delay time.Duration
|
||||
lastJob Job
|
||||
}
|
||||
|
||||
func (p *fakePipeline) Run(ctx context.Context, job Job) Outcome {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
p.lastJob = job
|
||||
delay := p.delay
|
||||
out, ok := p.results[job.ID]
|
||||
if !ok {
|
||||
out = p.results[""]
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if delay > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return Outcome{Err: ctx.Err()}
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestWorker_DonePath(t *testing.T) {
|
||||
q := newFakeQueue(
|
||||
Job{ID: "j1", Prompt: "a", Backend: "mock"},
|
||||
)
|
||||
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img-1"}}}
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
if err := w.Run(ctx); err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
if got := q.state["j1"]; got != "done" {
|
||||
t.Fatalf("state=%q want done", got)
|
||||
}
|
||||
if got := q.last["j1"]; got != "img-1" {
|
||||
t.Fatalf("image_id=%q want img-1", got)
|
||||
}
|
||||
if q.done != 1 || q.failed != 0 {
|
||||
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
|
||||
}
|
||||
if p.calls != 1 {
|
||||
t.Fatalf("pipeline calls=%d want 1", p.calls)
|
||||
}
|
||||
if q.resets != 1 {
|
||||
t.Fatalf("ResetStaleRunning calls=%d want 1", q.resets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_FailedPath_RecordsErrorText(t *testing.T) {
|
||||
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
|
||||
p := &fakePipeline{results: map[string]Outcome{"j1": {Err: errors.New("backend unreachable")}}}
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
||||
_ = w.Run(ctx)
|
||||
|
||||
if got := q.state["j1"]; got != "failed" {
|
||||
t.Fatalf("state=%q want failed", got)
|
||||
}
|
||||
if got := q.last["j1"]; got != "backend unreachable" {
|
||||
t.Fatalf("error=%q want %q", got, "backend unreachable")
|
||||
}
|
||||
if q.done != 0 || q.failed != 1 {
|
||||
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_MissingImageID_TreatedAsFailure(t *testing.T) {
|
||||
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
|
||||
// Outcome has neither Err nor ImageID — pipeline silently swallowed
|
||||
// cloud-sync. flexsiebels needs the image_id; without it, fail the job.
|
||||
p := &fakePipeline{results: map[string]Outcome{"j1": {}}}
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
||||
_ = w.Run(ctx)
|
||||
|
||||
if got := q.state["j1"]; got != "failed" {
|
||||
t.Fatalf("state=%q want failed", got)
|
||||
}
|
||||
if q.last["j1"] == "" {
|
||||
t.Fatalf("expected non-empty error explanation for missing image_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_DrainsMultipleBeforeWaiting(t *testing.T) {
|
||||
q := newFakeQueue(
|
||||
Job{ID: "j1", Backend: "mock"},
|
||||
Job{ID: "j2", Backend: "mock"},
|
||||
Job{ID: "j3", Backend: "mock"},
|
||||
)
|
||||
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
|
||||
w := New(q, p, Config{PollInterval: 200 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() { time.Sleep(60 * time.Millisecond); cancel() }()
|
||||
_ = w.Run(ctx)
|
||||
|
||||
for _, id := range []string{"j1", "j2", "j3"} {
|
||||
if got := q.state[id]; got != "done" {
|
||||
t.Fatalf("%s state=%q want done", id, got)
|
||||
}
|
||||
}
|
||||
if q.done != 3 {
|
||||
t.Fatalf("done=%d want 3", q.done)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_NotifyWakesEarlierThanPoll(t *testing.T) {
|
||||
q := newFakeQueue()
|
||||
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
|
||||
// Set poll interval high so a working LISTEN is required to see the job
|
||||
// promptly. Without NOTIFY plumbing this test would time out the worker
|
||||
// before drain ever runs.
|
||||
w := New(q, p, Config{PollInterval: 5 * time.Second, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = w.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
// Append a job and ping the wake channel.
|
||||
q.mu.Lock()
|
||||
q.pending = append(q.pending, Job{ID: "late", Backend: "mock"})
|
||||
q.state["late"] = "pending"
|
||||
q.mu.Unlock()
|
||||
q.pingNotify()
|
||||
|
||||
// Give the worker a beat to claim + process.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
q.mu.Lock()
|
||||
s := q.state["late"]
|
||||
q.mu.Unlock()
|
||||
if s == "done" {
|
||||
cancel()
|
||||
<-done
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("worker did not pick up the late job within the 500ms window — NOTIFY wake-up path is broken")
|
||||
}
|
||||
|
||||
func TestWorker_HonoursContextCancellation(t *testing.T) {
|
||||
q := newFakeQueue()
|
||||
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
if err := w.Run(ctx); err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
if dur := time.Since(start); dur > 200*time.Millisecond {
|
||||
t.Fatalf("worker did not exit promptly on ctx cancel: %v", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
|
||||
q := newFakeQueue(Job{ID: "long", Backend: "mock"})
|
||||
p := &fakePipeline{
|
||||
results: map[string]Outcome{"long": {ImageID: "img-long"}},
|
||||
delay: 120 * time.Millisecond,
|
||||
}
|
||||
// Short JobTimeout would also kill the in-flight job; give it enough
|
||||
// budget so the test exercises the shutdown-during-job path.
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: 5 * time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
// Let the job start, then cancel mid-flight.
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
_ = w.Run(ctx)
|
||||
if got := q.state["long"]; got != "done" {
|
||||
t.Fatalf("state=%q want done (in-flight job should finish even on shutdown)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWorker_PropagatesSeriesIDToPipeline verifies the worker hands the
|
||||
// Job's SeriesID through to the pipeline unchanged. The pipeline owns the
|
||||
// cloud-sync side of the propagation (cloud.SyncRequest.SeriesID lands on
|
||||
// imagen.images.series_id) — see cloud_test.go for that half — so the
|
||||
// worker contract is simply: don't drop or rewrite SeriesID between
|
||||
// claim and Run.
|
||||
func TestWorker_PropagatesSeriesIDToPipeline(t *testing.T) {
|
||||
const seriesID = "11111111-1111-1111-1111-111111111111"
|
||||
q := newFakeQueue(Job{
|
||||
ID: "j-series",
|
||||
Prompt: "p",
|
||||
Backend: "mock",
|
||||
SeriesID: seriesID,
|
||||
})
|
||||
p := &fakePipeline{results: map[string]Outcome{"j-series": {ImageID: "img-series"}}}
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
||||
if err := w.Run(ctx); err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
if got := p.lastJob.SeriesID; got != seriesID {
|
||||
t.Fatalf("pipeline saw SeriesID=%q want %q", got, seriesID)
|
||||
}
|
||||
if got := q.state["j-series"]; got != "done" {
|
||||
t.Fatalf("state=%q want done", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWorker_SoloJobLeavesSeriesIDEmpty is the negative case — a job
|
||||
// claimed with no series row keeps the field empty all the way to the
|
||||
// pipeline so cloud-sync writes NULL into imagen.images.series_id.
|
||||
func TestWorker_SoloJobLeavesSeriesIDEmpty(t *testing.T) {
|
||||
q := newFakeQueue(Job{ID: "j-solo", Prompt: "p", Backend: "mock"})
|
||||
p := &fakePipeline{results: map[string]Outcome{"j-solo": {ImageID: "img-solo"}}}
|
||||
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
|
||||
_ = w.Run(ctx)
|
||||
if got := p.lastJob.SeriesID; got != "" {
|
||||
t.Fatalf("solo job pipeline.lastJob.SeriesID=%q want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
|
||||
// First claim returns an error; the loop should log and try again on the
|
||||
// next wake — it must not propagate the error and exit.
|
||||
q := newFakeQueue(Job{ID: "j1", Backend: "mock"})
|
||||
q.claimErr = fmt.Errorf("transient: connection reset")
|
||||
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img"}}}
|
||||
w := New(q, p, Config{PollInterval: 20 * time.Millisecond, JobTimeout: time.Second})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Heal the claim error after a beat so the second drain succeeds.
|
||||
go func() {
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
q.mu.Lock()
|
||||
q.claimErr = nil
|
||||
q.mu.Unlock()
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
if err := w.Run(ctx); err != nil {
|
||||
t.Fatalf("Run returned: %v (transient claim errors should not kill the loop)", err)
|
||||
}
|
||||
if got := q.state["j1"]; got != "done" {
|
||||
t.Fatalf("state=%q want done", got)
|
||||
}
|
||||
}
|
||||
22
scripts/imagen-worker.env.example
Normal file
22
scripts/imagen-worker.env.example
Normal file
@@ -0,0 +1,22 @@
|
||||
# Environment for the imagen-worker.service systemd unit.
|
||||
# Copy to ~/.dotfiles/.env.imagen-worker and fill in real values.
|
||||
# Never commit the populated file — it carries the Supabase service-role key.
|
||||
|
||||
# Direct Postgres DSN for LISTEN/NOTIFY + imagen.jobs UPDATE statements.
|
||||
# PostgREST cannot LISTEN, so the worker connects to Postgres directly.
|
||||
# Host + port + password come from the msupabase compose env on mlake.
|
||||
IMAGEN_WORKER_DATABASE_URL=postgres://postgres:CHANGE_ME@100.99.98.201:6789/postgres?sslmode=disable
|
||||
|
||||
# PostgREST endpoint for the imagen.images cloud-sync writer (same as
|
||||
# `imagen generate`'s cloud-sync code path).
|
||||
SUPABASE_URL=https://supa.flexsiebels.de
|
||||
SUPABASE_SERVICE_KEY=CHANGE_ME
|
||||
|
||||
# Default owner_user_id. Per-job owner from the imagen.jobs row overrides
|
||||
# this, so it's only used as a fallback when a job arrives with a NULL
|
||||
# owner_user_id — which the schema disallows. Keep it set for safety.
|
||||
IMAGEN_OWNER_USER_ID=ac6c9501-3757-4a6d-8b97-2cff4288382b
|
||||
|
||||
# Optional: REPLICATE_API_TOKEN if any imagen.jobs.backend may resolve to
|
||||
# a Replicate adapter instance.
|
||||
# REPLICATE_API_TOKEN=CHANGE_ME
|
||||
19
scripts/imagen-worker.service
Normal file
19
scripts/imagen-worker.service
Normal file
@@ -0,0 +1,19 @@
|
||||
[Unit]
|
||||
Description=ImaGen worker (consumes imagen.jobs queue)
|
||||
Documentation=https://mgit.msbls.de/m/ImaGen/issues/8
|
||||
Wants=network-online.target
|
||||
After=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=%h/dev/ImaGen/bin/imagen worker
|
||||
WorkingDirectory=%h/dev/ImaGen
|
||||
EnvironmentFile=%h/.dotfiles/.env.imagen-worker
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
# Give the worker time to finish an in-flight generation on shutdown
|
||||
# (FLUX dev up to ~30s, plus the cloud-sync write-back).
|
||||
TimeoutStopSec=60
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
Reference in New Issue
Block a user