Path 1 architecture: one comfyui adapter, workflows as data.
- workflow_template.go: embed.FS + token substitution with type-preserving
whole-value placeholders. ${prompt} → string, ${seed} → int64,
${cfg} → float64 — no JSON round-tripping. Partial matches ignored.
- comfyui.go: refactored to load workflow from embedded FS or filesystem
path. Back-compat preserved: workflow: defaults to flux1-schnell.
- workflows/{flux1-schnell,flux2-klein,sd35-medium}.json — bundled
templates. flux1-schnell migrated from hardcoded with identical node IDs.
- compare.go: new `imagen compare` subcommand. Sequential N-backend run
(one GPU on mRock — parallel would OOM), per-backend PNG, sidecar JSON
with per-model metadata + errors, composite contact sheet via Go image
package (no ImageMagick dep).
- Sample config gains flux2-klein-local + sd35-medium-local instances.
- docs/backends.md: architecture rationale + per-model HF download paths
+ how to add a new bundled workflow + compare-harness reference.
Live smoke verified: compare mock + flux-schnell-local at 768×768 →
both PNGs written, sidecar JSON has workflow="flux1-schnell" + full
metadata, contact sheet renders. Worker contract (Request → Generate)
unchanged, so flexsiebels /imagine UI API surface preserved.
Tests: 11 existing comfyui + 6 new workflow_template + 5 new compare
tests, all green.
Adding a new model is now yaml + JSON, never Go.
495 lines
15 KiB
Go
495 lines
15 KiB
Go
package backend
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image"
|
|
"image/color"
|
|
"image/png"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// fakeComfy is a programmable mock of the ComfyUI HTTP API. Tests configure
|
|
// its behaviour by adjusting the public fields before issuing the request.
|
|
type fakeComfy struct {
|
|
t *testing.T
|
|
|
|
// /prompt
|
|
promptStatus int
|
|
promptBody []byte
|
|
promptCalls atomic.Int32
|
|
failPromptUntil int32 // first N /prompt calls return promptFailStatus
|
|
promptFailStatus int
|
|
promptFailBody []byte
|
|
|
|
// /history — start by returning {} (no entry), flip to completed once
|
|
// historyReadyAfter polls have happened.
|
|
historyReadyAfter int32
|
|
historyCalls atomic.Int32
|
|
historyError bool
|
|
|
|
// /view
|
|
viewStatus int
|
|
viewBody []byte
|
|
viewType string
|
|
|
|
// /system_stats
|
|
statsTotal int64
|
|
statsFree int64
|
|
|
|
server *httptest.Server
|
|
}
|
|
|
|
func (f *fakeComfy) handler() http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.URL.Path == "/prompt" && r.Method == http.MethodPost:
|
|
n := f.promptCalls.Add(1)
|
|
if n <= int32(f.failPromptUntil) {
|
|
w.WriteHeader(f.promptFailStatus)
|
|
_, _ = w.Write(f.promptFailBody)
|
|
return
|
|
}
|
|
w.WriteHeader(f.promptStatus)
|
|
_, _ = w.Write(f.promptBody)
|
|
case strings.HasPrefix(r.URL.Path, "/history/") && r.Method == http.MethodGet:
|
|
n := f.historyCalls.Add(1)
|
|
id := strings.TrimPrefix(r.URL.Path, "/history/")
|
|
w.WriteHeader(http.StatusOK)
|
|
if f.historyError {
|
|
_, _ = fmt.Fprintf(w, `{"%s":{"status":{"completed":false,"status_str":"error"},"outputs":{}}}`, id)
|
|
return
|
|
}
|
|
if n <= f.historyReadyAfter {
|
|
_, _ = w.Write([]byte(`{}`))
|
|
return
|
|
}
|
|
_, _ = fmt.Fprintf(w,
|
|
`{"%s":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`,
|
|
id,
|
|
)
|
|
case r.URL.Path == "/view" && r.Method == http.MethodGet:
|
|
ct := f.viewType
|
|
if ct == "" {
|
|
ct = "image/png"
|
|
}
|
|
w.Header().Set("Content-Type", ct)
|
|
w.WriteHeader(f.viewStatus)
|
|
_, _ = w.Write(f.viewBody)
|
|
case r.URL.Path == "/system_stats" && r.Method == http.MethodGet:
|
|
w.Header().Set("Content-Type", "application/json")
|
|
body := map[string]any{
|
|
"system": map[string]any{},
|
|
"devices": []map[string]any{
|
|
{"vram_total": f.statsTotal, "vram_free": f.statsFree},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(body)
|
|
default:
|
|
f.t.Errorf("fakeComfy: unexpected request %s %s", r.Method, r.URL.Path)
|
|
http.NotFound(w, r)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (f *fakeComfy) start() {
|
|
f.server = httptest.NewServer(f.handler())
|
|
f.t.Cleanup(f.server.Close)
|
|
}
|
|
|
|
// newFakeComfy spins up a fakeComfy with happy-path defaults.
|
|
func newFakeComfy(t *testing.T) *fakeComfy {
|
|
t.Helper()
|
|
f := &fakeComfy{
|
|
t: t,
|
|
promptStatus: http.StatusOK,
|
|
promptBody: []byte(`{"prompt_id":"pid-abc","number":1,"node_errors":{}}`),
|
|
viewStatus: http.StatusOK,
|
|
viewBody: mustPNG(t, 16, 16),
|
|
statsTotal: 16 * 1024 * 1024 * 1024,
|
|
statsFree: 8 * 1024 * 1024 * 1024,
|
|
}
|
|
f.start()
|
|
return f
|
|
}
|
|
|
|
// newComfy returns a Comfy pointed at f, with poll interval squashed for fast
|
|
// tests and deterministic seed/client_id.
|
|
func newComfy(t *testing.T, f *fakeComfy) *Comfy {
|
|
t.Helper()
|
|
be, err := NewComfy("flux-test", map[string]any{
|
|
"base_url": f.server.URL,
|
|
"model": "flux1-schnell.safetensors",
|
|
"default_steps": 4,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewComfy: %v", err)
|
|
}
|
|
c := be.(*Comfy)
|
|
c.pollInterval = time.Millisecond
|
|
c.pollTimeout = 5 * time.Second
|
|
c.randSeed = func() int64 { return 42 }
|
|
c.clientIDFn = func() string { return "imagen-test" }
|
|
return c
|
|
}
|
|
|
|
func mustPNG(t *testing.T, w, h int) []byte {
|
|
t.Helper()
|
|
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
|
for y := range h {
|
|
for x := range w {
|
|
img.Set(x, y, color.RGBA{R: 200, G: 100, B: 50, A: 255})
|
|
}
|
|
}
|
|
var buf bytes.Buffer
|
|
if err := png.Encode(&buf, img); err != nil {
|
|
t.Fatalf("encode png: %v", err)
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func TestComfyConstructorRequiresBaseAndModel(t *testing.T) {
|
|
if _, err := NewComfy("x", map[string]any{}); err == nil {
|
|
t.Errorf("expected error for missing base_url")
|
|
}
|
|
if _, err := NewComfy("x", map[string]any{"base_url": "http://h:1"}); err == nil {
|
|
t.Errorf("expected error for missing model")
|
|
}
|
|
if _, err := NewComfy("", map[string]any{"base_url": "http://h:1", "model": "m"}); err == nil {
|
|
t.Errorf("expected error for empty instance name")
|
|
}
|
|
}
|
|
|
|
func TestComfyHappyPath(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.historyReadyAfter = 2 // exercise the polling loop
|
|
c := newComfy(t, f)
|
|
|
|
res, err := c.Generate(context.Background(), Request{
|
|
Prompt: "a small fishbowl with a cat",
|
|
Width: 512,
|
|
Height: 512,
|
|
Steps: 4,
|
|
Seed: 1234567,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
defer res.ImageReader.Close()
|
|
|
|
if res.MimeType != "image/png" {
|
|
t.Errorf("mime = %q", res.MimeType)
|
|
}
|
|
body, err := io.ReadAll(res.ImageReader)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
if !bytes.Equal(body, f.viewBody) {
|
|
t.Errorf("image body did not round-trip")
|
|
}
|
|
|
|
if seed, _ := res.Metadata["seed"].(int64); seed != 1234567 {
|
|
t.Errorf("metadata seed = %v", res.Metadata["seed"])
|
|
}
|
|
if model, _ := res.Metadata["model"].(string); model != "flux1-schnell.safetensors" {
|
|
t.Errorf("metadata model = %v", res.Metadata["model"])
|
|
}
|
|
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
|
|
t.Errorf("metadata steps = %v", res.Metadata["steps"])
|
|
}
|
|
if pid, _ := res.Metadata["prompt_id"].(string); pid != "pid-abc" {
|
|
t.Errorf("metadata prompt_id = %v", res.Metadata["prompt_id"])
|
|
}
|
|
if _, ok := res.Metadata["latency_ms"]; !ok {
|
|
t.Errorf("metadata missing latency_ms")
|
|
}
|
|
// vram_used_mib is best-effort but should be present given our mock stats
|
|
if vram, _ := res.Metadata["vram_used_mib"].(int64); vram != 8192 {
|
|
t.Errorf("metadata vram_used_mib = %v, want 8192", res.Metadata["vram_used_mib"])
|
|
}
|
|
|
|
if got := f.historyCalls.Load(); got < 3 {
|
|
t.Errorf("expected at least 3 /history polls, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestComfyDefaultsAppliedWhenZero(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
c := newComfy(t, f)
|
|
|
|
res, err := c.Generate(context.Background(), Request{Prompt: "p"}) // all-zero
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
defer res.ImageReader.Close()
|
|
_, _ = io.ReadAll(res.ImageReader)
|
|
|
|
if w, _ := res.Metadata["width"].(int); w != 1024 {
|
|
t.Errorf("width default = %v", res.Metadata["width"])
|
|
}
|
|
if steps, _ := res.Metadata["steps"].(int); steps != 4 {
|
|
t.Errorf("steps default = %v", res.Metadata["steps"])
|
|
}
|
|
if seed, _ := res.Metadata["seed"].(int64); seed != 42 {
|
|
t.Errorf("seed default (test rand hook) = %v", res.Metadata["seed"])
|
|
}
|
|
if s, _ := res.Metadata["sampler"].(string); s != "euler" {
|
|
t.Errorf("sampler default = %q", s)
|
|
}
|
|
}
|
|
|
|
func TestComfyPromptRetriesOnce5xx(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.failPromptUntil = 1
|
|
f.promptFailStatus = http.StatusBadGateway
|
|
f.promptFailBody = []byte("upstream busy")
|
|
c := newComfy(t, f)
|
|
|
|
res, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err != nil {
|
|
t.Fatalf("Generate (with one 502 then OK): %v", err)
|
|
}
|
|
defer res.ImageReader.Close()
|
|
_, _ = io.ReadAll(res.ImageReader)
|
|
|
|
if got := f.promptCalls.Load(); got != 2 {
|
|
t.Errorf("expected exactly 2 /prompt calls (1 fail + 1 retry), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestComfyPromptGivesUpAfterTwo5xx(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.failPromptUntil = 99 // every call fails
|
|
f.promptFailStatus = http.StatusServiceUnavailable
|
|
f.promptFailBody = []byte("nope")
|
|
c := newComfy(t, f)
|
|
|
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error after sustained 503s")
|
|
}
|
|
if !strings.Contains(err.Error(), "503") {
|
|
t.Errorf("expected error to mention 503, got %v", err)
|
|
}
|
|
if got := f.promptCalls.Load(); got != 2 {
|
|
t.Errorf("expected exactly 2 /prompt calls (no further retries), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestComfyPromptDoesNotRetryOn4xx(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.failPromptUntil = 99
|
|
f.promptFailStatus = http.StatusBadRequest
|
|
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation"},"node_errors":{"some":"thing"}}`)
|
|
c := newComfy(t, f)
|
|
|
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error for 400")
|
|
}
|
|
if got := f.promptCalls.Load(); got != 1 {
|
|
t.Errorf("expected exactly 1 /prompt call (no retry on 4xx), got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestComfyMissingModelHintsAtSetupDoc(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.failPromptUntil = 99
|
|
f.promptFailStatus = http.StatusBadRequest
|
|
f.promptFailBody = []byte(`{"error":{"type":"prompt_outputs_failed_validation","message":"Prompt outputs failed validation"},"node_errors":{"12":{"errors":[{"type":"value_not_in_list","message":"Value not in list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
|
|
c := newComfy(t, f)
|
|
|
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
msg := err.Error()
|
|
if !strings.Contains(msg, "docs/backends.md") {
|
|
t.Errorf("error should point at the setup doc, got %v", err)
|
|
}
|
|
if !strings.Contains(msg, "flux1-schnell.safetensors") {
|
|
t.Errorf("error should name the missing model, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComfyMissingModelOn200WithNodeErrors(t *testing.T) {
|
|
// Older ComfyUI builds 200 a workflow-validation failure.
|
|
f := newFakeComfy(t)
|
|
f.promptStatus = http.StatusOK
|
|
f.promptBody = []byte(`{"prompt_id":"","node_errors":{"12":{"errors":[{"type":"value_not_in_list","details":"unet_name: 'flux1-schnell.safetensors' not in []"}]}}}`)
|
|
c := newComfy(t, f)
|
|
|
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error for node_errors on 200")
|
|
}
|
|
if !strings.Contains(err.Error(), "docs/backends.md") {
|
|
t.Errorf("error should point at the setup doc, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComfyHistoryErrorSurfaced(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.historyError = true
|
|
c := newComfy(t, f)
|
|
|
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error when history reports execution error")
|
|
}
|
|
if !strings.Contains(err.Error(), "errored") {
|
|
t.Errorf("expected 'errored' in message, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComfyViewFailureSurfaced(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.viewStatus = http.StatusNotFound
|
|
f.viewBody = []byte("nope")
|
|
c := newComfy(t, f)
|
|
|
|
_, err := c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error when /view 404s")
|
|
}
|
|
if !strings.Contains(err.Error(), "404") {
|
|
t.Errorf("expected status code in error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComfyUnreachableHostMentionsBootHelper(t *testing.T) {
|
|
be, err := NewComfy("flux-test", map[string]any{
|
|
"base_url": "http://127.0.0.1:1", // closed port; connection refused
|
|
"model": "flux1-schnell.safetensors",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewComfy: %v", err)
|
|
}
|
|
c := be.(*Comfy)
|
|
c.httpClient = &http.Client{Timeout: 500 * time.Millisecond}
|
|
|
|
_, err = c.Generate(context.Background(), Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected error for unreachable host")
|
|
}
|
|
if !strings.Contains(err.Error(), "boot-whitetower mrock") {
|
|
t.Errorf("expected boot-helper hint, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComfyContextCancelStopsPolling(t *testing.T) {
|
|
f := newFakeComfy(t)
|
|
f.historyReadyAfter = 1_000_000 // never finishes
|
|
c := newComfy(t, f)
|
|
c.pollInterval = 5 * time.Millisecond
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err := c.Generate(ctx, Request{Prompt: "p", Width: 64, Height: 64})
|
|
if err == nil {
|
|
t.Fatal("expected ctx.Err()")
|
|
}
|
|
if !strings.Contains(err.Error(), "context deadline exceeded") {
|
|
t.Errorf("expected deadline exceeded, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComfyWorkflowReflectsRequest(t *testing.T) {
|
|
// Capture the workflow body to assert KSampler + EmptyLatentImage values.
|
|
var captured []byte
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/prompt":
|
|
captured, _ = io.ReadAll(r.Body)
|
|
_, _ = w.Write([]byte(`{"prompt_id":"pid","number":1,"node_errors":{}}`))
|
|
case "/history/pid":
|
|
_, _ = w.Write([]byte(`{"pid":{"status":{"completed":true,"status_str":"success"},"outputs":{"9":{"images":[{"filename":"imagen_00001_.png","subfolder":"","type":"output"}]}}}}`))
|
|
case "/view":
|
|
_, _ = w.Write(mustPNG(t, 8, 8))
|
|
case "/system_stats":
|
|
_, _ = w.Write([]byte(`{"devices":[{"vram_total":1,"vram_free":1}]}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
|
|
be, err := NewComfy("flux-test", map[string]any{
|
|
"base_url": srv.URL,
|
|
"model": "custom.safetensors",
|
|
"default_steps": 7,
|
|
"default_sampler": "dpmpp_2m",
|
|
"default_scheduler": "karras",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewComfy: %v", err)
|
|
}
|
|
c := be.(*Comfy)
|
|
c.pollInterval = time.Millisecond
|
|
c.randSeed = func() int64 { return 9999 }
|
|
|
|
res, err := c.Generate(context.Background(), Request{
|
|
Prompt: "a cat",
|
|
NegativePrompt: "blurry",
|
|
Width: 768,
|
|
Height: 512,
|
|
Steps: 11,
|
|
Seed: 555,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
res.ImageReader.Close()
|
|
|
|
var sent struct {
|
|
Prompt map[string]map[string]any `json:"prompt"`
|
|
ClientID string `json:"client_id"`
|
|
}
|
|
if err := json.Unmarshal(captured, &sent); err != nil {
|
|
t.Fatalf("unmarshal captured: %v", err)
|
|
}
|
|
ks := sent.Prompt["31"]["inputs"].(map[string]any)
|
|
if ks["seed"].(float64) != 555 {
|
|
t.Errorf("KSampler seed = %v, want 555", ks["seed"])
|
|
}
|
|
if ks["steps"].(float64) != 11 {
|
|
t.Errorf("KSampler steps = %v, want 11", ks["steps"])
|
|
}
|
|
if ks["sampler_name"].(string) != "dpmpp_2m" {
|
|
t.Errorf("sampler_name = %v", ks["sampler_name"])
|
|
}
|
|
if ks["scheduler"].(string) != "karras" {
|
|
t.Errorf("scheduler = %v", ks["scheduler"])
|
|
}
|
|
latent := sent.Prompt["27"]["inputs"].(map[string]any)
|
|
if latent["width"].(float64) != 768 || latent["height"].(float64) != 512 {
|
|
t.Errorf("EmptySD3LatentImage size = %vx%v", latent["width"], latent["height"])
|
|
}
|
|
unet := sent.Prompt["12"]["inputs"].(map[string]any)
|
|
if unet["unet_name"].(string) != "custom.safetensors" {
|
|
t.Errorf("unet_name = %v", unet["unet_name"])
|
|
}
|
|
neg := sent.Prompt["13"]["inputs"].(map[string]any)
|
|
if neg["text"].(string) != "blurry" {
|
|
t.Errorf("negative prompt not threaded: %v", neg["text"])
|
|
}
|
|
if !strings.HasPrefix(sent.ClientID, "imagen-") && sent.ClientID == "" {
|
|
t.Errorf("client_id should be set: %q", sent.ClientID)
|
|
}
|
|
}
|
|
|
|
func TestComfyTypeIsRegistered(t *testing.T) {
|
|
if !Default.Has(ComfyType) {
|
|
t.Errorf("comfyui type not registered in Default")
|
|
}
|
|
}
|