Files
ImaGen/cmd/imagen/compare.go
mAi 8435817ce1 mAi: #10 - multi-model backend expansion (workflow templates + compare harness)
Path 1 architecture: one comfyui adapter, workflows as data.

- workflow_template.go: embed.FS + token substitution with type-preserving
  whole-value placeholders. ${prompt} → string, ${seed} → int64,
  ${cfg} → float64 — no JSON round-tripping. Partial matches ignored.
- comfyui.go: refactored to load workflow from embedded FS or filesystem
  path. Back-compat preserved: workflow: defaults to flux1-schnell.
- workflows/{flux1-schnell,flux2-klein,sd35-medium}.json — bundled
  templates. flux1-schnell migrated from hardcoded with identical node IDs.
- compare.go: new `imagen compare` subcommand. Sequential N-backend run
  (one GPU on mRock — parallel would OOM), per-backend PNG, sidecar JSON
  with per-model metadata + errors, composite contact sheet via Go image
  package (no ImageMagick dep).
- Sample config gains flux2-klein-local + sd35-medium-local instances.
- docs/backends.md: architecture rationale + per-model HF download paths
  + how to add a new bundled workflow + compare-harness reference.

Live smoke verified: compare mock + flux-schnell-local at 768×768 →
both PNGs written, sidecar JSON has workflow="flux1-schnell" + full
metadata, contact sheet renders. Worker contract (Request → Generate)
unchanged, so flexsiebels /imagine UI API surface preserved.

Tests: 11 existing comfyui + 6 new workflow_template + 5 new compare
tests, all green.

Adding a new model is now yaml + JSON, never Go.
2026-05-11 17:29:57 +02:00

387 lines
12 KiB
Go

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"image"
"image/color"
"image/draw"
"image/png"
"io"
"os"
"path/filepath"
"sort"
"strings"
"time"
"golang.org/x/image/font"
"golang.org/x/image/font/basicfont"
"golang.org/x/image/math/fixed"
"mgit.msbls.de/m/ImaGen/internal/backend"
"mgit.msbls.de/m/ImaGen/internal/config"
"mgit.msbls.de/m/ImaGen/internal/output"
"mgit.msbls.de/m/ImaGen/internal/prompt"
)
// runCompare implements `imagen compare "<prompt>" --models a,b,c --output <dir>`.
//
// Each backend in --models runs sequentially against the same prompt (mRock
// has a single GPU; parallelising would just OOM). Each generation lands as
// a backend-suffixed file in the output dir; a contact sheet stitches them
// together into one PNG with the backend name overlaid on each cell. A
// sidecar JSON next to the contact sheet lists every generation with its
// per-model metadata (latency, seed, model file, VRAM peak).
func runCompare(ctx context.Context, args []string) error {
fs := flag.NewFlagSet("compare", flag.ContinueOnError)
var (
modelsCSV string
size string
outDir string
style string
negative string
seed int64
steps int
configPath string
noContact bool
)
fs.StringVar(&modelsCSV, "models", "", "comma-separated backend instance names (required)")
fs.StringVar(&size, "size", "1024x1024", "WxH for every backend")
fs.StringVar(&outDir, "output", "", "directory to write the images + contact sheet (default: ~/Pictures/imagen/compare)")
fs.StringVar(&style, "style", "", "style preset applied to the prompt before dispatching to each backend")
fs.StringVar(&negative, "negative", "", "negative prompt (forwarded to every backend that supports it)")
fs.Int64Var(&seed, "seed", 0, "deterministic seed for every backend (0 = each backend rolls its own)")
fs.IntVar(&steps, "steps", 0, "diffusion steps (0 = each backend's default)")
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
fs.BoolVar(&noContact, "no-contact-sheet", false, "skip the composite PNG; only write per-backend images + sidecar")
fs.Usage = func() {
fmt.Fprintln(fs.Output(), `Usage: imagen compare "<prompt>" --models a,b,c [flags]`)
fs.PrintDefaults()
}
leadingPositional, flagArgs := splitLeadingPositional(args)
if err := fs.Parse(flagArgs); err != nil {
return err
}
positional := append(leadingPositional, fs.Args()...)
if len(positional) == 0 {
fs.Usage()
return userErr("missing prompt")
}
rawPrompt := strings.Join(positional, " ")
modelNames := splitCSV(modelsCSV)
if len(modelNames) == 0 {
return userErr("--models is required (comma-separated backend instance names)")
}
w, h, err := parseSize(size)
if err != nil {
return userErr("bad --size: %v", err)
}
cfg, cfgErr := config.Load(configPath)
if cfgErr != nil && !os.IsNotExist(cfgErr) {
return cfgErr
}
if outDir == "" {
home, _ := os.UserHomeDir()
outDir = filepath.Join(home, "Pictures", "imagen", "compare")
}
outDir = config.ExpandPath(outDir)
finalPrompt, err := prompt.Apply(rawPrompt, style)
if err != nil {
return userErr("%v", err)
}
runID := time.Now().Format("20060102-150405")
runDir := filepath.Join(outDir, runID+"-"+output.Slug(rawPrompt))
if err := os.MkdirAll(runDir, 0o755); err != nil {
return fmt.Errorf("mkdir %s: %w", runDir, err)
}
results := make([]compareResult, 0, len(modelNames))
for i, name := range modelNames {
fmt.Fprintf(os.Stderr, "[%d/%d] %s ...\n", i+1, len(modelNames), name)
res, err := generateOne(ctx, cfg, name, finalPrompt, negative, w, h, seed, steps, runDir, rawPrompt)
if err != nil {
// Don't abort the whole run on a single backend failure — record
// the error and continue. flexsiebels-style consumers want to
// see N-1 results rather than zero when one model is offline.
fmt.Fprintf(os.Stderr, " failed: %v\n", err)
results = append(results, compareResult{Backend: name, Error: err.Error()})
continue
}
fmt.Fprintf(os.Stderr, " %s (%d ms)\n", res.ImagePath, res.LatencyMs)
results = append(results, res)
}
// Sidecar JSON beside the run dir captures every attempt.
sidecar := filepath.Join(runDir, "compare.json")
if err := writeCompareSidecar(sidecar, rawPrompt, style, negative, w, h, seed, steps, results); err != nil {
return err
}
fmt.Fprintln(os.Stderr, "sidecar:", sidecar)
// Contact sheet stitches the successful results together. If every
// backend failed there's nothing to draw, so skip silently.
if !noContact {
successes := successfulResults(results)
if len(successes) > 0 {
sheet := filepath.Join(runDir, "contact-sheet.png")
if err := writeContactSheet(sheet, rawPrompt, successes); err != nil {
return fmt.Errorf("contact sheet: %w", err)
}
fmt.Println(sheet)
} else {
fmt.Fprintln(os.Stderr, "imagen compare: all backends failed; no contact sheet written")
}
}
return nil
}
// compareResult is one backend's output in a comparison run. Error is set
// when Generate failed for this backend; ImagePath + Metadata are empty in
// that case.
type compareResult struct {
Backend string `json:"backend"`
ImagePath string `json:"image_path,omitempty"`
Seed int64 `json:"seed"`
LatencyMs int64 `json:"latency_ms,omitempty"`
Model string `json:"model,omitempty"`
VRAMUsedMiB int64 `json:"vram_used_mib,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Error string `json:"error,omitempty"`
}
func generateOne(ctx context.Context, cfg *config.Config, name, finalPrompt, negative string, w, h int, seed int64, steps int, runDir, rawPrompt string) (compareResult, error) {
be, err := buildBackend(cfg, name)
if err != nil {
return compareResult{Backend: name}, err
}
attachUsageSink(be)
req := backend.Request{
Prompt: finalPrompt,
NegativePrompt: negative,
Width: w,
Height: h,
Steps: steps,
Seed: seed,
}
res, err := be.Generate(ctx, req)
if err != nil {
return compareResult{Backend: name}, err
}
defer res.ImageReader.Close()
imgBytes, err := io.ReadAll(res.ImageReader)
if err != nil {
return compareResult{Backend: name}, fmt.Errorf("read image: %w", err)
}
imgPath := filepath.Join(runDir, output.Slug(rawPrompt)+"--"+output.Slug(name)+"."+extFromMime(res.MimeType))
if err := os.WriteFile(imgPath, imgBytes, 0o644); err != nil {
return compareResult{Backend: name}, fmt.Errorf("write %s: %w", imgPath, err)
}
cr := compareResult{
Backend: name,
ImagePath: imgPath,
Seed: seedFromMetadata(res.Metadata, seed),
LatencyMs: metaInt64(res.Metadata, "latency_ms"),
Model: metaString(res.Metadata, "model"),
Metadata: res.Metadata,
}
if v, ok := res.Metadata["vram_used_mib"].(int64); ok {
cr.VRAMUsedMiB = v
}
return cr, nil
}
func successfulResults(rs []compareResult) []compareResult {
out := make([]compareResult, 0, len(rs))
for _, r := range rs {
if r.Error == "" && r.ImagePath != "" {
out = append(out, r)
}
}
return out
}
func writeCompareSidecar(path, rawPrompt, style, negative string, w, h int, seed int64, steps int, results []compareResult) error {
body := map[string]any{
"timestamp": time.Now().UTC().Format(time.RFC3339),
"prompt": rawPrompt,
"style": style,
"negative": negative,
"width": w,
"height": h,
"seed": seed,
"steps": steps,
"results": results,
"backends": backendNames(results),
"successful": len(successfulResults(results)),
"total": len(results),
}
data, err := json.MarshalIndent(body, "", " ")
if err != nil {
return fmt.Errorf("marshal sidecar: %w", err)
}
return os.WriteFile(path, append(data, '\n'), 0o644)
}
func backendNames(rs []compareResult) []string {
out := make([]string, len(rs))
for i, r := range rs {
out[i] = r.Backend
}
sort.Strings(out)
return out
}
// writeContactSheet stitches a grid of (image, label) cells into one PNG.
// Cells are sized to fit in a target width of ~2400px while keeping each
// individual image full-resolution (no downscale) up to the column limit;
// past that, images sit at their native size and we just lay them out.
//
// The grid is a simple horizontal row when N <= 4; otherwise N/2 rows of 2.
// This is a contact sheet, not a fancy gallery — readability for side-by-
// side eyeballing is the goal.
func writeContactSheet(path, prompt string, results []compareResult) error {
if len(results) == 0 {
return fmt.Errorf("no successful results to lay out")
}
cells := make([]contactCell, 0, len(results))
for _, r := range results {
img, err := readPNG(r.ImagePath)
if err != nil {
return fmt.Errorf("read %s: %w", r.ImagePath, err)
}
cells = append(cells, contactCell{
Image: img,
Label: r.Backend,
SubLabel: fmt.Sprintf("%dms · seed %d", r.LatencyMs, r.Seed),
})
}
cols := len(cells)
if cols > 4 {
cols = 2
}
rows := (len(cells) + cols - 1) / cols
const labelH = 64
const pad = 16
cellW := cells[0].Image.Bounds().Dx()
cellH := cells[0].Image.Bounds().Dy()
for _, c := range cells {
if w := c.Image.Bounds().Dx(); w > cellW {
cellW = w
}
if h := c.Image.Bounds().Dy(); h > cellH {
cellH = h
}
}
totalW := cols*cellW + (cols+1)*pad
totalH := rows*(cellH+labelH) + (rows+1)*pad + 48 // header band
canvas := image.NewRGBA(image.Rect(0, 0, totalW, totalH))
draw.Draw(canvas, canvas.Bounds(), &image.Uniform{C: color.RGBA{R: 30, G: 30, B: 35, A: 255}}, image.Point{}, draw.Src)
// Header: show the truncated prompt.
headerText := "imagen compare — " + truncate(prompt, 100)
drawText(canvas, headerText, pad, 30, color.RGBA{R: 240, G: 240, B: 245, A: 255})
for i, c := range cells {
col := i % cols
row := i / cols
x0 := pad + col*(cellW+pad)
y0 := 48 + pad + row*(cellH+labelH+pad)
// Center the image inside the cell when smaller than the max cell size.
iw := c.Image.Bounds().Dx()
ih := c.Image.Bounds().Dy()
offX := (cellW - iw) / 2
offY := (cellH - ih) / 2
dstRect := image.Rect(x0+offX, y0+offY, x0+offX+iw, y0+offY+ih)
draw.Draw(canvas, dstRect, c.Image, c.Image.Bounds().Min, draw.Src)
// Label band underneath.
labelY := y0 + cellH + 20
drawText(canvas, c.Label, x0+8, labelY, color.RGBA{R: 250, G: 250, B: 250, A: 255})
drawText(canvas, c.SubLabel, x0+8, labelY+22, color.RGBA{R: 180, G: 180, B: 190, A: 255})
}
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
return png.Encode(f, canvas)
}
type contactCell struct {
Image image.Image
Label string
SubLabel string
}
func readPNG(path string) (image.Image, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
img, _, err := image.Decode(f)
return img, err
}
func drawText(dst *image.RGBA, s string, x, y int, c color.Color) {
drawer := &font.Drawer{
Dst: dst,
Src: &image.Uniform{C: c},
Face: basicfont.Face7x13,
Dot: fixed.Point26_6{X: fixed.I(x), Y: fixed.I(y)},
}
drawer.DrawString(s)
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max-1] + "…"
}
func splitCSV(s string) []string {
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func metaInt64(m map[string]any, key string) int64 {
v, ok := m[key]
if !ok {
return 0
}
switch n := v.(type) {
case int64:
return n
case int:
return int64(n)
case float64:
return int64(n)
}
return 0
}