czq0719 commited on
Commit
1eab9d3
·
1 Parent(s): 8968c91

Add Gradio app for Spaces

Browse files
Files changed (2) hide show
  1. demo.py +916 -0
  2. requirements.txt +11 -0
demo.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================
2
+ # ONE-CELL: SDXL + CritiCore + SpecFusion + Gradio UI
3
+ # - Fixes: api_name=False + show_api=False, share=True, VARIANT_* before Blocks
4
+ # - UI style follows your reference (ui_run_once + _run_async + gallery/meta)
5
+ # =========================
6
+
7
+ import os, re, io, json, time, base64, asyncio, inspect, traceback
8
+ from pathlib import Path
9
+ from typing import List, Dict, Optional, Tuple, Iterable, Set
10
+
11
+ import torch
12
+ from PIL import Image
13
+ import nest_asyncio
14
+ nest_asyncio.apply()
15
+
16
+ import gradio as gr
17
+ from diffusers import (
18
+ StableDiffusionXLPipeline,
19
+ StableDiffusionXLImg2ImgPipeline,
20
+ DPMSolverMultistepScheduler,
21
+ )
22
+
23
+ # =========================
24
+ # 0) Variant names / display order (MUST be BEFORE Blocks)
25
+ # =========================
26
+ VARIANT_LABELS = {
27
+ "base_multi_llm": "1_base_multi_llm",
28
+ "criticore_on_multi_llm__specfusion": "2_criticore_on_multi_llm__specfusion",
29
+ "base_original": "3_base_original",
30
+ "criticore_on_original__specfusion": "4_criticore_on_original__specfusion",
31
+ }
32
+ VARIANT_ORDER = [
33
+ "1_base_multi_llm",
34
+ "2_criticore_on_multi_llm__specfusion",
35
+ "3_base_original",
36
+ "4_criticore_on_original__specfusion",
37
+ ]
38
+ RHO_T_DEFAULT = 0.85 # fixed as requested
39
+
40
+ # ---- SAFETY: do NOT hardcode API keys ----
41
+ TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY", "").strip()
42
+ if not TOGETHER_API_KEY:
43
+ print("[Warn] TOGETHER_API_KEY is not set. CritiCore will fail if you use Together models.")
44
+
45
+ # =========================
46
+ # 1) SDXL init
47
+ # =========================
48
+ DEVICE_STR = "cuda" if torch.cuda.is_available() else "cpu"
49
+ DEVICE = torch.device(DEVICE_STR)
50
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
51
+ SDXL_ID = os.environ.get("SDXL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
52
+
53
+ print(f"[Init] DEVICE={DEVICE_STR} DTYPE={DTYPE} SDXL_ID={SDXL_ID}")
54
+
55
+ SDXL_base = StableDiffusionXLPipeline.from_pretrained(SDXL_ID, torch_dtype=DTYPE).to(DEVICE)
56
+ SDXL_i2i = StableDiffusionXLImg2ImgPipeline.from_pretrained(SDXL_ID, torch_dtype=DTYPE).to(DEVICE)
57
+
58
+ for p in (SDXL_base, SDXL_i2i):
59
+ try:
60
+ p.enable_vae_slicing()
61
+ p.enable_attention_slicing()
62
+ except Exception:
63
+ pass
64
+ p.scheduler = DPMSolverMultistepScheduler.from_config(p.scheduler.config, use_karras_sigmas=True)
65
+
66
+ DEFAULT_NEG = (
67
+ "blurry, low quality, artifacts, watermark, extra fingers, missing limbs, "
68
+ "over-sharpened, harsh lighting, oversaturated"
69
+ )
70
+
71
+ @torch.no_grad()
72
+ def decode_image_sdxl(latents: torch.Tensor, pipe: StableDiffusionXLImg2ImgPipeline, output_type="pil"):
73
+ vae = pipe.vae
74
+ needs_upcast = (vae.dtype in (torch.float16, torch.bfloat16)) and bool(getattr(vae.config, "force_upcast", False))
75
+ if needs_upcast:
76
+ try:
77
+ pipe.upcast_vae()
78
+ except Exception:
79
+ pipe.vae = pipe.vae.to(torch.float32)
80
+ vae = pipe.vae
81
+
82
+ lat = latents.to(device=vae.device, dtype=(next(vae.post_quant_conv.parameters()).dtype))
83
+ lat = lat / vae.config.scaling_factor
84
+ out = vae.decode(lat)
85
+ x = out[0] if isinstance(out, (list, tuple)) else (out.sample if hasattr(out, "sample") else out)
86
+ if getattr(pipe, "watermark", None) is not None:
87
+ x = pipe.watermark.apply_watermark(x)
88
+ img = pipe.image_processor.postprocess(x.detach(), output_type=output_type)[0]
89
+ return img
90
+
91
+ @torch.no_grad()
92
+ def base_sample_latent(prompt: str, seed: int = 2025, H: int = 1024, W: int = 1024, neg: str = ""):
93
+ g = torch.Generator(device=DEVICE).manual_seed(int(seed))
94
+ out = SDXL_base(
95
+ prompt=prompt,
96
+ negative_prompt=neg,
97
+ height=int(H), width=int(W),
98
+ guidance_scale=4.5,
99
+ num_inference_steps=50,
100
+ generator=g,
101
+ output_type="latent"
102
+ )
103
+ z0 = out.images
104
+ x0 = decode_image_sdxl(z0, SDXL_i2i)
105
+ return z0, x0
106
+
107
+ @torch.no_grad()
108
+ def img2img_latent(prompt: str, image_or_latent, strength: float, guidance: float, steps: int, seed: int):
109
+ g = torch.Generator(device=DEVICE).manual_seed(int(seed))
110
+ out = SDXL_i2i(
111
+ prompt=prompt,
112
+ image=image_or_latent,
113
+ strength=float(strength),
114
+ guidance_scale=float(guidance),
115
+ num_inference_steps=int(steps),
116
+ generator=g,
117
+ output_type="latent",
118
+ negative_prompt=DEFAULT_NEG
119
+ )
120
+ return out.images
121
+
122
+ def strength_for_last_k(k: int, total_steps: int) -> float:
123
+ k = max(1, int(k))
124
+ return min(0.95, max(0.01, float(k) / float(max(1, total_steps))))
125
+
126
+ # =========================
127
+ # 2) CLIP-77 + text utils
128
+ # =========================
129
+ try:
130
+ from transformers import CLIPTokenizerFast
131
+ _clip_tok = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
132
+ def _count_tokens(txt: str) -> int:
133
+ return len(_clip_tok(txt, add_special_tokens=True, truncation=False)["input_ids"])
134
+ except Exception:
135
+ _clip_tok = None
136
+ def _count_tokens(txt: str) -> int:
137
+ return int(len(re.findall(r"\w+", txt)) * 1.3)
138
+
139
+ def _cleanup_commas(s: str) -> str:
140
+ s = re.sub(r"\s*,\s*", ", ", (s or "").strip())
141
+ s = re.sub(r"(,\s*){2,}", ", ", s)
142
+ return s.strip(" ,")
143
+
144
+ def clip77_strict(text: str, max_tok: int = 77) -> str:
145
+ text = (text or "").strip()
146
+ if _count_tokens(text) <= max_tok:
147
+ return text
148
+ words = text.split()
149
+ lo, hi, best = 0, len(words), ""
150
+ while lo <= hi:
151
+ mid = (lo + hi) // 2
152
+ cand = " ".join(words[:mid]) if mid > 0 else ""
153
+ if _count_tokens(cand) <= max_tok:
154
+ best = cand; lo = mid + 1
155
+ else:
156
+ hi = mid - 1
157
+ return best.strip()
158
+
159
+ def _split_tags(s: str) -> List[str]:
160
+ return [p.strip() for p in re.split(r",|\n", (s or "").strip()) if p.strip()]
161
+
162
+ def _dedup_keep_order(items: List[str]) -> List[str]:
163
+ seen, out = set(), []
164
+ for t in items:
165
+ key = re.sub(r"\s+", " ", t.lower()).strip()
166
+ if key and key not in seen:
167
+ seen.add(key); out.append(t.strip())
168
+ return out
169
+
170
+ def _order_tags(subject_first: List[str], rest: List[str]) -> List[str]:
171
+ buckets = {"subject": [], "style": [], "composition": [], "lighting": [], "color": [], "detail": [], "other": []}
172
+ style_kw = ("style","painterly","illustration","photorealistic","neon","poster","matte painting","watercolor","cyberpunk")
173
+ comp_kw = ("composition","rule of thirds","centered","symmetry","balanced composition")
174
+ light_kw = ("lighting","light","glow","glowing","rim","sunset","sunrise","golden hour","global illumination","cinematic")
175
+ color_kw = ("color","palette","vibrant","muted","monochrome","pastel","warm","cool","balanced contrast")
176
+ detail_kw= ("detailed","hyperdetailed","texture","intricate","high detail","highly detailed","sharp focus","uhd","8k")
177
+
178
+ for t in subject_first:
179
+ if t: buckets["subject"].append(t)
180
+ for t in rest:
181
+ lt = t.lower()
182
+ if any(k in lt for k in style_kw): buckets["style"].append(t)
183
+ elif any(k in lt for k in comp_kw): buckets["composition"].append(t)
184
+ elif any(k in lt for k in light_kw): buckets["lighting"].append(t)
185
+ elif any(k in lt for k in color_kw): buckets["color"].append(t)
186
+ elif any(k in lt for k in detail_kw): buckets["detail"].append(t)
187
+ else: buckets["other"].append(t)
188
+
189
+ return buckets["subject"] + buckets["style"] + buckets["composition"] + buckets["lighting"] + buckets["color"] + buckets["detail"] + buckets["other"]
190
+
191
+ def pil_to_base64(img: Image.Image, fmt: str = "PNG") -> str:
192
+ buf = io.BytesIO()
193
+ img.save(buf, format=fmt)
194
+ return base64.b64encode(buf.getvalue()).decode("ascii")
195
+
196
+ async def _maybe_close_async_together(client) -> None:
197
+ try:
198
+ if hasattr(client, "aclose") and inspect.iscoroutinefunction(client.aclose):
199
+ await client.aclose()
200
+ elif hasattr(client, "close"):
201
+ fn = client.close
202
+ if inspect.iscoroutinefunction(fn):
203
+ await fn()
204
+ else:
205
+ try: fn()
206
+ except Exception: pass
207
+ except Exception:
208
+ pass
209
+
210
+ # =========================
211
+ # 3) Async runner (reference style, but made safer for notebook/gradio)
212
+ # =========================
213
+ def _run_async(coro):
214
+ """
215
+ Robust sync->async bridge for:
216
+ - notebook (loop already running, same thread) via nest_asyncio + run_until_complete
217
+ - normal python (no running loop) via asyncio.run
218
+ """
219
+ try:
220
+ loop = asyncio.get_event_loop()
221
+ if loop.is_running():
222
+ # nest_asyncio makes this workable in notebook main loop
223
+ return loop.run_until_complete(coro)
224
+ return loop.run_until_complete(coro)
225
+ except RuntimeError:
226
+ return asyncio.run(coro)
227
+
228
+ # =========================
229
+ # 4) CritiCore (Together)
230
+ # =========================
231
+ from together import AsyncTogether
232
+
233
+ AGGREGATOR_MODEL = os.environ.get("AGGREGATOR_MODEL", "Qwen/Qwen2.5-72B-Instruct-Turbo")
234
+ LLM_MULTI_CANDIDATES = [
235
+ "meta-llama/Llama-3.3-70B-Instruct-Turbo",
236
+ "Qwen/Qwen2.5-72B-Instruct-Turbo",
237
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
238
+ "deepseek-ai/DeepSeek-V3",
239
+ "nvidia/NVIDIA-Nemotron-Nano-9B-v2",
240
+ ]
241
+ _env_list = [s.strip() for s in os.environ.get("VLM_MOA_CANDIDATES","").split(",") if s.strip()]
242
+ VLM_CANDIDATES = _env_list or ["meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"]
243
+
244
+ TAG_PRESETS = {
245
+ "hq_preference": {
246
+ "seed_pos": [
247
+ "balanced composition",
248
+ "natural color palette","vibrant colors","balanced contrast",
249
+ "high detail","highly detailed","hyperdetailed","sharp focus",
250
+ "UHD","8k"
251
+ ],
252
+ "seed_neg": [
253
+ "low quality","blurry","watermark","jpeg artifacts","overexposed","underexposed",
254
+ "color banding","extra fingers","missing limbs","disfigured","mutated hands"
255
+ ]
256
+ }
257
+ }
258
+
259
+ _DECOMP_SYS = (
260
+ "Decompose the user's visual instruction into 3-6 concrete, checkable visual components "
261
+ "(entities + interactions + spatial relations). Return ONLY JSON: "
262
+ '{"components":["..."]}'
263
+ )
264
+
265
+ _TXT_SYS = (
266
+ "Expand a VERY SHORT visual idea into a COMMA-SEPARATED TAG LIST for SDXL.\n"
267
+ "Constraints:\n"
268
+ "- Start with the subject phrase first.\n"
269
+ "- Prioritize composition, lighting, color, and detail over style.\n"
270
+ "- Use at most TWO style tags if any.\n"
271
+ "- 16–26 concise tags total. Commas only, no sentences, no 'and'. No trailing period.\n"
272
+ "- Prefer human-preference aesthetics; keep 'high detailed', 'sharp focus', '8k', 'UHD'."
273
+ )
274
+
275
+ def _TAG_RE(tag: str):
276
+ return re.compile(rf"<\s*{tag}\s*>(.*?)</\s*{tag}\s*>", re.S|re.I)
277
+
278
+ def _extract_tag(text: str, tag: str, fallback: str = "") -> str:
279
+ s = (text or "").strip()
280
+ r = _TAG_RE(tag); m = r.search(s)
281
+ if m: return m.group(1).strip()
282
+ s2 = s.replace("&lt;","<").replace("&gt;",">")
283
+ m2 = r.search(s2)
284
+ return m2.group(1).strip() if m2 else fallback.strip()
285
+
286
+ def _summarize_issues_lines(text: str, max_lines: int = 5) -> str:
287
+ if not text: return ""
288
+ parts = [p.strip(" -•\t") for p in re.split(r"[\n;]+", text) if p.strip()]
289
+ parts = parts[:max_lines]
290
+ return "\n".join(f"- {p}" for p in parts)
291
+
292
+ class CritiCore:
293
+ def __init__(self, preset: str = "hq_preference", aggregator_model: str = AGGREGATOR_MODEL):
294
+ if not os.environ.get("TOGETHER_API_KEY"):
295
+ raise RuntimeError("Missing TOGETHER_API_KEY in environment.")
296
+ self.preset = preset
297
+ self.aggregator = aggregator_model
298
+
299
+ async def decompose_components(self, user_prompt: str) -> List[str]:
300
+ client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"])
301
+ try:
302
+ tasks = [client.chat.completions.create(
303
+ model=m,
304
+ messages=[{"role":"system","content": _DECOMP_SYS},
305
+ {"role":"user","content": user_prompt}],
306
+ temperature=0.4, max_tokens=256
307
+ ) for m in LLM_MULTI_CANDIDATES]
308
+ rs = await asyncio.gather(*tasks, return_exceptions=True)
309
+ texts = []
310
+ for r in rs:
311
+ try: texts.append(r.choices[0].message.content)
312
+ except Exception: pass
313
+ if not texts: return []
314
+ joined = "\n\n---\n\n".join(texts)
315
+ merged = await client.chat.completions.create(
316
+ model=self.aggregator,
317
+ messages=[{"role":"system","content": "Merge JSON candidates and return ONLY {'components':[...]}."},
318
+ {"role":"user","content": joined}],
319
+ temperature=0.2, max_tokens=256
320
+ )
321
+ txt = merged.choices[0].message.content
322
+ try:
323
+ obj = json.loads(txt)
324
+ except Exception:
325
+ s,e = txt.find("{"), txt.rfind("}")
326
+ obj = json.loads(txt[s:e+1]) if (s!=-1 and e!=-1) else {"components":[]}
327
+ comps = [c.strip() for c in obj.get("components", []) if isinstance(c, str) and c.strip()]
328
+ return comps[:6]
329
+ finally:
330
+ await _maybe_close_async_together(client)
331
+
332
+ async def make_tags(self, user_prompt: str, clip77: bool = True) -> Tuple[str, str]:
333
+ client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"])
334
+ seed = TAG_PRESETS.get(self.preset, TAG_PRESETS["hq_preference"])
335
+ seed_pos = _dedup_keep_order(seed["seed_pos"])
336
+ seed_neg = seed["seed_neg"]
337
+ try:
338
+ tasks = [client.chat.completions.create(
339
+ model=m,
340
+ messages=[{"role":"system","content": _TXT_SYS},
341
+ {"role":"user","content":
342
+ f"Short idea: {user_prompt}\nSeed: {', '.join(seed_pos)}\nOutput: a single comma-separated tag list."}],
343
+ temperature=0.7, max_tokens=220
344
+ ) for m in LLM_MULTI_CANDIDATES]
345
+ rs = await asyncio.gather(*tasks, return_exceptions=True)
346
+ props = []
347
+ for r in rs:
348
+ try: props.append(r.choices[0].message.content)
349
+ except Exception: pass
350
+
351
+ if not props:
352
+ pos = ", ".join([user_prompt.strip()] + seed_pos)
353
+ else:
354
+ joined = "\n---\n".join(props)
355
+ merged = await client.chat.completions.create(
356
+ model=self.aggregator,
357
+ messages=[{"role":"system","content":
358
+ "Merge candidate tag lists into ONE comma list (16–26 tags). Subject first; at most TWO style tags; keep high detailed/sharp focus/8k/UHD."},
359
+ {"role":"user","content": joined}],
360
+ temperature=0.2, max_tokens=240
361
+ )
362
+ raw = merged.choices[0].message.content
363
+ tags = _dedup_keep_order(_split_tags(raw))
364
+ subject = user_prompt.strip().rstrip(",.")
365
+ if subject and not any(subject.lower() == t.lower() for t in tags):
366
+ tags = [subject] + tags
367
+ ordered = _order_tags([tags[0]], tags[1:])
368
+ pos = ", ".join(_dedup_keep_order(ordered))
369
+
370
+ # quality floor
371
+ for q in ["high detailed","sharp focus","8k","UHD"]:
372
+ if q.lower() not in {t.lower() for t in _split_tags(pos)}:
373
+ pos += ", " + q
374
+
375
+ pos = _cleanup_commas(pos)
376
+ if clip77 and _count_tokens(pos) > 77:
377
+ pos = clip77_strict(pos, 77)
378
+
379
+ neg = ", ".join(seed_neg)
380
+ return pos, neg
381
+ finally:
382
+ await _maybe_close_async_together(client)
383
+
384
+ async def vlm_refine(self, image: Image.Image, original_prompt: str, components: List[str]) -> Dict[str, object]:
385
+ client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"])
386
+ b64 = pil_to_base64(image, "PNG")
387
+
388
+ def _user_prompt_text() -> str:
389
+ return (
390
+ "You are a precise image-grounded critic.\n"
391
+ "1) List concrete visual problems and brief corrections.\n"
392
+ "2) Provide a refined prompt that keeps the original intent.\n\n"
393
+ f'Original prompt: "{original_prompt}"\n'
394
+ f"Key components to check: {components}\n"
395
+ "Output EXACTLY two tags:\n"
396
+ "<issues>...</issues>\n<refined>...</refined>"
397
+ )
398
+ try:
399
+ tasks = []
400
+ for m in VLM_CANDIDATES:
401
+ msgs = [
402
+ {"role":"system","content": "Return ONLY <issues> and <refined>. No extra text."},
403
+ {"role":"user","content": [
404
+ {"type":"text","text": _user_prompt_text()},
405
+ {"type":"image_url","image_url":{"url": f"data:image/png;base64,{b64}"}}
406
+ ]}
407
+ ]
408
+ tasks.append(client.chat.completions.create(model=m, messages=msgs, temperature=0.2, max_tokens=420))
409
+
410
+ rs = await asyncio.gather(*tasks, return_exceptions=True)
411
+ ok = []
412
+ for m, r in zip(VLM_CANDIDATES, rs):
413
+ try: ok.append((m, r.choices[0].message.content))
414
+ except Exception: pass
415
+
416
+ if not ok:
417
+ return {"refined": original_prompt, "issues_merged": ""}
418
+
419
+ refined_items, per_vlm_issues = [], {}
420
+ for m, raw in ok:
421
+ issues = _extract_tag(raw, "issues", "")
422
+ refined = _extract_tag(raw, "refined", original_prompt)
423
+ if refined.strip(): refined_items.append((m, refined.strip()))
424
+ if issues.strip(): per_vlm_issues[m] = _summarize_issues_lines(issues, 5)
425
+
426
+ joined_issues = "\n".join(f"[{m}] {t}" for m,t in per_vlm_issues.items())
427
+ joined_refined = "\n".join(f"[{m}] {t}" for m,t in refined_items) if refined_items else original_prompt
428
+
429
+ merged = await client.chat.completions.create(
430
+ model=self.aggregator,
431
+ messages=[{"role":"system","content":
432
+ "Merge multiple critics. Output ONLY <issues> (≤5 bullets) and <refined> (≤70 words)."},
433
+ {"role":"user","content": f"{joined_issues}\n\n----\n\n{joined_refined}"}],
434
+ temperature=0.2, max_tokens=420
435
+ )
436
+ final_raw = merged.choices[0].message.content
437
+ final_refined = clip77_strict(_extract_tag(final_raw, "refined", original_prompt), 77)
438
+ issues_merged = _summarize_issues_lines(_extract_tag(final_raw, "issues", ""), 5)
439
+
440
+ return {"refined": final_refined, "issues_merged": issues_merged}
441
+ finally:
442
+ await _maybe_close_async_together(client)
443
+
444
+ @staticmethod
445
+ def merge_vlm_multi_text(vlm_refined_77: str, tags_77: str) -> str:
446
+ vlm_tags = _split_tags(vlm_refined_77)
447
+ moa_tags = _split_tags(tags_77)
448
+ merged = _dedup_keep_order(_order_tags([vlm_tags[0] if vlm_tags else ""], (vlm_tags[1:] + moa_tags)))
449
+ merged = [t for t in merged if t]
450
+ text = _cleanup_commas(", ".join(merged))
451
+ if _count_tokens(text) > 77:
452
+ text = clip77_strict(text, 77)
453
+ return text
454
+
455
+ # =========================
456
+ # 5) Standalone frequency_fusion (your reference)
457
+ # =========================
458
+ @torch.no_grad()
459
+ def frequency_fusion(
460
+ x_hi_latent: torch.Tensor,
461
+ x_lo_latent: torch.Tensor,
462
+ base_c: float = 0.5,
463
+ rho_t: float = 0.85,
464
+ device=None,
465
+ ) -> torch.Tensor:
466
+ if device is None:
467
+ device = x_hi_latent.device
468
+ B, C, H, W = x_hi_latent.shape
469
+
470
+ x_h = x_hi_latent.to(torch.float32).to(device)
471
+ x_l = x_lo_latent.to(torch.float32).to(device)
472
+
473
+ Xh = torch.fft.fftshift(torch.fft.fftn(x_h, dim=(-2, -1)), dim=(-2, -1))
474
+ Xl = torch.fft.fftshift(torch.fft.fftn(x_l, dim=(-2, -1)), dim=(-2, -1))
475
+
476
+ tau_h = int(H * base_c * (1 - rho_t))
477
+ tau_w = int(W * base_c * (1 - rho_t))
478
+
479
+ mask = torch.ones((B, C, H, W), device=device, dtype=torch.float32)
480
+ cy, cx = H // 2, W // 2
481
+ if tau_h > 0 and tau_w > 0:
482
+ mask[..., cy - tau_h : cy + tau_h, cx - tau_w : cx + tau_w] = rho_t
483
+
484
+ Xf = Xh * mask + Xl * (1 - mask)
485
+ x = torch.fft.ifftn(torch.fft.ifftshift(Xf, dim=(-2, -1)), dim=(-2, -1)).real
486
+ x = x + torch.randn_like(x) * 0.001
487
+ return x.to(dtype=x_hi_latent.dtype)
488
+
489
+ def _decode_to_pil(latents, pipe):
490
+ out = decode_image_sdxl(latents, pipe)
491
+ if isinstance(out, Image.Image):
492
+ return out
493
+ if hasattr(out, "images"):
494
+ return out.images[0]
495
+ return out
496
+
497
+ # =========================
498
+ # 6) Helpers for variants
499
+ # =========================
500
+ def _normalize_enabled(enabled_variants: Optional[Iterable[str]]) -> Set[str]:
501
+ default = set(VARIANT_LABELS.keys())
502
+ if enabled_variants is None:
503
+ return default
504
+ return set(enabled_variants)
505
+
506
+ def _guidance_for_k(k: int) -> float:
507
+ if k >= 20: return 12.0
508
+ if k >= 10: return 7.5
509
+ return 5.2
510
+
511
+ async def _shared_materials(user_prompt: str, seed: int, H: int, W: int, preset: str):
512
+ critic = CritiCore(preset=preset)
513
+
514
+ pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
515
+ comps = await critic.decompose_components(user_prompt)
516
+
517
+ z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
518
+ z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags)
519
+
520
+ vlm_out = await critic.vlm_refine(base_enh, pos_tags_77, comps or [])
521
+ vlm_agg_77 = vlm_out.get("refined") or pos_tags_77
522
+
523
+ return dict(
524
+ pos_tags_77=pos_tags_77, neg_tags=neg_tags, comps=comps,
525
+ z0_og=z0_og, base_og=base_og,
526
+ z0_enh=z0_enh, base_enh=base_enh,
527
+ vlm_agg_77=vlm_agg_77,
528
+ critic=critic
529
+ )
530
+
531
+ async def _collect_meta(user_prompt: str, seed: int, H: int, W: int, preset: str):
532
+ shared = await _shared_materials(user_prompt, seed, H, W, preset)
533
+ return {
534
+ "user_prompt": user_prompt,
535
+ "pos_tags_77": shared["pos_tags_77"],
536
+ "neg_tags": shared["neg_tags"],
537
+ "components": shared["comps"],
538
+ "vlm_agg_77_on_multi_llm": shared["vlm_agg_77"],
539
+ }
540
+
541
+ # =========================
542
+ # 7) Variants generator (signature matches your reference: last_k_list + guidance_list)
543
+ # =========================
544
+ async def generate_variants(
545
+ user_prompt: str,
546
+ seed: int,
547
+ H: int, W: int,
548
+ total_steps_refine: int,
549
+ last_k_list: Iterable[int],
550
+ guidance_list: Optional[List[float]] = None,
551
+ preset: str = "hq_preference",
552
+ out_dir: Optional[Path] = None,
553
+ enabled_variants: Optional[Iterable[str]] = None,
554
+ ) -> Dict[str, Dict[int, Image.Image]]:
555
+ enabled = _normalize_enabled(enabled_variants)
556
+
557
+ lk = int(last_k_list) if isinstance(last_k_list, int) else (int(list(last_k_list)[-1]) if last_k_list else 36)
558
+
559
+ shared = await _shared_materials(user_prompt, seed, H, W, preset)
560
+ pos_tags_77 = shared["pos_tags_77"]; comps = shared["comps"]
561
+ z0_og, base_og = shared["z0_og"], shared["base_og"]
562
+ z0_enh, base_enh = shared["z0_enh"], shared["base_enh"]
563
+ vlm_agg_77 = shared["vlm_agg_77"]
564
+ critic: CritiCore = shared["critic"]
565
+
566
+ out: Dict[str, Dict[int, Image.Image]] = {}
567
+
568
+ def _save(im: Image.Image, vname: str, k: int = 0):
569
+ if out_dir is None: return
570
+ sub = out_dir / f"var_{vname}"
571
+ sub.mkdir(parents=True, exist_ok=True)
572
+ im.save(sub / f"{vname}_k{k}.png")
573
+
574
+ # 1) base_multi_llm
575
+ if "base_multi_llm" in enabled:
576
+ v = VARIANT_LABELS["base_multi_llm"]
577
+ out[v] = {0: base_enh}
578
+ _save(base_enh, v, 0)
579
+
580
+ # 2) criticore_on_multi_llm__specfusion
581
+ if "criticore_on_multi_llm__specfusion" in enabled:
582
+ v = VARIANT_LABELS["criticore_on_multi_llm__specfusion"]; out[v] = {}
583
+ refined_on_enh = CritiCore.merge_vlm_multi_text(vlm_agg_77, pos_tags_77)
584
+
585
+ strength = float(strength_for_last_k(lk, total_steps_refine))
586
+ guidance = float(guidance_list[-1]) if guidance_list else float(_guidance_for_k(lk))
587
+ steps = int(total_steps_refine)
588
+
589
+ z_ref = img2img_latent(
590
+ refined_on_enh, z0_enh,
591
+ strength=strength, guidance=guidance, steps=steps,
592
+ seed=seed + 2100 + lk
593
+ )
594
+ fused_lat = frequency_fusion(z_ref, z0_enh, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
595
+ img_sf = _decode_to_pil(fused_lat, SDXL_i2i)
596
+
597
+ out[v][0] = img_sf
598
+ _save(img_sf, v, 0)
599
+
600
+ # 3) base_original
601
+ if "base_original" in enabled:
602
+ v = VARIANT_LABELS["base_original"]
603
+ out[v] = {0: base_og}
604
+ _save(base_og, v, 0)
605
+
606
+ # 4) criticore_on_original__specfusion
607
+ if "criticore_on_original__specfusion" in enabled:
608
+ v = VARIANT_LABELS["criticore_on_original__specfusion"]; out[v] = {}
609
+
610
+ vlm_on_og = await critic.vlm_refine(base_og, user_prompt, comps or [])
611
+ refined_og_77 = clip77_strict(vlm_on_og.get("refined") or user_prompt, 77)
612
+ refined_merge = CritiCore.merge_vlm_multi_text(refined_og_77, pos_tags_77)
613
+
614
+ strength = float(strength_for_last_k(lk, total_steps_refine))
615
+ guidance = float(guidance_list[-1]) if guidance_list else float(_guidance_for_k(lk))
616
+ steps = int(total_steps_refine)
617
+
618
+ z_ref = img2img_latent(
619
+ refined_merge, z0_og,
620
+ strength=strength, guidance=guidance, steps=steps,
621
+ seed=seed + 2400 + lk
622
+ )
623
+ fused_lat = frequency_fusion(z_ref, z0_og, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
624
+ img_sf = _decode_to_pil(fused_lat, SDXL_i2i)
625
+
626
+ out[v][0] = img_sf
627
+ _save(img_sf, v, 0)
628
+
629
+ return out
630
+
631
+ # =========================
632
+ # 8) Full CADR pipeline (kept for your "完整 demo" tab)
633
+ # =========================
634
+ def _try_pref_score(enhanced_prompt: str, base_img: Image.Image) -> Optional[float]:
635
+ fn = globals().get("pref_score", None) # optional external
636
+ if fn is None:
637
+ return None
638
+ try:
639
+ s01 = float(fn(enhanced_prompt, base_img))
640
+ return max(0.0, min(100.0, s01 * 100.0))
641
+ except Exception:
642
+ return None
643
+
644
+ def _clamp01(x: float) -> float: return max(0.0, min(1.0, float(x)))
645
+ def _lerp(a: float, b: float, t: float) -> float: return a + (b - a) * t
646
+
647
+ class SpecFusionCADR:
648
+ def __init__(self, device: torch.device):
649
+ self.device = device
650
+
651
+ @staticmethod
652
+ def cadr_from_alignment(align_score: float) -> Tuple[float, float, int, float]:
653
+ s = _clamp01(align_score / 100.0); mis = 1.0 - s
654
+ strength = _lerp(0.12, 0.30, mis)
655
+ guidance = _lerp(3.6, 5.0, mis)
656
+ steps = int(round(_lerp(16, 30, mis)))
657
+ rho_t = _lerp(0.60, 0.85, mis)
658
+ return strength, guidance, steps, rho_t
659
+
660
+ @torch.no_grad()
661
+ def final_touch(self, enhanced_prompt: str, base_latent: torch.Tensor, align_score: float, seed: int):
662
+ strength, guidance, steps, rho_t = self.cadr_from_alignment(float(align_score))
663
+ z_ref = img2img_latent(enhanced_prompt, base_latent, strength=strength, guidance=guidance, steps=steps, seed=seed)
664
+ fused = frequency_fusion(z_ref, base_latent, base_c=0.5, rho_t=rho_t, device=DEVICE)
665
+ img = decode_image_sdxl(fused, SDXL_i2i)
666
+ return img, dict(strength=strength, guidance=guidance, steps=steps, rho_t=rho_t)
667
+
668
+ async def pipeline_full_cadr(user_prompt: str, seed: int, H: int, W: int, preset: str, align_score: Optional[float], save_dir: Optional[Path]):
669
+ critic = CritiCore(preset=preset)
670
+ spec = SpecFusionCADR(device=DEVICE)
671
+
672
+ z0_base, base_img = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
673
+
674
+ pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
675
+ comps = await critic.decompose_components(user_prompt)
676
+ vlm_out = await critic.vlm_refine(base_img, user_prompt, comps or [])
677
+ vlm_ref_77 = vlm_out.get("refined") or user_prompt
678
+ enhanced_77 = CritiCore.merge_vlm_multi_text(vlm_ref_77, pos_tags_77)
679
+
680
+ if align_score is None:
681
+ auto = _try_pref_score(enhanced_77, base_img)
682
+ align_score = auto if auto is not None else 60.0
683
+
684
+ final_img, cadr_params = spec.final_touch(enhanced_77, z0_base, float(align_score), seed=seed+999)
685
+
686
+ meta = {
687
+ "user_prompt": user_prompt,
688
+ "pos_tags_77": pos_tags_77,
689
+ "vlm_refined_77": vlm_ref_77,
690
+ "enhanced_prompt_77": enhanced_77,
691
+ "align_score": float(align_score),
692
+ "cadr_params": cadr_params,
693
+ "components": comps,
694
+ "vlm_issues": vlm_out.get("issues_merged",""),
695
+ }
696
+ meta_json = json.dumps(meta, ensure_ascii=False, indent=2)
697
+
698
+ if save_dir is not None:
699
+ save_dir.mkdir(parents=True, exist_ok=True)
700
+ base_img.save(save_dir / "base.png")
701
+ final_img.save(save_dir / "cadr_final.png")
702
+ (save_dir / "record.json").write_text(meta_json, encoding="utf-8")
703
+
704
+ return base_img, final_img, enhanced_77, meta_json
705
+
706
+ # =========================
707
+ # 9) UI callbacks (reference style)
708
+ # =========================
709
+ def ui_run_once(
710
+ user_prompt: str,
711
+ seed: int,
712
+ H: int,
713
+ W: int,
714
+ preset: str,
715
+ total_steps_refine: int,
716
+ last_k: int,
717
+ guidance: float,
718
+ enabled_variants: List[str],
719
+ save_outputs: bool,
720
+ out_dir: str,
721
+ ):
722
+ t0 = time.time()
723
+ try:
724
+ if not user_prompt or not user_prompt.strip():
725
+ return [], "Empty prompt."
726
+
727
+ if not TOGETHER_API_KEY:
728
+ return [], "ERROR: TOGETHER_API_KEY not set."
729
+
730
+ # display -> internal
731
+ display_to_internal = {v: k for k, v in VARIANT_LABELS.items()}
732
+ internal_enabled = [display_to_internal.get(v, v) for v in (enabled_variants or [])]
733
+
734
+ out_path = Path(out_dir) if (save_outputs and out_dir) else None
735
+ if out_path is not None:
736
+ out_path.mkdir(parents=True, exist_ok=True)
737
+
738
+ results = _run_async(generate_variants(
739
+ user_prompt=user_prompt.strip(),
740
+ seed=int(seed),
741
+ H=int(H), W=int(W),
742
+ total_steps_refine=int(total_steps_refine),
743
+ last_k_list=(int(last_k),),
744
+ guidance_list=[float(guidance)] if guidance > 0 else None,
745
+ preset=preset,
746
+ out_dir=out_path,
747
+ enabled_variants=internal_enabled,
748
+ ))
749
+
750
+ gallery = []
751
+ for name in VARIANT_ORDER:
752
+ if name in results and 0 in results[name]:
753
+ gallery.append((results[name][0], name))
754
+
755
+ try:
756
+ meta = _run_async(_collect_meta(user_prompt.strip(), int(seed), int(H), int(W), preset))
757
+ except Exception as e:
758
+ meta = {"meta_error": str(e)}
759
+
760
+ meta["ui"] = {
761
+ "seed": int(seed),
762
+ "H": int(H),
763
+ "W": int(W),
764
+ "preset": preset,
765
+ "total_steps_refine": int(total_steps_refine),
766
+ "last_k": int(last_k),
767
+ "guidance": float(guidance),
768
+ "enabled_variants": enabled_variants,
769
+ "save_outputs": bool(save_outputs),
770
+ "out_dir": out_dir if save_outputs else None,
771
+ }
772
+ meta["elapsed_sec"] = round(time.time() - t0, 3)
773
+
774
+ return gallery, json.dumps(meta, ensure_ascii=False, indent=2)
775
+
776
+ except Exception:
777
+ return [], traceback.format_exc()
778
+
779
+ def ui_run_full(
780
+ user_prompt: str,
781
+ seed: int,
782
+ H: int,
783
+ W: int,
784
+ preset: str,
785
+ align_mode: str,
786
+ align_score: float,
787
+ save_outputs: bool,
788
+ out_dir: str,
789
+ ):
790
+ try:
791
+ if not user_prompt or not user_prompt.strip():
792
+ return None, None, "", "Empty prompt."
793
+ if not TOGETHER_API_KEY:
794
+ return None, None, "", "ERROR: TOGETHER_API_KEY not set."
795
+
796
+ save_dir = Path(out_dir) if (save_outputs and out_dir) else None
797
+ a = None if align_mode.startswith("Auto") else float(align_score)
798
+
799
+ base_img, final_img, enhanced_77, meta_json = _run_async(
800
+ pipeline_full_cadr(
801
+ user_prompt=user_prompt.strip(),
802
+ seed=int(seed), H=int(H), W=int(W),
803
+ preset=preset,
804
+ align_score=a,
805
+ save_dir=save_dir,
806
+ )
807
+ )
808
+ return base_img, final_img, enhanced_77, meta_json
809
+ except Exception:
810
+ return None, None, "", traceback.format_exc()
811
+
812
+ # =========================
813
+ # 10) Gradio UI (matches your reference fixes)
814
+ # =========================
815
+ with gr.Blocks(title="CritiFusion (SDXL) Demo", theme=gr.themes.Soft()) as demo:
816
+ gr.Markdown(
817
+ "## CritiFusion Demo (SDXL)\n"
818
+ "- You can run **Variants** (outputs 4 variants in one click).\n"
819
+ "- You can also run the **Full CADR Pipeline** (end-to-end).\n"
820
+ f"- Device: **{DEVICE_STR}**, DType: **{DTYPE}**\n"
821
+ f"- Together API: {'✅ set' if TOGETHER_API_KEY else '❌ missing (set TOGETHER_API_KEY)'}"
822
+ )
823
+
824
+ with gr.Tabs():
825
+ with gr.Tab("Variants (Run Once)"):
826
+ with gr.Row():
827
+ with gr.Column(scale=7):
828
+ user_prompt = gr.Textbox(
829
+ label="Prompt",
830
+ value="A fluffy orange cat lying on a window ledge, front-facing, stylized in 3D Pixar look, soft indoor lighting",
831
+ lines=3,
832
+ )
833
+ with gr.Row():
834
+ seed = gr.Number(label="Seed", value=2026, precision=0)
835
+ preset = gr.Dropdown(label="Preset", choices=["hq_preference"], value="hq_preference")
836
+ with gr.Row():
837
+ H = gr.Number(label="H", value=1024, precision=0)
838
+ W = gr.Number(label="W", value=1024, precision=0)
839
+ with gr.Row():
840
+ total_steps_refine = gr.Slider(label="total_steps_refine", minimum=10, maximum=80, step=1, value=50)
841
+ last_k = gr.Slider(label="last_k", minimum=1, maximum=50, step=1, value=37)
842
+ guidance = gr.Slider(label="Guidance (0 => fallback rule)", minimum=0.0, maximum=15.0, step=0.1, value=0.0)
843
+
844
+ enabled_variants = gr.CheckboxGroup(
845
+ label="Enabled Variants",
846
+ choices=[VARIANT_LABELS[k] for k in VARIANT_LABELS.keys()],
847
+ value=[
848
+ VARIANT_LABELS["base_original"],
849
+ VARIANT_LABELS["base_multi_llm"],
850
+ VARIANT_LABELS["criticore_on_original__specfusion"],
851
+ VARIANT_LABELS["criticore_on_multi_llm__specfusion"],
852
+ ],
853
+ )
854
+
855
+ with gr.Row():
856
+ save_outputs = gr.Checkbox(label="Save outputs to disk", value=False)
857
+ out_dir = gr.Textbox(label="Output dir (only if save enabled)", value="./variants_demo_gradio")
858
+
859
+ run_btn = gr.Button("Run", variant="primary")
860
+
861
+ with gr.Column(scale=8):
862
+ gallery = gr.Gallery(label="Results", columns=2, height=600)
863
+ meta_json = gr.Code(label="Meta / Debug (JSON)", language="json")
864
+
865
+ run_btn.click(
866
+ fn=ui_run_once,
867
+ inputs=[user_prompt, seed, H, W, preset, total_steps_refine, last_k, guidance, enabled_variants, save_outputs, out_dir],
868
+ outputs=[gallery, meta_json],
869
+ api_name=False,
870
+ show_api=False,
871
+ )
872
+
873
+ with gr.Tab("Full CADR Pipeline"):
874
+ with gr.Row():
875
+ with gr.Column(scale=7):
876
+ p2 = gr.Textbox(
877
+ label="Prompt",
878
+ value="A fluffy orange cat lying on a window ledge, front-facing, stylized in 3D Pixar look, soft indoor lighting",
879
+ lines=3,
880
+ )
881
+ with gr.Row():
882
+ seed2 = gr.Number(label="Seed", value=2026, precision=0)
883
+ preset2 = gr.Dropdown(label="Preset", choices=["hq_preference"], value="hq_preference")
884
+ with gr.Row():
885
+ H2 = gr.Number(label="H", value=1024, precision=0)
886
+ W2 = gr.Number(label="W", value=1024, precision=0)
887
+
888
+ align_mode = gr.Radio(
889
+ label="Alignment score source",
890
+ choices=["Auto (pref_score if available else 60)", "Manual (use slider below)"],
891
+ value="Auto (pref_score if available else 60)",
892
+ )
893
+ align_score = gr.Slider(label="Manual align_score (0..100)", minimum=0, maximum=100, step=1, value=60)
894
+
895
+ with gr.Row():
896
+ save2 = gr.Checkbox(label="Save outputs to disk", value=False)
897
+ out2 = gr.Textbox(label="Output dir (only if save enabled)", value="./full_cadr_gradio")
898
+
899
+ run2 = gr.Button("Run Full CADR", variant="primary")
900
+
901
+ with gr.Column(scale=8):
902
+ base_img = gr.Image(label="Base", type="pil")
903
+ final_img = gr.Image(label="Final (CADR + SpecFusion)", type="pil")
904
+ enhanced = gr.Textbox(label="Enhanced prompt (≤77)", lines=3)
905
+ meta2 = gr.Code(label="Meta / Debug (JSON)", language="json")
906
+
907
+ run2.click(
908
+ fn=ui_run_full,
909
+ inputs=[p2, seed2, H2, W2, preset2, align_mode, align_score, save2, out2],
910
+ outputs=[base_img, final_img, enhanced, meta2],
911
+ api_name=False,
912
+ show_api=False,
913
+ )
914
+
915
+ # IMPORTANT: share=True fixes "localhost not accessible"
916
+ demo.queue().launch(debug=True, share=True, max_threads=1, show_api=False)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ torch
3
+ diffusers>=0.30.0
4
+ transformers>=4.43.0
5
+ accelerate
6
+ safetensors
7
+ huggingface_hub
8
+ pillow
9
+ numpy
10
+ together
11
+ nest_asyncio