czq0719 commited on
Commit
e571656
·
1 Parent(s): 6bffc9a

Add gradio app

Browse files
Files changed (1) hide show
  1. app.py +264 -197
app.py CHANGED
@@ -1,8 +1,9 @@
1
  # =========================
2
  # ONE-CELL: SDXL + CritiCore + SpecFusion + Gradio UI
3
- # - 3 variants only
4
- # - single-choice dropdown (no format_fn; compatible with older Gradio)
5
- # - generate ONE image per click
 
6
  # =========================
7
 
8
  import os, re, io, json, time, base64, asyncio, inspect, traceback
@@ -25,22 +26,30 @@ from diffusers import (
25
  os.environ["TOGETHER_NO_BANNER"] = "1"
26
 
27
  # =========================
28
- # 0) Variant labels (MUST be BEFORE Blocks)
29
  # =========================
30
- # Expose ONLY 3 variants in UI.
31
- # Map internal "criticore_on_multi_llm__specfusion" -> UI name "CritiFusion".
32
  VARIANT_LABELS = {
33
- "base_original": "Base (Original prompt)",
34
- "base_multi_llm": "Base (Multi-LLM tag expansion)",
35
- "CritiFusion": "CritiFusion (Multi-LLM + VLM critique + SpecFusion)",
 
36
  }
37
- VARIANT_KEYS_UI = ["base_original", "base_multi_llm", "CritiFusion"]
38
 
39
- RHO_T_DEFAULT = 0.85 # fixed as requested
 
 
 
 
 
 
 
 
40
 
 
41
  TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY", "").strip()
42
  if not TOGETHER_API_KEY:
43
- print("[Warn] TOGETHER_API_KEY is not set. Together steps will fail for base_multi_llm/CritiFusion.")
44
 
45
  # =========================
46
  # 1) SDXL init
@@ -105,7 +114,7 @@ def base_sample_latent(prompt: str, seed: int, H: int, W: int, neg: str):
105
  return z0, x0
106
 
107
  @torch.no_grad()
108
- def img2img_latent(prompt: str, image_or_latent, strength: float, guidance: float, steps: int, seed: int):
109
  g = torch.Generator(device=DEVICE).manual_seed(int(seed))
110
  out = SDXL_i2i(
111
  prompt=prompt,
@@ -115,7 +124,7 @@ def img2img_latent(prompt: str, image_or_latent, strength: float, guidance: floa
115
  num_inference_steps=int(steps),
116
  generator=g,
117
  output_type="latent",
118
- negative_prompt=DEFAULT_NEG
119
  )
120
  return out.images
121
 
@@ -151,8 +160,7 @@ def clip77_strict(text: str, max_tok: int = 77) -> str:
151
  mid = (lo + hi) // 2
152
  cand = " ".join(words[:mid]) if mid > 0 else ""
153
  if _count_tokens(cand) <= max_tok:
154
- best = cand
155
- lo = mid + 1
156
  else:
157
  hi = mid - 1
158
  return best.strip()
@@ -165,8 +173,7 @@ def _dedup_keep_order(items: List[str]) -> List[str]:
165
  for t in items:
166
  key = re.sub(r"\s+", " ", t.lower()).strip()
167
  if key and key not in seen:
168
- seen.add(key)
169
- out.append(t.strip())
170
  return out
171
 
172
  def _order_tags(subject_first: List[str], rest: List[str]) -> List[str]:
@@ -188,10 +195,7 @@ def _order_tags(subject_first: List[str], rest: List[str]) -> List[str]:
188
  elif any(k in lt for k in detail_kw): buckets["detail"].append(t)
189
  else: buckets["other"].append(t)
190
 
191
- return (
192
- buckets["subject"] + buckets["style"] + buckets["composition"] +
193
- buckets["lighting"] + buckets["color"] + buckets["detail"] + buckets["other"]
194
- )
195
 
196
  def pil_to_base64(img: Image.Image, fmt: str = "PNG") -> str:
197
  buf = io.BytesIO()
@@ -207,10 +211,8 @@ async def _maybe_close_async_together(client) -> None:
207
  if inspect.iscoroutinefunction(fn):
208
  await fn()
209
  else:
210
- try:
211
- fn()
212
- except Exception:
213
- pass
214
  except Exception:
215
  pass
216
 
@@ -221,7 +223,7 @@ def _run_async(coro):
221
  try:
222
  loop = asyncio.get_event_loop()
223
  if loop.is_running():
224
- return loop.run_until_complete(coro)
225
  return loop.run_until_complete(coro)
226
  except RuntimeError:
227
  return asyncio.run(coro)
@@ -278,10 +280,8 @@ def _TAG_RE(tag: str):
278
 
279
  def _extract_tag(text: str, tag: str, fallback: str = "") -> str:
280
  s = (text or "").strip()
281
- r = _TAG_RE(tag)
282
- m = r.search(s)
283
- if m:
284
- return m.group(1).strip()
285
  s2 = s.replace("&lt;","<").replace("&gt;",">")
286
  m2 = r.search(s2)
287
  return m2.group(1).strip() if m2 else fallback.strip()
@@ -303,22 +303,17 @@ class CritiCore:
303
  async def decompose_components(self, user_prompt: str) -> List[str]:
304
  client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"])
305
  try:
306
- tasks = [
307
- client.chat.completions.create(
308
- model=m,
309
- messages=[{"role":"system","content": _DECOMP_SYS},
310
- {"role":"user","content": user_prompt}],
311
- temperature=0.4, max_tokens=256
312
- )
313
- for m in LLM_MULTI_CANDIDATES
314
- ]
315
  rs = await asyncio.gather(*tasks, return_exceptions=True)
316
  texts = []
317
  for r in rs:
318
- try:
319
- texts.append(r.choices[0].message.content)
320
- except Exception:
321
- pass
322
  if not texts:
323
  return []
324
  joined = "\n\n---\n\n".join(texts)
@@ -332,8 +327,8 @@ class CritiCore:
332
  try:
333
  obj = json.loads(txt)
334
  except Exception:
335
- s, e = txt.find("{"), txt.rfind("}")
336
- obj = json.loads(txt[s:e+1]) if (s != -1 and e != -1) else {"components":[]}
337
  comps = [c.strip() for c in obj.get("components", []) if isinstance(c, str) and c.strip()]
338
  return comps[:6]
339
  finally:
@@ -345,23 +340,18 @@ class CritiCore:
345
  seed_pos = _dedup_keep_order(seed["seed_pos"])
346
  seed_neg = seed["seed_neg"]
347
  try:
348
- tasks = [
349
- client.chat.completions.create(
350
- model=m,
351
- messages=[{"role":"system","content": _TXT_SYS},
352
- {"role":"user","content":
353
- f"Short idea: {user_prompt}\nSeed: {', '.join(seed_pos)}\nOutput: a single comma-separated tag list."}],
354
- temperature=0.7, max_tokens=220
355
- )
356
- for m in LLM_MULTI_CANDIDATES
357
- ]
358
  rs = await asyncio.gather(*tasks, return_exceptions=True)
359
  props = []
360
  for r in rs:
361
- try:
362
- props.append(r.choices[0].message.content)
363
- except Exception:
364
- pass
365
 
366
  if not props:
367
  pos = ", ".join([user_prompt.strip()] + seed_pos)
@@ -382,6 +372,7 @@ class CritiCore:
382
  ordered = _order_tags([tags[0]], tags[1:])
383
  pos = ", ".join(_dedup_keep_order(ordered))
384
 
 
385
  for q in ["high detailed","sharp focus","8k","UHD"]:
386
  if q.lower() not in {t.lower() for t in _split_tags(pos)}:
387
  pos += ", " + q
@@ -409,7 +400,6 @@ class CritiCore:
409
  "Output EXACTLY two tags:\n"
410
  "<issues>...</issues>\n<refined>...</refined>"
411
  )
412
-
413
  try:
414
  tasks = []
415
  for m in VLM_CANDIDATES:
@@ -425,10 +415,8 @@ class CritiCore:
425
  rs = await asyncio.gather(*tasks, return_exceptions=True)
426
  ok = []
427
  for m, r in zip(VLM_CANDIDATES, rs):
428
- try:
429
- ok.append((m, r.choices[0].message.content))
430
- except Exception:
431
- pass
432
 
433
  if not ok:
434
  return {"refined": original_prompt, "issues_merged": ""}
@@ -437,13 +425,11 @@ class CritiCore:
437
  for m, raw in ok:
438
  issues = _extract_tag(raw, "issues", "")
439
  refined = _extract_tag(raw, "refined", original_prompt)
440
- if refined.strip():
441
- refined_items.append((m, refined.strip()))
442
- if issues.strip():
443
- per_vlm_issues[m] = _summarize_issues_lines(issues, 5)
444
 
445
- joined_issues = "\n".join(f"[{m}] {t}" for m, t in per_vlm_issues.items())
446
- joined_refined = "\n".join(f"[{m}] {t}" for m, t in refined_items) if refined_items else original_prompt
447
 
448
  merged = await client.chat.completions.create(
449
  model=self.aggregator,
@@ -472,7 +458,7 @@ class CritiCore:
472
  return text
473
 
474
  # =========================
475
- # 5) SpecFusion
476
  # =========================
477
  @torch.no_grad()
478
  def frequency_fusion(
@@ -505,9 +491,13 @@ def frequency_fusion(
505
  x = x + torch.randn_like(x) * 0.001
506
  return x.to(dtype=x_hi_latent.dtype)
507
 
508
- def _decode_to_pil(latents, pipe):
509
- out = decode_image_sdxl(latents, pipe)
510
- return out if isinstance(out, Image.Image) else out.images[0]
 
 
 
 
511
 
512
  def _guidance_for_k(k: int) -> float:
513
  if k >= 20: return 12.0
@@ -515,116 +505,145 @@ def _guidance_for_k(k: int) -> float:
515
  return 5.2
516
 
517
  # =========================
518
- # 6) Shared + one-variant generator
519
  # =========================
520
- async def _shared_materials(user_prompt: str, seed: int, H: int, W: int, preset: str):
521
- critic = CritiCore(preset=preset)
522
-
523
- pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
524
- comps = await critic.decompose_components(user_prompt)
525
-
526
- z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
527
- z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags)
528
-
529
- vlm_out = await critic.vlm_refine(base_enh, pos_tags_77, comps or [])
530
- vlm_agg_77 = vlm_out.get("refined") or pos_tags_77
531
-
532
- return dict(
533
- pos_tags_77=pos_tags_77, neg_tags=neg_tags, comps=comps,
534
- z0_og=z0_og, base_og=base_og,
535
- z0_enh=z0_enh, base_enh=base_enh,
536
- vlm_agg_77=vlm_agg_77,
537
- vlm_issues=vlm_out.get("issues_merged",""),
538
- )
539
-
540
  async def generate_one_variant(
541
  user_prompt: str,
542
  seed: int,
543
- H: int, W: int,
544
- preset: str,
545
- variant_key: str,
546
  total_steps_refine: int,
547
  last_k: int,
548
  guidance: float,
549
- save_outputs: bool,
550
- out_dir: str,
551
- ):
552
- shared = await _shared_materials(user_prompt, seed, H, W, preset)
553
-
554
- pos_tags_77 = shared["pos_tags_77"]
555
- neg_tags = shared["neg_tags"]
556
- comps = shared["comps"]
557
- z0_og = shared["z0_og"]
558
- base_og = shared["base_og"]
559
- z0_enh = shared["z0_enh"]
560
- base_enh = shared["base_enh"]
561
- vlm_agg_77 = shared["vlm_agg_77"]
562
- vlm_issues = shared["vlm_issues"]
563
-
564
- out_path = Path(out_dir) if (save_outputs and out_dir) else None
565
- if out_path is not None:
566
- out_path.mkdir(parents=True, exist_ok=True)
567
-
568
- meta = {
569
  "user_prompt": user_prompt,
570
- "preset": preset,
571
  "variant_key": variant_key,
572
- "variant_label": VARIANT_LABELS.get(variant_key, variant_key),
573
- "seed": int(seed),
574
- "H": int(H),
575
- "W": int(W),
576
- "pos_tags_77": pos_tags_77,
577
- "neg_tags": neg_tags,
578
- "components": comps,
579
- "vlm_agg_77_on_multi_llm": vlm_agg_77,
580
- "vlm_issues": vlm_issues,
581
- "params": {
582
- "total_steps_refine": int(total_steps_refine),
583
- "last_k": int(last_k),
584
- "guidance": float(guidance),
585
- "rho_t": float(RHO_T_DEFAULT),
586
- }
587
  }
588
 
 
 
 
 
 
 
 
 
 
 
589
  if variant_key == "base_original":
590
- img = base_og
591
- if out_path is not None:
592
- img.save(out_path / "base_original.png")
593
- return img, meta
594
 
595
- if variant_key == "base_multi_llm":
596
- img = base_enh
597
- if out_path is not None:
598
- img.save(out_path / "base_multi_llm.png")
599
- return img, meta
 
 
 
 
 
 
 
 
600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  if variant_key == "CritiFusion":
602
- lk = int(last_k)
603
- strength = float(strength_for_last_k(lk, int(total_steps_refine)))
604
- gscale = float(guidance) if float(guidance) > 0 else float(_guidance_for_k(lk))
605
- steps = int(total_steps_refine)
606
 
 
 
607
  refined_on_enh = CritiCore.merge_vlm_multi_text(vlm_agg_77, pos_tags_77)
608
- meta["refined_prompt_77"] = refined_on_enh
609
- meta["img2img"] = {"strength": strength, "guidance_scale": gscale, "steps": steps}
610
 
611
  z_ref = img2img_latent(
612
  refined_on_enh, z0_enh,
613
- strength=strength, guidance=gscale, steps=steps,
614
- seed=int(seed) + 2100 + lk
 
615
  )
616
  fused_lat = frequency_fusion(z_ref, z0_enh, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
617
- img = _decode_to_pil(fused_lat, SDXL_i2i)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
619
- if out_path is not None:
620
- img.save(out_path / "CritiFusion.png")
621
- (out_path / "meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
622
- return img, meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
  raise ValueError(f"Unknown variant_key: {variant_key}")
625
 
626
  # =========================
627
- # 7) UI callback
628
  # =========================
629
  def ui_run_once(
630
  user_prompt: str,
@@ -632,55 +651,95 @@ def ui_run_once(
632
  H: int,
633
  W: int,
634
  preset: str,
635
- variant_key: str,
636
  total_steps_refine: int,
637
  last_k: int,
638
  guidance: float,
 
639
  save_outputs: bool,
640
  out_dir: str,
641
  ):
642
  t0 = time.time()
643
  try:
644
  if not user_prompt or not user_prompt.strip():
645
- return None, "Empty prompt."
 
 
 
 
 
 
646
 
647
- # Only these require Together
648
- if variant_key in ("base_multi_llm", "CritiFusion") and not TOGETHER_API_KEY:
649
- return None, "ERROR: TOGETHER_API_KEY not set (required for this variant)."
 
650
 
651
- img, meta = _run_async(generate_one_variant(
 
 
652
  user_prompt=user_prompt.strip(),
653
  seed=int(seed),
654
  H=int(H), W=int(W),
655
- preset=preset,
656
- variant_key=variant_key,
657
  total_steps_refine=int(total_steps_refine),
658
  last_k=int(last_k),
659
  guidance=float(guidance),
660
- save_outputs=bool(save_outputs),
661
- out_dir=str(out_dir or ""),
 
662
  ))
663
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  meta["elapsed_sec"] = round(time.time() - t0, 3)
665
- return img, json.dumps(meta, ensure_ascii=False, indent=2)
 
 
666
 
667
  except Exception:
668
- return None, traceback.format_exc()
669
 
670
  @spaces.GPU
671
  def ui_run_once_gpu(*args, **kwargs):
672
  return ui_run_once(*args, **kwargs)
673
 
674
  # =========================
675
- # 8) Gradio UI (Dropdown: single choice; 3 variants only)
676
  # =========================
677
- VARIANT_CHOICES_DISPLAY = [VARIANT_LABELS[k] for k in VARIANT_KEYS_UI]
678
- DISPLAY_TO_KEY = {VARIANT_LABELS[k]: k for k in VARIANT_KEYS_UI}
 
679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo:
681
  gr.Markdown(
682
  "## CritiFusion Demo (SDXL)\n"
683
- "- Choose **one** variant and generate **one** image per click.\n"
684
  f"- Device: **{DEVICE_STR}**, DType: **{DTYPE}**\n"
685
  f"- Together API: {'✅ set' if TOGETHER_API_KEY else '❌ missing (set TOGETHER_API_KEY)'}"
686
  )
@@ -689,7 +748,7 @@ with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo:
689
  with gr.Column(scale=7):
690
  user_prompt = gr.Textbox(
691
  label="Prompt",
692
- value="A fluffy orange cat lying on a window ledge, front-facing, stylized in 3D Pixar look, soft indoor lighting",
693
  lines=3,
694
  )
695
  with gr.Row():
@@ -698,42 +757,50 @@ with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo:
698
  with gr.Row():
699
  H = gr.Number(label="H", value=1024, precision=0)
700
  W = gr.Number(label="W", value=1024, precision=0)
701
-
702
- variant_display = gr.Dropdown(
703
- label="Variant (select ONE)",
704
- choices=VARIANT_CHOICES_DISPLAY,
705
- value=VARIANT_LABELS["CritiFusion"],
706
- )
707
-
708
  with gr.Row():
709
  total_steps_refine = gr.Slider(label="total_steps_refine", minimum=10, maximum=80, step=1, value=50)
710
  last_k = gr.Slider(label="last_k", minimum=1, maximum=50, step=1, value=37)
711
- guidance = gr.Slider(label="Guidance (0 => fallback rule)", minimum=0.0, maximum=15.0, step=0.1, value=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
 
713
  with gr.Row():
714
- save_outputs = gr.Checkbox(label="Save outputs to disk", value=False)
715
  out_dir = gr.Textbox(label="Output dir (only if save enabled)", value="./variants_demo_gradio")
716
 
717
- run_btn = gr.Button("Generate", variant="primary")
718
 
719
  with gr.Column(scale=8):
720
- out_img = gr.Image(label="Result", type="pil")
721
  meta_json = gr.Code(label="Meta / Debug (JSON)", language="json")
722
 
723
- def _map_variant_display_to_key(vdisp: str) -> str:
724
- return DISPLAY_TO_KEY.get(vdisp, "CritiFusion")
725
-
726
  run_btn.click(
727
- fn=lambda p, s, h, w, pre, vdisp, ts, lk, g, sv, od: ui_run_once_gpu(
728
- p, s, h, w, pre, _map_variant_display_to_key(vdisp), ts, lk, g, sv, od
729
- ),
730
- inputs=[user_prompt, seed, H, W, preset, variant_display, total_steps_refine, last_k, guidance, save_outputs, out_dir],
731
- outputs=[out_img, meta_json],
732
- api_name=False,
733
  )
734
 
735
  demo.queue().launch(
736
  debug=True,
737
- share=True,
738
- theme=gr.themes.Soft(),
739
  )
 
1
  # =========================
2
  # ONE-CELL: SDXL + CritiCore + SpecFusion + Gradio UI
3
+ # - Keep original "Enabled Variants" pills UI (CheckboxGroup)
4
+ # - Enforce: ONLY ONE can be selected at a time (auto-fix on change)
5
+ # - 4 variants (but names are clearer)
6
+ # - No Radio.format_fn (older gradio safe)
7
  # =========================
8
 
9
  import os, re, io, json, time, base64, asyncio, inspect, traceback
 
26
  os.environ["TOGETHER_NO_BANNER"] = "1"
27
 
28
  # =========================
29
+ # 0) Variants (MUST be BEFORE Blocks)
30
  # =========================
31
+ # internal_key -> UI display label
 
32
  VARIANT_LABELS = {
33
+ "base_original": "Base (Original Prompt)",
34
+ "base_multi_llm": "Base (MoA Tags)",
35
+ "CritiFusion": "CritiFusion (MoA+VLM+SpecFusion)",
36
+ "criticore_on_original__specfusion": "CritiFusion (Original+VLM+SpecFusion)",
37
  }
 
38
 
39
+ # order for gallery display
40
+ VARIANT_ORDER = [
41
+ VARIANT_LABELS["base_original"],
42
+ VARIANT_LABELS["base_multi_llm"],
43
+ VARIANT_LABELS["CritiFusion"],
44
+ VARIANT_LABELS["criticore_on_original__specfusion"],
45
+ ]
46
+
47
+ RHO_T_DEFAULT = 0.85 # fixed
48
 
49
+ # ---- SAFETY: do NOT hardcode API keys ----
50
  TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY", "").strip()
51
  if not TOGETHER_API_KEY:
52
+ print("[Warn] TOGETHER_API_KEY is not set. Together-based variants will error if selected.")
53
 
54
  # =========================
55
  # 1) SDXL init
 
114
  return z0, x0
115
 
116
  @torch.no_grad()
117
+ def img2img_latent(prompt: str, image_or_latent, strength: float, guidance: float, steps: int, seed: int, neg: str):
118
  g = torch.Generator(device=DEVICE).manual_seed(int(seed))
119
  out = SDXL_i2i(
120
  prompt=prompt,
 
124
  num_inference_steps=int(steps),
125
  generator=g,
126
  output_type="latent",
127
+ negative_prompt=neg
128
  )
129
  return out.images
130
 
 
160
  mid = (lo + hi) // 2
161
  cand = " ".join(words[:mid]) if mid > 0 else ""
162
  if _count_tokens(cand) <= max_tok:
163
+ best = cand; lo = mid + 1
 
164
  else:
165
  hi = mid - 1
166
  return best.strip()
 
173
  for t in items:
174
  key = re.sub(r"\s+", " ", t.lower()).strip()
175
  if key and key not in seen:
176
+ seen.add(key); out.append(t.strip())
 
177
  return out
178
 
179
  def _order_tags(subject_first: List[str], rest: List[str]) -> List[str]:
 
195
  elif any(k in lt for k in detail_kw): buckets["detail"].append(t)
196
  else: buckets["other"].append(t)
197
 
198
+ return buckets["subject"] + buckets["style"] + buckets["composition"] + buckets["lighting"] + buckets["color"] + buckets["detail"] + buckets["other"]
 
 
 
199
 
200
  def pil_to_base64(img: Image.Image, fmt: str = "PNG") -> str:
201
  buf = io.BytesIO()
 
211
  if inspect.iscoroutinefunction(fn):
212
  await fn()
213
  else:
214
+ try: fn()
215
+ except Exception: pass
 
 
216
  except Exception:
217
  pass
218
 
 
223
  try:
224
  loop = asyncio.get_event_loop()
225
  if loop.is_running():
226
+ return loop.run_until_complete(coro) # nest_asyncio enabled
227
  return loop.run_until_complete(coro)
228
  except RuntimeError:
229
  return asyncio.run(coro)
 
280
 
281
  def _extract_tag(text: str, tag: str, fallback: str = "") -> str:
282
  s = (text or "").strip()
283
+ r = _TAG_RE(tag); m = r.search(s)
284
+ if m: return m.group(1).strip()
 
 
285
  s2 = s.replace("&lt;","<").replace("&gt;",">")
286
  m2 = r.search(s2)
287
  return m2.group(1).strip() if m2 else fallback.strip()
 
303
  async def decompose_components(self, user_prompt: str) -> List[str]:
304
  client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"])
305
  try:
306
+ tasks = [client.chat.completions.create(
307
+ model=m,
308
+ messages=[{"role":"system","content": _DECOMP_SYS},
309
+ {"role":"user","content": user_prompt}],
310
+ temperature=0.4, max_tokens=256
311
+ ) for m in LLM_MULTI_CANDIDATES]
 
 
 
312
  rs = await asyncio.gather(*tasks, return_exceptions=True)
313
  texts = []
314
  for r in rs:
315
+ try: texts.append(r.choices[0].message.content)
316
+ except Exception: pass
 
 
317
  if not texts:
318
  return []
319
  joined = "\n\n---\n\n".join(texts)
 
327
  try:
328
  obj = json.loads(txt)
329
  except Exception:
330
+ s,e = txt.find("{"), txt.rfind("}")
331
+ obj = json.loads(txt[s:e+1]) if (s!=-1 and e!=-1) else {"components":[]}
332
  comps = [c.strip() for c in obj.get("components", []) if isinstance(c, str) and c.strip()]
333
  return comps[:6]
334
  finally:
 
340
  seed_pos = _dedup_keep_order(seed["seed_pos"])
341
  seed_neg = seed["seed_neg"]
342
  try:
343
+ tasks = [client.chat.completions.create(
344
+ model=m,
345
+ messages=[{"role":"system","content": _TXT_SYS},
346
+ {"role":"user","content":
347
+ f"Short idea: {user_prompt}\nSeed: {', '.join(seed_pos)}\nOutput: a single comma-separated tag list."}],
348
+ temperature=0.7, max_tokens=220
349
+ ) for m in LLM_MULTI_CANDIDATES]
 
 
 
350
  rs = await asyncio.gather(*tasks, return_exceptions=True)
351
  props = []
352
  for r in rs:
353
+ try: props.append(r.choices[0].message.content)
354
+ except Exception: pass
 
 
355
 
356
  if not props:
357
  pos = ", ".join([user_prompt.strip()] + seed_pos)
 
372
  ordered = _order_tags([tags[0]], tags[1:])
373
  pos = ", ".join(_dedup_keep_order(ordered))
374
 
375
+ # quality floor
376
  for q in ["high detailed","sharp focus","8k","UHD"]:
377
  if q.lower() not in {t.lower() for t in _split_tags(pos)}:
378
  pos += ", " + q
 
400
  "Output EXACTLY two tags:\n"
401
  "<issues>...</issues>\n<refined>...</refined>"
402
  )
 
403
  try:
404
  tasks = []
405
  for m in VLM_CANDIDATES:
 
415
  rs = await asyncio.gather(*tasks, return_exceptions=True)
416
  ok = []
417
  for m, r in zip(VLM_CANDIDATES, rs):
418
+ try: ok.append((m, r.choices[0].message.content))
419
+ except Exception: pass
 
 
420
 
421
  if not ok:
422
  return {"refined": original_prompt, "issues_merged": ""}
 
425
  for m, raw in ok:
426
  issues = _extract_tag(raw, "issues", "")
427
  refined = _extract_tag(raw, "refined", original_prompt)
428
+ if refined.strip(): refined_items.append((m, refined.strip()))
429
+ if issues.strip(): per_vlm_issues[m] = _summarize_issues_lines(issues, 5)
 
 
430
 
431
+ joined_issues = "\n".join(f"[{m}] {t}" for m,t in per_vlm_issues.items())
432
+ joined_refined = "\n".join(f"[{m}] {t}" for m,t in refined_items) if refined_items else original_prompt
433
 
434
  merged = await client.chat.completions.create(
435
  model=self.aggregator,
 
458
  return text
459
 
460
  # =========================
461
+ # 5) SpecFusion (latent FFT gate)
462
  # =========================
463
  @torch.no_grad()
464
  def frequency_fusion(
 
491
  x = x + torch.randn_like(x) * 0.001
492
  return x.to(dtype=x_hi_latent.dtype)
493
 
494
+ def _decode_to_pil(latents):
495
+ out = decode_image_sdxl(latents, SDXL_i2i)
496
+ if isinstance(out, Image.Image):
497
+ return out
498
+ if hasattr(out, "images"):
499
+ return out.images[0]
500
+ return out
501
 
502
  def _guidance_for_k(k: int) -> float:
503
  if k >= 20: return 12.0
 
505
  return 5.2
506
 
507
  # =========================
508
+ # 6) ONE-variant generator (because UI enforces single selection)
509
  # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  async def generate_one_variant(
511
  user_prompt: str,
512
  seed: int,
513
+ H: int,
514
+ W: int,
 
515
  total_steps_refine: int,
516
  last_k: int,
517
  guidance: float,
518
+ preset: str,
519
+ variant_key: str,
520
+ out_dir: Optional[Path] = None,
521
+ ) -> Tuple[Image.Image, str, Dict[str, object]]:
522
+ """
523
+ Returns:
524
+ img, display_name, meta_dict
525
+ """
526
+ meta: Dict[str, object] = {
 
 
 
 
 
 
 
 
 
 
 
527
  "user_prompt": user_prompt,
 
528
  "variant_key": variant_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  }
530
 
531
+ def _save(im: Image.Image, display_name: str):
532
+ if out_dir is None:
533
+ return
534
+ out_dir.mkdir(parents=True, exist_ok=True)
535
+ safe = re.sub(r"[^a-zA-Z0-9_\\-]+", "_", display_name)[:120]
536
+ im.save(out_dir / f"{safe}.png")
537
+
538
+ # ----------------------------------------------------------
539
+ # Variant 1: Base (Original Prompt) [NO Together needed]
540
+ # ----------------------------------------------------------
541
  if variant_key == "base_original":
542
+ z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
543
+ meta.update({"note": "SDXL base generation from original prompt."})
544
+ _save(base_og, VARIANT_LABELS[variant_key])
545
+ return base_og, VARIANT_LABELS[variant_key], meta
546
 
547
+ # The rest need Together
548
+ if not TOGETHER_API_KEY:
549
+ raise RuntimeError("TOGETHER_API_KEY not set, but selected variant requires Together.")
550
+
551
+ critic = CritiCore(preset=preset)
552
+
553
+ # Common refine params
554
+ lk = int(last_k)
555
+ strength = float(strength_for_last_k(lk, total_steps_refine))
556
+ use_guidance = float(guidance) if float(guidance) > 0 else float(_guidance_for_k(lk))
557
+ steps = int(total_steps_refine)
558
+
559
+ meta.update({"strength": strength, "guidance": use_guidance, "steps": steps, "last_k": lk})
560
 
561
+ # ----------------------------------------------------------
562
+ # Variant 2: Base (MoA Tags)
563
+ # ----------------------------------------------------------
564
+ if variant_key == "base_multi_llm":
565
+ pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
566
+ z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags)
567
+ meta.update({
568
+ "pos_tags_77": pos_tags_77,
569
+ "neg_tags": neg_tags,
570
+ "note": "SDXL base generation from MoA-generated tags."
571
+ })
572
+ _save(base_enh, VARIANT_LABELS[variant_key])
573
+ return base_enh, VARIANT_LABELS[variant_key], meta
574
+
575
+ # ----------------------------------------------------------
576
+ # Variant 3: CritiFusion (MoA+VLM+SpecFusion)
577
+ # ----------------------------------------------------------
578
  if variant_key == "CritiFusion":
579
+ pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
580
+ comps = await critic.decompose_components(user_prompt)
581
+
582
+ z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags)
583
 
584
+ vlm_out = await critic.vlm_refine(base_enh, pos_tags_77, comps or [])
585
+ vlm_agg_77 = vlm_out.get("refined") or pos_tags_77
586
  refined_on_enh = CritiCore.merge_vlm_multi_text(vlm_agg_77, pos_tags_77)
 
 
587
 
588
  z_ref = img2img_latent(
589
  refined_on_enh, z0_enh,
590
+ strength=strength, guidance=use_guidance, steps=steps,
591
+ seed=seed + 2100 + lk,
592
+ neg=DEFAULT_NEG
593
  )
594
  fused_lat = frequency_fusion(z_ref, z0_enh, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
595
+ img_sf = _decode_to_pil(fused_lat)
596
+
597
+ meta.update({
598
+ "pos_tags_77": pos_tags_77,
599
+ "neg_tags": neg_tags,
600
+ "components": comps,
601
+ "vlm_refined_77": vlm_agg_77,
602
+ "enhanced_prompt_77": refined_on_enh,
603
+ "vlm_issues": vlm_out.get("issues_merged", ""),
604
+ "note": "MoA tags + VLM critique prompt + img2img + SpecFusion."
605
+ })
606
+ _save(img_sf, VARIANT_LABELS[variant_key])
607
+ return img_sf, VARIANT_LABELS[variant_key], meta
608
+
609
+ # ----------------------------------------------------------
610
+ # Variant 4: CritiFusion (Original+VLM+SpecFusion)
611
+ # ----------------------------------------------------------
612
+ if variant_key == "criticore_on_original__specfusion":
613
+ pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
614
+ comps = await critic.decompose_components(user_prompt)
615
+
616
+ z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
617
+
618
+ vlm_on_og = await critic.vlm_refine(base_og, user_prompt, comps or [])
619
+ refined_og_77 = clip77_strict(vlm_on_og.get("refined") or user_prompt, 77)
620
+ refined_merge = CritiCore.merge_vlm_multi_text(refined_og_77, pos_tags_77)
621
 
622
+ z_ref = img2img_latent(
623
+ refined_merge, z0_og,
624
+ strength=strength, guidance=use_guidance, steps=steps,
625
+ seed=seed + 2400 + lk,
626
+ neg=DEFAULT_NEG
627
+ )
628
+ fused_lat = frequency_fusion(z_ref, z0_og, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
629
+ img_sf = _decode_to_pil(fused_lat)
630
+
631
+ meta.update({
632
+ "pos_tags_77": pos_tags_77,
633
+ "neg_tags": neg_tags,
634
+ "components": comps,
635
+ "vlm_refined_77": refined_og_77,
636
+ "enhanced_prompt_77": refined_merge,
637
+ "vlm_issues": vlm_on_og.get("issues_merged", ""),
638
+ "note": "Original prompt + VLM critique prompt + img2img + SpecFusion."
639
+ })
640
+ _save(img_sf, VARIANT_LABELS[variant_key])
641
+ return img_sf, VARIANT_LABELS[variant_key], meta
642
 
643
  raise ValueError(f"Unknown variant_key: {variant_key}")
644
 
645
  # =========================
646
+ # 7) UI callbacks
647
  # =========================
648
  def ui_run_once(
649
  user_prompt: str,
 
651
  H: int,
652
  W: int,
653
  preset: str,
 
654
  total_steps_refine: int,
655
  last_k: int,
656
  guidance: float,
657
+ enabled_variants_display: List[str],
658
  save_outputs: bool,
659
  out_dir: str,
660
  ):
661
  t0 = time.time()
662
  try:
663
  if not user_prompt or not user_prompt.strip():
664
+ return [], "Empty prompt."
665
+
666
+ # display -> internal
667
+ display_to_internal = {v: k for k, v in VARIANT_LABELS.items()}
668
+ chosen_display = (enabled_variants_display or [])[-1:] # enforce single here too
669
+ if not chosen_display:
670
+ return [], "Please select ONE variant."
671
 
672
+ chosen_display = chosen_display[0]
673
+ variant_key = display_to_internal.get(chosen_display)
674
+ if variant_key is None:
675
+ return [], f"Unknown selected variant: {chosen_display}"
676
 
677
+ out_path = Path(out_dir) if (save_outputs and out_dir) else None
678
+
679
+ img, disp_name, meta = _run_async(generate_one_variant(
680
  user_prompt=user_prompt.strip(),
681
  seed=int(seed),
682
  H=int(H), W=int(W),
 
 
683
  total_steps_refine=int(total_steps_refine),
684
  last_k=int(last_k),
685
  guidance=float(guidance),
686
+ preset=preset,
687
+ variant_key=variant_key,
688
+ out_dir=out_path,
689
  ))
690
 
691
+ meta["ui"] = {
692
+ "seed": int(seed),
693
+ "H": int(H),
694
+ "W": int(W),
695
+ "preset": preset,
696
+ "total_steps_refine": int(total_steps_refine),
697
+ "last_k": int(last_k),
698
+ "guidance": float(guidance),
699
+ "selected_variant": chosen_display,
700
+ "save_outputs": bool(save_outputs),
701
+ "out_dir": out_dir if save_outputs else None,
702
+ }
703
  meta["elapsed_sec"] = round(time.time() - t0, 3)
704
+
705
+ gallery = [(img, disp_name)]
706
+ return gallery, json.dumps(meta, ensure_ascii=False, indent=2)
707
 
708
  except Exception:
709
+ return [], traceback.format_exc()
710
 
711
  @spaces.GPU
712
  def ui_run_once_gpu(*args, **kwargs):
713
  return ui_run_once(*args, **kwargs)
714
 
715
  # =========================
716
+ # 8) Single-select enforcement for CheckboxGroup
717
  # =========================
718
+ def enforce_single_variant(new_list: List[str], prev_list: List[str]):
719
+ new_list = new_list or []
720
+ prev_list = prev_list or []
721
 
722
+ new_set = set(new_list)
723
+ prev_set = set(prev_list)
724
+
725
+ added = list(new_set - prev_set)
726
+ if added:
727
+ # keep the newly added one
728
+ chosen = added[-1]
729
+ out = [chosen]
730
+ else:
731
+ # no added; maybe removed or same; if multi exists, keep last item
732
+ out = new_list[-1:] if len(new_list) > 1 else new_list
733
+
734
+ return out, out # update checkbox value + state
735
+
736
+ # =========================
737
+ # 9) Gradio UI
738
+ # =========================
739
  with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo:
740
  gr.Markdown(
741
  "## CritiFusion Demo (SDXL)\n"
742
+ "- Keep **Enabled Variants** pills UI, but **only one** can be selected.\n"
743
  f"- Device: **{DEVICE_STR}**, DType: **{DTYPE}**\n"
744
  f"- Together API: {'✅ set' if TOGETHER_API_KEY else '❌ missing (set TOGETHER_API_KEY)'}"
745
  )
 
748
  with gr.Column(scale=7):
749
  user_prompt = gr.Textbox(
750
  label="Prompt",
751
+ value="A fluffy orange cat lying on a window ledge, front-facing, stylized 3D, soft indoor lighting",
752
  lines=3,
753
  )
754
  with gr.Row():
 
757
  with gr.Row():
758
  H = gr.Number(label="H", value=1024, precision=0)
759
  W = gr.Number(label="W", value=1024, precision=0)
 
 
 
 
 
 
 
760
  with gr.Row():
761
  total_steps_refine = gr.Slider(label="total_steps_refine", minimum=10, maximum=80, step=1, value=50)
762
  last_k = gr.Slider(label="last_k", minimum=1, maximum=50, step=1, value=37)
763
+
764
+ guidance = gr.Slider(
765
+ label="Guidance (0 => fallback rule)",
766
+ minimum=0.0, maximum=15.0, step=0.1, value=0.0
767
+ )
768
+
769
+ # --- pills UI, but single-select enforced ---
770
+ selected_state = gr.State([VARIANT_LABELS["base_original"]])
771
+
772
+ enabled_variants = gr.CheckboxGroup(
773
+ label="Enabled Variants (select ONE)",
774
+ choices=[VARIANT_LABELS[k] for k in VARIANT_LABELS.keys()],
775
+ value=[VARIANT_LABELS["base_original"]],
776
+ )
777
+
778
+ # enforce single selection on change
779
+ enabled_variants.change(
780
+ fn=enforce_single_variant,
781
+ inputs=[enabled_variants, selected_state],
782
+ outputs=[enabled_variants, selected_state],
783
+ )
784
 
785
  with gr.Row():
786
+ save_outputs = gr.Checkbox(label="Save output to disk", value=False)
787
  out_dir = gr.Textbox(label="Output dir (only if save enabled)", value="./variants_demo_gradio")
788
 
789
+ run_btn = gr.Button("Run", variant="primary")
790
 
791
  with gr.Column(scale=8):
792
+ gallery = gr.Gallery(label="Result", columns=1, height=600)
793
  meta_json = gr.Code(label="Meta / Debug (JSON)", language="json")
794
 
 
 
 
795
  run_btn.click(
796
+ fn=ui_run_once_gpu,
797
+ inputs=[user_prompt, seed, H, W, preset, total_steps_refine, last_k, guidance, enabled_variants, save_outputs, out_dir],
798
+ outputs=[gallery, meta_json],
799
+ api_name=False, # gradio-safe (avoid schema issues)
 
 
800
  )
801
 
802
  demo.queue().launch(
803
  debug=True,
804
+ share=True, # optional; helps if you run outside Spaces
805
+ show_api=False,
806
  )