15 Commits

Author SHA1 Message Date
mAi
c2b6f8bf97 Merge mai/hades/styles-expansion: add 6 style presets (cross-coord with flexsiebels) 2026-05-11 21:15:34 +02:00
mAi
f8dd5e0736 mAi: add 6 style presets — cinematic, watercolor, anime, 3d-render, line-art, isometric
Cross-coordination with flexsiebels/head (paul). m wants more style options
on /imagine/new; flexsiebels has the UI side ready to bump IMAGEN_STYLES
in lib/server/imagen.ts + schemas.ts as soon as the worker accepts them.

styles.yaml: 6 new entries with FLUX-friendly prompt fragments. No code
changes — Apply() and Styles() consume the embedded YAML directly, the
"enum" is dynamic.

prompt_test.go: extend TestStylesContainsAllExpected expectation list
(alphabetical, '3' < 'a' so 3d-render leads).

Total enum: 11 (5 existing + 6 new). flexsiebels delegation message 1669.
2026-05-11 21:15:23 +02:00
mAi
7caf975335 Merge mai/hermes/issue-10-multi-model: multi-model backend expansion + compare harness (#10) 2026-05-11 17:32:36 +02:00
mAi
8435817ce1 mAi: #10 - multi-model backend expansion (workflow templates + compare harness)
Path 1 architecture: one comfyui adapter, workflows as data.

- workflow_template.go: embed.FS + token substitution with type-preserving
  whole-value placeholders. ${prompt} → string, ${seed} → int64,
  ${cfg} → float64 — no JSON round-tripping. Partial matches ignored.
- comfyui.go: refactored to load workflow from embedded FS or filesystem
  path. Back-compat preserved: workflow: defaults to flux1-schnell.
- workflows/{flux1-schnell,flux2-klein,sd35-medium}.json — bundled
  templates. flux1-schnell migrated from hardcoded with identical node IDs.
- compare.go: new `imagen compare` subcommand. Sequential N-backend run
  (one GPU on mRock — parallel would OOM), per-backend PNG, sidecar JSON
  with per-model metadata + errors, composite contact sheet via Go image
  package (no ImageMagick dep).
- Sample config gains flux2-klein-local + sd35-medium-local instances.
- docs/backends.md: architecture rationale + per-model HF download paths
  + how to add a new bundled workflow + compare-harness reference.

Live smoke verified: compare mock + flux-schnell-local at 768×768 →
both PNGs written, sidecar JSON has workflow="flux1-schnell" + full
metadata, contact sheet renders. Worker contract (Request → Generate)
unchanged, so flexsiebels /imagine UI API surface preserved.

Tests: 11 existing comfyui + 6 new workflow_template + 5 new compare
tests, all green.

Adding a new model is now yaml + JSON, never Go.
2026-05-11 17:29:57 +02:00
mAi
623dd290c5 Merge mai/hermes/issue-9-imagen-9-imagen: imagen.series + series_id propagation (#9) 2026-05-11 10:50:54 +02:00
mAi
64120c27d7 mAi: #9 - imagen.series (batch tries 1-10 + selection)
Schema (applied via migration imagen_series_init):
- imagen.series parent table (prompt + params + count CHECK 1..10 + selected_image_id)
- imagen.jobs += series_id (FK) + series_idx
- imagen.images += series_id (FK)
- Owner-scoped RLS on series (SELECT/INSERT/UPDATE) + grants
- Partial indexes WHERE series_id IS NOT NULL on both child tables

Worker pipeline:
- worker.Job += SeriesID, populated from imagen.jobs.series_id via the
  claim query.
- cloud.SyncRequest += SeriesID; insertRow writes series_id when non-empty,
  omits the key when empty so solo runs leave the column NULL.
- maybeCloudSync threads seriesID from job.SeriesID through to the cloud
  sink. generate.go (CLI) always passes "" — solo path unchanged.

Tests:
- worker: SeriesID propagates from Job to fakePipeline.lastJob unchanged,
  solo job keeps it empty.
- cloud: SyncRequest.SeriesID lands as row.series_id in the POST body;
  empty SeriesID omits the key entirely.

Refs ImaGen#9.
2026-05-11 10:48:12 +02:00
mAi
dbe1704f42 Merge mai/hermes/issue-8-imagen-8-imagen: jobs queue + worker subcommand (#8) 2026-05-11 10:24:34 +02:00
mAi
2758c5a500 mAi: #8 - imagen.jobs queue + worker subcommand (flexsiebels write path)
Async write path for the flexsiebels owner-mode UI: flexsiebels INSERTs into
imagen.jobs, the worker on mRiver claims pending rows via LISTEN/NOTIFY +
5s safety poll, runs the same generate pipeline imagen generate uses, and
writes the result through internal/cloud into imagen.images.

- Schema migration imagen_jobs_init: table + status CHECK + two indexes +
  owner-scoped RLS + grants + AFTER INSERT trigger publishing on the
  imagen_jobs channel via pg_notify.
- internal/worker: DB-agnostic loop over a Queue interface. Drains the
  whole pending backlog on each wake. Job-scoped contexts are derived
  from Background so SIGTERM lets the in-flight generation finish (no
  half-state). ResetStaleRunning at startup unsticks rows left over from
  a previous crash. Eight unit tests cover the done / failed / missing-id /
  drain / NOTIFY-wake / shutdown / transient-error paths against a fake
  queue (no real Postgres in CI).
- cmd/imagen/worker.go: pgx-backed Queue (one dedicated conn for LISTEN +
  UPDATE), plus the workerPipeline that reuses buildBackend +
  attachUsageSink + prompt.Apply + buildWriter + maybeCloudSync. The
  per-job owner_user_id overrides the env-level fallback so each row in
  imagen.images is attributed correctly.
- maybeCloudSync now returns (*cloud.SyncResult, error) so the worker can
  link imagen.jobs.image_id to the inserted imagen.images row. The CLI
  generate path keeps printing its stderr summary unchanged.
- scripts/imagen-worker.service + .env.example for the systemd --user unit
  on mRiver. EnvironmentFile lives in ~/.dotfiles and is never committed.
- docs/setup-worker-mriver.md walks through installation + the spec's
  SQL-INSERT smoke; docs/architecture.md grows an "async write path"
  section.
- worker_integration_test.go (env-guarded by IMAGEN_WORKER_INTEGRATION=1)
  drives one real job through the full pipeline against msupabase using
  the mock backend, then verifies imagen.images + Storage object landed
  and the row flipped to done with image_id linked. Verified end-to-end:
  pickup latency ~7ms, total 74ms, failure path captures error text.
2026-05-11 10:23:33 +02:00
mAi
cb6656c436 Merge mai/hermes/issue-7-imagen-7-cloud: Supabase cloud-sync for flexsiebels viewer (#7) 2026-05-11 01:53:12 +02:00
mAi
e22f286024 mAi: #7 - cloud-sync to Supabase Storage + imagen.images
Every successful imagen generate now (a) uploads the PNG to the private
imagen-generated bucket and (b) inserts a row into imagen.images, the
data plane the flexsiebels owner-mode viewer reads from.

Schema, RLS, indexes, bucket and PostgREST exposure landed via four
applied migrations on msupabase: imagen_schema_init,
imagen_schema_grants, imagen_storage_policies, imagen_pgrst_expose
(authenticator role-level ALTER + reload). Owner UUID for m:
ac6c9501-3757-4a6d-8b97-2cff4288382b — documented in the config sample.

Code: new internal/cloud/ package mirroring the internal/usage/ shape.
PostgREST POST against the imagen schema (Accept-Profile + Content-
Profile headers), Storage upload via PUT with x-upsert, retry on 5xx /
transport but not 4xx, owner_user_id required (the column is NOT NULL
and the read-side RLS policy needs it).

Wiring in cmd/imagen/generate.go: --no-cloud flag, output.cloud_sync
config knob (auto|on|off mirroring --preview), $IMAGEN_CLOUD_SYNC env
override. The hook reads the just-written PNG + sidecar from disk and
calls cloud.Sync; failures emit "imagen: cloud sync: <err>" to stderr
without changing exit code, so a Supabase blip never loses the artefact.
output.Outputs grew Date/Slug/Seed fields so storage_path mirrors the
local filename's prefix exactly (no UTC-vs-local drift).

Config: owner_user_id field added; sample comment points at the
auth.users lookup. imagen config validate warns on stderr when
cloud_sync is on/auto but owner_user_id is empty.

Tests: cloud_test.go covers happy path, retry-on-5xx, no-retry-on-4xx,
missing-owner-uuid, missing-date-or-slug, signed URL, and the partial-
success case where the upload landed but the DB insert failed.
generate_test.go covers the precedence chain for cloud-sync mode
resolution. Build + tests clean across the tree.

Real smoke against mRock: generation through flux-schnell-local writes
the local PNG + sidecar AND uploads to imagen-generated/2026-05-11/...
AND inserts into imagen.images. Signed URL round-trips the same bytes.
--no-cloud verified to skip both Storage and DB.
2026-05-11 01:51:09 +02:00
mAi
2d5896e27d Merge mai/hermes/issue-3-imagen-3: Replicate API backend + cost-tracking + usage CLI (#3) 2026-05-08 17:32:09 +02:00
mAi
b282325663 mAi: #3 - Replicate adapter, mai.imagen_usage cost-tracking, usage CLI
Implements the Replicate API backend (FLUX schnell / FLUX dev) per ImaGen
issue #3:

- internal/backend/replicate.go — Backend adapter. Supports model
  refs as "owner/name" (uses /v1/models/{owner}/{name}/predictions) and
  "owner/name:hash" (uses /v1/predictions with explicit version). Polls
  /v1/predictions/{id} every 500ms with model-aware timeout (60s schnell,
  120s dev). Resilience: 401 names api_token_env, 429 with exp backoff
  up to 3 retries (honours Retry-After), 5xx retries once, image
  download retries once on transient failure.
- internal/backend/replicate_pricing.go — hardcoded per-image USD rates
  for known FLUX models, snapshotted from replicate.com/pricing with a
  refresh TODO.
- internal/backend/replicate_test.go — mocked-HTTP unit tests covering
  happy path (model + version-pinned), 401, 429 retry policy, failed
  prediction, poll timeout, image-download retry, ctx cancel, BackendOpts
  passthrough, default_steps, aspect-ratio reduction, sha256 prompt hash.
- internal/usage/usage.go — Supabase REST sink + read-side query for
  mai.imagen_usage. Adapter writes are best-effort: failures warn but
  the image still lands.
- cmd/imagen/usage.go — `imagen usage [--since DATE] [--raw]` reads
  the table and prints a tab-aligned grouped or raw table with totals.
- cmd/imagen/backends.go — instances of type=replicate now report
  "ok" or "not configured (set REPLICATE_API_TOKEN)" depending on env.
- internal/config/config.go — sample adds flux-schnell-replicate +
  flux-dev-replicate; default_backend stays flux-schnell-local.
- Supabase migration mai.imagen_usage (id, created_at, backend, model,
  seed, prompt_hash, latency_ms, cost_usd_estimate, caller) + indexes
  on (created_at DESC) and (caller). The raw prompt is never stored.

Caller identity resolves from MAI_FROM_ID, then the tmux pane's
@mai-name option, mirroring the maimcp identity logic. Prompt hash is
sha256 of the user-facing prompt; raw prompt never reaches the table.
2026-05-08 17:28:29 +02:00
mAi
a1d0165445 Merge mai/hermes/issue-5-imagen-5-tmux: tmux-window preview for generate (#5) 2026-05-08 17:12:57 +02:00
mAi
2a8bd4313b mAi: #5 - tmux-window preview for generate
Adds an optional `imagen generate` post-step that opens a sibling
tmux window running tmux-img --hold <path>.

- internal/preview: Mode (auto|on|off), Resolve, and a Spawner that
  shells out to tmux new-window. Typed errors for missing tmux,
  missing tmux-img, and "preview forced on outside $TMUX".
- cmd/imagen/generate: --preview / --no-preview flags plus
  $IMAGEN_PREVIEW. Resolution chain: config -> env -> flag.
  auto requires both stdout-is-tty and $TMUX. Failures are
  warnings - the image is already on disk.
- internal/config: output.preview field, validated to auto|on|off,
  threaded into the sample.
- Tests for ParseMode, Resolve, Spawn argv (incl. shell quoting of
  paths with apostrophes), missing-binary errors, and the CLI
  resolution table.
- Docs (usage + architecture) updated.

/imagine SKILL.md edit lives in dotfiles - deferred to coordinate
with #4.
2026-05-08 17:09:59 +02:00
mAi
4183d4c55a Merge mai/hermes/issue-2-imagen-2-comfyui: ComfyUI/FLUX schnell on mRock + Go adapter (#2) 2026-05-08 17:01:02 +02:00
41 changed files with 6141 additions and 147 deletions

View File

@@ -10,13 +10,17 @@ and lifecycle of its own block in `~/.config/imagen.yaml`.
## Architecture
```
cmd/imagen/ CLI shell — generate, backends, config, serve
cmd/imagen/ CLI shell — generate, worker, backends, config, serve
internal/backend/ Backend interface + Registry + Mock reference impl
internal/prompt/ Style preset registry (embedded styles.yaml)
internal/output/ Filename templating, image writer, JSON sidecar
internal/config/ YAML loader, validation, sample generator
internal/cloud/ Supabase Storage + imagen.images writer
internal/usage/ mai.imagen_usage cost-tracking sink
internal/worker/ imagen.jobs queue consumer (DB-agnostic via Queue interface)
internal/server/ HTTP stub (not implemented yet — follow-up issue)
docs/ architecture.md, usage.md
scripts/ imagen-worker.service + env template, ComfyUI scripts
docs/ architecture.md, usage.md, setup-worker-mriver.md
```
Data flow for `imagen generate`:

View File

@@ -10,6 +10,27 @@ import (
"mgit.msbls.de/m/ImaGen/internal/config"
)
// instanceStatus checks adapter-specific preconditions (e.g. the
// Replicate API token env var being set) and returns a short
// user-facing status string.
func instanceStatus(spec config.BackendSpec) string {
if !backend.Default.Has(spec.Type) {
return fmt.Sprintf("type %q not compiled in", spec.Type)
}
switch spec.Type {
case backend.ReplicateType:
envName, _ := spec.Raw["api_token_env"].(string)
if envName == "" {
envName = "REPLICATE_API_TOKEN"
}
if os.Getenv(envName) == "" {
return fmt.Sprintf("not configured (set %s)", envName)
}
return "ok"
}
return "registered"
}
func runBackends(args []string) error {
fs := flag.NewFlagSet("backends", flag.ContinueOnError)
var configPath string
@@ -27,10 +48,7 @@ func runBackends(args []string) error {
fmt.Fprintln(tw, "INSTANCE\tTYPE\tSTATUS")
if cfg != nil {
for name, spec := range cfg.Backends {
status := "registered"
if !backend.Default.Has(spec.Type) {
status = fmt.Sprintf("type %q not compiled in", spec.Type)
}
status := instanceStatus(spec)
marker := ""
if name == cfg.DefaultBackend {
marker = " (default)"

386
cmd/imagen/compare.go Normal file
View File

@@ -0,0 +1,386 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"image"
"image/color"
"image/draw"
"image/png"
"io"
"os"
"path/filepath"
"sort"
"strings"
"time"
"golang.org/x/image/font"
"golang.org/x/image/font/basicfont"
"golang.org/x/image/math/fixed"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/prompt"
)
// runCompare implements `imagen compare "<prompt>" --models a,b,c --output <dir>`.
//
// Each backend in --models runs sequentially against the same prompt (mRock
// has a single GPU; parallelising would just OOM). Each generation lands as
// a backend-suffixed file in the output dir; a contact sheet stitches them
// together into one PNG with the backend name overlaid on each cell. A
// sidecar JSON next to the contact sheet lists every generation with its
// per-model metadata (latency, seed, model file, VRAM peak).
func runCompare(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("compare", flag.ContinueOnError)
var (
modelsCSV string
size string
outDir string
style string
negative string
seed int64
steps int
configPath string
noContact bool
)
fs.StringVar(&modelsCSV, "models", "", "comma-separated backend instance names (required)")
fs.StringVar(&size, "size", "1024x1024", "WxH for every backend")
fs.StringVar(&outDir, "output", "", "directory to write the images + contact sheet (default: ~/Pictures/imagen/compare)")
fs.StringVar(&style, "style", "", "style preset applied to the prompt before dispatching to each backend")
fs.StringVar(&negative, "negative", "", "negative prompt (forwarded to every backend that supports it)")
fs.Int64Var(&seed, "seed", 0, "deterministic seed for every backend (0 = each backend rolls its own)")
fs.IntVar(&steps, "steps", 0, "diffusion steps (0 = each backend's default)")
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.BoolVar(&noContact, "no-contact-sheet", false, "skip the composite PNG; only write per-backend images + sidecar")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen compare "<prompt>" --models a,b,c [flags]`)
fs.PrintDefaults()
}
leadingPositional, flagArgs := splitLeadingPositional(args)
if err := fs.Parse(flagArgs); err != nil {
return err
}
positional := append(leadingPositional, fs.Args()...)
if len(positional) == 0 {
fs.Usage()
return userErr("missing prompt")
}
rawPrompt := strings.Join(positional, " ")
modelNames := splitCSV(modelsCSV)
if len(modelNames) == 0 {
return userErr("--models is required (comma-separated backend instance names)")
}
w, h, err := parseSize(size)
if err != nil {
return userErr("bad --size: %v", err)
}
cfg, cfgErr := config.Load(configPath)
if cfgErr != nil && !os.IsNotExist(cfgErr) {
return cfgErr
}
if outDir == "" {
home, _ := os.UserHomeDir()
outDir = filepath.Join(home, "Pictures", "imagen", "compare")
}
outDir = config.ExpandPath(outDir)
finalPrompt, err := prompt.Apply(rawPrompt, style)
if err != nil {
return userErr("%v", err)
}
runID := time.Now().Format("20060102-150405")
runDir := filepath.Join(outDir, runID+"-"+output.Slug(rawPrompt))
if err := os.MkdirAll(runDir, 0o755); err != nil {
return fmt.Errorf("mkdir %s: %w", runDir, err)
}
results := make([]compareResult, 0, len(modelNames))
for i, name := range modelNames {
fmt.Fprintf(os.Stderr, "[%d/%d] %s ...\n", i+1, len(modelNames), name)
res, err := generateOne(ctx, cfg, name, finalPrompt, negative, w, h, seed, steps, runDir, rawPrompt)
if err != nil {
// Don't abort the whole run on a single backend failure — record
// the error and continue. flexsiebels-style consumers want to
// see N-1 results rather than zero when one model is offline.
fmt.Fprintf(os.Stderr, " failed: %v\n", err)
results = append(results, compareResult{Backend: name, Error: err.Error()})
continue
}
fmt.Fprintf(os.Stderr, " %s (%d ms)\n", res.ImagePath, res.LatencyMs)
results = append(results, res)
}
// Sidecar JSON beside the run dir captures every attempt.
sidecar := filepath.Join(runDir, "compare.json")
if err := writeCompareSidecar(sidecar, rawPrompt, style, negative, w, h, seed, steps, results); err != nil {
return err
}
fmt.Fprintln(os.Stderr, "sidecar:", sidecar)
// Contact sheet stitches the successful results together. If every
// backend failed there's nothing to draw, so skip silently.
if !noContact {
successes := successfulResults(results)
if len(successes) > 0 {
sheet := filepath.Join(runDir, "contact-sheet.png")
if err := writeContactSheet(sheet, rawPrompt, successes); err != nil {
return fmt.Errorf("contact sheet: %w", err)
}
fmt.Println(sheet)
} else {
fmt.Fprintln(os.Stderr, "imagen compare: all backends failed; no contact sheet written")
}
}
return nil
}
// compareResult is one backend's output in a comparison run. Error is set
// when Generate failed for this backend; ImagePath + Metadata are empty in
// that case.
type compareResult struct {
Backend string `json:"backend"`
ImagePath string `json:"image_path,omitempty"`
Seed int64 `json:"seed"`
LatencyMs int64 `json:"latency_ms,omitempty"`
Model string `json:"model,omitempty"`
VRAMUsedMiB int64 `json:"vram_used_mib,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Error string `json:"error,omitempty"`
}
func generateOne(ctx context.Context, cfg *config.Config, name, finalPrompt, negative string, w, h int, seed int64, steps int, runDir, rawPrompt string) (compareResult, error) {
be, err := buildBackend(cfg, name)
if err != nil {
return compareResult{Backend: name}, err
}
attachUsageSink(be)
req := backend.Request{
Prompt: finalPrompt,
NegativePrompt: negative,
Width: w,
Height: h,
Steps: steps,
Seed: seed,
}
res, err := be.Generate(ctx, req)
if err != nil {
return compareResult{Backend: name}, err
}
defer res.ImageReader.Close()
imgBytes, err := io.ReadAll(res.ImageReader)
if err != nil {
return compareResult{Backend: name}, fmt.Errorf("read image: %w", err)
}
imgPath := filepath.Join(runDir, output.Slug(rawPrompt)+"--"+output.Slug(name)+"."+extFromMime(res.MimeType))
if err := os.WriteFile(imgPath, imgBytes, 0o644); err != nil {
return compareResult{Backend: name}, fmt.Errorf("write %s: %w", imgPath, err)
}
cr := compareResult{
Backend: name,
ImagePath: imgPath,
Seed: seedFromMetadata(res.Metadata, seed),
LatencyMs: metaInt64(res.Metadata, "latency_ms"),
Model: metaString(res.Metadata, "model"),
Metadata: res.Metadata,
}
if v, ok := res.Metadata["vram_used_mib"].(int64); ok {
cr.VRAMUsedMiB = v
}
return cr, nil
}
func successfulResults(rs []compareResult) []compareResult {
out := make([]compareResult, 0, len(rs))
for _, r := range rs {
if r.Error == "" && r.ImagePath != "" {
out = append(out, r)
}
}
return out
}
func writeCompareSidecar(path, rawPrompt, style, negative string, w, h int, seed int64, steps int, results []compareResult) error {
body := map[string]any{
"timestamp": time.Now().UTC().Format(time.RFC3339),
"prompt": rawPrompt,
"style": style,
"negative": negative,
"width": w,
"height": h,
"seed": seed,
"steps": steps,
"results": results,
"backends": backendNames(results),
"successful": len(successfulResults(results)),
"total": len(results),
}
data, err := json.MarshalIndent(body, "", " ")
if err != nil {
return fmt.Errorf("marshal sidecar: %w", err)
}
return os.WriteFile(path, append(data, '\n'), 0o644)
}
func backendNames(rs []compareResult) []string {
out := make([]string, len(rs))
for i, r := range rs {
out[i] = r.Backend
}
sort.Strings(out)
return out
}
// writeContactSheet stitches a grid of (image, label) cells into one PNG.
// Cells are sized to fit in a target width of ~2400px while keeping each
// individual image full-resolution (no downscale) up to the column limit;
// past that, images sit at their native size and we just lay them out.
//
// The grid is a simple horizontal row when N <= 4; otherwise N/2 rows of 2.
// This is a contact sheet, not a fancy gallery — readability for side-by-
// side eyeballing is the goal.
func writeContactSheet(path, prompt string, results []compareResult) error {
if len(results) == 0 {
return fmt.Errorf("no successful results to lay out")
}
cells := make([]contactCell, 0, len(results))
for _, r := range results {
img, err := readPNG(r.ImagePath)
if err != nil {
return fmt.Errorf("read %s: %w", r.ImagePath, err)
}
cells = append(cells, contactCell{
Image: img,
Label: r.Backend,
SubLabel: fmt.Sprintf("%dms · seed %d", r.LatencyMs, r.Seed),
})
}
cols := len(cells)
if cols > 4 {
cols = 2
}
rows := (len(cells) + cols - 1) / cols
const labelH = 64
const pad = 16
cellW := cells[0].Image.Bounds().Dx()
cellH := cells[0].Image.Bounds().Dy()
for _, c := range cells {
if w := c.Image.Bounds().Dx(); w > cellW {
cellW = w
}
if h := c.Image.Bounds().Dy(); h > cellH {
cellH = h
}
}
totalW := cols*cellW + (cols+1)*pad
totalH := rows*(cellH+labelH) + (rows+1)*pad + 48 // header band
canvas := image.NewRGBA(image.Rect(0, 0, totalW, totalH))
draw.Draw(canvas, canvas.Bounds(), &image.Uniform{C: color.RGBA{R: 30, G: 30, B: 35, A: 255}}, image.Point{}, draw.Src)
// Header: show the truncated prompt.
headerText := "imagen compare — " + truncate(prompt, 100)
drawText(canvas, headerText, pad, 30, color.RGBA{R: 240, G: 240, B: 245, A: 255})
for i, c := range cells {
col := i % cols
row := i / cols
x0 := pad + col*(cellW+pad)
y0 := 48 + pad + row*(cellH+labelH+pad)
// Center the image inside the cell when smaller than the max cell size.
iw := c.Image.Bounds().Dx()
ih := c.Image.Bounds().Dy()
offX := (cellW - iw) / 2
offY := (cellH - ih) / 2
dstRect := image.Rect(x0+offX, y0+offY, x0+offX+iw, y0+offY+ih)
draw.Draw(canvas, dstRect, c.Image, c.Image.Bounds().Min, draw.Src)
// Label band underneath.
labelY := y0 + cellH + 20
drawText(canvas, c.Label, x0+8, labelY, color.RGBA{R: 250, G: 250, B: 250, A: 255})
drawText(canvas, c.SubLabel, x0+8, labelY+22, color.RGBA{R: 180, G: 180, B: 190, A: 255})
}
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
return png.Encode(f, canvas)
}
type contactCell struct {
Image image.Image
Label string
SubLabel string
}
func readPNG(path string) (image.Image, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
img, _, err := image.Decode(f)
return img, err
}
func drawText(dst *image.RGBA, s string, x, y int, c color.Color) {
drawer := &font.Drawer{
Dst: dst,
Src: &image.Uniform{C: c},
Face: basicfont.Face7x13,
Dot: fixed.Point26_6{X: fixed.I(x), Y: fixed.I(y)},
}
drawer.DrawString(s)
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max-1] + "…"
}
func splitCSV(s string) []string {
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func metaInt64(m map[string]any, key string) int64 {
v, ok := m[key]
if !ok {
return 0
}
switch n := v.(type) {
case int64:
return n
case int:
return int64(n)
case float64:
return int64(n)
}
return 0
}

203
cmd/imagen/compare_test.go Normal file
View File

@@ -0,0 +1,203 @@
package main
import (
"bytes"
"context"
"encoding/json"
"image/png"
"os"
"path/filepath"
"strings"
"testing"
)
// runCompareWithEnv runs the compare subcommand in a writable tmpdir, with
// XDG_CONFIG_HOME pointing somewhere empty so no host imagen.yaml leaks in.
func runCompareWithEnv(t *testing.T, args []string) (stderr, stdout *bytes.Buffer, runDir string, err error) {
t.Helper()
tmp := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", filepath.Join(tmp, "no-config"))
t.Setenv("HOME", tmp)
out := filepath.Join(tmp, "compare")
// stdlib flag parsing requires flags after the leading positional. Append
// --output at the end so any caller-supplied flags still parse cleanly.
args = append(args, "--output", out)
// Capture stdout/stderr via os pipes since runCompare writes directly.
oldStdout := os.Stdout
oldStderr := os.Stderr
rOut, wOut, _ := os.Pipe()
rErr, wErr, _ := os.Pipe()
os.Stdout = wOut
os.Stderr = wErr
defer func() {
os.Stdout = oldStdout
os.Stderr = oldStderr
}()
cmdErr := runCompare(context.Background(), args)
_ = wOut.Close()
_ = wErr.Close()
stdout = &bytes.Buffer{}
stderr = &bytes.Buffer{}
_, _ = stdout.ReadFrom(rOut)
_, _ = stderr.ReadFrom(rErr)
entries, _ := os.ReadDir(out)
if len(entries) == 1 {
runDir = filepath.Join(out, entries[0].Name())
}
return stderr, stdout, runDir, cmdErr
}
func TestCompareHappyPathWithMockBackends(t *testing.T) {
// Two mock instances stand in for two different backends. mock ignores
// cfg so we can reuse the registered type as the instance name and skip
// writing imagen.yaml entirely.
stderr, stdout, runDir, err := runCompareWithEnv(t, []string{
"a cat in a fishbowl",
"--models", "mock,mock",
"--size", "64x64",
"--seed", "42",
})
if err != nil {
t.Fatalf("runCompare: %v\nstderr: %s", err, stderr.String())
}
if runDir == "" {
t.Fatal("expected a run directory under --output")
}
// Sidecar JSON
sidecar := filepath.Join(runDir, "compare.json")
data, err := os.ReadFile(sidecar)
if err != nil {
t.Fatalf("read sidecar: %v", err)
}
var body struct {
Prompt string `json:"prompt"`
Successful int `json:"successful"`
Total int `json:"total"`
Results []struct {
Backend string `json:"backend"`
ImagePath string `json:"image_path"`
Error string `json:"error"`
} `json:"results"`
}
if err := json.Unmarshal(data, &body); err != nil {
t.Fatalf("parse sidecar: %v\n%s", err, data)
}
if body.Prompt != "a cat in a fishbowl" {
t.Errorf("prompt = %q", body.Prompt)
}
if body.Total != 2 || body.Successful != 2 {
t.Errorf("counts = %d successful / %d total", body.Successful, body.Total)
}
for _, r := range body.Results {
if r.Error != "" {
t.Errorf("backend %s errored: %s", r.Backend, r.Error)
}
if _, err := os.Stat(r.ImagePath); err != nil {
t.Errorf("image not on disk for %s: %v", r.Backend, err)
}
}
// Contact sheet path was printed on stdout.
sheet := strings.TrimSpace(stdout.String())
if sheet == "" {
t.Fatal("expected contact sheet path on stdout")
}
f, err := os.Open(sheet)
if err != nil {
t.Fatalf("open contact sheet: %v", err)
}
defer f.Close()
img, err := png.Decode(f)
if err != nil {
t.Fatalf("decode contact sheet PNG: %v", err)
}
if w := img.Bounds().Dx(); w < 100 {
t.Errorf("contact sheet looks empty (width %d)", w)
}
}
func TestCompareSkipContactSheet(t *testing.T) {
stderr, stdout, runDir, err := runCompareWithEnv(t, []string{
"x",
"--models", "mock",
"--size", "32x32",
"--seed", "1",
"--no-contact-sheet",
})
if err != nil {
t.Fatalf("runCompare: %v\nstderr: %s", err, stderr.String())
}
if got := strings.TrimSpace(stdout.String()); got != "" {
t.Errorf("expected no stdout output (no contact sheet), got %q", got)
}
if _, err := os.Stat(filepath.Join(runDir, "contact-sheet.png")); err == nil {
t.Errorf("contact-sheet.png should not exist with --no-contact-sheet")
}
}
func TestCompareRecordsBackendErrors(t *testing.T) {
// One real (mock) + one unknown. Unknown should fail but not abort the
// run — sidecar records both, contact sheet built from successes only.
stderr, _, runDir, err := runCompareWithEnv(t, []string{
"y",
"--models", "mock,this-instance-does-not-exist",
"--size", "32x32",
})
if err != nil {
t.Fatalf("runCompare: %v\nstderr: %s", err, stderr.String())
}
sidecar := filepath.Join(runDir, "compare.json")
data, _ := os.ReadFile(sidecar)
var body struct {
Successful int `json:"successful"`
Total int `json:"total"`
Results []struct {
Backend string `json:"backend"`
Error string `json:"error"`
} `json:"results"`
}
if err := json.Unmarshal(data, &body); err != nil {
t.Fatalf("parse sidecar: %v", err)
}
if body.Total != 2 {
t.Errorf("expected 2 results, got %d", body.Total)
}
if body.Successful != 1 {
t.Errorf("expected 1 success, got %d", body.Successful)
}
var sawError bool
for _, r := range body.Results {
if r.Backend == "this-instance-does-not-exist" && r.Error != "" {
sawError = true
}
}
if !sawError {
t.Errorf("expected an error recorded for the unknown backend")
}
}
func TestCompareNoModelsFails(t *testing.T) {
_, _, _, err := runCompareWithEnv(t, []string{"x"})
if err == nil {
t.Fatal("expected error when --models is empty")
}
if !strings.Contains(err.Error(), "--models") {
t.Errorf("error should mention --models, got %v", err)
}
}
func TestCompareNoPromptFails(t *testing.T) {
_, _, _, err := runCompareWithEnv(t, []string{"--models", "mock"})
if err == nil {
t.Fatal("expected error when prompt is missing")
}
if !strings.Contains(err.Error(), "missing prompt") {
t.Errorf("error should mention missing prompt, got %v", err)
}
}

View File

@@ -39,6 +39,18 @@ func runConfig(args []string) error {
}
fmt.Fprintf(os.Stdout, "OK — %d backend(s) defined, default=%q\n",
len(cfg.Backends), cfg.DefaultBackend)
// Soft warnings — surfaced on stderr so they're visible but don't
// fail the validate exit code.
cloudMode := cfg.Output.CloudSync
if cloudMode == "" {
cloudMode = "auto"
}
if cloudMode != "off" && cfg.OwnerUserID == "" {
fmt.Fprintln(os.Stderr,
"warning: cloud_sync is "+cloudMode+" but owner_user_id is empty — DB inserts will be skipped.")
fmt.Fprintln(os.Stderr,
" look it up: SELECT id FROM auth.users WHERE email = '<your-supabase-email>';")
}
return nil
default:
return userErr("unknown config subcommand %q (init|validate|path)", args[0])

View File

@@ -2,30 +2,39 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/cloud"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/preview"
"mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
func runGenerate(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("generate", flag.ContinueOnError)
var (
backendName string
size string
outPath string
seed int64
steps int
style string
negative string
configPath string
noSidecar bool
backendName string
size string
outPath string
seed int64
steps int
style string
negative string
configPath string
noSidecar bool
previewOn bool
previewOff bool
noCloud bool
)
fs.StringVar(&backendName, "backend", "", "backend instance name (default: config.default_backend)")
fs.StringVar(&size, "size", "1024x1024", "WxH, e.g. 1024x1024")
@@ -36,6 +45,9 @@ func runGenerate(ctx context.Context, args []string) error {
fs.StringVar(&negative, "negative", "", "negative prompt (ignored by backends that don't support it)")
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.BoolVar(&noSidecar, "no-sidecar", false, "skip the JSON sidecar even if config enables it")
fs.BoolVar(&previewOn, "preview", false, "force tmux preview window on (errors outside $TMUX)")
fs.BoolVar(&previewOff, "no-preview", false, "skip the tmux preview window")
fs.BoolVar(&noCloud, "no-cloud", false, "skip Supabase upload + imagen.images insert for this generation")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen generate "<prompt>" [flags]`)
fs.PrintDefaults()
@@ -76,6 +88,7 @@ func runGenerate(ctx context.Context, args []string) error {
if err != nil {
return err
}
attachUsageSink(be)
finalPrompt, err := prompt.Apply(rawPrompt, style)
if err != nil {
@@ -118,9 +131,249 @@ func runGenerate(ctx context.Context, args []string) error {
if paths.SidecarPath != "" {
fmt.Fprintln(os.Stderr, "sidecar:", paths.SidecarPath)
}
if result, err := maybeCloudSync(ctx, cfg, noCloud, "", "", paths, in, res, w, h); err != nil {
// cloud-sync failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: cloud sync:", err)
} else if result != nil && result.ImageID != "" {
fmt.Fprintf(os.Stderr, "cloud: imagen.images.id=%s storage_path=%s\n", result.ImageID, result.StoragePath)
}
if err := maybePreview(cfg, previewOn, previewOff, paths.ImagePath, rawPrompt); err != nil {
// preview failures are warnings — the image already wrote.
fmt.Fprintln(os.Stderr, "imagen: preview:", err)
}
return nil
}
// resolveCloudSyncMode applies the precedence chain config -> env -> flag.
// Flags win, env beats config, config beats the implicit auto default.
// Mirrors resolvePreviewMode shape.
func resolveCloudSyncMode(cfg *config.Config, noCloudFlag bool, env string) (string, error) {
mode := "auto"
if cfg != nil && cfg.Output.CloudSync != "" {
mode = cfg.Output.CloudSync
}
if env != "" {
switch env {
case "auto", "on", "off":
mode = env
default:
return "", fmt.Errorf("$IMAGEN_CLOUD_SYNC = %q (must be auto|on|off)", env)
}
}
if noCloudFlag {
mode = "off"
}
return mode, nil
}
// maybeCloudSync resolves the effective mode and, if it says yes, uploads
// the PNG and inserts the row. Returns the SyncResult on success so callers
// that need the imagen.images.id (e.g. the worker linking a job row) can pick
// it up. ownerOverride, when non-empty, wins over config + env — the worker
// passes the job row's owner_user_id so each job is attributed correctly.
// seriesID, when non-empty, lands on imagen.images.series_id so the
// list-page query (`WHERE series_id IS NULL`) hides series members from
// the flat grid; empty means solo run.
func maybeCloudSync(ctx context.Context, cfg *config.Config, noCloud bool, ownerOverride, seriesID string, paths *output.Outputs, in output.Inputs, res *backend.Result, width, height int) (*cloud.SyncResult, error) {
mode, err := resolveCloudSyncMode(cfg, noCloud, os.Getenv("IMAGEN_CLOUD_SYNC"))
if err != nil {
return nil, err
}
if mode == "off" {
return nil, nil
}
sink, ok := cloud.NewFromEnv()
if !ok {
if mode == "on" {
return nil, fmt.Errorf("cloud_sync=on but SUPABASE_URL / SUPABASE_SERVICE_KEY not set in env")
}
// auto + missing env = silent skip.
return nil, nil
}
switch {
case ownerOverride != "":
sink.OwnerUserID = ownerOverride
case cfg != nil && cfg.OwnerUserID != "":
// Config-supplied owner_user_id takes precedence over $IMAGEN_OWNER_USER_ID.
sink.OwnerUserID = cfg.OwnerUserID
}
if sink.OwnerUserID == "" {
if mode == "on" {
return nil, fmt.Errorf("cloud_sync=on but owner_user_id not set in config and $IMAGEN_OWNER_USER_ID is empty")
}
// auto + missing UUID = silent skip.
return nil, nil
}
pngBytes, readErr := os.ReadFile(paths.ImagePath)
if readErr != nil {
return nil, fmt.Errorf("read local image: %w", readErr)
}
// Reuse the writer's date/slug/seed so storage_path mirrors the local
// filename's prefix exactly — viewers can join `imagen.images` on
// either side without timezone drift.
date := paths.Date
slug := paths.Slug
if date == "" || slug == "" {
now := time.Now()
date = now.Format("2006-01-02")
slug = output.Slug(in.Prompt)
}
ext := in.Ext
if ext == "" {
ext = strings.TrimPrefix(filepath.Ext(paths.ImagePath), ".")
}
if ext == "" {
ext = "png"
}
// Snapshot the sidecar (if it exists) so the row carries the same
// metadata view a downstream viewer would see on disk.
var sidecar map[string]any
if paths.SidecarPath != "" {
if scBytes, err := os.ReadFile(paths.SidecarPath); err == nil {
_ = json.Unmarshal(scBytes, &sidecar)
}
}
model := metaString(res.Metadata, "model")
steps := metaInt(res.Metadata, "steps")
cost := metaFloatPtr(res.Metadata, "cost_usd_estimate")
latency := metaInt(res.Metadata, "latency_ms")
seed := paths.Seed
if seed == 0 {
seed = in.Seed
}
syncReq := cloud.SyncRequest{
Date: date,
Slug: slug,
Seed: seed,
Ext: ext,
PNG: pngBytes,
MimeType: res.MimeType,
Prompt: in.Prompt,
Backend: in.Backend,
Model: model,
Steps: steps,
Width: width,
Height: height,
LatencyMs: latency,
CostUSDEstimate: cost,
Sidecar: sidecar,
SeriesID: seriesID,
}
syncCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
return sink.Sync(syncCtx, syncReq)
}
func metaString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func metaInt(m map[string]any, key string) int {
v, ok := m[key]
if !ok {
return 0
}
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 0
}
func metaFloatPtr(m map[string]any, key string) *float64 {
v, ok := m[key]
if !ok {
return nil
}
switch n := v.(type) {
case float64:
return &n
case float32:
f := float64(n)
return &f
case int:
f := float64(n)
return &f
case int64:
f := float64(n)
return &f
}
return nil
}
// resolvePreviewMode applies the precedence chain config -> env -> flag.
// Flags win, env beats config, config beats the implicit auto default.
func resolvePreviewMode(cfg *config.Config, flagOn, flagOff bool, env string) (preview.Mode, error) {
mode := preview.ModeAuto
if cfg != nil && cfg.Output.Preview != "" {
m, err := preview.ParseMode(cfg.Output.Preview)
if err != nil {
return "", fmt.Errorf("config output.preview: %w", err)
}
mode = m
}
if env != "" {
m, err := preview.ParseMode(env)
if err != nil {
return "", fmt.Errorf("$IMAGEN_PREVIEW: %w", err)
}
mode = m
}
if flagOn && flagOff {
return "", userErr("--preview and --no-preview are mutually exclusive")
}
if flagOn {
mode = preview.ModeOn
}
if flagOff {
mode = preview.ModeOff
}
return mode, nil
}
// maybePreview resolves the effective preview mode and, if it says yes,
// spawns a tmux window via tmux-img. Always non-fatal.
func maybePreview(cfg *config.Config, flagOn, flagOff bool, imagePath, rawPrompt string) error {
mode, err := resolvePreviewMode(cfg, flagOn, flagOff, os.Getenv("IMAGEN_PREVIEW"))
if err != nil {
return err
}
decision, err := preview.Resolve(mode, os.Getenv("TMUX") != "", stdoutIsTTY())
if err != nil {
return err
}
if !decision.ShouldPreview {
return nil
}
spawner := &preview.Spawner{}
return spawner.Spawn(imagePath, output.Slug(rawPrompt))
}
func stdoutIsTTY() bool {
fi, err := os.Stdout.Stat()
if err != nil {
return false
}
return fi.Mode()&os.ModeCharDevice != 0
}
// splitLeadingPositional separates the positional args at the start of args
// from the rest (which begins with the first flag). A literal "--" terminator
// pushes everything after it into the positional list and out of flag parsing.
@@ -153,6 +406,21 @@ func parseSize(s string) (int, int, error) {
return w, h, nil
}
// attachUsageSink wires a Supabase cost-tracking sink into the backend
// when it accepts one and the env is configured. Adapters that record
// usage expose a public Sink field of type backend.UsageSink.
func attachUsageSink(be backend.Backend) {
r, ok := be.(*backend.Replicate)
if !ok {
return
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return
}
r.Sink = sink
}
func buildBackend(cfg *config.Config, name string) (backend.Backend, error) {
if cfg != nil {
spec, ok := cfg.Backends[name]

View File

@@ -0,0 +1,87 @@
package main
import (
"testing"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/preview"
)
func TestResolvePreviewMode(t *testing.T) {
type tc struct {
name string
cfg *config.Config
flagOn bool
flagOff bool
env string
want preview.Mode
wantError bool
}
cases := []tc{
{name: "all-empty-defaults-to-auto", want: preview.ModeAuto},
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, want: preview.ModeOn},
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{Preview: "off"}}, want: preview.ModeOff},
{name: "config-auto-explicit", cfg: &config.Config{Output: config.OutputConfig{Preview: "auto"}}, want: preview.ModeAuto},
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, env: "off", want: preview.ModeOff},
{name: "flag-on-overrides-env-off", env: "off", flagOn: true, want: preview.ModeOn},
{name: "flag-off-overrides-env-on", env: "on", flagOff: true, want: preview.ModeOff},
{name: "flag-off-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{Preview: "on"}}, flagOff: true, want: preview.ModeOff},
{name: "both-flags-error", flagOn: true, flagOff: true, wantError: true},
{name: "bad-env-errors", env: "yes", wantError: true},
{name: "bad-config-errors", cfg: &config.Config{Output: config.OutputConfig{Preview: "yes"}}, wantError: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := resolvePreviewMode(c.cfg, c.flagOn, c.flagOff, c.env)
if c.wantError {
if err == nil {
t.Fatalf("expected error, got mode %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != c.want {
t.Errorf("mode = %q, want %q", got, c.want)
}
})
}
}
func TestResolveCloudSyncMode(t *testing.T) {
type tc struct {
name string
cfg *config.Config
noCloud bool
env string
want string
wantError bool
}
cases := []tc{
{name: "all-empty-defaults-to-auto", want: "auto"},
{name: "config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, want: "on"},
{name: "config-off", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "off"}}, want: "off"},
{name: "env-overrides-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "off", want: "off"},
{name: "flag-overrides-env-and-config", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, env: "on", noCloud: true, want: "off"},
{name: "flag-overrides-config-on", cfg: &config.Config{Output: config.OutputConfig{CloudSync: "on"}}, noCloud: true, want: "off"},
{name: "bad-env-errors", env: "yes", wantError: true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := resolveCloudSyncMode(c.cfg, c.noCloud, c.env)
if c.wantError {
if err == nil {
t.Fatalf("expected error, got mode %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != c.want {
t.Errorf("mode = %q, want %q", got, c.want)
}
})
}
}

View File

@@ -14,14 +14,18 @@ import (
_ "mgit.msbls.de/m/ImaGen/internal/backend"
)
const usage = `imagen — model-agnostic image generation
const helpText = `imagen — model-agnostic image generation
Usage:
imagen generate <prompt> [flags] generate one image
imagen compare <prompt> --models a,b,c [flags]
run one prompt across N backends + contact sheet
imagen worker [flags] consume the imagen.jobs queue (daemon)
imagen backends list registered backend types
imagen config init print a sample imagen.yaml on stdout
imagen config validate validate the active config
imagen serve [--addr :8080] (stub) start the HTTP server
imagen usage [--since DATE] show cost-tracking rows
imagen version print version
imagen help show this help
@@ -33,7 +37,7 @@ var Version = "dev"
func main() {
if len(os.Args) < 2 {
fmt.Fprint(os.Stderr, usage)
fmt.Fprint(os.Stderr, helpText)
os.Exit(2)
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
@@ -44,18 +48,24 @@ func main() {
switch os.Args[1] {
case "generate":
err = runGenerate(ctx, args)
case "compare":
err = runCompare(ctx, args)
case "worker":
err = runWorker(ctx, args)
case "backends":
err = runBackends(args)
case "config":
err = runConfig(args)
case "serve":
err = runServe(args)
case "usage":
err = runUsage(ctx, args)
case "version", "-v", "--version":
fmt.Println(Version)
case "help", "-h", "--help":
fmt.Print(usage)
fmt.Print(helpText)
default:
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], usage)
fmt.Fprintf(os.Stderr, "imagen: unknown subcommand %q\n\n%s", os.Args[1], helpText)
os.Exit(2)
}
if err != nil {

189
cmd/imagen/usage.go Normal file
View File

@@ -0,0 +1,189 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"sort"
"strings"
"text/tabwriter"
"time"
"mgit.msbls.de/m/ImaGen/internal/usage"
)
// runUsage handles `imagen usage [--since DATE]`. Reads mai.imagen_usage
// via Supabase REST and prints a tab-aligned table grouped by week +
// backend + model + caller, with totals at the bottom.
func runUsage(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("usage", flag.ContinueOnError)
var (
since string
raw bool
)
fs.StringVar(&since, "since", "", "ISO date (YYYY-MM-DD) — only rows on/after this UTC date")
fs.BoolVar(&raw, "raw", false, "print one line per row instead of grouped")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), "Usage: imagen usage [--since YYYY-MM-DD] [--raw]")
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
var sinceT time.Time
if since != "" {
t, err := time.Parse("2006-01-02", since)
if err != nil {
return userErr("--since must be YYYY-MM-DD: %v", err)
}
sinceT = t
}
sink, ok := usage.NewSupabaseSinkFromEnv()
if !ok {
return userErr("SUPABASE_URL and SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) must be set to read mai.imagen_usage")
}
rows, err := sink.Query(ctx, sinceT)
if err != nil {
return err
}
if raw {
printRawRows(rows)
return nil
}
printGroupedRows(rows)
return nil
}
func printRawRows(rows []usage.Row) {
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "TIME\tBACKEND\tMODEL\tCALLER\tLATENCY_MS\tCOST_USD")
var totalCost float64
for _, r := range rows {
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n",
r.CreatedAt.Local().Format("2006-01-02 15:04"),
r.Backend,
r.Model,
derefString(r.Caller),
intOrDash(r.LatencyMs),
costOrDash(r.CostUSDEstimate),
)
if r.CostUSDEstimate != nil {
totalCost += *r.CostUSDEstimate
}
}
fmt.Fprintf(tw, "\t\t\t\t%d rows\t%.4f USD\n", len(rows), totalCost)
_ = tw.Flush()
}
type group struct {
week string
backend string
model string
caller string
count int
cost float64
costSet bool
}
type groupKey struct {
week, backend, model, caller string
}
func printGroupedRows(rows []usage.Row) {
groups := map[groupKey]*group{}
for _, r := range rows {
caller := derefString(r.Caller)
k := groupKey{
week: weekStart(r.CreatedAt).Format("2006-01-02"),
backend: r.Backend,
model: r.Model,
caller: caller,
}
g, ok := groups[k]
if !ok {
g = &group{week: k.week, backend: r.Backend, model: r.Model, caller: caller}
groups[k] = g
}
g.count++
if r.CostUSDEstimate != nil {
g.cost += *r.CostUSDEstimate
g.costSet = true
}
}
keys := make([]groupKey, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].week != keys[j].week {
return keys[i].week > keys[j].week // newest first
}
if keys[i].backend != keys[j].backend {
return keys[i].backend < keys[j].backend
}
if keys[i].model != keys[j].model {
return keys[i].model < keys[j].model
}
return keys[i].caller < keys[j].caller
})
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "WEEK_OF\tBACKEND\tMODEL\tCALLER\tCOUNT\tCOST_USD")
var totalCount int
var totalCost float64
for _, k := range keys {
g := groups[k]
fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%s\n",
g.week, g.backend, g.model, g.caller, g.count, costStr(g.cost, g.costSet),
)
totalCount += g.count
totalCost += g.cost
}
fmt.Fprintf(tw, "\t\t\tTOTAL\t%d\t%.4f USD\n", totalCount, totalCost)
_ = tw.Flush()
}
// weekStart returns the Monday of the week containing t (UTC).
func weekStart(t time.Time) time.Time {
t = t.UTC()
wd := int(t.Weekday())
if wd == 0 {
wd = 7 // shift Sunday to end-of-week
}
delta := time.Duration(wd-1) * -24 * time.Hour
d := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
return d.Add(delta)
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
func intOrDash(p *int) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%d", *p)
}
func costOrDash(p *float64) string {
if p == nil {
return "-"
}
return fmt.Sprintf("%.4f", *p)
}
func costStr(v float64, set bool) string {
if !set {
return "-"
}
return strings.TrimSpace(fmt.Sprintf("%.4f", v))
}

292
cmd/imagen/worker.go Normal file
View File

@@ -0,0 +1,292 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/jackc/pgx/v5"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/prompt"
"mgit.msbls.de/m/ImaGen/internal/worker"
)
// runWorker is the `imagen worker` subcommand: a long-running daemon that
// consumes the imagen.jobs queue and writes results into imagen.images via
// the same cloud-sync path generate uses.
func runWorker(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("worker", flag.ContinueOnError)
var (
configPath string
pollInterval time.Duration
jobTimeout time.Duration
)
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.DurationVar(&pollInterval, "poll-interval", 5*time.Second, "safety-poll cadence between LISTEN wakeups")
fs.DurationVar(&jobTimeout, "job-timeout", 5*time.Minute, "max wall-time per job before the worker marks it failed")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen worker [flags]
Long-running daemon. LISTENs on the Postgres 'imagen_jobs' channel and polls
imagen.jobs every --poll-interval as a safety net, claims pending rows, runs
the generation pipeline, then updates the row with status + image_id.
Env:
IMAGEN_WORKER_DATABASE_URL Postgres DSN for direct LISTEN + UPDATE.
Required (PostgREST cannot LISTEN).
SUPABASE_URL, SUPABASE_SERVICE_KEY, IMAGEN_OWNER_USER_ID
Reused from generate's cloud-sync path; the
worker writes imagen.images rows through the
same code path. Per-job owner_user_id from the
job row overrides IMAGEN_OWNER_USER_ID.`)
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return err
}
cfg, cfgErr := config.Load(configPath)
if cfgErr != nil && !os.IsNotExist(cfgErr) {
return cfgErr
}
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
if dsn == "" {
return userErr("IMAGEN_WORKER_DATABASE_URL not set; the worker needs a direct Postgres DSN for LISTEN/NOTIFY")
}
q, err := dialQueue(ctx, dsn)
if err != nil {
return fmt.Errorf("queue: %w", err)
}
defer q.Close()
p := &workerPipeline{cfg: cfg}
w := worker.New(q, p, worker.Config{
PollInterval: pollInterval,
JobTimeout: jobTimeout,
Logger: func(format string, a ...any) { fmt.Fprintf(os.Stderr, format+"\n", a...) },
})
fmt.Fprintln(os.Stderr, "imagen worker: ready (poll-interval", pollInterval, "job-timeout", jobTimeout, ")")
return w.Run(ctx)
}
// pgxQueue is the production Queue. It opens one dedicated connection used
// for both LISTEN (long-lived) and UPDATE operations. A second connection
// would split state needlessly — a single worker process processes one job
// at a time so the connection is never contended.
type pgxQueue struct {
conn *pgx.Conn
}
func dialQueue(ctx context.Context, dsn string) (*pgxQueue, error) {
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("pgx.Connect: %w", err)
}
if _, err := conn.Exec(ctx, "LISTEN imagen_jobs"); err != nil {
conn.Close(ctx)
return nil, fmt.Errorf("LISTEN imagen_jobs: %w", err)
}
return &pgxQueue{conn: conn}, nil
}
func (q *pgxQueue) Close() {
if q == nil || q.conn == nil {
return
}
// Best-effort: a 5s budget is enough to send a polite TerminateMessage.
shutdown, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = q.conn.Close(shutdown)
}
// ClaimNextPending atomically marks the oldest pending row 'running' and
// returns it. FOR UPDATE SKIP LOCKED is belt + braces against a second worker
// process — out of scope for v1 but cheap insurance.
func (q *pgxQueue) ClaimNextPending(ctx context.Context) (*worker.Job, error) {
// series_id is nullable on imagen.jobs (solo run when NULL); cast to text
// with COALESCE so pgx scans into a plain Go string. Empty string =
// solo run; the pipeline skips series propagation in that case.
const stmt = `
UPDATE imagen.jobs
SET status='running', started_at=now()
WHERE id = (
SELECT id FROM imagen.jobs
WHERE status='pending'
ORDER BY created_at
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING id, owner_user_id, prompt, backend,
COALESCE(model,''),
COALESCE(width, 0), COALESCE(height, 0),
COALESCE(steps, 0), COALESCE(seed, 0),
COALESCE(style,''),
COALESCE(series_id::text, '')`
var j worker.Job
err := q.conn.QueryRow(ctx, stmt).Scan(
&j.ID, &j.OwnerUserID, &j.Prompt, &j.Backend,
&j.Model, &j.Width, &j.Height, &j.Steps, &j.Seed, &j.Style,
&j.SeriesID,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
return &j, nil
}
func (q *pgxQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='done', image_id=$2, completed_at=now() WHERE id=$1`,
jobID, imageID)
return err
}
func (q *pgxQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
// Trim outrageously long error text so a 10MB stack-trace doesn't end up
// in the row (callers see a summary, full text goes to stderr / logs).
const maxLen = 2000
if len(msg) > maxLen {
msg = msg[:maxLen] + "... [truncated]"
}
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='failed', error=$2, completed_at=now() WHERE id=$1`,
jobID, msg)
return err
}
// WaitForJob blocks until a NOTIFY arrives on imagen_jobs, the timeout fires,
// or ctx is cancelled. Notifications during a previous processJob are queued
// by pgx and delivered on the next call — we don't lose wake-ups even when
// processing took longer than poll-interval.
func (q *pgxQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
_, err := q.conn.WaitForNotification(waitCtx)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil // poll cadence fired
}
if errors.Is(err, context.Canceled) {
return context.Canceled
}
return err
}
return nil
}
// ResetStaleRunning bumps any rows stuck in 'running' back to 'pending' so
// they get re-claimed. Called once at startup. A row stuck in 'running' came
// from a previous worker crash; without this, flexsiebels would poll
// forever on a job nobody is processing.
func (q *pgxQueue) ResetStaleRunning(ctx context.Context) error {
_, err := q.conn.Exec(ctx,
`UPDATE imagen.jobs SET status='pending', started_at=NULL WHERE status='running'`)
return err
}
// workerPipeline is the Pipeline implementation that drives a single job
// through buildBackend → prompt enrichment → generate → write disk →
// cloud-sync, then returns the imagen.images.id back to the worker so it
// can link the row.
type workerPipeline struct {
cfg *config.Config
}
func (p *workerPipeline) Run(ctx context.Context, job worker.Job) worker.Outcome {
if job.OwnerUserID == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: missing owner_user_id", job.ID)}
}
if job.Prompt == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: empty prompt", job.ID)}
}
if job.Backend == "" {
return worker.Outcome{Err: fmt.Errorf("job %s: missing backend", job.ID)}
}
be, err := buildBackend(p.cfg, job.Backend)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("backend %q: %w", job.Backend, err)}
}
attachUsageSink(be)
finalPrompt, err := prompt.Apply(job.Prompt, job.Style)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("style: %w", err)}
}
req := backend.Request{
Prompt: finalPrompt,
Width: job.Width,
Height: job.Height,
Steps: job.Steps,
Seed: job.Seed,
Style: job.Style,
}
res, err := be.Generate(ctx, req)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("generate: %w", err)}
}
defer res.ImageReader.Close()
writer := buildWriter(p.cfg, false)
in := output.Inputs{
Prompt: job.Prompt,
Backend: be.Name(),
Seed: seedFromMetadata(res.Metadata, job.Seed),
Ext: extFromMime(res.MimeType),
Metadata: res.Metadata,
}
paths, err := writer.Write(res.ImageReader, in)
if err != nil {
return worker.Outcome{Err: fmt.Errorf("write disk: %w", err)}
}
// Worker is queue-driven: cloud-sync is mandatory because flexsiebels
// needs imagen.images.id to render the result. Pass cloud_sync=on via
// the override path (third arg = ownerUserID); we set the mode by
// disallowing the 'off' branch through the cfg later if the user
// explicitly turned it off in config.
if cloudModeOff(p.cfg) {
// We refuse to silently drop a queued job. If cloud sync is off in
// config, the worker can't serve flexsiebels at all.
return worker.Outcome{Err: fmt.Errorf("output.cloud_sync=off in config; the worker requires cloud_sync=on or auto")}
}
syncRes, syncErr := maybeCloudSync(ctx, p.cfg, false, job.OwnerUserID, job.SeriesID, paths, in, res, dimOrFallback(job.Width, res, "width"), dimOrFallback(job.Height, res, "height"))
if syncErr != nil {
return worker.Outcome{Err: fmt.Errorf("cloud sync: %w", syncErr)}
}
if syncRes == nil || syncRes.ImageID == "" {
return worker.Outcome{Err: fmt.Errorf("cloud sync returned no imagen.images id (check SUPABASE_URL + SUPABASE_SERVICE_KEY)")}
}
return worker.Outcome{ImageID: syncRes.ImageID}
}
func cloudModeOff(cfg *config.Config) bool {
if cfg == nil {
return false
}
return strings.EqualFold(cfg.Output.CloudSync, "off")
}
// dimOrFallback returns job.<dim> when the job specified one, otherwise the
// dimension reported by the backend's metadata. Some backends (Replicate
// when given an aspect ratio) round the requested size to their nearest
// supported value; this keeps the row honest about what was actually generated.
func dimOrFallback(jobDim int, res *backend.Result, key string) int {
if jobDim > 0 {
return jobDim
}
return metaInt(res.Metadata, key)
}

View File

@@ -0,0 +1,129 @@
package main
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/worker"
)
// TestWorker_Integration_EndToEnd runs the full pipeline against a real
// msupabase instance: insert a row into imagen.jobs, let the worker claim
// it, generate via the mock backend (no Replicate spend, no ComfyUI
// dependency), write to Supabase Storage + imagen.images, then flip the job
// to 'done' with the linked image_id.
//
// Guarded by IMAGEN_WORKER_INTEGRATION=1. Required env beyond that:
//
// IMAGEN_WORKER_DATABASE_URL postgres DSN (direct, not PostgREST)
// SUPABASE_URL e.g. https://supa.flexsiebels.de
// SUPABASE_SERVICE_KEY service-role JWT
// IMAGEN_OWNER_USER_ID UUID of an auth.users row (RLS fallback)
//
// The test creates and later deletes its own job row so repeated runs don't
// leave debris.
func TestWorker_Integration_EndToEnd(t *testing.T) {
if os.Getenv("IMAGEN_WORKER_INTEGRATION") != "1" {
t.Skip("set IMAGEN_WORKER_INTEGRATION=1 to run the integration test")
}
dsn := os.Getenv("IMAGEN_WORKER_DATABASE_URL")
if dsn == "" {
t.Fatal("IMAGEN_WORKER_DATABASE_URL must be set for the integration test")
}
if os.Getenv("SUPABASE_URL") == "" || os.Getenv("SUPABASE_SERVICE_KEY") == "" {
t.Fatal("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set for the integration test")
}
owner := os.Getenv("IMAGEN_OWNER_USER_ID")
if owner == "" {
t.Fatal("IMAGEN_OWNER_USER_ID must be set for the integration test")
}
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
q, err := dialQueue(ctx, dsn)
if err != nil {
t.Fatalf("dialQueue: %v", err)
}
defer q.Close()
// Insert the test job on a separate connection (the worker's conn is
// busy LISTENing). Mock backend = no external dependency.
insertConn, err := pgx.Connect(ctx, dsn)
if err != nil {
t.Fatalf("insert conn: %v", err)
}
defer insertConn.Close(ctx)
var jobID string
prompt := fmt.Sprintf("imagen integration test %d", time.Now().UnixNano())
err = insertConn.QueryRow(ctx, `
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
VALUES ($1, $2, 'mock', 64, 64)
RETURNING id`,
owner, prompt).Scan(&jobID)
if err != nil {
t.Fatalf("insert job: %v", err)
}
t.Logf("inserted imagen.jobs id=%s", jobID)
// Tidy up at the end of the test so a re-run starts clean.
defer func() {
cleanup, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _ = insertConn.Exec(cleanup, `DELETE FROM imagen.jobs WHERE id=$1`, jobID)
}()
// Use a per-test temp dir so the generated PNG doesn't litter the repo.
tmpDir := t.TempDir()
cfg := &config.Config{Output: config.OutputConfig{Directory: tmpDir}}
p := &workerPipeline{cfg: cfg}
w := worker.New(q, p, worker.Config{
PollInterval: 1 * time.Second,
JobTimeout: 30 * time.Second,
Logger: func(format string, a ...any) { t.Logf("worker: "+format, a...) },
})
// Run the worker until it processes one job (the one we just inserted)
// or the test context times out.
runCtx, runCancel := context.WithCancel(ctx)
done := make(chan struct{})
go func() {
_ = w.Run(runCtx)
close(done)
}()
// Poll for completion.
deadline := time.Now().Add(60 * time.Second)
var status, imageID string
for time.Now().Before(deadline) {
err = insertConn.QueryRow(ctx,
`SELECT status, COALESCE(image_id::text,'') FROM imagen.jobs WHERE id=$1`,
jobID).Scan(&status, &imageID)
if err != nil {
t.Fatalf("poll: %v", err)
}
if status == "done" || status == "failed" {
break
}
time.Sleep(500 * time.Millisecond)
}
runCancel()
<-done
if status != "done" {
var errText string
_ = insertConn.QueryRow(ctx,
`SELECT COALESCE(error,'') FROM imagen.jobs WHERE id=$1`, jobID).Scan(&errText)
t.Fatalf("job not done within timeout: status=%q error=%q", status, errText)
}
if imageID == "" {
t.Fatalf("job done but image_id is empty")
}
t.Logf("job done: image_id=%s", imageID)
}

View File

@@ -7,7 +7,7 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
```
┌───────────────────────┐
│ cmd/imagen │ CLI dispatch
│ cmd/imagen │ CLI dispatch (generate / worker / …)
│ (or HTTP server) │
└──────────┬────────────┘
@@ -15,6 +15,10 @@ upstream API. Each adapter only ever sees its own slice of `imagen.yaml`.
│ internal/prompt │ style preset → prompt suffix
│ internal/output │ filename templating, sidecar
│ internal/config │ YAML loader, validation
│ internal/preview │ tmux-img window spawner
│ internal/cloud │ Supabase Storage + imagen.images
│ internal/usage │ mai.imagen_usage cost-tracking
│ internal/worker │ imagen.jobs queue consumer
└──────────┬────────────┘
┌──────────▼────────────┐
@@ -102,9 +106,37 @@ contains the prompt, backend instance name, seed, ISO timestamp, and the
- Network errors during `Generate` — wrap and return; no retry policy yet
(decide per-adapter, or move to a shared retry helper if a pattern emerges).
## Async write path: `imagen worker` + `imagen.jobs`
`imagen generate` is the synchronous CLI. For web callers (flexsiebels'
owner-mode UI) `cmd/imagen worker` runs as a daemon that consumes the
`imagen.jobs` table.
```
flexsiebels POST imagen worker (mRiver, systemd)
→ INSERT INTO LISTEN imagen_jobs ◄── pg_notify trigger
imagen.jobs(pending) claim row (UPDATE … RETURNING)
dispatch through internal/backend
write disk + cloud-sync via internal/cloud
UPDATE imagen.jobs SET status='done', image_id=…
```
The queue table lives next to `imagen.images` in the same `imagen` schema.
Owner-scoped RLS lets the flexsiebels user INSERT + read their own rows;
the worker writes (status updates + image_id link) via service-role which
bypasses RLS. A 5-second safety poll fires on every wake-up to cover
dropped NOTIFY events and worker cold starts with a non-empty queue. See
`docs/setup-worker-mriver.md` for the systemd installation.
The worker reuses `internal/backend`, `internal/output`, and
`internal/cloud` unchanged — it is purely an orchestration layer around
the same pipeline `imagen generate` drives.
## Out of scope (today)
- Image post-processing (cropping, watermarking).
- Cost-tracking (lands with the Replicate adapter, since only API backends bill).
- Multi-image `n>1` per request — backends that support it can expose it via
`BackendOpts`; the framework doesn't have a first-class field yet.
- Job cancellation / kill switch — separate follow-up issue.
- Concurrent workers / multi-host scale-out — `FOR UPDATE SKIP LOCKED` in
the claim query makes it cheap to add, but a single worker is the v1 setup.

310
docs/backends.md Normal file
View File

@@ -0,0 +1,310 @@
# ImaGen backends
This document covers the local-ComfyUI backend plug-in story: how adapters
are layered, how to add a new model without touching Go, and the per-model
setup steps for the bundled templates.
For the host-side ComfyUI install (mRock — venv, weights for the default
FLUX.1-schnell, systemd, VRAM coexistence with Ollama, smoke test against
the raw HTTP API), see [`setup-comfyui-mrock.md`](setup-comfyui-mrock.md).
## Architecture: Path 1 — workflow-template adapter
`imagen generate` and `imagen compare` dispatch through the `comfyui`
adapter, which holds the HTTP plumbing (`/prompt`, `/history/{id}`, `/view`,
`/system_stats`) and treats the workflow itself as data. Each backend
instance in `imagen.yaml` picks a workflow JSON via the `workflow:` key.
Adding a new model is yaml + JSON, never Go:
```
internal/backend/
comfyui.go # one adapter, all ComfyUI models
workflow_template.go # loader + token-substitution
workflows/
flux1-schnell.json # bundled templates (embedded with //go:embed)
flux2-klein.json
sd35-medium.json
```
### Why Path 1 over per-family adapters (`comfyui-flux.go`, `comfyui-sd3.go`…)
- **Workflow JSON is the natural exchange format**. ComfyUI users export
workflows from its GUI as JSON. Anything else means rebuilding the graph
by hand in Go for every new model.
- **Adding a model is a config change, not a build change**. With Path 2,
every new family is a Go file, a new test file, a registry entry, a new
worker binary, a redeploy. Path 1 lets us land a new model with one yaml
block + one JSON file + one section in this doc.
- **The HTTP plumbing is identical across families**. `/prompt`,
`/history`, `/view`, the retry policy, the "value not in list" hint, VRAM
reporting — none of it depends on the workflow shape. Path 2 would
duplicate that across files.
- **Failure isolation stays clean**. The workflow loader fails at adapter
construction (`imagen backends` surfaces the error), the HTTP layer
fails at `Generate`, and ComfyUI's own validation surfaces missing-model
hints. Each layer's error message points at the right config knob.
Path 2's argument was "each family owns its quirks (samplers, schedulers,
dual-stage etc.)". That argument doesn't survive contact with the
substitution-map design: per-family knobs are just key/value fields in the
yaml block and `${shift}`/`${guidance}`/`${cfg}` placeholders in the
template. No code duplication, no inheritance to debug.
### Token substitution
`workflow_template.SubstituteWorkflow` walks the parsed JSON and replaces
every whole-value string of the form `"${key}"` with the typed value from
the substitution map. Numbers stay numbers, strings stay strings — no
round-tripping through `strings.Replace`.
The substitution map is built per call from:
1. **Request fields** (always present): `${prompt}`, `${negative}`,
`${width}`, `${height}`, `${seed}`, `${steps}`, `${sampler}`,
`${scheduler}`, `${cfg}`.
2. **Every scalar field from the yaml block** (string / int / int64 /
float64 / bool), minus framework keys (`type`, `base_url`, `workflow`,
`default_*`). So `${vae}`, `${clip}`, `${clip_l}`, `${clip_t5}`,
`${dtype}`, `${shift}`, `${guidance}` all become substitutable just by
being in yaml.
3. **Sensible defaults** for the common optional knobs above, so a
workflow that references `${dtype}` without the user setting one in
yaml still substitutes cleanly (`fp8_e4m3fn` for FLUX, `3.0` for SD3
shift, etc.). Extra defaults are ignored by workflows that don't
reference them.
Partial matches (e.g. `"prefix ${prompt} suffix"`) are deliberately **not**
substituted — the placeholder must be the entire value so we can preserve
its JSON type. This prevents a prompt containing literal `${seed}` text
from corrupting the workflow.
Unknown placeholders (referenced in JSON but missing from the substitution
map) error out before the workflow leaves the binary.
### Back-compat
The `workflow:` field defaults to `flux1-schnell` if omitted. Existing
yaml blocks like the pre-#10 FLUX.1-schnell instance:
```yaml
flux-schnell-local:
type: comfyui
base_url: http://mrock:8188
model: flux1-schnell.safetensors
```
still work unchanged — they implicitly pick up the migrated
`flux1-schnell.json` template, which keeps the same node IDs (6, 8, 9, 10,
11, 12, 13, 27, 30, 31) as the historical hardcoded workflow.
## Bundled workflows
### FLUX.1-schnell — the back-compat default
| Field | Default | Notes |
|---|---|---|
| `model` | `flux1-schnell.safetensors` | drop in `models/unet/` |
| `vae` | `ae.safetensors` | `models/vae/` |
| `clip_l` | `clip_l.safetensors` | `models/clip/` |
| `clip_t5` | `t5xxl_fp8_e4m3fn.safetensors` | `models/clip/` |
| `dtype` | `fp8_e4m3fn` | weight dtype for the UNet loader |
| `default_steps` / `default_cfg` | 4 / 1.0 | schnell is distilled to ~4 steps |
VRAM peak ~1012 GB at 1024×1024. Install path:
[`setup-comfyui-mrock.md`](setup-comfyui-mrock.md). Already shipping.
### FLUX.2 [klein] 4B — direct upgrade
Released by Black Forest Labs late 2025 / early 2026, BFL non-commercial
license. The distilled 4B "klein" variant lands sub-second on the RTX
4070 Ti SUPER and shares the new Qwen-based text encoder + a re-trained
VAE with the larger family.
```yaml
flux2-klein-local:
type: comfyui
base_url: http://mrock:8188
workflow: flux2-klein
model: flux-2-klein-base-4b-fp8.safetensors # models/unet/
vae: flux2-vae.safetensors # models/vae/
clip: qwen_3_4b.safetensors # models/text_encoders/
dtype: fp8_e4m3fn
default_steps: 4
default_cfg: 1.0
guidance: 4.0
```
**Model downloads** (on mRock, ungated mirrors when available):
```bash
cd ~/dev/comfyui/models
curl -L -o unet/flux-2-klein-base-4b-fp8.safetensors \
https://huggingface.co/black-forest-labs/FLUX.2-klein/resolve/main/flux-2-klein-base-4b-fp8.safetensors
curl -L -o vae/flux2-vae.safetensors \
https://huggingface.co/black-forest-labs/FLUX.2-klein/resolve/main/flux2-vae.safetensors
mkdir -p text_encoders
curl -L -o text_encoders/qwen_3_4b.safetensors \
https://huggingface.co/black-forest-labs/FLUX.2-klein/resolve/main/qwen_3_4b.safetensors
```
BFL's primary repo is gated; if `curl` returns 401, configure an HF token
in `~/.cache/huggingface/token` or use one of the community mirrors
(check the official model card for the current list). The filenames the
template references match BFL's canonical names — rename downloads to
match if a mirror uses different ones.
VRAM peak: ~8.5 GB (4B fp8). With Ollama parked at ~8 GB this still fits;
unlike FLUX.1-schnell, klein doesn't require stopping Ollama on mRock.
### SD3.5-medium — single-checkpoint variant
Stability AI's 2.5B mid-size model with bundled text encoders. The
`incl_clips_t5xxlfp8scaled` variant ships clip_g + clip_l + t5xxl_fp8 all
in one `.safetensors`, so the workflow uses `CheckpointLoaderSimple`
instead of separate UNet/VAE/CLIP loaders.
```yaml
sd35-medium-local:
type: comfyui
base_url: http://mrock:8188
workflow: sd35-medium
model: sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors # models/checkpoints/
default_steps: 28
default_sampler: dpmpp_2m
default_scheduler: sgm_uniform
default_cfg: 4.5
shift: 3.0
```
**Model download** (on mRock):
```bash
cd ~/dev/comfyui/models
curl -L -o checkpoints/sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors \
https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/resolve/main/sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors
```
VRAM peak: ~9.9 GB at 1024×1024. Same envelope as FLUX.1-schnell — stop
Ollama before generating, restart after.
## Adding a new bundled workflow
1. **Export from ComfyUI**: load the model in the ComfyUI GUI, build a
text-to-image workflow that produces what you want, "Save (API
Format)" — the file you get is the right shape.
2. **Sprinkle placeholders**: open the JSON and replace per-call values
with `${name}` tokens. Whole-value substitution only:
```json
"inputs": {
"text": "${prompt}", // was "a cat sitting on a chair"
"seed": "${seed}", // was 1234567
"steps": "${steps}", // was 28
"cfg": "${cfg}",
"sampler_name": "${sampler}",
"scheduler": "${scheduler}",
"width": "${width}",
"height": "${height}"
}
```
Use `${model}` for the checkpoint / unet filename and any per-template
knobs (`${vae}`, `${shift}`, `${guidance}`, `${clip}` …).
3. **Drop it into `internal/backend/workflows/<name>.json`**. The
`//go:embed workflows/*.json` directive in `workflow_template.go`
picks it up at build time — no registry entry needed.
4. **Add a yaml instance** in `internal/config/config.go`'s `Sample` block
for `imagen config init` (and `~/.config/imagen.yaml`) so users
discover the new backend.
5. **Document the model files + HF download URLs** in this doc.
6. **Smoke test**: `imagen generate "test" --backend <new-instance>
--size 1024x1024` should produce an image.
Per-call overrides for sampler/scheduler/cfg go via `--steps`, `--seed`,
and (programmatic) `backend.Request.BackendOpts["sampler"]` /
`["scheduler"]` / `["cfg"]`. The compare harness forwards the
constant-across-backends knobs verbatim.
## Loading a workflow from disk (one-off)
Pass an absolute filesystem path as `workflow:` and the adapter reads it
from disk instead of the embedded FS. Handy for prototyping a new model
before committing it:
```yaml
my-experimental:
type: comfyui
base_url: http://mrock:8188
workflow: /home/m/dev/comfyui/workflows/my-test.json
model: my-test-model.safetensors
```
The fallback chain is: filesystem path (if the string looks like a path
or ends in `.json`), then bundled lookup by name, then bundled lookup
with `.json` appended.
## `imagen compare`: cross-backend evaluation
```bash
imagen compare "a wizard casting a spell" \
--models flux-schnell-local,flux2-klein-local,sd35-medium-local \
--size 1024x1024 \
--output ~/Pictures/imagen/compare
```
Per run, `compare`:
- creates `<output>/<YYYYMMDD-HHMMSS>-<prompt-slug>/`
- dispatches each named backend sequentially (mRock has one GPU; parallel
would OOM) — one backend's failure doesn't abort the run
- writes per-backend PNGs as `<prompt-slug>--<backend-slug>.png`
- writes `compare.json` listing every attempt (success + failure) with
per-model `seed`, `latency_ms`, `model`, `vram_used_mib`, full
`metadata` map, and the error string for any failure
- composites a `contact-sheet.png` with the prompt as header and each
cell labelled `<backend>` / `<latency>ms · seed <n>`
Flags mirror `generate`: `--seed`, `--steps`, `--style`, `--negative`,
`--size` are shared across all backends. `--no-contact-sheet` skips the
composite when only the per-image PNGs and sidecar matter (e.g. for a
worker script that builds its own diff view).
## Diagnostics
`imagen backends` shows every instance with its registration state. For
local ComfyUI, the status is currently just `registered` (we don't probe
the upstream HTTP endpoint at startup — the boot-helper hint kicks in on
first generation if mRock is asleep).
Per-backend errors emit at most three kinds:
1. **Adapter construction failure** (e.g. workflow JSON not found,
missing required yaml field). Caught at `buildBackend` time:
`imagen: backend "<name>": <err>`.
2. **HTTP / runtime failure during Generate**. Wrapped with the boot
helper for `connection refused`/`no such host`/timeouts pointing at
`boot-whitetower mrock` so a sleeping mRock has an obvious next step.
3. **ComfyUI workflow-validation failure** (200-with-node_errors or 400).
Surfaces with a model-not-found hint (matching `value_not_in_list` +
`unet_name`/`ckpt_name`) when applicable, pointing back at this doc.
## Worker daemon notes
`imagen worker` (the `imagen.jobs` queue consumer) uses the same adapter
+ workflow lookup as the synchronous CLI — flexsiebels' `/imagine` UI
INSERTs a `backend = <instance>` row, the worker claims it, and the
underlying ComfyUI HTTP calls are identical to what `generate` makes. No
worker-specific changes are required when a new backend lands; the
config + workflow are the only state that has to be present on the
worker host.
After merging a new template or yaml block:
```bash
# On the worker host (mRiver today):
systemctl --user restart imagen-worker
```
The daemon-rebuild trap from issue #9 still applies: if you build the
imagen binary on the dev machine and `scp` it over, restart the unit so
systemd picks up the new ELF.

View File

@@ -0,0 +1,97 @@
# `imagen worker` on mRiver
The worker is a long-running daemon that consumes the `imagen.jobs` queue
(written by flexsiebels' owner-mode UI) and writes the resulting image to
Supabase Storage + `imagen.images` via the same cloud-sync path the CLI
`imagen generate` uses.
## Architecture
```
flexsiebels (owner UI)
|
v INSERT INTO imagen.jobs (...)
|
msupabase Postgres
|
| AFTER INSERT trigger:
| pg_notify('imagen_jobs', NEW.id)
v
imagen worker (mRiver) ── LISTEN imagen_jobs
|
| 1. claim oldest 'pending' row (status='running')
| 2. dispatch to backend (FLUX schnell local / FLUX dev replicate / …)
| 3. write PNG to disk
| 4. upload to Storage + INSERT into imagen.images
| 5. UPDATE imagen.jobs SET status='done', image_id=...
v
flexsiebels polls GET .../jobs/<id> → renders the rendered card
```
A 5-second safety poll covers dropped NOTIFY events and worker cold starts
with a non-empty queue.
## One-time setup
```bash
# 1. Build the binary (or `task build`).
cd ~/dev/ImaGen
go build -o bin/imagen ./cmd/imagen
# 2. Write the environment file.
cp scripts/imagen-worker.env.example ~/.dotfiles/.env.imagen-worker
chmod 600 ~/.dotfiles/.env.imagen-worker
$EDITOR ~/.dotfiles/.env.imagen-worker # fill in real DSN, service key
# 3. Install the user systemd unit.
mkdir -p ~/.config/systemd/user
cp scripts/imagen-worker.service ~/.config/systemd/user/imagen-worker.service
systemctl --user daemon-reload
systemctl --user enable --now imagen-worker.service
# 4. Tail the logs.
journalctl --user -u imagen-worker -f
```
## Required env vars
See `scripts/imagen-worker.env.example` for the canonical list. Required:
- `IMAGEN_WORKER_DATABASE_URL` — direct Postgres DSN. PostgREST cannot LISTEN.
- `SUPABASE_URL`, `SUPABASE_SERVICE_KEY` — same pair `imagen generate`
reads for the cloud-sync writer.
- `IMAGEN_OWNER_USER_ID` — fallback owner UUID; per-job row's
`owner_user_id` overrides this.
Optional, depending on enabled backends:
- `REPLICATE_API_TOKEN` if any job will request a Replicate-typed backend.
## Operating
```bash
systemctl --user status imagen-worker # health
systemctl --user restart imagen-worker # pick up a new binary
journalctl --user -u imagen-worker -n 200 # recent log lines
```
On startup the worker calls `ResetStaleRunning` once, flipping any rows
left in `'running'` from a previous crash back to `'pending'` so they get
re-claimed by the 5-second poll.
## Smoke test
With the worker running, INSERT a test job:
```sql
INSERT INTO imagen.jobs (owner_user_id, prompt, backend, width, height)
VALUES (
'ac6c9501-3757-4a6d-8b97-2cff4288382b',
'a tiny owl wearing wire-rim glasses, photo',
'flux-schnell-local', 1024, 1024
);
```
Within ~10 seconds the row should show `status='done'`, a populated
`image_id` linking to a real `imagen.images` row, and a Storage object at
`<YYYY-MM-DD>/<slug>-<seed>.png` in the `imagen-generated` bucket.

View File

@@ -4,14 +4,21 @@
```
imagen generate <prompt> [flags] generate one image
imagen compare <prompt> --models a,b,c [flags]
run one prompt across N backends + contact sheet
imagen worker [flags] consume the imagen.jobs queue (daemon)
imagen backends list configured + registered backends
imagen config init print a sample imagen.yaml on stdout
imagen config validate parse + validate the active config
imagen config path print the resolved config path
imagen serve [--addr :8080] (stub) start the HTTP server
imagen usage [--since DATE] show cost-tracking rows
imagen version print version
```
For the per-backend setup (FLUX.1, FLUX.2 [klein], SD3.5 medium, …) and
the architecture rationale, see [`backends.md`](backends.md).
## `generate` flags
| Flag | Default | Notes |
@@ -24,8 +31,29 @@ imagen version print version
| `--negative` | empty | Negative prompt (ignored by some adapters) |
| `--output` | empty (= use naming template) | Explicit path |
| `--no-sidecar` | `false` | Skip the JSON sidecar even if config enables it |
| `--preview` | (auto) | Force open a tmux preview window via `tmux-img` |
| `--no-preview` | (auto) | Suppress the preview window (use for batch / CI callers) |
| `--no-cloud` | `false` | Skip Supabase upload + `imagen.images` insert for this call |
| `--config` | `~/.config/imagen.yaml` | Override config path |
### Preview window
After a successful generate, imagen optionally opens a sibling tmux window
named `img:<slug>` running `tmux-img --hold <path>`. The new window is
spawned in the background (`tmux new-window -d`) so the generating pane
keeps focus and its terminal output.
Resolution order is **config → `$IMAGEN_PREVIEW` → flag** (later wins):
- `output.preview` in `imagen.yaml`: `auto` (default) | `on` | `off`
- `IMAGEN_PREVIEW=auto|on|off` overrides config
- `--preview` / `--no-preview` override env
`auto` previews iff stdout is a TTY *and* `$TMUX` is set. `on` previews
unconditionally and errors outside a tmux session. `off` never previews.
Preview failures are non-fatal — the image already wrote.
## Examples
```sh
@@ -71,3 +99,53 @@ API-backed adapters read tokens from env vars referenced by the config
export REPLICATE_API_TOKEN=...
imagen generate "a cat" --backend flux-dev-replicate
```
## Cost-tracking (Replicate)
Successful generations through the Replicate adapter write one row to
`mai.imagen_usage` on Supabase: backend, model, latency, per-image cost
estimate, prompt sha256 hash (never the prompt itself), and the caller
identity (resolved from `MAI_FROM_ID` or the tmux pane's `@mai-name`).
The writer is best-effort. If `SUPABASE_URL` / `SUPABASE_SERVICE_KEY` are
unset, or the database write fails, the image still lands and the CLI
prints a warning to stderr.
Inspect spend:
```sh
imagen usage # all rows, grouped by week + backend + model + caller
imagen usage --since 2026-05-01 # only rows on/after a UTC date
imagen usage --since 2026-05-01 --raw
```
Per-model rates live in `internal/backend/replicate_pricing.go` — they
are snapshotted from <https://replicate.com/pricing> and refreshed on a
quarterly cadence.
## Cloud-sync (Supabase)
Successful generations also upload the PNG to the private Supabase
Storage bucket `imagen-generated` (path: `<YYYY-MM-DD>/<slug>-<seed>.png`)
and insert a row into `imagen.images`. The row carries the prompt,
sha256-hashed prompt, backend, model, seed/steps/width/height, latency,
cost estimate, the full local sidecar JSON, and an empty `tags` array
ready for the flexsiebels viewer to fill in.
Configuration:
- `owner_user_id` in `imagen.yaml` — m's `auth.users.id`. Empty disables
inserts (the column is `NOT NULL`).
- `output.cloud_sync` in `imagen.yaml`: `auto` (default — on iff
SUPABASE creds + `owner_user_id` are set), `on` (errors if either is
missing), `off`.
- `IMAGEN_CLOUD_SYNC=auto|on|off` overrides config.
- `--no-cloud` overrides everything for one call.
Reuses the same Supabase env (`SUPABASE_URL` + `SUPABASE_SERVICE_KEY` or
`MAI_SUPABASE_KEY`) as cost-tracking. Service-role bypasses RLS for
inserts; the `owner_user_id = auth.uid()` policy on the table gates the
read path the flexsiebels viewer hits.
Failures (Storage 5xx, DB unreachable) emit `imagen: cloud sync: <err>`
to stderr and the local PNG + sidecar stay put. Exit code is unchanged.

16
go.mod
View File

@@ -1,5 +1,17 @@
module mgit.msbls.de/m/ImaGen
go 1.24
go 1.25.0
require gopkg.in/yaml.v3 v3.0.1
require (
github.com/jackc/pgx/v5 v5.9.2
golang.org/x/image v0.40.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/text v0.37.0 // indirect
)

35
go.sum
View File

@@ -1,4 +1,37 @@
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/image v0.40.0 h1:Tw4GyDXMo+daZN1znreBRC3VayR1aLFUyUEOLUdW1a8=
golang.org/x/image v0.40.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -20,24 +20,29 @@ import (
const ComfyType = "comfyui"
// Comfy is the ComfyUI adapter. It speaks the public `/prompt` + `/history`
// + `/view` HTTP API and submits a fixed FLUX.1 schnell workflow built from
// the values in Request.
// + `/view` HTTP API and submits a workflow built by substituting Request
// values into a JSON template (bundled under internal/backend/workflows/ or
// loaded from a filesystem path).
//
// Concurrency: a single Comfy is safe to share across goroutines as long as
// the underlying http.Client is. Generate does not hold long-lived state.
type Comfy struct {
instance string
base string
model string
vae string
clipL string
clipT5 string
dtype string
base string
workflow string
// rawCfg keeps the original yaml block (minus framework keys) so we can
// expose every user-defined string/number as a workflow substitution
// without enumerating each per-model knob in Go. Empty values still get
// a substitution entry so a template can reference ${negative} when the
// request didn't pass one.
rawCfg map[string]any
defaultSteps int
defaultSampler string
defaultScheduler string
defaultCFG float64
httpClient *http.Client
pollInterval time.Duration
@@ -49,12 +54,20 @@ type Comfy struct {
}
// NewComfy is the registry constructor. cfg is the adapter's slice of
// imagen.yaml. Required keys: base_url, model. The rest have sensible FLUX
// schnell defaults.
// imagen.yaml.
//
// Required keys: base_url, model.
// Optional keys: workflow (defaults to "flux1-schnell" for back-compat with
// existing configs), default_steps, default_sampler, default_scheduler,
// default_cfg, plus any template-specific knobs (vae, clip, clip_l,
// clip_t5, dtype, shift, guidance, …) the chosen workflow references.
func NewComfy(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("comfyui: empty instance name")
}
if cfg == nil {
cfg = map[string]any{}
}
base := strings.TrimRight(getString(cfg, "base_url", ""), "/")
if base == "" {
return nil, fmt.Errorf("comfyui[%s]: base_url is required", name)
@@ -67,23 +80,27 @@ func NewComfy(name string, cfg map[string]any) (Backend, error) {
return nil, fmt.Errorf("comfyui[%s]: model is required", name)
}
workflow := getString(cfg, "workflow", "flux1-schnell")
// Fail fast on a bad workflow ref so users see the error at startup,
// not on first /prompt submission.
if _, err := LoadWorkflowTemplate(workflow); err != nil {
return nil, fmt.Errorf("comfyui[%s]: %w", name, err)
}
c := &Comfy{
instance: name,
base: base,
model: model,
vae: getString(cfg, "vae", "ae.safetensors"),
clipL: getString(cfg, "clip_l", "clip_l.safetensors"),
clipT5: getString(cfg, "clip_t5", "t5xxl_fp8_e4m3fn.safetensors"),
dtype: getString(cfg, "weight_dtype", "fp8_e4m3fn"),
workflow: workflow,
rawCfg: cfg,
defaultSteps: getInt(cfg, "default_steps", 4),
defaultSampler: getString(cfg, "default_sampler", "euler"),
defaultScheduler: getString(cfg, "default_scheduler", "simple"),
defaultCFG: getFloat(cfg, "default_cfg", 1.0),
httpClient: &http.Client{Timeout: 60 * time.Second},
pollInterval: 250 * time.Millisecond,
pollTimeout: 120 * time.Second,
pollTimeout: 300 * time.Second,
randSeed: cryptoSeed,
clientIDFn: randClientID,
@@ -103,19 +120,26 @@ func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
sampler := c.defaultSampler
scheduler := c.defaultScheduler
cfg := c.defaultCFG
if v, ok := req.BackendOpts["sampler"].(string); ok && v != "" {
sampler = v
}
if v, ok := req.BackendOpts["scheduler"].(string); ok && v != "" {
scheduler = v
}
if v, ok := req.BackendOpts["cfg"].(float64); ok && v > 0 {
cfg = v
}
seed := req.Seed
if seed == 0 {
seed = c.randSeed()
}
workflow := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler)
workflow, err := c.buildWorkflow(req.Prompt, req.NegativePrompt, width, height, seed, steps, sampler, scheduler, cfg)
if err != nil {
return nil, fmt.Errorf("comfyui[%s]: build workflow: %w", c.instance, err)
}
clientID := c.clientIDFn()
start := time.Now()
@@ -133,14 +157,17 @@ func (c *Comfy) Generate(ctx context.Context, req Request) (*Result, error) {
}
latencyMs := time.Since(start).Milliseconds()
model := getString(c.rawCfg, "model", "")
meta := map[string]any{
"backend": c.instance,
"backend_type": ComfyType,
"model": c.model,
"workflow": c.workflow,
"model": model,
"seed": seed,
"steps": steps,
"sampler": sampler,
"scheduler": scheduler,
"cfg": cfg,
"width": width,
"height": height,
"latency_ms": latencyMs,
@@ -173,6 +200,7 @@ func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clien
return "", fmt.Errorf("comfyui: marshal workflow: %w", err)
}
model := getString(c.rawCfg, "model", "")
var lastErr error
for attempt := range 2 {
if attempt > 0 {
@@ -196,7 +224,7 @@ func (c *Comfy) submitPrompt(ctx context.Context, workflow map[string]any, clien
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return parsePromptID(respBody, c.model)
return parsePromptID(respBody, model)
case resp.StatusCode >= 500:
lastErr = fmt.Errorf("comfyui /prompt %d: %s", resp.StatusCode, snip(respBody))
continue
@@ -333,98 +361,74 @@ func (c *Comfy) connError(err error) error {
// workflow-validation failures and put the diagnostics in node_errors; older
// builds use 200 + node_errors. This handles the 4xx flavour.
func (c *Comfy) classifyBadRequest(status int, body []byte) error {
if hint, ok := missingModelHint(body, c.model); ok {
return fmt.Errorf("comfyui /prompt %d: %s — see docs/setup-comfyui-mrock.md", status, hint)
model := getString(c.rawCfg, "model", "")
if hint, ok := missingModelHint(body, model); ok {
return fmt.Errorf("comfyui /prompt %d: %s — see docs/backends.md", status, hint)
}
return fmt.Errorf("comfyui /prompt %d: %s", status, snip(body))
}
// buildWorkflow assembles the canonical FLUX.1 schnell ComfyUI workflow,
// node-IDs matching the upstream "flux-schnell" template so anyone debugging
// in the ComfyUI UI sees a familiar shape.
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string) map[string]any {
return map[string]any{
"6": map[string]any{
"class_type": "CLIPTextEncode",
"inputs": map[string]any{
"text": prompt,
"clip": []any{"11", 0},
},
},
"8": map[string]any{
"class_type": "VAEDecode",
"inputs": map[string]any{
"samples": []any{"31", 0},
"vae": []any{"10", 0},
},
},
"9": map[string]any{
"class_type": "SaveImage",
"inputs": map[string]any{
"filename_prefix": "imagen",
"images": []any{"8", 0},
},
},
"10": map[string]any{
"class_type": "VAELoader",
"inputs": map[string]any{"vae_name": c.vae},
},
"11": map[string]any{
"class_type": "DualCLIPLoader",
"inputs": map[string]any{
"clip_name1": c.clipT5,
"clip_name2": c.clipL,
"type": "flux",
},
},
"12": map[string]any{
"class_type": "UNETLoader",
"inputs": map[string]any{
"unet_name": c.model,
"weight_dtype": c.dtype,
},
},
"13": map[string]any{
"class_type": "CLIPTextEncode",
"inputs": map[string]any{
"text": negative,
"clip": []any{"11", 0},
},
},
"27": map[string]any{
"class_type": "EmptySD3LatentImage",
"inputs": map[string]any{
"width": w,
"height": h,
"batch_size": 1,
},
},
"30": map[string]any{
"class_type": "ModelSamplingFlux",
"inputs": map[string]any{
"model": []any{"12", 0},
"max_shift": 1.15,
"base_shift": 0.5,
"width": w,
"height": h,
},
},
"31": map[string]any{
"class_type": "KSampler",
"inputs": map[string]any{
"model": []any{"30", 0},
"seed": seed,
"steps": steps,
"cfg": 1.0,
"sampler_name": sampler,
"scheduler": scheduler,
"denoise": 1.0,
"positive": []any{"6", 0},
"negative": []any{"13", 0},
"latent_image": []any{"27", 0},
},
},
// buildWorkflow loads the configured workflow template and substitutes the
// per-call placeholders (prompt, seed, sampler, …) plus any string/number
// fields the user defined in the yaml block. The set of placeholder keys
// that aren't in `subs` produces an error from SubstituteWorkflow.
func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, steps int, sampler, scheduler string, cfg float64) (map[string]any, error) {
wf, err := LoadWorkflowTemplate(c.workflow)
if err != nil {
return nil, err
}
subs := map[string]any{
"prompt": prompt,
"negative": negative,
"width": w,
"height": h,
"seed": seed,
"steps": steps,
"sampler": sampler,
"scheduler": scheduler,
"cfg": cfg,
}
// Surface every scalar field from the yaml block so per-template knobs
// (vae, clip, clip_l, clip_t5, dtype, shift, guidance, …) work without
// adapter-code changes. Framework keys are excluded.
for k, v := range c.rawCfg {
switch k {
case "type", "base_url", "workflow",
"default_steps", "default_sampler", "default_scheduler", "default_cfg":
continue
}
if _, alreadySet := subs[k]; alreadySet {
// A per-call var (e.g. ${prompt}) beats anything yaml put under
// the same key — yaml can't shadow request-derived values.
continue
}
switch v := v.(type) {
case string, int, int64, float64, bool:
subs[k] = v
}
}
// Provide sensible defaults for common optional knobs so a workflow that
// references one of these doesn't fail substitution when the user
// didn't override it in yaml. Extra keys are ignored if the workflow
// doesn't reference them, so it's safe to always set the lot.
defaults := map[string]any{
"vae": "ae.safetensors",
"clip_l": "clip_l.safetensors",
"clip_t5": "t5xxl_fp8_e4m3fn.safetensors",
"clip": "qwen_3_4b.safetensors",
"dtype": "fp8_e4m3fn",
"guidance": 4.0,
"shift": 3.0,
}
for k, v := range defaults {
if _, ok := subs[k]; !ok {
subs[k] = v
}
}
if _, err := SubstituteWorkflow(wf, subs); err != nil {
return nil, err
}
return wf, nil
}
// parsePromptID handles the 2xx /prompt response. ComfyUI sometimes 200s a
@@ -432,8 +436,8 @@ func (c *Comfy) buildWorkflow(prompt, negative string, w, h int, seed int64, ste
// turns that into the same user-facing error as a 4xx with the same body.
func parsePromptID(body []byte, model string) (string, error) {
var resp struct {
PromptID string `json:"prompt_id"`
NodeErrors map[string]any `json:"node_errors"`
PromptID string `json:"prompt_id"`
NodeErrors map[string]any `json:"node_errors"`
Error json.RawMessage `json:"error"`
}
if err := json.Unmarshal(body, &resp); err != nil {
@@ -441,7 +445,7 @@ func parsePromptID(body []byte, model string) (string, error) {
}
if len(resp.NodeErrors) > 0 || len(resp.Error) > 0 {
if hint, ok := missingModelHint(body, model); ok {
return "", fmt.Errorf("comfyui /prompt: %s — see docs/setup-comfyui-mrock.md", hint)
return "", fmt.Errorf("comfyui /prompt: %s — see docs/backends.md", hint)
}
return "", fmt.Errorf("comfyui /prompt rejected workflow: %s", snip(body))
}
@@ -489,15 +493,21 @@ func parseHistory(body []byte, promptID string) (string, bool, error) {
}
// missingModelHint returns a user-actionable message when the response body
// indicates the configured unet model isn't loaded on the server. ComfyUI
// uses both the human-readable "Value not in list" message and the enum
// "value_not_in_list" type — match either.
// indicates the configured unet/checkpoint model isn't loaded on the server.
// ComfyUI uses both the human-readable "Value not in list" message and the
// enum "value_not_in_list" type — match either.
func missingModelHint(body []byte, model string) (string, bool) {
s := string(body)
hasMarker := strings.Contains(s, "Value not in list") || strings.Contains(s, "value_not_in_list")
if hasMarker && strings.Contains(s, "unet_name") {
if !hasMarker {
return "", false
}
if strings.Contains(s, "unet_name") {
return fmt.Sprintf("model %q not present in the ComfyUI server's models/unet/", model), true
}
if strings.Contains(s, "ckpt_name") {
return fmt.Sprintf("checkpoint %q not present in the ComfyUI server's models/checkpoints/", model), true
}
return "", false
}
@@ -536,6 +546,22 @@ func getInt(m map[string]any, k string, def int) int {
return def
}
func getFloat(m map[string]any, k string, def float64) float64 {
if v, ok := m[k]; ok {
switch n := v.(type) {
case float64:
return n
case float32:
return float64(n)
case int:
return float64(n)
case int64:
return float64(n)
}
}
return def
}
func orDefaultInt(v, def int) int {
if v == 0 {
return def

View File

@@ -312,7 +312,7 @@ func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "docs/setup-comfyui-mrock.md") {
if !strings.Contains(msg, "docs/backends.md") {
t.Errorf("error should point at the setup doc, got %v", err)
}
if !strings.Contains(msg, "flux1-schnell.safetensors") {
@@ -331,7 +331,7 @@ func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) {
if err == nil {
t.Fatal("expected error for node_errors on 200")
}
if !strings.Contains(err.Error(), "docs/setup-comfyui-mrock.md") {
if !strings.Contains(err.Error(), "docs/backends.md") {
t.Errorf("error should point at the setup doc, got %v", err)
}
}

View File

@@ -0,0 +1,567 @@
package backend
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"time"
)
// ReplicateType is the type-name adapters register under for Replicate
// instances.
const ReplicateType = "replicate"
// Replicate is the Replicate API adapter. It speaks the public REST API
// — POST /v1/predictions or POST /v1/models/{owner}/{name}/predictions
// to submit, then polls /v1/predictions/{id} until the prediction
// settles, then downloads the produced image.
//
// Concurrency: a single Replicate is safe to share across goroutines.
type Replicate struct {
instance string
apiBase string
apiToken string
tokenEnv string
model string // "owner/name" or "owner/name:version-hash"
owner string
name string
version string // optional; empty means "use the model-based predictions endpoint"
defaultSteps int
defaultAspect string
httpClient *http.Client
pollInterval time.Duration
pollTimeout time.Duration
// Hooks for tests; production paths use the package-level defaults.
randSeed func() int64
initialBackoff time.Duration
// Sink is where successful generations are recorded for cost-tracking.
// nil means "do not record". The framework wires this up in the CLI;
// adapter tests inject a fake.
Sink UsageSink
}
// UsageSink writes one row per successful generation. Implementations
// should treat write failures as warnings, not errors — the image has
// already landed on disk; failing the call would lose the artefact.
type UsageSink interface {
Record(ctx context.Context, row UsageRow) error
}
// UsageRow is the cost-tracking row stored in mai.imagen_usage. Note the
// prompt itself is intentionally NOT included — only the sha256 hash.
type UsageRow struct {
Backend string
Model string
Seed *int64
PromptHash string
LatencyMs int
CostUSDEstimate *float64
Caller string
}
// NewReplicate is the registry constructor. cfg is the adapter's slice
// of imagen.yaml.
func NewReplicate(name string, cfg map[string]any) (Backend, error) {
if name == "" {
return nil, fmt.Errorf("replicate: empty instance name")
}
tokenEnv := getString(cfg, "api_token_env", "REPLICATE_API_TOKEN")
model := getString(cfg, "model", "")
if model == "" {
return nil, fmt.Errorf("replicate[%s]: model is required (e.g. black-forest-labs/flux-schnell)", name)
}
owner, modelName, version, err := parseModelRef(model)
if err != nil {
return nil, fmt.Errorf("replicate[%s]: %w", name, err)
}
apiBase := strings.TrimRight(getString(cfg, "api_base", "https://api.replicate.com"), "/")
pollTimeout := timeoutForModel(modelName)
r := &Replicate{
instance: name,
apiBase: apiBase,
tokenEnv: tokenEnv,
apiToken: os.Getenv(tokenEnv),
model: model,
owner: owner,
name: modelName,
version: version,
defaultSteps: getInt(cfg, "default_steps", 0),
defaultAspect: getString(cfg, "default_aspect_ratio", "1:1"),
httpClient: &http.Client{Timeout: 30 * time.Second},
pollInterval: 500 * time.Millisecond,
pollTimeout: pollTimeout,
randSeed: cryptoSeed,
initialBackoff: time.Second,
}
return r, nil
}
// Name returns the instance name from imagen.yaml.
func (r *Replicate) Name() string { return r.instance }
// Generate submits one prediction and returns the resulting PNG.
func (r *Replicate) Generate(ctx context.Context, req Request) (*Result, error) {
if r.apiToken == "" {
return nil, fmt.Errorf("replicate[%s]: API token missing — export %s with a Replicate API token", r.instance, r.tokenEnv)
}
width := orDefaultInt(req.Width, 1024)
height := orDefaultInt(req.Height, 1024)
aspect := computeAspectRatio(width, height, r.defaultAspect)
steps := orDefaultInt(req.Steps, r.defaultSteps)
seed := req.Seed
if seed == 0 {
seed = r.randSeed()
}
input := map[string]any{
"prompt": req.Prompt,
"aspect_ratio": aspect,
"num_outputs": 1,
"output_format": "png",
"seed": seed,
}
if steps > 0 {
input["num_inference_steps"] = steps
}
if req.NegativePrompt != "" {
input["negative_prompt"] = req.NegativePrompt
}
maps.Copy(input, req.BackendOpts)
start := time.Now()
pred, err := r.submitPrediction(ctx, input)
if err != nil {
return nil, err
}
final, err := r.waitForCompletion(ctx, pred.ID)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgURL, err := pickFirstOutputURL(final.Output)
if err != nil {
return nil, fmt.Errorf("replicate[%s] prediction %s: %w", r.instance, pred.ID, err)
}
imgBytes, mime, err := r.fetchImage(ctx, imgURL)
if err != nil {
return nil, err
}
latencyMs := int(time.Since(start).Milliseconds())
predictTime := final.Metrics.PredictTime
costEst, costKnown := replicatePerImageUSD(r.model)
meta := map[string]any{
"backend": r.instance,
"backend_type": ReplicateType,
"model": r.model,
"model_version": final.Version,
"prediction_id": pred.ID,
"seed": seed,
"width": width,
"height": height,
"aspect_ratio": aspect,
"latency_ms": int64(latencyMs),
}
if predictTime > 0 {
meta["predict_time_seconds"] = predictTime
}
if costKnown {
meta["cost_usd_estimate"] = costEst
}
if r.Sink != nil {
row := UsageRow{
Backend: r.instance,
Model: r.model,
PromptHash: hashPrompt(req.Prompt),
LatencyMs: latencyMs,
Caller: ResolveCaller(),
}
row.Seed = new(int64)
*row.Seed = seed
if costKnown {
c := costEst
row.CostUSDEstimate = &c
}
if err := r.Sink.Record(ctx, row); err != nil {
fmt.Fprintf(os.Stderr, "imagen: cost-tracking write failed (continuing): %v\n", err)
}
}
return &Result{
ImageReader: io.NopCloser(bytes.NewReader(imgBytes)),
MimeType: mime,
Metadata: meta,
}, nil
}
// replicatePrediction is what the REST API returns under /v1/predictions
// and /v1/models/{owner}/{name}/predictions. Only the fields we use are
// declared.
type replicatePrediction struct {
ID string `json:"id"`
Status string `json:"status"`
Version string `json:"version"`
Error json.RawMessage `json:"error"`
Output json.RawMessage `json:"output"`
Metrics struct {
PredictTime float64 `json:"predict_time"`
} `json:"metrics"`
}
// submitPrediction creates a prediction and returns it. Picks the
// model-based endpoint when no version was given (recommended for
// Replicate's official models), and the legacy /v1/predictions otherwise.
func (r *Replicate) submitPrediction(ctx context.Context, input map[string]any) (*replicatePrediction, error) {
var (
reqURL string
body []byte
err error
)
if r.version == "" {
reqURL = fmt.Sprintf("%s/v1/models/%s/%s/predictions", r.apiBase, r.owner, r.name)
body, err = json.Marshal(map[string]any{"input": input})
} else {
reqURL = r.apiBase + "/v1/predictions"
body, err = json.Marshal(map[string]any{
"version": r.version,
"input": input,
})
}
if err != nil {
return nil, fmt.Errorf("replicate: marshal prediction body: %w", err)
}
respBody, err := r.doWithRetry(ctx, http.MethodPost, reqURL, body)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(respBody, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse predictions response: %w (body: %s)", err, snip(respBody))
}
if pred.ID == "" {
return nil, fmt.Errorf("replicate: empty prediction id (body: %s)", snip(respBody))
}
return &pred, nil
}
// waitForCompletion polls /v1/predictions/{id} until the status is a
// terminal value (succeeded, failed, canceled) or the timeout fires.
func (r *Replicate) waitForCompletion(ctx context.Context, id string) (*replicatePrediction, error) {
deadline := time.Now().Add(r.pollTimeout)
getURL := r.apiBase + "/v1/predictions/" + url.PathEscape(id)
start := time.Now()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if time.Now().After(deadline) {
return nil, fmt.Errorf("did not complete within %s (waited %s)", r.pollTimeout, time.Since(start).Round(time.Millisecond))
}
body, err := r.doWithRetry(ctx, http.MethodGet, getURL, nil)
if err != nil {
return nil, err
}
var pred replicatePrediction
if err := json.Unmarshal(body, &pred); err != nil {
return nil, fmt.Errorf("replicate: parse poll response: %w (body: %s)", err, snip(body))
}
switch pred.Status {
case "succeeded":
return &pred, nil
case "failed":
return nil, fmt.Errorf("status=failed: %s", snip(pred.Error))
case "canceled":
return nil, fmt.Errorf("status=canceled")
case "starting", "processing", "":
default:
return nil, fmt.Errorf("status=%q (unknown): %s", pred.Status, snip(body))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(r.pollInterval):
}
}
}
// doWithRetry executes a Replicate API request and applies the resilience
// policy: 401 surfaces a clean message naming the env var; 429 retries
// with exponential backoff up to three times; 5xx retries once; other
// errors surface unchanged.
func (r *Replicate) doWithRetry(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
const max429Retries = 3
backoff := r.initialBackoff
if backoff <= 0 {
backoff = time.Second
}
var lastErr error
for attempt := 0; ; attempt++ {
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytesReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Token "+r.apiToken)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s: %w", method, shortPath(reqURL), err)
}
lastErr = err
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
}
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return respBody, nil
case resp.StatusCode == http.StatusUnauthorized:
return nil, fmt.Errorf("replicate[%s] %d: API token missing or invalid; export %s with a valid Replicate API token", r.instance, resp.StatusCode, r.tokenEnv)
case resp.StatusCode == http.StatusTooManyRequests:
if attempt >= max429Retries {
return nil, fmt.Errorf("replicate %s %s 429 after %d retries: %s", method, shortPath(reqURL), max429Retries, snip(respBody))
}
wait := backoffFor429(resp.Header.Get("Retry-After"), backoff)
lastErr = fmt.Errorf("429 (retry %d): %s", attempt+1, snip(respBody))
if !sleepCtx(ctx, wait) {
return nil, ctx.Err()
}
backoff *= 2
continue
case resp.StatusCode >= 500:
if attempt >= 1 {
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(respBody))
if !sleepCtx(ctx, backoff) {
return nil, ctx.Err()
}
continue
default:
_ = lastErr
return nil, fmt.Errorf("replicate %s %s %d: %s", method, shortPath(reqURL), resp.StatusCode, snip(respBody))
}
}
}
// fetchImage downloads the rendered image from the Replicate-provided
// CDN URL. One retry on a generic network error.
func (r *Replicate) fetchImage(ctx context.Context, imgURL string) ([]byte, string, error) {
var lastErr error
for attempt := range 2 {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imgURL, nil)
if err != nil {
return nil, "", err
}
resp, err := r.httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, "", err
}
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
lastErr = readErr
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if resp.StatusCode >= 500 && attempt == 0 {
lastErr = fmt.Errorf("image download %d: %s", resp.StatusCode, snip(body))
continue
}
return nil, "", fmt.Errorf("replicate image download %d: %s", resp.StatusCode, snip(body))
}
mime := resp.Header.Get("Content-Type")
if mime == "" {
mime = "image/png"
}
return body, mime, nil
}
return nil, "", fmt.Errorf("replicate image download failed: %w", lastErr)
}
// parseModelRef accepts "owner/name" or "owner/name:version-hash".
func parseModelRef(ref string) (owner, name, version string, err error) {
rest := ref
if i := strings.IndexByte(rest, ':'); i >= 0 {
version = rest[i+1:]
rest = rest[:i]
}
parts := strings.SplitN(rest, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", "", fmt.Errorf("model %q must be of the form owner/name or owner/name:version", ref)
}
return parts[0], parts[1], version, nil
}
// timeoutForModel picks the polling timeout. FLUX dev takes notably longer
// than schnell; everything else gets the dev timeout for safety.
func timeoutForModel(name string) time.Duration {
switch strings.ToLower(name) {
case "flux-schnell":
return 60 * time.Second
default:
return 120 * time.Second
}
}
// pickFirstOutputURL extracts the first image URL from the output field.
// Replicate returns either a single string or an array of strings for image
// models — we accept both.
func pickFirstOutputURL(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", fmt.Errorf("output is empty")
}
var s string
if err := json.Unmarshal(raw, &s); err == nil && s != "" {
return s, nil
}
var arr []string
if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 && arr[0] != "" {
return arr[0], nil
}
return "", fmt.Errorf("output is not a string or non-empty string array (got: %s)", snip(raw))
}
// computeAspectRatio reduces width:height to a Replicate-supported aspect
// ratio when the reduction lands on one of the canonical values; otherwise
// returns fallback.
func computeAspectRatio(w, h int, fallback string) string {
if w <= 0 || h <= 0 {
return fallback
}
g := gcd(w, h)
a, b := w/g, h/g
s := fmt.Sprintf("%d:%d", a, b)
if isReplicateAspectRatio(s) {
return s
}
return fallback
}
func isReplicateAspectRatio(s string) bool {
switch s {
case "1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3", "9:16", "9:21":
return true
}
return false
}
func gcd(a, b int) int {
for b != 0 {
a, b = b, a%b
}
if a < 0 {
return -a
}
return a
}
// hashPrompt returns the sha256 hex digest of the prompt. The raw prompt
// is intentionally never written to the cost-tracking table.
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// ResolveCaller returns the agent identity the cost-tracking row is
// attributed to. Order of resolution mirrors the maimcp identity logic:
// MAI_FROM_ID env var first, then the tmux pane's @mai-name option, then
// "unknown".
func ResolveCaller() string {
if v := strings.TrimSpace(os.Getenv("MAI_FROM_ID")); v != "" {
return v
}
if pane := os.Getenv("TMUX_PANE"); pane != "" {
out, err := exec.Command("tmux", "display-message", "-p", "-t", pane, "#{@mai-name}").Output()
if err == nil {
if name := strings.TrimSpace(string(out)); name != "" {
return name
}
}
}
return "unknown"
}
// backoffFor429 honours a Retry-After header (in seconds) when present
// and within reason, otherwise falls back to the caller's backoff.
func backoffFor429(retryAfter string, fallback time.Duration) time.Duration {
if retryAfter == "" {
return fallback
}
d, err := time.ParseDuration(retryAfter + "s")
if err != nil || d <= 0 {
return fallback
}
if d > 30*time.Second {
return 30 * time.Second
}
return d
}
func sleepCtx(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
func bytesReader(b []byte) io.Reader {
if b == nil {
return nil
}
return bytes.NewReader(b)
}
// shortPath strips the host so error messages don't leak the API base.
func shortPath(u string) string {
if i := strings.Index(u, "/v1/"); i >= 0 {
return u[i:]
}
return u
}
func init() {
Register(ReplicateType, NewReplicate)
}

View File

@@ -0,0 +1,42 @@
package backend
import "strings"
// Replicate pricing snapshot.
//
// Source: https://replicate.com/pricing and the per-model "Run" tab on
// each model page. Replicate bills per second of GPU time, but the
// black-forest-labs FLUX models also publish a flat per-image price for
// the typical settings — that flat number is what we hardcode here.
//
// Snapshot date: 2026-05-08. TODO(refresh): re-check quarterly. If the
// rates drift more than ~10%, update the table and bump snapshotDate.
const replicatePricingSnapshotDate = "2026-05-08"
// replicatePerImageUSD is the per-image cost estimate keyed by Replicate
// model identifier ("owner/name", with any ":version" trimmed). Returns
// the rate and true if the model is known, 0 and false otherwise — an
// unknown model writes a row with NULL cost rather than a wrong number.
func replicatePerImageUSD(model string) (float64, bool) {
key := normalisePricingKey(model)
switch key {
case "black-forest-labs/flux-schnell":
return 0.003, true
case "black-forest-labs/flux-dev":
return 0.025, true
case "black-forest-labs/flux-pro":
return 0.055, true
case "black-forest-labs/flux-1.1-pro":
return 0.040, true
}
return 0, false
}
// normalisePricingKey strips the optional ":version" suffix and lowercases
// the owner/name pair. "Owner/Name:hash" → "owner/name".
func normalisePricingKey(model string) string {
if i := strings.IndexByte(model, ':'); i >= 0 {
model = model[:i]
}
return strings.ToLower(model)
}

View File

@@ -0,0 +1,675 @@
package backend
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// fakeReplicate is a programmable mock of the Replicate REST API.
type fakeReplicate struct {
t *testing.T
mu sync.Mutex
// Responses for the predictions submission. If the request matches the
// model-based path, modelEndpointHits increments; for /v1/predictions,
// versionEndpointHits increments.
createStatus int
createBody []byte
createCalls atomic.Int32
// Auth-fail policy: if true, every request to the prediction endpoints
// returns 401 with a stock body.
auth401 bool
// 429-then-OK policy: first N create calls return 429.
create429Until int32
retryAfter string
// Sequence of /predictions/{id} responses, walked in order; once
// exhausted, the last entry is returned indefinitely.
pollResponses []string
pollIdx atomic.Int32
pollCalls atomic.Int32
// Image download policy.
imageStatus int
imageBody []byte
image5xxFirst int32
imageCalls atomic.Int32
server *httptest.Server
}
func (f *fakeReplicate) handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1/models/") && strings.HasSuffix(r.URL.Path, "/predictions"):
f.handleCreate(w, r)
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
f.handleCreate(w, r)
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
f.handlePoll(w, r)
case r.Method == http.MethodGet && r.URL.Path == "/img":
f.handleImage(w, r)
default:
f.t.Errorf("fakeReplicate: unexpected %s %s", r.Method, r.URL.Path)
http.NotFound(w, r)
}
})
}
func (f *fakeReplicate) handleCreate(w http.ResponseWriter, _ *http.Request) {
n := f.createCalls.Add(1)
if f.auth401 {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"detail":"Invalid token"}`))
return
}
if n <= f.create429Until {
if f.retryAfter != "" {
w.Header().Set("Retry-After", f.retryAfter)
}
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"detail":"too many requests"}`))
return
}
w.WriteHeader(f.createStatus)
_, _ = w.Write(f.createBody)
}
func (f *fakeReplicate) handlePoll(w http.ResponseWriter, _ *http.Request) {
f.pollCalls.Add(1)
idx := int(f.pollIdx.Add(1)) - 1
if idx >= len(f.pollResponses) {
idx = len(f.pollResponses) - 1
}
if idx < 0 {
http.Error(w, "no poll response configured", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(f.pollResponses[idx]))
}
func (f *fakeReplicate) handleImage(w http.ResponseWriter, _ *http.Request) {
n := f.imageCalls.Add(1)
if n <= f.image5xxFirst {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte("upstream unavailable"))
return
}
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(f.imageStatus)
_, _ = w.Write(f.imageBody)
}
func (f *fakeReplicate) start() {
f.server = httptest.NewServer(f.handler())
f.t.Cleanup(f.server.Close)
}
func (f *fakeReplicate) imageURL() string { return f.server.URL + "/img" }
func newFakeReplicate(t *testing.T) *fakeReplicate {
t.Helper()
f := &fakeReplicate{
t: t,
createStatus: http.StatusCreated,
imageStatus: http.StatusOK,
imageBody: mustPNG(t, 8, 8),
}
f.start()
// Default happy-path responses now that the server URL is known.
f.createBody = []byte(`{"id":"pred-abc","status":"starting","version":"v1","output":null}`)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"processing","version":"v1","output":null}`,
fmt.Sprintf(`{"id":"pred-abc","status":"succeeded","version":"v1","output":"%s","metrics":{"predict_time":1.23}}`, f.imageURL()),
}
return f
}
func newReplicate(t *testing.T, f *fakeReplicate, model string) *Replicate {
t.Helper()
be, err := NewReplicate("flux-test", map[string]any{
"api_token_env": "TEST_REPLICATE_TOKEN",
"model": model,
"api_base": f.server.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake-token"
r.pollInterval = time.Millisecond
r.pollTimeout = 5 * time.Second
r.initialBackoff = time.Millisecond
r.randSeed = func() int64 { return 42 }
return r
}
func TestReplicateConstructorRejectsBadInputs(t *testing.T) {
if _, err := NewReplicate("", map[string]any{"model": "owner/name"}); err == nil {
t.Errorf("expected error for empty instance name")
}
if _, err := NewReplicate("x", map[string]any{}); err == nil {
t.Errorf("expected error for missing model")
}
if _, err := NewReplicate("x", map[string]any{"model": "no-slash"}); err == nil {
t.Errorf("expected error for malformed model")
}
}
func TestReplicateMissingTokenSurfacesEnvName(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.apiToken = "" // simulate the env var being unset
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error when token is missing")
}
if !strings.Contains(err.Error(), "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name the env var: %v", err)
}
}
func TestReplicateHappyPathSchnellUsesModelEndpoint(t *testing.T) {
f := newFakeReplicate(t)
sink := &recordingSink{}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sink
res, err := r.Generate(context.Background(), Request{
Prompt: "a tiny dragon",
Width: 1024, Height: 1024,
Seed: 0,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
defer res.ImageReader.Close()
body, _ := io.ReadAll(res.ImageReader)
if !bytes.Equal(body, f.imageBody) {
t.Errorf("image body did not round-trip")
}
if mime := res.MimeType; mime != "image/png" {
t.Errorf("mime = %q", mime)
}
if got, _ := res.Metadata["model"].(string); got != "black-forest-labs/flux-schnell" {
t.Errorf("metadata model = %v", got)
}
if got, _ := res.Metadata["model_version"].(string); got != "v1" {
t.Errorf("metadata model_version = %v", got)
}
if got, _ := res.Metadata["predict_time_seconds"].(float64); got != 1.23 {
t.Errorf("metadata predict_time_seconds = %v", got)
}
if got, ok := res.Metadata["cost_usd_estimate"].(float64); !ok || got != 0.003 {
t.Errorf("metadata cost_usd_estimate = %v (ok=%v)", got, ok)
}
if got, _ := res.Metadata["aspect_ratio"].(string); got != "1:1" {
t.Errorf("aspect_ratio = %q", got)
}
if got := f.pollCalls.Load(); got < 3 {
t.Errorf("expected at least 3 poll calls (starting → processing → succeeded), got %d", got)
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
row := sink.rows[0]
if row.Backend != "flux-test" || row.Model != "black-forest-labs/flux-schnell" {
t.Errorf("sink row backend/model = %q/%q", row.Backend, row.Model)
}
if row.PromptHash == "" || row.PromptHash == "a tiny dragon" {
t.Errorf("sink row should have sha256 hash, got %q", row.PromptHash)
}
if len(row.PromptHash) != 64 {
t.Errorf("expected 64-char sha256 hex, got %d chars", len(row.PromptHash))
}
if row.CostUSDEstimate == nil || *row.CostUSDEstimate != 0.003 {
t.Errorf("sink row cost = %v", row.CostUSDEstimate)
}
if row.LatencyMs <= 0 {
t.Errorf("sink row latency_ms should be > 0, got %d", row.LatencyMs)
}
}
func TestReplicateVersionPinUsesPredictionsEndpoint(t *testing.T) {
f := newFakeReplicate(t)
var sentBody []byte
var sentPath string
mu := sync.Mutex{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/predictions":
mu.Lock()
sentBody, _ = io.ReadAll(r.Body)
sentPath = r.URL.Path
mu.Unlock()
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pred-vp","status":"starting","version":"abc123","output":null}`))
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/predictions/"):
_, _ = fmt.Fprintf(w, `{"id":"pred-vp","status":"succeeded","version":"abc123","output":"%s"}`, f.imageURL())
default:
f.t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev:abc123",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
res, err := r.Generate(context.Background(), Request{Prompt: "x"})
if err != nil {
t.Fatalf("Generate: %v", err)
}
res.ImageReader.Close()
mu.Lock()
defer mu.Unlock()
if sentPath != "/v1/predictions" {
t.Errorf("expected version-pinned model to hit /v1/predictions, got %q", sentPath)
}
var body map[string]any
if err := json.Unmarshal(sentBody, &body); err != nil {
t.Fatalf("unmarshal sent body: %v", err)
}
if body["version"] != "abc123" {
t.Errorf("expected version=abc123 in body, got %v", body["version"])
}
}
func TestReplicate401SurfacesEnvHint(t *testing.T) {
f := newFakeReplicate(t)
f.auth401 = true
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected 401 to surface as error")
}
msg := err.Error()
if !strings.Contains(msg, "TEST_REPLICATE_TOKEN") {
t.Errorf("error should name env var: %v", err)
}
if !strings.Contains(msg, "401") {
t.Errorf("error should mention 401: %v", err)
}
}
func TestReplicate429RetriesThenSucceeds(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 2 // first two calls 429 then succeed
f.retryAfter = "" // force the adapter's exp-backoff path
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = time.Millisecond
// Squash the backoff so the test is fast.
r.httpClient = &http.Client{Timeout: 5 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := r.Generate(ctx, Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate after 429s: %v", err)
}
res.ImageReader.Close()
if got := f.createCalls.Load(); got != 3 {
t.Errorf("expected 3 create calls (2x 429 + 1 OK), got %d", got)
}
}
func TestReplicate429GivesUpAfterMaxRetries(t *testing.T) {
f := newFakeReplicate(t)
f.create429Until = 99 // every call 429
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained 429s")
}
if !strings.Contains(err.Error(), "429") {
t.Errorf("expected 429 in error: %v", err)
}
// max429Retries=3 → 1 initial + 3 retries = 4 total
if got := f.createCalls.Load(); got != 4 {
t.Errorf("expected 4 create calls (1+3 retries), got %d", got)
}
}
func TestReplicateFailedPredictionSurfacesError(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
`{"id":"pred-abc","status":"starting","version":"v1","output":null}`,
`{"id":"pred-abc","status":"failed","version":"v1","output":null,"error":"NSFW filtered"}`,
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error for failed prediction")
}
if !strings.Contains(err.Error(), "failed") {
t.Errorf("error should mention failure: %v", err)
}
if !strings.Contains(err.Error(), "NSFW") {
t.Errorf("error should include the API error message: %v", err)
}
}
func TestReplicatePollTimeoutSurfacesPartialLatency(t *testing.T) {
f := newFakeReplicate(t)
// Always return processing → adapter times out.
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 30 * time.Millisecond
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected timeout error")
}
if !strings.Contains(err.Error(), "did not complete") {
t.Errorf("expected 'did not complete' in error, got %v", err)
}
if !strings.Contains(err.Error(), "waited") {
t.Errorf("expected partial latency ('waited X') for diagnostics, got %v", err)
}
}
func TestReplicateImageDownloadRetriesOnce5xx(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 1 // first download 502, second OK
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (download retry): %v", err)
}
res.ImageReader.Close()
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (1 fail + 1 retry), got %d", got)
}
}
func TestReplicateImageDownload5xxGivesUpAfterRetry(t *testing.T) {
f := newFakeReplicate(t)
f.image5xxFirst = 99 // every download fails
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
_, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err == nil {
t.Fatal("expected error after sustained image-download 5xx")
}
if got := f.imageCalls.Load(); got != 2 {
t.Errorf("expected 2 image fetches (no further retries), got %d", got)
}
}
func TestReplicateContextCancelStopsPolling(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{`{"id":"pred-abc","status":"processing","version":"v1","output":null}`}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.pollInterval = 5 * time.Millisecond
r.pollTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
_, err := r.Generate(ctx, Request{Prompt: "p"})
if err == nil {
t.Fatal("expected ctx error")
}
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context deadline exceeded") {
t.Errorf("expected deadline exceeded, got %v", err)
}
}
func TestReplicateBackendOptsMergedIntoInput(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-schnell",
"api_base": srv.URL,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
// We expect this to fail at "pickFirstOutputURL" (empty output) but the
// captured input should still have been recorded.
_, _ = r.Generate(context.Background(), Request{
Prompt: "p",
Width: 1024, Height: 1024,
BackendOpts: map[string]any{
"output_quality": 90,
"go_fast": true,
},
})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["output_quality"] != float64(90) {
t.Errorf("output_quality not threaded: %v", captured["output_quality"])
}
if captured["go_fast"] != true {
t.Errorf("go_fast not threaded: %v", captured["go_fast"])
}
if captured["prompt"] != "p" {
t.Errorf("prompt not threaded: %v", captured["prompt"])
}
}
func TestReplicateOutputArrayAccepted(t *testing.T) {
f := newFakeReplicate(t)
f.pollResponses = []string{
fmt.Sprintf(`{"id":"pid","status":"succeeded","version":"v1","output":["%s"],"metrics":{"predict_time":0.5}}`, f.imageURL()),
}
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (output as array): %v", err)
}
res.ImageReader.Close()
}
func TestReplicateUnknownModelLeavesCostUnsetButGenerates(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "stability-ai/sdxl")
sink := &recordingSink{}
r.Sink = sink
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("Generate (unknown model): %v", err)
}
res.ImageReader.Close()
if _, present := res.Metadata["cost_usd_estimate"]; present {
t.Errorf("unknown-model meta should not include cost_usd_estimate; got %v", res.Metadata["cost_usd_estimate"])
}
if len(sink.rows) != 1 {
t.Fatalf("expected 1 sink row, got %d", len(sink.rows))
}
if sink.rows[0].CostUSDEstimate != nil {
t.Errorf("expected nil cost in sink row for unknown model")
}
}
func TestReplicateSinkFailureIsWarningNotError(t *testing.T) {
f := newFakeReplicate(t)
r := newReplicate(t, f, "black-forest-labs/flux-schnell")
r.Sink = sinkFunc(func(context.Context, UsageRow) error { return errors.New("db unreachable") })
res, err := r.Generate(context.Background(), Request{Prompt: "p"})
if err != nil {
t.Fatalf("sink failure should not fail Generate: %v", err)
}
res.ImageReader.Close()
}
func TestReplicateDefaultStepsApplied(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/predictions"):
body, _ := io.ReadAll(r.Body)
var top struct {
Input map[string]any `json:"input"`
}
_ = json.Unmarshal(body, &top)
captured = top.Input
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":"pid","status":"succeeded","version":"v","output":""}`))
}
}))
t.Cleanup(srv.Close)
be, err := NewReplicate("flux-test", map[string]any{
"model": "black-forest-labs/flux-dev",
"api_base": srv.URL,
"default_steps": 28,
})
if err != nil {
t.Fatalf("NewReplicate: %v", err)
}
r := be.(*Replicate)
r.apiToken = "fake"
r.pollInterval = time.Millisecond
_, _ = r.Generate(context.Background(), Request{Prompt: "p"})
if captured == nil {
t.Fatal("create endpoint not hit")
}
if captured["num_inference_steps"] != float64(28) {
t.Errorf("expected num_inference_steps=28 from default_steps, got %v", captured["num_inference_steps"])
}
}
func TestComputeAspectRatio(t *testing.T) {
cases := []struct {
w, h int
fallback string
want string
}{
{1024, 1024, "1:1", "1:1"},
{1920, 1080, "1:1", "16:9"},
{2560, 1440, "1:1", "16:9"},
{1024, 768, "1:1", "4:3"},
{1024, 1280, "1:1", "4:5"},
{1000, 1234, "1:1", "1:1"}, // weird ratio falls back
{0, 1024, "1:1", "1:1"},
}
for _, c := range cases {
got := computeAspectRatio(c.w, c.h, c.fallback)
if got != c.want {
t.Errorf("computeAspectRatio(%d,%d,%q)=%q, want %q", c.w, c.h, c.fallback, got, c.want)
}
}
}
func TestParseModelRef(t *testing.T) {
owner, name, ver, err := parseModelRef("black-forest-labs/flux-schnell")
if err != nil || owner != "black-forest-labs" || name != "flux-schnell" || ver != "" {
t.Errorf("parseModelRef plain: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
owner, name, ver, err = parseModelRef("owner/name:hash123")
if err != nil || owner != "owner" || name != "name" || ver != "hash123" {
t.Errorf("parseModelRef versioned: o=%q n=%q v=%q err=%v", owner, name, ver, err)
}
if _, _, _, err := parseModelRef("noslash"); err == nil {
t.Errorf("expected error for malformed ref")
}
}
func TestHashPromptStable(t *testing.T) {
a := hashPrompt("hello")
b := hashPrompt("hello")
c := hashPrompt("hello!")
if a != b {
t.Errorf("hashPrompt should be deterministic")
}
if a == c {
t.Errorf("different prompts should hash differently")
}
if len(a) != 64 {
t.Errorf("sha256 hex should be 64 chars, got %d", len(a))
}
}
func TestReplicatePricingKnownModels(t *testing.T) {
if v, ok := replicatePerImageUSD("black-forest-labs/flux-schnell"); !ok || v != 0.003 {
t.Errorf("schnell rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev"); !ok || v != 0.025 {
t.Errorf("dev rate = %v (ok=%v)", v, ok)
}
if v, ok := replicatePerImageUSD("black-forest-labs/flux-dev:hashabc"); !ok || v != 0.025 {
t.Errorf("versioned ref should resolve to base price: %v %v", v, ok)
}
if _, ok := replicatePerImageUSD("nobody/unknown-model"); ok {
t.Errorf("unknown model should report ok=false")
}
}
func TestReplicateTypeIsRegistered(t *testing.T) {
if !Default.Has(ReplicateType) {
t.Errorf("replicate type not registered in Default")
}
}
// recordingSink captures rows for assertion.
type recordingSink struct {
mu sync.Mutex
rows []UsageRow
}
func (s *recordingSink) Record(_ context.Context, row UsageRow) error {
s.mu.Lock()
defer s.mu.Unlock()
s.rows = append(s.rows, row)
return nil
}
type sinkFunc func(context.Context, UsageRow) error
func (f sinkFunc) Record(ctx context.Context, row UsageRow) error { return f(ctx, row) }

View File

@@ -0,0 +1,156 @@
package backend
import (
"embed"
"encoding/json"
"fmt"
"io/fs"
"maps"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
)
//go:embed workflows/*.json
var bundledWorkflows embed.FS
// placeholderRE matches a single-token placeholder like "${prompt}" — the
// whole string value must be the placeholder, leading/trailing whitespace
// allowed. This lets us preserve types (a numeric substitution becomes a
// JSON number, not a stringified one) instead of round-tripping through
// strings.Replace which would force everything into a string.
var placeholderRE = regexp.MustCompile(`^\s*\$\{([a-zA-Z][a-zA-Z0-9_]*)\}\s*$`)
// LoadWorkflowTemplate returns the parsed JSON for a workflow template.
// `name` is resolved in this order:
//
// 1. exact filesystem path that exists on disk (absolute or relative);
// 2. one of the bundled templates under internal/backend/workflows/
// (with or without the .json suffix).
//
// The returned map is a fresh deep copy of the template; callers can mutate
// it freely.
func LoadWorkflowTemplate(name string) (map[string]any, error) {
if name == "" {
return nil, fmt.Errorf("workflow template name is empty")
}
raw, err := readWorkflowBytes(name)
if err != nil {
return nil, err
}
var wf map[string]any
if err := json.Unmarshal(raw, &wf); err != nil {
return nil, fmt.Errorf("workflow %s: parse: %w", name, err)
}
return wf, nil
}
// BundledWorkflowNames returns the names of templates compiled into the
// binary, sorted. Each name is the basename without the .json suffix.
func BundledWorkflowNames() []string {
entries, err := fs.ReadDir(bundledWorkflows, "workflows")
if err != nil {
return nil
}
out := make([]string, 0, len(entries))
for _, e := range entries {
n := e.Name()
if !strings.HasSuffix(n, ".json") {
continue
}
out = append(out, strings.TrimSuffix(n, ".json"))
}
sort.Strings(out)
return out
}
func readWorkflowBytes(name string) ([]byte, error) {
// Filesystem path wins if it points at a real file. Lets a user override
// a bundled template by passing an absolute path in yaml.
if strings.ContainsRune(name, os.PathSeparator) || strings.HasSuffix(name, ".json") {
if b, err := os.ReadFile(name); err == nil {
return b, nil
} else if !os.IsNotExist(err) {
return nil, fmt.Errorf("workflow %s: %w", name, err)
}
}
// Bundled lookup. Try the literal name as a file inside workflows/, then
// with the .json suffix appended.
candidates := []string{
filepath.Join("workflows", name),
filepath.Join("workflows", name+".json"),
}
for _, c := range candidates {
if b, err := bundledWorkflows.ReadFile(c); err == nil {
return b, nil
}
}
return nil, fmt.Errorf("workflow %q not found (bundled templates: %v)", name, BundledWorkflowNames())
}
// SubstituteWorkflow walks wf and replaces every "${key}" string with the
// matching value from subs, preserving JSON types. Returns the set of
// placeholder keys it actually touched, so the caller can detect missing
// substitutions even when a key is defined in subs but never referenced in
// the workflow (typical when a yaml block sets a knob a different template
// would consume).
//
// Unknown placeholders (referenced in the workflow but absent from subs)
// produce an error so we never submit a workflow with raw "${foo}" tokens.
func SubstituteWorkflow(wf map[string]any, subs map[string]any) (used map[string]struct{}, err error) {
used = make(map[string]struct{})
walked, err := substituteValue(wf, subs, used)
if err != nil {
return nil, err
}
// substituteValue returns the replacement for the top-level value, which
// should still be the same map (just with mutated children).
if m, ok := walked.(map[string]any); ok {
// Copy back into wf so the caller's reference reflects the result.
for k := range wf {
delete(wf, k)
}
maps.Copy(wf, m)
}
return used, nil
}
func substituteValue(v any, subs map[string]any, used map[string]struct{}) (any, error) {
switch x := v.(type) {
case map[string]any:
out := make(map[string]any, len(x))
for k, child := range x {
replaced, err := substituteValue(child, subs, used)
if err != nil {
return nil, err
}
out[k] = replaced
}
return out, nil
case []any:
out := make([]any, len(x))
for i, child := range x {
replaced, err := substituteValue(child, subs, used)
if err != nil {
return nil, err
}
out[i] = replaced
}
return out, nil
case string:
if m := placeholderRE.FindStringSubmatch(x); m != nil {
key := m[1]
val, ok := subs[key]
if !ok {
return nil, fmt.Errorf("workflow placeholder ${%s} has no substitution", key)
}
used[key] = struct{}{}
return val, nil
}
return x, nil
default:
return v, nil
}
}

View File

@@ -0,0 +1,153 @@
package backend
import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestBundledWorkflowsParseable(t *testing.T) {
names := BundledWorkflowNames()
if len(names) == 0 {
t.Fatal("expected at least one bundled workflow")
}
mustHave := []string{"flux1-schnell", "flux2-klein", "sd35-medium"}
for _, want := range mustHave {
if !slices.Contains(names, want) {
t.Errorf("bundled workflows missing %q (have: %v)", want, names)
}
}
// Every bundled template must parse and contain at least one node.
for _, n := range names {
wf, err := LoadWorkflowTemplate(n)
if err != nil {
t.Errorf("LoadWorkflowTemplate(%q): %v", n, err)
continue
}
if len(wf) == 0 {
t.Errorf("workflow %q has zero nodes", n)
}
}
}
func TestLoadWorkflowFromFilesystem(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "custom.json")
body := `{"1":{"class_type":"X","inputs":{"v":"${prompt}"}}}`
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
t.Fatalf("write tmp workflow: %v", err)
}
wf, err := LoadWorkflowTemplate(path)
if err != nil {
t.Fatalf("load from path: %v", err)
}
if _, ok := wf["1"]; !ok {
t.Errorf("custom workflow missing node 1")
}
}
func TestLoadWorkflowUnknownNameErrors(t *testing.T) {
_, err := LoadWorkflowTemplate("definitely-not-a-real-workflow")
if err == nil {
t.Fatal("expected error for unknown workflow name")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should say not found, got %v", err)
}
}
func TestSubstituteWorkflowPreservesTypes(t *testing.T) {
wf := map[string]any{
"31": map[string]any{
"class_type": "KSampler",
"inputs": map[string]any{
"seed": "${seed}",
"steps": "${steps}",
"text": "${prompt}",
"cfg": "${cfg}",
},
},
}
subs := map[string]any{
"seed": int64(42),
"steps": 11,
"prompt": "a cat",
"cfg": 4.5,
}
used, err := SubstituteWorkflow(wf, subs)
if err != nil {
t.Fatalf("Substitute: %v", err)
}
if len(used) != 4 {
t.Errorf("used = %v, want all four", used)
}
inputs := wf["31"].(map[string]any)["inputs"].(map[string]any)
if seed, ok := inputs["seed"].(int64); !ok || seed != 42 {
t.Errorf("seed = %T %v, want int64 42", inputs["seed"], inputs["seed"])
}
if steps, ok := inputs["steps"].(int); !ok || steps != 11 {
t.Errorf("steps = %T %v, want int 11", inputs["steps"], inputs["steps"])
}
if text, ok := inputs["text"].(string); !ok || text != "a cat" {
t.Errorf("text = %T %v, want string", inputs["text"], inputs["text"])
}
if cfg, ok := inputs["cfg"].(float64); !ok || cfg != 4.5 {
t.Errorf("cfg = %T %v, want float64 4.5", inputs["cfg"], inputs["cfg"])
}
}
func TestSubstituteWorkflowMissingPlaceholderErrors(t *testing.T) {
wf := map[string]any{
"1": map[string]any{"inputs": map[string]any{"v": "${missing}"}},
}
_, err := SubstituteWorkflow(wf, map[string]any{})
if err == nil {
t.Fatal("expected error for missing placeholder")
}
if !strings.Contains(err.Error(), "${missing}") {
t.Errorf("error should name the placeholder, got %v", err)
}
}
func TestSubstituteWorkflowOnlyWholeTokens(t *testing.T) {
// Partial-match strings ("prefix ${prompt} suffix") are NOT substituted —
// the placeholder must be the whole value so we can preserve types.
wf := map[string]any{
"1": map[string]any{"inputs": map[string]any{
"keep_string": "stuff with ${prompt} inside",
"replace_full": "${prompt}",
}},
}
used, err := SubstituteWorkflow(wf, map[string]any{"prompt": "x"})
if err != nil {
t.Fatalf("Substitute: %v", err)
}
inputs := wf["1"].(map[string]any)["inputs"].(map[string]any)
if inputs["keep_string"].(string) != "stuff with ${prompt} inside" {
t.Errorf("partial match should be left alone, got %q", inputs["keep_string"])
}
if inputs["replace_full"].(string) != "x" {
t.Errorf("full-value match should substitute, got %q", inputs["replace_full"])
}
if _, ok := used["prompt"]; !ok {
t.Errorf("used should track keys that fired")
}
}
func TestFlux1SchnellTemplateMatchesLegacyShape(t *testing.T) {
// Regression guard against the historical hardcoded workflow: every
// node ID the old Comfy.buildWorkflow used must still exist in the
// migrated template.
wf, err := LoadWorkflowTemplate("flux1-schnell")
if err != nil {
t.Fatalf("load flux1-schnell: %v", err)
}
legacyNodes := []string{"6", "8", "9", "10", "11", "12", "13", "27", "30", "31"}
for _, id := range legacyNodes {
if _, ok := wf[id]; !ok {
t.Errorf("flux1-schnell template missing node %q (legacy parity)", id)
}
}
}

View File

@@ -0,0 +1,84 @@
{
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "${prompt}",
"clip": ["11", 0]
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["31", 0],
"vae": ["10", 0]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "imagen",
"images": ["8", 0]
}
},
"10": {
"class_type": "VAELoader",
"inputs": {
"vae_name": "${vae}"
}
},
"11": {
"class_type": "DualCLIPLoader",
"inputs": {
"clip_name1": "${clip_t5}",
"clip_name2": "${clip_l}",
"type": "flux"
}
},
"12": {
"class_type": "UNETLoader",
"inputs": {
"unet_name": "${model}",
"weight_dtype": "${dtype}"
}
},
"13": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "${negative}",
"clip": ["11", 0]
}
},
"27": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": "${width}",
"height": "${height}",
"batch_size": 1
}
},
"30": {
"class_type": "ModelSamplingFlux",
"inputs": {
"model": ["12", 0],
"max_shift": 1.15,
"base_shift": 0.5,
"width": "${width}",
"height": "${height}"
}
},
"31": {
"class_type": "KSampler",
"inputs": {
"model": ["30", 0],
"seed": "${seed}",
"steps": "${steps}",
"cfg": "${cfg}",
"sampler_name": "${sampler}",
"scheduler": "${scheduler}",
"denoise": 1.0,
"positive": ["6", 0],
"negative": ["13", 0],
"latent_image": ["27", 0]
}
}
}

View File

@@ -0,0 +1,79 @@
{
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "${prompt}",
"clip": ["11", 0]
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["31", 0],
"vae": ["10", 0]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "imagen",
"images": ["8", 0]
}
},
"10": {
"class_type": "VAELoader",
"inputs": {
"vae_name": "${vae}"
}
},
"11": {
"class_type": "CLIPLoader",
"inputs": {
"clip_name": "${clip}",
"type": "flux2"
}
},
"12": {
"class_type": "UNETLoader",
"inputs": {
"unet_name": "${model}",
"weight_dtype": "${dtype}"
}
},
"14": {
"class_type": "FluxGuidance",
"inputs": {
"conditioning": ["6", 0],
"guidance": "${guidance}"
}
},
"15": {
"class_type": "ConditioningZeroOut",
"inputs": {
"conditioning": ["6", 0]
}
},
"27": {
"class_type": "EmptyFlux2LatentImage",
"inputs": {
"width": "${width}",
"height": "${height}",
"batch_size": 1
}
},
"31": {
"class_type": "KSampler",
"inputs": {
"model": ["12", 0],
"seed": "${seed}",
"steps": "${steps}",
"cfg": "${cfg}",
"sampler_name": "${sampler}",
"scheduler": "${scheduler}",
"denoise": 1.0,
"positive": ["14", 0],
"negative": ["15", 0],
"latent_image": ["27", 0]
}
}
}

View File

@@ -0,0 +1,66 @@
{
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "${model}"
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "${prompt}",
"clip": ["4", 1]
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "${negative}",
"clip": ["4", 1]
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["31", 0],
"vae": ["4", 2]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "imagen",
"images": ["8", 0]
}
},
"13": {
"class_type": "ModelSamplingSD3",
"inputs": {
"model": ["4", 0],
"shift": "${shift}"
}
},
"27": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": "${width}",
"height": "${height}",
"batch_size": 1
}
},
"31": {
"class_type": "KSampler",
"inputs": {
"model": ["13", 0],
"seed": "${seed}",
"steps": "${steps}",
"cfg": "${cfg}",
"sampler_name": "${sampler}",
"scheduler": "${scheduler}",
"denoise": 1.0,
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["27", 0]
}
}
}

365
internal/cloud/cloud.go Normal file
View File

@@ -0,0 +1,365 @@
// Package cloud syncs a generated image to Supabase Storage and inserts
// a row into imagen.images. Both steps are best-effort: callers log the
// returned error and proceed, because the local PNG + sidecar are already
// on disk by the time Sync runs and a cloud blip should not lose the
// artefact.
//
// The single source of truth for the row schema is the imagen_schema_init
// migration — see internal docs in the issue body for #7.
package cloud
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
)
// supabaseSchema is the PostgREST profile header value the imagen schema
// is exposed under (see ALTER ROLE authenticator SET pgrst.db_schemas).
const supabaseSchema = "imagen"
// bucketName is the Supabase Storage bucket all generated images land in.
const bucketName = "imagen-generated"
// Sink writes one PNG + one row per generation. It is safe to share
// across goroutines.
type Sink struct {
// URL is SUPABASE_URL — e.g. https://supa.flexsiebels.de.
URL string
// APIKey is the service-role key (SUPABASE_SERVICE_KEY). Storage uploads
// and DB inserts both bypass RLS with this key — the policies on the
// table + bucket are the contract for the read side.
APIKey string
// OwnerUserID is m's auth.users.id. It populates owner_user_id on every
// row. Empty means the sink refuses to insert (the column is NOT NULL
// and the user-mode reader needs it for the RLS policy).
OwnerUserID string
// HTTP is the http client; tests inject one pointing at httptest.
HTTP *http.Client
// MaxRetries is the number of additional attempts after the first
// failure for retryable (5xx) responses. Zero means single-shot.
MaxRetries int
// InitialBackoff is the wait before the first retry; doubles per attempt.
// Set very small in tests.
InitialBackoff time.Duration
}
// NewFromEnv returns a sink populated from SUPABASE_URL +
// SUPABASE_SERVICE_KEY (or MAI_SUPABASE_KEY) + IMAGEN_OWNER_USER_ID.
// Returns ok=false if the URL or key are missing — the caller treats that
// as "cloud-sync disabled by environment".
func NewFromEnv() (*Sink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &Sink{
URL: u,
APIKey: key,
OwnerUserID: os.Getenv("IMAGEN_OWNER_USER_ID"),
HTTP: &http.Client{Timeout: 30 * time.Second},
MaxRetries: 2,
InitialBackoff: time.Second,
}, true
}
// SyncRequest is the cross-backend ingredient set Sync needs. Date is
// formatted as YYYY-MM-DD; Slug + Seed are reused from the local
// filename so storage_path mirrors disk layout.
type SyncRequest struct {
Date string
Slug string
Seed int64
Ext string // "png", "jpg", "webp" — no leading dot
PNG []byte
MimeType string
Prompt string
Backend string
Model string
Steps int
Width int
Height int
LatencyMs int
CostUSDEstimate *float64
Sidecar map[string]any
// SeriesID is the parent imagen.series row when this image is one of
// N tries in a batch. Empty means a solo run — the column stays NULL,
// which keeps the row visible on the main list-page query
// (`WHERE series_id IS NULL`).
SeriesID string
}
// SyncResult tells the caller what landed where.
type SyncResult struct {
StoragePath string // e.g. "2026-05-11/lighthouse-42.png"
ImageID string // imagen.images.id (UUID)
}
// Sync uploads the bytes and inserts the metadata row. Returns the row's
// id and storage_path on success; any non-nil error is what the caller
// surfaces as "imagen: cloud sync: <err>" and otherwise ignores.
func (s *Sink) Sync(ctx context.Context, req SyncRequest) (*SyncResult, error) {
if s == nil {
return nil, fmt.Errorf("cloud sink not configured")
}
if s.OwnerUserID == "" {
return nil, fmt.Errorf("owner_user_id not set (config or $IMAGEN_OWNER_USER_ID); refusing to insert NULL into imagen.images")
}
if req.Date == "" || req.Slug == "" {
return nil, fmt.Errorf("date and slug are required for storage_path")
}
ext := req.Ext
if ext == "" {
ext = "png"
}
storagePath := fmt.Sprintf("%s/%s-%d.%s", req.Date, req.Slug, req.Seed, ext)
if err := s.upload(ctx, storagePath, req.PNG, req.MimeType); err != nil {
return nil, fmt.Errorf("storage upload: %w", err)
}
id, err := s.insertRow(ctx, storagePath, req)
if err != nil {
return &SyncResult{StoragePath: storagePath}, fmt.Errorf("db insert: %w", err)
}
return &SyncResult{StoragePath: storagePath, ImageID: id}, nil
}
// upload PUTs the PNG into the imagen-generated bucket. We use
// Content-Type so signed URLs render in the browser without a download
// prompt. POST would error on second-write; PUT (with x-upsert: true) is
// idempotent for re-runs of the same date+slug+seed.
func (s *Sink) upload(ctx context.Context, storagePath string, body []byte, mime string) error {
if mime == "" {
mime = "image/png"
}
endpoint := fmt.Sprintf("%s/storage/v1/object/%s/%s", s.URL, bucketName, pathEscape(storagePath))
return s.doRetry(ctx, func(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", mime)
req.Header.Set("x-upsert", "true")
return s.HTTP.Do(req)
})
}
// insertRow POSTs to PostgREST against the imagen schema. Prefer:
// return=representation gives us the inserted id back without a second
// round-trip.
func (s *Sink) insertRow(ctx context.Context, storagePath string, req SyncRequest) (string, error) {
row := map[string]any{
"owner_user_id": s.OwnerUserID,
"prompt": req.Prompt,
"prompt_hash": hashPrompt(req.Prompt),
"backend": req.Backend,
"storage_path": storagePath,
}
if req.Model != "" {
row["model"] = req.Model
}
if req.Seed != 0 {
row["seed"] = req.Seed
}
if req.Steps != 0 {
row["steps"] = req.Steps
}
if req.Width != 0 {
row["width"] = req.Width
}
if req.Height != 0 {
row["height"] = req.Height
}
if req.LatencyMs != 0 {
row["latency_ms"] = req.LatencyMs
}
if req.CostUSDEstimate != nil {
row["cost_usd_estimate"] = *req.CostUSDEstimate
}
if len(req.Sidecar) > 0 {
row["sidecar"] = req.Sidecar
}
if req.SeriesID != "" {
row["series_id"] = req.SeriesID
}
body, err := json.Marshal(row)
if err != nil {
return "", fmt.Errorf("marshal row: %w", err)
}
endpoint := s.URL + "/rest/v1/images"
respBody, err := s.doRetryRead(ctx, func(ctx context.Context) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=representation")
return s.HTTP.Do(req)
})
if err != nil {
return "", err
}
var rows []struct {
ID string `json:"id"`
}
if err := json.Unmarshal(respBody, &rows); err != nil {
return "", fmt.Errorf("parse insert response: %w (body: %s)", err, snip(respBody))
}
if len(rows) == 0 {
return "", fmt.Errorf("insert returned 0 rows (body: %s)", snip(respBody))
}
return rows[0].ID, nil
}
// SignedURL asks the Storage API for a time-limited URL. ttlSeconds is
// the validity window. Returned URL is host-qualified and ready to hand
// to a browser.
func (s *Sink) SignedURL(ctx context.Context, storagePath string, ttlSeconds int) (string, error) {
if s == nil {
return "", fmt.Errorf("cloud sink not configured")
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
endpoint := fmt.Sprintf("%s/storage/v1/object/sign/%s/%s", s.URL, bucketName, pathEscape(storagePath))
body, err := json.Marshal(map[string]any{"expiresIn": ttlSeconds})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.HTTP.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("sign %d: %s", resp.StatusCode, snip(respBody))
}
var parsed struct {
SignedURL string `json:"signedURL"`
}
if err := json.Unmarshal(respBody, &parsed); err != nil {
return "", fmt.Errorf("parse sign response: %w (body: %s)", err, snip(respBody))
}
if parsed.SignedURL == "" {
return "", fmt.Errorf("empty signedURL in response: %s", snip(respBody))
}
full := parsed.SignedURL
if strings.HasPrefix(full, "/") {
full = s.URL + full
}
return full, nil
}
// doRetry runs op up to MaxRetries+1 times. 5xx and transport errors are
// retried with exponential backoff; 4xx surfaces immediately as a
// permanent error (caller's bug in the row, not a network blip).
func (s *Sink) doRetry(ctx context.Context, op func(context.Context) (*http.Response, error)) error {
_, err := s.doRetryRead(ctx, op)
return err
}
// doRetryRead is the read-the-body variant. Returns the 2xx response
// body bytes; non-2xx is wrapped in an error. Same retry semantics as
// doRetry: 5xx/transport retries with exponential backoff, 4xx is fatal.
func (s *Sink) doRetryRead(ctx context.Context, op func(context.Context) (*http.Response, error)) ([]byte, error) {
backoff := s.InitialBackoff
if backoff == 0 {
backoff = time.Second
}
attempts := s.MaxRetries + 1
if attempts < 1 {
attempts = 1
}
var lastErr error
for i := 0; i < attempts; i++ {
if i > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(backoff):
}
backoff *= 2
}
resp, err := op(ctx)
if err != nil {
lastErr = err
continue
}
body, readErr := io.ReadAll(resp.Body)
resp.Body.Close()
if readErr != nil {
lastErr = fmt.Errorf("read body: %w", readErr)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return body, nil
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
}
lastErr = fmt.Errorf("%d: %s", resp.StatusCode, snip(body))
}
return nil, lastErr
}
func hashPrompt(p string) string {
sum := sha256.Sum256([]byte(p))
return hex.EncodeToString(sum[:])
}
// pathEscape encodes each path segment but keeps the slashes — the
// Storage API treats the part after the bucket name as a virtual file
// path with directory separators.
func pathEscape(p string) string {
parts := strings.Split(p, "/")
for i, seg := range parts {
parts[i] = url.PathEscape(seg)
}
return path.Join(parts...)
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}

View File

@@ -0,0 +1,374 @@
package cloud
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"
)
// fakeSupabase is a tiny stand-in for Supabase Storage + PostgREST. It
// records what came in and returns canned responses based on path.
type fakeSupabase struct {
t *testing.T
mux *http.ServeMux
server *httptest.Server
uploadCalls int32
insertCalls int32
uploadBytes []byte
uploadHdr http.Header
insertBody []byte
insertHdr http.Header
}
func newFakeSupabase(t *testing.T, opts ...func(*fakeSupabase)) *fakeSupabase {
f := &fakeSupabase{t: t}
f.mux = http.NewServeMux()
// Storage upload — anything under /storage/v1/object/<bucket>/...
f.mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&f.uploadCalls, 1)
body, _ := io.ReadAll(r.Body)
f.uploadBytes = body
f.uploadHdr = r.Header.Clone()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"Key":"imagen-generated/somepath"}`))
})
// Storage sign URL
f.mux.HandleFunc("/storage/v1/object/sign/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"signedURL":"/storage/v1/object/sign/imagen-generated/some.png?token=abc"}`))
})
// PostgREST insert
f.mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&f.insertCalls, 1)
body, _ := io.ReadAll(r.Body)
f.insertBody = body
f.insertHdr = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`[{"id":"00000000-0000-0000-0000-000000000abc"}]`))
})
for _, opt := range opts {
opt(f)
}
f.server = httptest.NewServer(f.mux)
t.Cleanup(f.server.Close)
return f
}
func newSink(server *httptest.Server) *Sink {
return &Sink{
URL: server.URL,
APIKey: "fake-service-key",
OwnerUserID: "00000000-0000-0000-0000-000000000001",
HTTP: server.Client(),
MaxRetries: 2,
InitialBackoff: time.Millisecond,
}
}
func TestSyncHappyPath(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
cost := 0.003
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11",
Slug: "lighthouse",
Seed: 42,
Ext: "png",
PNG: []byte("PNGbytes"),
MimeType: "image/png",
Prompt: "a tiny lighthouse on a stormy cliff",
Backend: "flux-schnell-local",
Model: "flux1-schnell",
Steps: 4,
Width: 1024,
Height: 1024,
LatencyMs: 1500,
CostUSDEstimate: &cost,
Sidecar: map[string]any{
"timestamp": "2026-05-11T01:30:00Z",
"backend": "flux-schnell-local",
},
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
if res.StoragePath != "2026-05-11/lighthouse-42.png" {
t.Errorf("storage_path = %q", res.StoragePath)
}
if res.ImageID != "00000000-0000-0000-0000-000000000abc" {
t.Errorf("image_id = %q", res.ImageID)
}
if got := atomic.LoadInt32(&f.uploadCalls); got != 1 {
t.Errorf("upload calls = %d, want 1", got)
}
if got := atomic.LoadInt32(&f.insertCalls); got != 1 {
t.Errorf("insert calls = %d, want 1", got)
}
if !bytes.Equal(f.uploadBytes, []byte("PNGbytes")) {
t.Errorf("uploaded bytes = %q", f.uploadBytes)
}
// Verify the row payload carries the prompt + computed hash + non-zero
// metadata. Empty fields should be omitted from the JSON body so RLS
// won't see surprise keys.
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("insert body parse: %v\n%s", err, f.insertBody)
}
if row["prompt"] != "a tiny lighthouse on a stormy cliff" {
t.Errorf("row.prompt = %v", row["prompt"])
}
if row["owner_user_id"] != "00000000-0000-0000-0000-000000000001" {
t.Errorf("row.owner_user_id = %v", row["owner_user_id"])
}
if row["storage_path"] != "2026-05-11/lighthouse-42.png" {
t.Errorf("row.storage_path = %v", row["storage_path"])
}
hash, _ := row["prompt_hash"].(string)
if len(hash) != 64 {
t.Errorf("prompt_hash should be 64-char sha256 hex, got %q", hash)
}
if row["backend"] != "flux-schnell-local" {
t.Errorf("row.backend = %v", row["backend"])
}
if row["seed"].(float64) != 42 {
t.Errorf("row.seed = %v", row["seed"])
}
if row["latency_ms"].(float64) != 1500 {
t.Errorf("row.latency_ms = %v", row["latency_ms"])
}
if row["cost_usd_estimate"].(float64) != 0.003 {
t.Errorf("row.cost = %v", row["cost_usd_estimate"])
}
if row["sidecar"] == nil {
t.Errorf("row.sidecar missing")
}
// PostgREST schema headers — hardcoded to "imagen".
if got := f.insertHdr.Get("Accept-Profile"); got != "imagen" {
t.Errorf("Accept-Profile = %q", got)
}
if got := f.insertHdr.Get("Content-Profile"); got != "imagen" {
t.Errorf("Content-Profile = %q", got)
}
if got := f.insertHdr.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") {
t.Errorf("Authorization = %q", got)
}
// Storage upsert should be set so re-runs of the same date+slug+seed
// don't fail with 409.
if got := f.uploadHdr.Get("x-upsert"); got != "true" {
t.Errorf("x-upsert = %q", got)
}
}
func TestSyncRetryOn5xx(t *testing.T) {
var uploadAttempts int32
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&uploadAttempts, 1)
// Two 503s, then OK.
if n < 3 {
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`[{"id":"row-id"}]`))
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err != nil {
t.Fatalf("Sync (with retry): %v", err)
}
if got := atomic.LoadInt32(&uploadAttempts); got != 3 {
t.Errorf("upload attempts = %d, want 3", got)
}
if res.ImageID != "row-id" {
t.Errorf("image_id = %q", res.ImageID)
}
}
func TestSyncNoRetryOn4xx(t *testing.T) {
var uploadAttempts int32
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&uploadAttempts, 1)
http.Error(w, `{"message":"bad request"}`, http.StatusBadRequest)
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error on 400")
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("error should mention 400 status: %v", err)
}
if got := atomic.LoadInt32(&uploadAttempts); got != 1 {
t.Errorf("upload attempts = %d, want 1 (no retry on 4xx)", got)
}
}
func TestSyncMissingOwnerUserID(t *testing.T) {
srv := httptest.NewServer(http.NewServeMux())
defer srv.Close()
s := &Sink{
URL: srv.URL,
APIKey: "k",
// OwnerUserID intentionally empty.
HTTP: srv.Client(),
InitialBackoff: time.Millisecond,
}
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error when owner_user_id unset")
}
if !strings.Contains(err.Error(), "owner_user_id") {
t.Errorf("error should mention owner_user_id: %v", err)
}
}
func TestSyncRequiresDateAndSlug(t *testing.T) {
srv := httptest.NewServer(http.NewServeMux())
defer srv.Close()
s := newSink(srv)
_, err := s.Sync(context.Background(), SyncRequest{
Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error for missing date")
}
}
func TestSignedURL(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
got, err := s.SignedURL(context.Background(), "2026-05-11/x.png", 60)
if err != nil {
t.Fatalf("SignedURL: %v", err)
}
want := f.server.URL + "/storage/v1/object/sign/imagen-generated/some.png?token=abc"
if got != want {
t.Errorf("signed URL = %q, want %q", got, want)
}
}
func TestSyncDBFailureSurfacesPathOnError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/storage/v1/object/imagen-generated/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/rest/v1/images", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "schema cache miss", http.StatusInternalServerError)
})
srv := httptest.NewServer(mux)
defer srv.Close()
s := newSink(srv)
res, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 9, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
})
if err == nil {
t.Fatal("expected error from DB insert failure")
}
// Storage upload succeeded — caller can still see the upload landed.
if res == nil || res.StoragePath != "2026-05-11/x-9.png" {
t.Errorf("expected storage_path on partial success, got %+v", res)
}
}
// TestSyncWritesSeriesID is the second half of the ImaGen#9 propagation
// contract: when SeriesID is non-empty, the POST body to imagen.images
// carries `series_id`. When empty, the key is omitted entirely so the
// row's series_id stays NULL (solo-run path, list-page query
// `WHERE series_id IS NULL` keeps showing it).
func TestSyncWritesSeriesID(t *testing.T) {
const seriesID = "22222222-2222-2222-2222-222222222222"
f := newFakeSupabase(t)
s := newSink(f.server)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
SeriesID: seriesID,
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
}
if row["series_id"] != seriesID {
t.Fatalf("row.series_id = %v want %q", row["series_id"], seriesID)
}
}
func TestSyncOmitsSeriesIDWhenEmpty(t *testing.T) {
f := newFakeSupabase(t)
s := newSink(f.server)
_, err := s.Sync(context.Background(), SyncRequest{
Date: "2026-05-11", Slug: "x", Seed: 1, Ext: "png",
PNG: []byte("p"), Prompt: "p", Backend: "b",
// SeriesID intentionally empty.
})
if err != nil {
t.Fatalf("Sync: %v", err)
}
var row map[string]any
if err := json.Unmarshal(f.insertBody, &row); err != nil {
t.Fatalf("parse insert body: %v\n%s", err, f.insertBody)
}
if _, present := row["series_id"]; present {
t.Fatalf("solo run should omit series_id from POST body, got %v", row["series_id"])
}
}
func TestPathEscape(t *testing.T) {
cases := map[string]string{
"2026-05-11/lighthouse-42.png": "2026-05-11/lighthouse-42.png",
"2026-05-11/two words.png": "2026-05-11/two%20words.png",
"with#hash/and?query.png": "with%23hash/and%3Fquery.png",
}
for in, want := range cases {
got := pathEscape(in)
if got != want {
t.Errorf("pathEscape(%q) = %q, want %q", in, got, want)
}
// Sanity: every part should round-trip via url.PathUnescape.
for _, seg := range strings.Split(got, "/") {
if _, err := url.PathUnescape(seg); err != nil {
t.Errorf("segment %q failed unescape: %v", seg, err)
}
}
}
}

View File

@@ -15,15 +15,29 @@ import (
// Config is the top-level shape of imagen.yaml.
type Config struct {
DefaultBackend string `yaml:"default_backend"`
// OwnerUserID is m's auth.users.id on msupabase. The cloud-sync writer
// uses it to populate imagen.images.owner_user_id (NOT NULL, owns RLS).
// Empty disables DB inserts even when cloud_sync is on.
OwnerUserID string `yaml:"owner_user_id"`
Output OutputConfig `yaml:"output"`
Backends map[string]BackendSpec `yaml:"backends"`
}
// OutputConfig controls where generated images and metadata sidecars land.
// OutputConfig controls where generated images and metadata sidecars land,
// and whether `imagen generate` opens a tmux preview window.
type OutputConfig struct {
Directory string `yaml:"directory"`
Naming string `yaml:"naming"`
WriteMetadataJSON bool `yaml:"write_metadata_json"`
// Preview is the tri-state preview mode: "auto" (default), "on", "off".
// Empty / unset is treated as "auto". $IMAGEN_PREVIEW and the
// --preview/--no-preview flags override this in turn.
Preview string `yaml:"preview"`
// CloudSync controls whether successful generations also upload to
// Supabase Storage and insert into imagen.images. Tri-state mirroring
// Preview: "auto" (default — on when SUPABASE_URL + SUPABASE_SERVICE_KEY
// are set), "on" (errors if env unset), "off". --no-cloud overrides.
CloudSync string `yaml:"cloud_sync"`
}
// BackendSpec is one entry under `backends:`. Type identifies the adapter;
@@ -78,6 +92,16 @@ func (c *Config) Validate() error {
return fmt.Errorf("default_backend %q is not defined under backends:", c.DefaultBackend)
}
}
switch c.Output.Preview {
case "", "auto", "on", "off":
default:
return fmt.Errorf("output.preview = %q (must be auto|on|off)", c.Output.Preview)
}
switch c.Output.CloudSync {
case "", "auto", "on", "off":
default:
return fmt.Errorf("output.cloud_sync = %q (must be auto|on|off)", c.Output.CloudSync)
}
for name, spec := range c.Backends {
if name == "" {
return errors.New("empty backend name")
@@ -97,30 +121,94 @@ const Sample = `# imagen.yaml — config for the imagen CLI.
default_backend: flux-schnell-local
# Owner UUID for the cloud-sync row in imagen.images. Look up via:
# SELECT id FROM auth.users WHERE email = '<your-supabase-email>';
# Empty disables imagen.images inserts even when cloud_sync is on.
owner_user_id: ""
output:
directory: ~/Pictures/imagen
naming: "{date}-{slug}-{seed}.png"
write_metadata_json: true
# Open a tmux window with tmux-img after a successful generation.
# auto (default): preview iff stdout is a TTY and $TMUX is set.
# on: always preview (errors outside a tmux session).
# off: never preview (use this for batch / CI callers).
preview: auto
# Sync the PNG to Supabase Storage (bucket: imagen-generated) and insert
# a row into imagen.images. Reads SUPABASE_URL + SUPABASE_SERVICE_KEY
# from env (same as mai.imagen_usage cost-tracking).
# auto (default): on iff env is configured AND owner_user_id is set.
# on: always upload (errors if env or owner_user_id is missing).
# off: never upload. --no-cloud also forces off per-call.
cloud_sync: auto
backends:
# FLUX.1-schnell on the local ComfyUI server. The "workflow" key picks the
# bundled template under internal/backend/workflows/; omit it for back-compat
# (defaults to flux1-schnell). See docs/backends.md for the per-model setup.
flux-schnell-local:
type: comfyui
base_url: http://mrock:8188
workflow: flux1-schnell
# Filename of the unet checkpoint inside the ComfyUI server's
# models/unet/ directory. See docs/setup-comfyui-mrock.md.
# models/unet/ directory.
model: flux1-schnell.safetensors
vae: ae.safetensors
clip_l: clip_l.safetensors
clip_t5: t5xxl_fp8_e4m3fn.safetensors
dtype: fp8_e4m3fn
default_steps: 4
default_sampler: euler
default_scheduler: simple
default_cfg: 1.0
# FLUX.2 [klein] 4B distilled — sub-second on RTX 4070 Ti SUPER.
# Weights: BFL non-commercial; flux-2-klein-base-4b-fp8 in models/unet/,
# qwen_3_4b in models/text_encoders/, flux2-vae in models/vae/.
flux2-klein-local:
type: comfyui
base_url: http://mrock:8188
workflow: flux2-klein
model: flux-2-klein-base-4b-fp8.safetensors
vae: flux2-vae.safetensors
clip: qwen_3_4b.safetensors
dtype: fp8_e4m3fn
default_steps: 4
default_sampler: euler
default_scheduler: simple
default_cfg: 1.0
guidance: 4.0
# SD3.5 medium — single-checkpoint variant that bundles the three text
# encoders inside the .safetensors. Drop into models/checkpoints/.
sd35-medium-local:
type: comfyui
base_url: http://mrock:8188
workflow: sd35-medium
model: sd3.5_medium_incl_clips_t5xxlfp8scaled.safetensors
default_steps: 28
default_sampler: dpmpp_2m
default_scheduler: sgm_uniform
default_cfg: 4.5
shift: 3.0
mock:
type: mock
flux-schnell-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-schnell
default_steps: 4
default_aspect_ratio: "1:1"
flux-dev-replicate:
type: replicate
api_token_env: REPLICATE_API_TOKEN
model: black-forest-labs/flux-dev
default_steps: 28
default_aspect_ratio: "1:1"
dalle3:
type: openai

View File

@@ -60,6 +60,67 @@ func TestValidateRejectsMissingType(t *testing.T) {
}
}
func TestValidatePreviewMode(t *testing.T) {
for _, mode := range []string{"", "auto", "on", "off"} {
c := &Config{Output: OutputConfig{Preview: mode}}
if err := c.Validate(); err != nil {
t.Errorf("preview=%q: unexpected error %v", mode, err)
}
}
bad := &Config{Output: OutputConfig{Preview: "yes"}}
if err := bad.Validate(); err == nil {
t.Errorf("expected error for invalid preview value")
}
}
func TestValidateCloudSyncMode(t *testing.T) {
for _, mode := range []string{"", "auto", "on", "off"} {
c := &Config{Output: OutputConfig{CloudSync: mode}}
if err := c.Validate(); err != nil {
t.Errorf("cloud_sync=%q: unexpected error %v", mode, err)
}
}
bad := &Config{Output: OutputConfig{CloudSync: "yes"}}
if err := bad.Validate(); err == nil {
t.Errorf("expected error for invalid cloud_sync value")
}
}
func TestSampleParsesCloudSyncAuto(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "imagen.yaml")
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
t.Fatalf("write sample: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.Output.CloudSync != "auto" {
t.Errorf("Output.CloudSync = %q, want auto", cfg.Output.CloudSync)
}
// owner_user_id is intentionally empty in the sample — operators fill
// it in after looking up their auth.users.id.
if cfg.OwnerUserID != "" {
t.Errorf("Sample OwnerUserID should be empty, got %q", cfg.OwnerUserID)
}
}
func TestSampleParsesPreviewAuto(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "imagen.yaml")
if err := os.WriteFile(path, []byte(Sample), 0o644); err != nil {
t.Fatalf("write sample: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.Output.Preview != "auto" {
t.Errorf("Output.Preview = %q, want auto", cfg.Output.Preview)
}
}
func TestExpandPath(t *testing.T) {
home, _ := os.UserHomeDir()
cases := map[string]string{

View File

@@ -35,6 +35,13 @@ type Inputs struct {
type Outputs struct {
ImagePath string
SidecarPath string
// Date is the YYYY-MM-DD the writer used for the filename. Cloud sync
// reuses this so storage_path matches the local filename's date.
Date string
// Slug is the filename-safe prompt fragment the writer used.
Slug string
// Seed is the seed value baked into the filename.
Seed int64
}
// Write streams img to disk and, if enabled, writes a sidecar. The image
@@ -50,10 +57,12 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
if tmpl == "" {
tmpl = "{date}-{slug}-{seed}.{ext}"
}
date := now.Format("2006-01-02")
slug := Slug(in.Prompt)
name := renderTemplate(tmpl, map[string]string{
"date": now.Format("2006-01-02"),
"date": date,
"time": now.Format("150405"),
"slug": Slug(in.Prompt),
"slug": slug,
"seed": fmt.Sprintf("%d", in.Seed),
"backend": in.Backend,
"ext": strings.TrimPrefix(ext, "."),
@@ -80,7 +89,7 @@ func (w *Writer) Write(img io.Reader, in Inputs) (*Outputs, error) {
return nil, fmt.Errorf("close %s: %w", imagePath, err)
}
out := &Outputs{ImagePath: imagePath}
out := &Outputs{ImagePath: imagePath, Date: date, Slug: slug, Seed: in.Seed}
if w.WriteSidecar {
sidecar := imagePath + ".json"
@@ -122,7 +131,7 @@ func (w *Writer) WriteToPath(img io.Reader, path string, in Inputs) (*Outputs, e
if err := f.Close(); err != nil {
return nil, fmt.Errorf("close %s: %w", path, err)
}
out := &Outputs{ImagePath: path}
out := &Outputs{ImagePath: path, Date: now.Format("2006-01-02"), Slug: Slug(in.Prompt), Seed: in.Seed}
if w.WriteSidecar {
sidecar := path + ".json"
body := map[string]any{

119
internal/preview/tmux.go Normal file
View File

@@ -0,0 +1,119 @@
// Package preview opens a tmux window showing a generated image via tmux-img.
// Mode resolution and the actual spawn are kept separate so the CLI can
// decide-then-act and tests can drive each half independently.
package preview
import (
"errors"
"fmt"
"os/exec"
"strings"
)
// Mode is the tri-state preview setting: auto (default), on (force), off.
type Mode string
const (
ModeAuto Mode = "auto"
ModeOn Mode = "on"
ModeOff Mode = "off"
)
// ParseMode normalises a string into a Mode. Empty parses to ModeAuto so
// callers can pass through unset config / env values.
func ParseMode(s string) (Mode, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "auto":
return ModeAuto, nil
case "on":
return ModeOn, nil
case "off":
return ModeOff, nil
}
return "", fmt.Errorf("invalid preview mode %q (auto|on|off)", s)
}
// Decision is the answer to "should we preview, and why".
type Decision struct {
ShouldPreview bool
Reason string
}
// Resolve maps (mode, runtime context) to a Decision.
//
// - off -> never preview
// - on -> preview, but error if not in tmux (forced on outside tmux)
// - auto -> preview iff inTmux && stdoutTTY
func Resolve(mode Mode, inTmux, stdoutTTY bool) (Decision, error) {
switch mode {
case ModeOff:
return Decision{ShouldPreview: false, Reason: "preview=off"}, nil
case ModeOn:
if !inTmux {
return Decision{}, ErrNoTmuxForced
}
return Decision{ShouldPreview: true, Reason: "preview=on"}, nil
case ModeAuto, "":
if !inTmux {
return Decision{ShouldPreview: false, Reason: "auto: $TMUX unset"}, nil
}
if !stdoutTTY {
return Decision{ShouldPreview: false, Reason: "auto: stdout not a tty"}, nil
}
return Decision{ShouldPreview: true, Reason: "auto"}, nil
}
return Decision{}, fmt.Errorf("invalid preview mode %q", mode)
}
// Errors returned by Spawn and Resolve. Each names the missing piece and,
// where relevant, where to install it.
var (
ErrTmuxMissing = errors.New("tmux: binary not found on $PATH (required for image preview)")
ErrTmuxImgMissing = errors.New("tmux-img: binary not found on $PATH (install at ~/.local/bin/tmux-img)")
ErrNoTmuxForced = errors.New("--preview requires $TMUX (are you in a tmux session?)")
)
// Spawner spawns the tmux preview window. The exec.LookPath / cmd.Run hooks
// exist so tests can inject fakes without touching $PATH.
type Spawner struct {
LookPath func(string) (string, error)
Run func(*exec.Cmd) error
}
// Spawn opens a new tmux window named img:<slug> running tmux-img --hold
// <imagePath>. -d keeps focus in the current pane. Caller is expected to
// have already verified that we are inside a tmux session.
func (s *Spawner) Spawn(imagePath, slug string) error {
look := s.LookPath
if look == nil {
look = exec.LookPath
}
run := s.Run
if run == nil {
run = func(c *exec.Cmd) error { return c.Run() }
}
tmuxBin, err := look("tmux")
if err != nil {
return ErrTmuxMissing
}
tmuxImgBin, err := look("tmux-img")
if err != nil {
return ErrTmuxImgMissing
}
name := "img:" + slug
shellCmd := fmt.Sprintf("%s --hold %s",
shellQuote(tmuxImgBin), shellQuote(imagePath))
cmd := exec.Command(tmuxBin, "new-window", "-d", "-n", name, shellCmd)
if err := run(cmd); err != nil {
return fmt.Errorf("tmux new-window: %w", err)
}
return nil
}
// shellQuote single-quotes s for /bin/sh — tmux passes the trailing arg of
// new-window through a shell.
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}

View File

@@ -0,0 +1,170 @@
package preview
import (
"errors"
"os/exec"
"strings"
"testing"
)
func TestParseMode(t *testing.T) {
cases := map[string]Mode{
"": ModeAuto,
"auto": ModeAuto,
"AUTO": ModeAuto,
"on": ModeOn,
" on ": ModeOn,
"off": ModeOff,
}
for in, want := range cases {
got, err := ParseMode(in)
if err != nil {
t.Errorf("ParseMode(%q) err = %v", in, err)
continue
}
if got != want {
t.Errorf("ParseMode(%q) = %q, want %q", in, got, want)
}
}
if _, err := ParseMode("nope"); err == nil {
t.Errorf("ParseMode(nope) should have errored")
}
}
func TestResolve(t *testing.T) {
type tc struct {
mode Mode
inTmux bool
stdoutTTY bool
want bool
wantErr error
}
cases := map[string]tc{
"off-anywhere": {ModeOff, false, false, false, nil},
"off-in-tmux-tty": {ModeOff, true, true, false, nil},
"on-in-tmux": {ModeOn, true, false, true, nil},
"on-outside-tmux-errs": {ModeOn, false, true, false, ErrNoTmuxForced},
"auto-no-tmux": {ModeAuto, false, true, false, nil},
"auto-tmux-no-tty": {ModeAuto, true, false, false, nil},
"auto-tmux-and-tty": {ModeAuto, true, true, true, nil},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
d, err := Resolve(c.mode, c.inTmux, c.stdoutTTY)
if c.wantErr != nil {
if !errors.Is(err, c.wantErr) {
t.Fatalf("err = %v, want %v", err, c.wantErr)
}
return
}
if err != nil {
t.Fatalf("err = %v", err)
}
if d.ShouldPreview != c.want {
t.Errorf("ShouldPreview = %v, want %v (reason: %s)", d.ShouldPreview, c.want, d.Reason)
}
})
}
}
func TestSpawn_BuildsCorrectCommand(t *testing.T) {
var captured *exec.Cmd
s := &Spawner{
LookPath: func(name string) (string, error) {
switch name {
case "tmux":
return "/usr/bin/tmux", nil
case "tmux-img":
return "/home/m/.local/bin/tmux-img", nil
}
return "", exec.ErrNotFound
},
Run: func(c *exec.Cmd) error {
captured = c
return nil
},
}
if err := s.Spawn("/tmp/imagen/cat.png", "cat-in-a-fishbowl"); err != nil {
t.Fatalf("Spawn: %v", err)
}
if captured == nil {
t.Fatal("Run was not called")
}
if captured.Path != "/usr/bin/tmux" {
t.Errorf("Path = %q, want /usr/bin/tmux", captured.Path)
}
args := captured.Args
if len(args) < 6 {
t.Fatalf("args = %v (need at least 6)", args)
}
// tmux new-window -d -n img:<slug> '<shell-cmd>'
if args[1] != "new-window" {
t.Errorf("args[1] = %q, want new-window", args[1])
}
if args[2] != "-d" {
t.Errorf("args[2] = %q, want -d", args[2])
}
if args[3] != "-n" {
t.Errorf("args[3] = %q, want -n", args[3])
}
if args[4] != "img:cat-in-a-fishbowl" {
t.Errorf("args[4] = %q, want img:cat-in-a-fishbowl", args[4])
}
shellCmd := args[5]
if !strings.Contains(shellCmd, "tmux-img") || !strings.Contains(shellCmd, "--hold") || !strings.Contains(shellCmd, "/tmp/imagen/cat.png") {
t.Errorf("shell cmd %q missing expected pieces", shellCmd)
}
}
func TestSpawn_PathWithSpacesAndQuotes(t *testing.T) {
var captured *exec.Cmd
s := &Spawner{
LookPath: func(name string) (string, error) {
if name == "tmux" {
return "/usr/bin/tmux", nil
}
if name == "tmux-img" {
return "/usr/local/bin/tmux-img", nil
}
return "", exec.ErrNotFound
},
Run: func(c *exec.Cmd) error { captured = c; return nil },
}
weird := "/tmp/imagen/o'malley's cat.png"
if err := s.Spawn(weird, "slug"); err != nil {
t.Fatalf("Spawn: %v", err)
}
shellCmd := captured.Args[5]
// Single-quoted with the embedded apostrophe escaped via the
// '\'' shell idiom — confirm we did not just splice the raw path.
if strings.Contains(shellCmd, "o'malley's") {
t.Errorf("shell cmd %q contains unescaped apostrophes", shellCmd)
}
}
func TestSpawn_MissingTmux(t *testing.T) {
s := &Spawner{
LookPath: func(string) (string, error) { return "", exec.ErrNotFound },
Run: func(*exec.Cmd) error { return nil },
}
err := s.Spawn("/x.png", "s")
if !errors.Is(err, ErrTmuxMissing) {
t.Errorf("err = %v, want ErrTmuxMissing", err)
}
}
func TestSpawn_MissingTmuxImg(t *testing.T) {
s := &Spawner{
LookPath: func(name string) (string, error) {
if name == "tmux" {
return "/usr/bin/tmux", nil
}
return "", exec.ErrNotFound
},
Run: func(*exec.Cmd) error { return nil },
}
err := s.Spawn("/x.png", "s")
if !errors.Is(err, ErrTmuxImgMissing) {
t.Errorf("err = %v, want ErrTmuxImgMissing", err)
}
}

View File

@@ -37,7 +37,7 @@ func TestApplyToEmptyPromptUsesPresetOnly(t *testing.T) {
}
func TestStylesContainsAllExpected(t *testing.T) {
want := []string{"blog-header", "diagram", "illustration", "photo", "sketch"}
want := []string{"3d-render", "anime", "blog-header", "cinematic", "diagram", "illustration", "isometric", "line-art", "photo", "sketch", "watercolor"}
got := Styles()
if len(got) != len(want) {
t.Fatalf("Styles() = %v, want %v", got, want)

View File

@@ -4,3 +4,9 @@ styles:
diagram: "minimal technical diagram, isometric, white background, line-art"
sketch: "rough pencil sketch, hand-drawn, monochrome"
blog-header: "wide aspect, conceptual, soft palette, editorial illustration"
cinematic: "cinematic still, 35mm film, shallow depth of field, dramatic lighting, color graded"
watercolor: "watercolor painting, soft washes, paper texture, loose brushwork"
anime: "anime illustration, cel-shaded, expressive linework, vibrant flat colors"
3d-render: "3d render, octane, soft global illumination, subtle ambient occlusion, physically based materials"
line-art: "clean line art, black ink on white, no shading, even stroke weight"
isometric: "isometric illustration, 30 degree projection, flat colors, crisp geometric shapes"

160
internal/usage/usage.go Normal file
View File

@@ -0,0 +1,160 @@
// Package usage records per-call cost-tracking rows for the imagen CLI
// to mai.imagen_usage on Supabase. The writer is best-effort by design —
// the calling adapter logs failures and proceeds, because the image
// itself has already landed on disk by the time we record.
package usage
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"mgit.msbls.de/m/ImaGen/internal/backend"
)
// Default REST schema is the mai schema where mai.imagen_usage lives.
const supabaseSchema = "mai"
// SupabaseSink writes rows via PostgREST. It uses Accept-Profile/
// Content-Profile headers to target the mai schema instead of public.
type SupabaseSink struct {
URL string // SUPABASE_URL — e.g. https://msup.msbls.de
APIKey string // SUPABASE_SERVICE_KEY
HTTP *http.Client
}
// NewSupabaseSinkFromEnv reads SUPABASE_URL and SUPABASE_SERVICE_KEY
// (falling back to MAI_SUPABASE_KEY) and returns a sink ready to use.
// Returns nil + ok=false if the env vars are not configured — the CLI
// uses that to skip cost-tracking gracefully.
func NewSupabaseSinkFromEnv() (*SupabaseSink, bool) {
u := strings.TrimRight(os.Getenv("SUPABASE_URL"), "/")
if u == "" {
return nil, false
}
key := os.Getenv("SUPABASE_SERVICE_KEY")
if key == "" {
key = os.Getenv("MAI_SUPABASE_KEY")
}
if key == "" {
return nil, false
}
return &SupabaseSink{
URL: u,
APIKey: key,
HTTP: &http.Client{Timeout: 10 * time.Second},
}, true
}
type supabaseRow struct {
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed,omitempty"`
PromptHash string `json:"prompt_hash"`
LatencyMs int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate,omitempty"`
Caller string `json:"caller,omitempty"`
}
// Record inserts one row into mai.imagen_usage.
func (s *SupabaseSink) Record(ctx context.Context, row backend.UsageRow) error {
body, err := json.Marshal(supabaseRow{
Backend: row.Backend,
Model: row.Model,
Seed: row.Seed,
PromptHash: row.PromptHash,
LatencyMs: row.LatencyMs,
CostUSDEstimate: row.CostUSDEstimate,
Caller: row.Caller,
})
if err != nil {
return fmt.Errorf("usage: marshal: %w", err)
}
endpoint := s.URL + "/rest/v1/imagen_usage"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Profile", supabaseSchema)
req.Header.Set("Content-Profile", supabaseSchema)
req.Header.Set("Prefer", "return=minimal")
resp, err := s.HTTP.Do(req)
if err != nil {
return fmt.Errorf("usage: POST: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("usage: POST %d: %s", resp.StatusCode, snip(respBody))
}
return nil
}
// Row is the read-side row shape (only the fields the CLI needs).
type Row struct {
CreatedAt time.Time `json:"created_at"`
Backend string `json:"backend"`
Model string `json:"model"`
Seed *int64 `json:"seed"`
PromptHash string `json:"prompt_hash"`
LatencyMs *int `json:"latency_ms"`
CostUSDEstimate *float64 `json:"cost_usd_estimate"`
Caller *string `json:"caller"`
}
// Query returns rows from mai.imagen_usage filtered by created_at >= since.
// Pass zero time to fetch the full table (capped server-side by PostgREST
// — we set a hard 5000-row limit here too).
func (s *SupabaseSink) Query(ctx context.Context, since time.Time) ([]Row, error) {
q := url.Values{}
q.Set("select", "created_at,backend,model,seed,prompt_hash,latency_ms,cost_usd_estimate,caller")
q.Set("order", "created_at.desc")
q.Set("limit", "5000")
if !since.IsZero() {
q.Set("created_at", "gte."+since.UTC().Format(time.RFC3339))
}
endpoint := s.URL + "/rest/v1/imagen_usage?" + q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("apikey", s.APIKey)
req.Header.Set("Authorization", "Bearer "+s.APIKey)
req.Header.Set("Accept-Profile", supabaseSchema)
resp, err := s.HTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("usage: GET: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("usage: GET %d: %s", resp.StatusCode, snip(body))
}
var rows []Row
if err := json.Unmarshal(body, &rows); err != nil {
return nil, fmt.Errorf("usage: parse rows: %w (body: %s)", err, snip(body))
}
return rows, nil
}
func snip(b []byte) string {
const max = 500
s := strings.TrimSpace(string(b))
if len(s) > max {
s = s[:max] + "..."
}
return s
}

217
internal/worker/worker.go Normal file
View File

@@ -0,0 +1,217 @@
// Package worker consumes the imagen.jobs queue. It claims pending rows via
// an UPDATE-returning lock (single source of truth, no double-claim window),
// runs the supplied generation pipeline, then writes status + image_id back.
//
// The package is DB-agnostic: it talks to two small interfaces (Queue +
// Pipeline) so unit tests can drive the claim/transition logic with no real
// Postgres connection. cmd/imagen wires the pgx implementation.
package worker
import (
"context"
"errors"
"fmt"
"sync"
"time"
)
// Job is the slice of an imagen.jobs row the worker needs to drive a
// generation. Null columns from the DB are represented as zero values; the
// pipeline treats zero values as "use backend default" (same convention as
// backend.Request).
type Job struct {
ID string
OwnerUserID string
Prompt string
Backend string
Model string
Width int
Height int
Steps int
Seed int64
Style string
// SeriesID is the parent imagen.series row when this job is one of N
// tries in a batch. Empty means a solo run — the pipeline must not
// propagate a series_id onto the resulting imagen.images row.
SeriesID string
}
// Outcome is what the pipeline reports back per job. ImageID is the
// imagen.images.id the cloud-sync produced. Empty ImageID with nil Err means
// the cloud-sync was skipped (config off) — we treat that as a failure for
// the worker since flexsiebels needs the image_id to render the result.
type Outcome struct {
ImageID string
Err error
}
// Queue is the persistence layer for the imagen.jobs table. Implementations
// must be safe for serialised single-worker use (concurrent claim across
// multiple worker processes is out of scope for v1 — the FOR UPDATE SKIP
// LOCKED clause in the pgx claim query covers it cheaply anyway).
type Queue interface {
// ClaimNextPending atomically marks the oldest pending row 'running' and
// returns it. Returns (nil, nil) when the queue is empty.
ClaimNextPending(ctx context.Context) (*Job, error)
// MarkDone records success: status='done', image_id, completed_at=now().
MarkDone(ctx context.Context, jobID, imageID string) error
// MarkFailed records failure: status='failed', error=msg, completed_at=now().
MarkFailed(ctx context.Context, jobID, errMsg string) error
// WaitForJob blocks until either a NOTIFY arrives on imagen_jobs, the
// timeout expires, or ctx is cancelled. Returns nil on notification or
// timeout; returns ctx.Err() on cancellation. Transient connection errors
// are returned so the caller can decide to reconnect.
WaitForJob(ctx context.Context, timeout time.Duration) error
// ResetStaleRunning marks any rows stuck in 'running' (e.g. left over
// from a crash before this process started) back to 'pending'. Called
// once at worker startup so the cold-start safety poll can pick them up.
ResetStaleRunning(ctx context.Context) error
}
// Pipeline runs one generation and reports back the imagen.images.id (or an
// error). The implementation owns backend dispatch, prompt enrichment, disk
// write, and cloud-sync; the worker only orchestrates queue state.
type Pipeline interface {
Run(ctx context.Context, job Job) Outcome
}
// Config is the runtime knob set for the worker loop.
type Config struct {
// PollInterval is the safety-poll cadence between LISTEN wakeups. Picking
// this too low wastes DB roundtrips; too high lets a dropped NOTIFY
// stall the queue. 5s is the spec'd default.
PollInterval time.Duration
// JobTimeout caps any single Pipeline.Run. A backend hang shouldn't
// freeze the queue forever.
JobTimeout time.Duration
// Logger receives one-line status events. nil means silent.
Logger func(format string, args ...any)
}
// Worker is the orchestration loop. It is not reusable across Run calls.
type Worker struct {
q Queue
p Pipeline
cfg Config
// processingMu guards the in-flight job so SIGTERM-triggered shutdown
// waits for it to complete before returning.
processingMu sync.Mutex
}
// New constructs a Worker.
func New(q Queue, p Pipeline, cfg Config) *Worker {
if cfg.PollInterval <= 0 {
cfg.PollInterval = 5 * time.Second
}
if cfg.JobTimeout <= 0 {
cfg.JobTimeout = 5 * time.Minute
}
return &Worker{q: q, p: p, cfg: cfg}
}
// Run drives the consume loop until ctx is cancelled or a fatal queue error
// (e.g. unrecoverable DB drop) is returned. A LISTEN wait can fail with a
// transient transport error; the worker logs and continues so a temporary
// network blip doesn't take it down.
func (w *Worker) Run(ctx context.Context) error {
if err := w.q.ResetStaleRunning(ctx); err != nil {
w.log("worker: reset stale running rows: %v", err)
// Don't return — a stale row will eventually be visible to the poll
// path once flexsiebels gives up and resubmits, and we'd rather keep
// serving fresh jobs than crash here.
}
for {
if err := ctx.Err(); err != nil {
return nil
}
// Drain the queue: claim and process until empty.
if err := w.drain(ctx); err != nil && !errors.Is(err, context.Canceled) {
w.log("worker: drain: %v", err)
}
if err := ctx.Err(); err != nil {
return nil
}
// Wait for the next wake. WaitForJob covers both LISTEN and the
// timeout-based poll fallback; either returns nil and we loop.
if err := w.q.WaitForJob(ctx, w.cfg.PollInterval); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
w.log("worker: wait: %v (continuing)", err)
// Pace the retries so a totally-broken DB doesn't busy-spin.
select {
case <-ctx.Done():
return nil
case <-time.After(w.cfg.PollInterval):
}
}
}
}
// drain claims and processes every currently-pending job. The job-scoped
// context is derived from context.Background() so that a SIGTERM mid-job
// still lets the pipeline finish — that's the "no half-state on shutdown"
// guarantee the issue calls for.
func (w *Worker) drain(ctx context.Context) error {
for {
if err := ctx.Err(); err != nil {
return err
}
job, err := w.q.ClaimNextPending(ctx)
if err != nil {
return fmt.Errorf("claim: %w", err)
}
if job == nil {
return nil
}
w.processOne(*job)
}
}
// processOne runs the pipeline for one already-claimed job and writes the
// outcome back to the queue. The job context is independent of the outer
// ctx so an in-flight job can finish even after SIGTERM.
func (w *Worker) processOne(job Job) {
w.processingMu.Lock()
defer w.processingMu.Unlock()
w.log("worker: processing job %s backend=%s", job.ID, job.Backend)
jobCtx, cancel := context.WithTimeout(context.Background(), w.cfg.JobTimeout)
defer cancel()
out := w.p.Run(jobCtx, job)
// Status-update uses Background ctx with a short timeout — we must
// always be able to record the outcome, otherwise the row sits in
// 'running' forever.
updCtx, updCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer updCancel()
if out.Err != nil {
w.log("worker: job %s failed: %v", job.ID, out.Err)
if err := w.q.MarkFailed(updCtx, job.ID, out.Err.Error()); err != nil {
w.log("worker: mark failed for %s: %v", job.ID, err)
}
return
}
if out.ImageID == "" {
// Pipeline reported success but no imagen.images row — treat as
// failure because flexsiebels has nothing to link.
const msg = "pipeline did not return an imagen.images id (cloud sync misconfigured?)"
w.log("worker: job %s: %s", job.ID, msg)
if err := w.q.MarkFailed(updCtx, job.ID, msg); err != nil {
w.log("worker: mark failed for %s: %v", job.ID, err)
}
return
}
if err := w.q.MarkDone(updCtx, job.ID, out.ImageID); err != nil {
w.log("worker: mark done for %s: %v", job.ID, err)
return
}
w.log("worker: job %s done image_id=%s", job.ID, out.ImageID)
}
func (w *Worker) log(format string, args ...any) {
if w.cfg.Logger != nil {
w.cfg.Logger(format, args...)
}
}

View File

@@ -0,0 +1,376 @@
package worker
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
)
// fakeQueue is a hand-rolled in-memory queue that mirrors the contract of a
// real Postgres-backed implementation: ClaimNextPending atomically takes one
// pending row and flips its status to "running", MarkDone/MarkFailed are
// idempotent terminal transitions, WaitForJob blocks until notified or until
// the timeout elapses.
type fakeQueue struct {
mu sync.Mutex
pending []Job
state map[string]string // jobID -> status
last map[string]string // jobID -> error msg or image_id
notify chan struct{}
claimErr error
doneErr error
failErr error
resetErr error
claimed int
done int
failed int
resets int
}
func newFakeQueue(jobs ...Job) *fakeQueue {
q := &fakeQueue{
state: make(map[string]string),
last: make(map[string]string),
notify: make(chan struct{}, 16),
}
for _, j := range jobs {
q.pending = append(q.pending, j)
q.state[j.ID] = "pending"
}
return q
}
func (q *fakeQueue) ClaimNextPending(ctx context.Context) (*Job, error) {
q.mu.Lock()
defer q.mu.Unlock()
if q.claimErr != nil {
return nil, q.claimErr
}
if len(q.pending) == 0 {
return nil, nil
}
j := q.pending[0]
q.pending = q.pending[1:]
q.state[j.ID] = "running"
q.claimed++
return &j, nil
}
func (q *fakeQueue) MarkDone(ctx context.Context, jobID, imageID string) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.doneErr != nil {
return q.doneErr
}
q.state[jobID] = "done"
q.last[jobID] = imageID
q.done++
return nil
}
func (q *fakeQueue) MarkFailed(ctx context.Context, jobID, msg string) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.failErr != nil {
return q.failErr
}
q.state[jobID] = "failed"
q.last[jobID] = msg
q.failed++
return nil
}
func (q *fakeQueue) WaitForJob(ctx context.Context, timeout time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-q.notify:
return nil
case <-time.After(timeout):
return nil
}
}
func (q *fakeQueue) ResetStaleRunning(ctx context.Context) error {
q.mu.Lock()
defer q.mu.Unlock()
q.resets++
return q.resetErr
}
// pingNotify simulates an INSERT-trigger NOTIFY by waking WaitForJob.
func (q *fakeQueue) pingNotify() {
select {
case q.notify <- struct{}{}:
default:
}
}
// stub pipeline.
type fakePipeline struct {
mu sync.Mutex
results map[string]Outcome // by job.ID; "" key = default outcome
calls int
delay time.Duration
lastJob Job
}
func (p *fakePipeline) Run(ctx context.Context, job Job) Outcome {
p.mu.Lock()
p.calls++
p.lastJob = job
delay := p.delay
out, ok := p.results[job.ID]
if !ok {
out = p.results[""]
}
p.mu.Unlock()
if delay > 0 {
select {
case <-ctx.Done():
return Outcome{Err: ctx.Err()}
case <-time.After(delay):
}
}
return out
}
func TestWorker_DonePath(t *testing.T) {
q := newFakeQueue(
Job{ID: "j1", Prompt: "a", Backend: "mock"},
)
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img-1"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(80 * time.Millisecond)
cancel()
}()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if got := q.state["j1"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
if got := q.last["j1"]; got != "img-1" {
t.Fatalf("image_id=%q want img-1", got)
}
if q.done != 1 || q.failed != 0 {
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
}
if p.calls != 1 {
t.Fatalf("pipeline calls=%d want 1", p.calls)
}
if q.resets != 1 {
t.Fatalf("ResetStaleRunning calls=%d want 1", q.resets)
}
}
func TestWorker_FailedPath_RecordsErrorText(t *testing.T) {
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
p := &fakePipeline{results: map[string]Outcome{"j1": {Err: errors.New("backend unreachable")}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := q.state["j1"]; got != "failed" {
t.Fatalf("state=%q want failed", got)
}
if got := q.last["j1"]; got != "backend unreachable" {
t.Fatalf("error=%q want %q", got, "backend unreachable")
}
if q.done != 0 || q.failed != 1 {
t.Fatalf("counts: done=%d failed=%d", q.done, q.failed)
}
}
func TestWorker_MissingImageID_TreatedAsFailure(t *testing.T) {
q := newFakeQueue(Job{ID: "j1", Prompt: "a", Backend: "mock"})
// Outcome has neither Err nor ImageID — pipeline silently swallowed
// cloud-sync. flexsiebels needs the image_id; without it, fail the job.
p := &fakePipeline{results: map[string]Outcome{"j1": {}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := q.state["j1"]; got != "failed" {
t.Fatalf("state=%q want failed", got)
}
if q.last["j1"] == "" {
t.Fatalf("expected non-empty error explanation for missing image_id")
}
}
func TestWorker_DrainsMultipleBeforeWaiting(t *testing.T) {
q := newFakeQueue(
Job{ID: "j1", Backend: "mock"},
Job{ID: "j2", Backend: "mock"},
Job{ID: "j3", Backend: "mock"},
)
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 200 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(60 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
for _, id := range []string{"j1", "j2", "j3"} {
if got := q.state[id]; got != "done" {
t.Fatalf("%s state=%q want done", id, got)
}
}
if q.done != 3 {
t.Fatalf("done=%d want 3", q.done)
}
}
func TestWorker_NotifyWakesEarlierThanPoll(t *testing.T) {
q := newFakeQueue()
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
// Set poll interval high so a working LISTEN is required to see the job
// promptly. Without NOTIFY plumbing this test would time out the worker
// before drain ever runs.
w := New(q, p, Config{PollInterval: 5 * time.Second, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
_ = w.Run(ctx)
close(done)
}()
// Append a job and ping the wake channel.
q.mu.Lock()
q.pending = append(q.pending, Job{ID: "late", Backend: "mock"})
q.state["late"] = "pending"
q.mu.Unlock()
q.pingNotify()
// Give the worker a beat to claim + process.
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
q.mu.Lock()
s := q.state["late"]
q.mu.Unlock()
if s == "done" {
cancel()
<-done
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("worker did not pick up the late job within the 500ms window — NOTIFY wake-up path is broken")
}
func TestWorker_HonoursContextCancellation(t *testing.T) {
q := newFakeQueue()
p := &fakePipeline{results: map[string]Outcome{"": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
start := time.Now()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if dur := time.Since(start); dur > 200*time.Millisecond {
t.Fatalf("worker did not exit promptly on ctx cancel: %v", dur)
}
}
func TestWorker_InflightJobFinishesAfterShutdown(t *testing.T) {
q := newFakeQueue(Job{ID: "long", Backend: "mock"})
p := &fakePipeline{
results: map[string]Outcome{"long": {ImageID: "img-long"}},
delay: 120 * time.Millisecond,
}
// Short JobTimeout would also kill the in-flight job; give it enough
// budget so the test exercises the shutdown-during-job path.
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: 5 * time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() {
// Let the job start, then cancel mid-flight.
time.Sleep(30 * time.Millisecond)
cancel()
}()
_ = w.Run(ctx)
if got := q.state["long"]; got != "done" {
t.Fatalf("state=%q want done (in-flight job should finish even on shutdown)", got)
}
}
// TestWorker_PropagatesSeriesIDToPipeline verifies the worker hands the
// Job's SeriesID through to the pipeline unchanged. The pipeline owns the
// cloud-sync side of the propagation (cloud.SyncRequest.SeriesID lands on
// imagen.images.series_id) — see cloud_test.go for that half — so the
// worker contract is simply: don't drop or rewrite SeriesID between
// claim and Run.
func TestWorker_PropagatesSeriesIDToPipeline(t *testing.T) {
const seriesID = "11111111-1111-1111-1111-111111111111"
q := newFakeQueue(Job{
ID: "j-series",
Prompt: "p",
Backend: "mock",
SeriesID: seriesID,
})
p := &fakePipeline{results: map[string]Outcome{"j-series": {ImageID: "img-series"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run: %v", err)
}
if got := p.lastJob.SeriesID; got != seriesID {
t.Fatalf("pipeline saw SeriesID=%q want %q", got, seriesID)
}
if got := q.state["j-series"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
}
// TestWorker_SoloJobLeavesSeriesIDEmpty is the negative case — a job
// claimed with no series row keeps the field empty all the way to the
// pipeline so cloud-sync writes NULL into imagen.images.series_id.
func TestWorker_SoloJobLeavesSeriesIDEmpty(t *testing.T) {
q := newFakeQueue(Job{ID: "j-solo", Prompt: "p", Backend: "mock"})
p := &fakePipeline{results: map[string]Outcome{"j-solo": {ImageID: "img-solo"}}}
w := New(q, p, Config{PollInterval: 10 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
go func() { time.Sleep(80 * time.Millisecond); cancel() }()
_ = w.Run(ctx)
if got := p.lastJob.SeriesID; got != "" {
t.Fatalf("solo job pipeline.lastJob.SeriesID=%q want empty", got)
}
}
func TestWorker_TransientClaimErrorDoesNotKillLoop(t *testing.T) {
// First claim returns an error; the loop should log and try again on the
// next wake — it must not propagate the error and exit.
q := newFakeQueue(Job{ID: "j1", Backend: "mock"})
q.claimErr = fmt.Errorf("transient: connection reset")
p := &fakePipeline{results: map[string]Outcome{"j1": {ImageID: "img"}}}
w := New(q, p, Config{PollInterval: 20 * time.Millisecond, JobTimeout: time.Second})
ctx, cancel := context.WithCancel(context.Background())
// Heal the claim error after a beat so the second drain succeeds.
go func() {
time.Sleep(40 * time.Millisecond)
q.mu.Lock()
q.claimErr = nil
q.mu.Unlock()
}()
go func() {
time.Sleep(200 * time.Millisecond)
cancel()
}()
if err := w.Run(ctx); err != nil {
t.Fatalf("Run returned: %v (transient claim errors should not kill the loop)", err)
}
if got := q.state["j1"]; got != "done" {
t.Fatalf("state=%q want done", got)
}
}

View File

@@ -0,0 +1,22 @@
# Environment for the imagen-worker.service systemd unit.
# Copy to ~/.dotfiles/.env.imagen-worker and fill in real values.
# Never commit the populated file — it carries the Supabase service-role key.
# Direct Postgres DSN for LISTEN/NOTIFY + imagen.jobs UPDATE statements.
# PostgREST cannot LISTEN, so the worker connects to Postgres directly.
# Host + port + password come from the msupabase compose env on mlake.
IMAGEN_WORKER_DATABASE_URL=postgres://postgres:CHANGE_ME@100.99.98.201:6789/postgres?sslmode=disable
# PostgREST endpoint for the imagen.images cloud-sync writer (same as
# `imagen generate`'s cloud-sync code path).
SUPABASE_URL=https://supa.flexsiebels.de
SUPABASE_SERVICE_KEY=CHANGE_ME
# Default owner_user_id. Per-job owner from the imagen.jobs row overrides
# this, so it's only used as a fallback when a job arrives with a NULL
# owner_user_id — which the schema disallows. Keep it set for safety.
IMAGEN_OWNER_USER_ID=ac6c9501-3757-4a6d-8b97-2cff4288382b
# Optional: REPLICATE_API_TOKEN if any imagen.jobs.backend may resolve to
# a Replicate adapter instance.
# REPLICATE_API_TOKEN=CHANGE_ME

View File

@@ -0,0 +1,19 @@
[Unit]
Description=ImaGen worker (consumes imagen.jobs queue)
Documentation=https://mgit.msbls.de/m/ImaGen/issues/8
Wants=network-online.target
After=network-online.target
[Service]
Type=simple
ExecStart=%h/dev/ImaGen/bin/imagen worker
WorkingDirectory=%h/dev/ImaGen
EnvironmentFile=%h/.dotfiles/.env.imagen-worker
Restart=on-failure
RestartSec=5
# Give the worker time to finish an in-flight generation on shutdown
# (FLUX dev up to ~30s, plus the cloud-sync write-back).
TimeoutStopSec=60
[Install]
WantedBy=default.target