# ========================= # ONE-CELL: SDXL + CritiCore + SpecFusion + Gradio UI # - Keep original "Enabled Variants" pills UI (CheckboxGroup) # - Enforce: ONLY ONE can be selected at a time (auto-fix on change) # - 4 variants (but names are clearer) # - No Radio.format_fn (older gradio safe) # ========================= import os, re, io, json, time, base64, asyncio, inspect, traceback from pathlib import Path from typing import List, Dict, Optional, Tuple import spaces import torch from PIL import Image import nest_asyncio nest_asyncio.apply() import gradio as gr from diffusers import ( StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DPMSolverMultistepScheduler, ) os.environ["TOGETHER_NO_BANNER"] = "1" # ========================= # 0) Variants (MUST be BEFORE Blocks) # ========================= # internal_key -> UI display label VARIANT_LABELS = { "base_original": "Base (Original Prompt)", "base_multi_llm": "Base (MoA Tags)", "CritiFusion": "CritiFusion (MoA+VLM+SpecFusion)", "criticore_on_original__specfusion": "CritiFusion (Original+VLM+SpecFusion)", } # order for gallery display VARIANT_ORDER = [ VARIANT_LABELS["base_original"], VARIANT_LABELS["base_multi_llm"], VARIANT_LABELS["CritiFusion"], VARIANT_LABELS["criticore_on_original__specfusion"], ] RHO_T_DEFAULT = 0.85 # fixed # ---- SAFETY: do NOT hardcode API keys ---- TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY", "").strip() if not TOGETHER_API_KEY: print("[Warn] TOGETHER_API_KEY is not set. Together-based variants will error if selected.") # ========================= # 1) SDXL init # ========================= DEVICE_STR = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = torch.device(DEVICE_STR) DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 SDXL_ID = os.environ.get("SDXL_ID", "stabilityai/stable-diffusion-xl-base-1.0") print(f"[Init] DEVICE={DEVICE_STR} DTYPE={DTYPE} SDXL_ID={SDXL_ID}") SDXL_base = StableDiffusionXLPipeline.from_pretrained(SDXL_ID, torch_dtype=DTYPE).to(DEVICE) SDXL_i2i = StableDiffusionXLImg2ImgPipeline.from_pretrained(SDXL_ID, torch_dtype=DTYPE).to(DEVICE) for p in (SDXL_base, SDXL_i2i): try: p.enable_vae_slicing() p.enable_attention_slicing() except Exception: pass p.scheduler = DPMSolverMultistepScheduler.from_config(p.scheduler.config, use_karras_sigmas=True) DEFAULT_NEG = ( "blurry, low quality, artifacts, watermark, extra fingers, missing limbs, " "over-sharpened, harsh lighting, oversaturated" ) @torch.no_grad() def decode_image_sdxl(latents: torch.Tensor, pipe: StableDiffusionXLImg2ImgPipeline, output_type="pil"): vae = pipe.vae needs_upcast = (vae.dtype in (torch.float16, torch.bfloat16)) and bool(getattr(vae.config, "force_upcast", False)) if needs_upcast: try: pipe.upcast_vae() except Exception: pipe.vae = pipe.vae.to(torch.float32) vae = pipe.vae lat = latents.to(device=vae.device, dtype=(next(vae.post_quant_conv.parameters()).dtype)) lat = lat / vae.config.scaling_factor out = vae.decode(lat) x = out[0] if isinstance(out, (list, tuple)) else (out.sample if hasattr(out, "sample") else out) if getattr(pipe, "watermark", None) is not None: x = pipe.watermark.apply_watermark(x) img = pipe.image_processor.postprocess(x.detach(), output_type=output_type)[0] return img @torch.no_grad() def base_sample_latent(prompt: str, seed: int, H: int, W: int, neg: str): g = torch.Generator(device=DEVICE).manual_seed(int(seed)) out = SDXL_base( prompt=prompt, negative_prompt=neg, height=int(H), width=int(W), guidance_scale=4.5, num_inference_steps=50, generator=g, output_type="latent" ) z0 = out.images x0 = decode_image_sdxl(z0, SDXL_i2i) return z0, x0 @torch.no_grad() def img2img_latent(prompt: str, image_or_latent, strength: float, guidance: float, steps: int, seed: int, neg: str): g = torch.Generator(device=DEVICE).manual_seed(int(seed)) out = SDXL_i2i( prompt=prompt, image=image_or_latent, strength=float(strength), guidance_scale=float(guidance), num_inference_steps=int(steps), generator=g, output_type="latent", negative_prompt=neg ) return out.images def strength_for_last_k(k: int, total_steps: int) -> float: k = max(1, int(k)) return min(0.95, max(0.01, float(k) / float(max(1, total_steps)))) # ========================= # 2) CLIP-77 + text utils # ========================= try: from transformers import CLIPTokenizerFast _clip_tok = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14") def _count_tokens(txt: str) -> int: return len(_clip_tok(txt, add_special_tokens=True, truncation=False)["input_ids"]) except Exception: _clip_tok = None def _count_tokens(txt: str) -> int: return int(len(re.findall(r"\w+", txt)) * 1.3) def _cleanup_commas(s: str) -> str: s = re.sub(r"\s*,\s*", ", ", (s or "").strip()) s = re.sub(r"(,\s*){2,}", ", ", s) return s.strip(" ,") def clip77_strict(text: str, max_tok: int = 77) -> str: text = (text or "").strip() if _count_tokens(text) <= max_tok: return text words = text.split() lo, hi, best = 0, len(words), "" while lo <= hi: mid = (lo + hi) // 2 cand = " ".join(words[:mid]) if mid > 0 else "" if _count_tokens(cand) <= max_tok: best = cand; lo = mid + 1 else: hi = mid - 1 return best.strip() def _split_tags(s: str) -> List[str]: return [p.strip() for p in re.split(r",|\n", (s or "").strip()) if p.strip()] def _dedup_keep_order(items: List[str]) -> List[str]: seen, out = set(), [] for t in items: key = re.sub(r"\s+", " ", t.lower()).strip() if key and key not in seen: seen.add(key); out.append(t.strip()) return out def _order_tags(subject_first: List[str], rest: List[str]) -> List[str]: buckets = {"subject": [], "style": [], "composition": [], "lighting": [], "color": [], "detail": [], "other": []} style_kw = ("style","painterly","illustration","photorealistic","neon","poster","matte painting","watercolor","cyberpunk") comp_kw = ("composition","rule of thirds","centered","symmetry","balanced composition") light_kw = ("lighting","light","glow","glowing","rim","sunset","sunrise","golden hour","global illumination","cinematic") color_kw = ("color","palette","vibrant","muted","monochrome","pastel","warm","cool","balanced contrast") detail_kw= ("detailed","hyperdetailed","texture","intricate","high detail","highly detailed","sharp focus","uhd","8k") for t in subject_first: if t: buckets["subject"].append(t) for t in rest: lt = t.lower() if any(k in lt for k in style_kw): buckets["style"].append(t) elif any(k in lt for k in comp_kw): buckets["composition"].append(t) elif any(k in lt for k in light_kw): buckets["lighting"].append(t) elif any(k in lt for k in color_kw): buckets["color"].append(t) elif any(k in lt for k in detail_kw): buckets["detail"].append(t) else: buckets["other"].append(t) return buckets["subject"] + buckets["style"] + buckets["composition"] + buckets["lighting"] + buckets["color"] + buckets["detail"] + buckets["other"] def pil_to_base64(img: Image.Image, fmt: str = "PNG") -> str: buf = io.BytesIO() img.save(buf, format=fmt) return base64.b64encode(buf.getvalue()).decode("ascii") async def _maybe_close_async_together(client) -> None: try: if hasattr(client, "aclose") and inspect.iscoroutinefunction(client.aclose): await client.aclose() elif hasattr(client, "close"): fn = client.close if inspect.iscoroutinefunction(fn): await fn() else: try: fn() except Exception: pass except Exception: pass # ========================= # 3) Async runner # ========================= def _run_async(coro): try: loop = asyncio.get_event_loop() if loop.is_running(): return loop.run_until_complete(coro) # nest_asyncio enabled return loop.run_until_complete(coro) except RuntimeError: return asyncio.run(coro) # ========================= # 4) CritiCore (Together) # ========================= from together import AsyncTogether AGGREGATOR_MODEL = os.environ.get("AGGREGATOR_MODEL", "Qwen/Qwen2.5-72B-Instruct-Turbo") LLM_MULTI_CANDIDATES = [ "meta-llama/Llama-3.3-70B-Instruct-Turbo", "Qwen/Qwen2.5-72B-Instruct-Turbo", "Qwen/Qwen2.5-Coder-32B-Instruct", "deepseek-ai/DeepSeek-V3", "nvidia/NVIDIA-Nemotron-Nano-9B-v2", ] _env_list = [s.strip() for s in os.environ.get("VLM_MOA_CANDIDATES","").split(",") if s.strip()] VLM_CANDIDATES = _env_list or ["meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"] TAG_PRESETS = { "hq_preference": { "seed_pos": [ "balanced composition", "natural color palette","vibrant colors","balanced contrast", "high detail","highly detailed","hyperdetailed","sharp focus", "UHD","8k" ], "seed_neg": [ "low quality","blurry","watermark","jpeg artifacts","overexposed","underexposed", "color banding","extra fingers","missing limbs","disfigured","mutated hands" ] } } _DECOMP_SYS = ( "Decompose the user's visual instruction into 3-6 concrete, checkable visual components " "(entities + interactions + spatial relations). Return ONLY JSON: " '{"components":["..."]}' ) _TXT_SYS = ( "Expand a VERY SHORT visual idea into a COMMA-SEPARATED TAG LIST for SDXL.\n" "Constraints:\n" "- Start with the subject phrase first.\n" "- Prioritize composition, lighting, color, and detail over style.\n" "- Use at most TWO style tags if any.\n" "- 16–26 concise tags total. Commas only, no sentences, no 'and'. No trailing period.\n" "- Prefer human-preference aesthetics; keep 'high detailed', 'sharp focus', '8k', 'UHD'." ) def _TAG_RE(tag: str): return re.compile(rf"<\s*{tag}\s*>(.*?)", re.S|re.I) def _extract_tag(text: str, tag: str, fallback: str = "") -> str: s = (text or "").strip() r = _TAG_RE(tag); m = r.search(s) if m: return m.group(1).strip() s2 = s.replace("<","<").replace(">",">") m2 = r.search(s2) return m2.group(1).strip() if m2 else fallback.strip() def _summarize_issues_lines(text: str, max_lines: int = 5) -> str: if not text: return "" parts = [p.strip(" -•\t") for p in re.split(r"[\n;]+", text) if p.strip()] parts = parts[:max_lines] return "\n".join(f"- {p}" for p in parts) class CritiCore: def __init__(self, preset: str = "hq_preference", aggregator_model: str = AGGREGATOR_MODEL): if not os.environ.get("TOGETHER_API_KEY"): raise RuntimeError("Missing TOGETHER_API_KEY in environment.") self.preset = preset self.aggregator = aggregator_model async def decompose_components(self, user_prompt: str) -> List[str]: client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"]) try: tasks = [client.chat.completions.create( model=m, messages=[{"role":"system","content": _DECOMP_SYS}, {"role":"user","content": user_prompt}], temperature=0.4, max_tokens=256 ) for m in LLM_MULTI_CANDIDATES] rs = await asyncio.gather(*tasks, return_exceptions=True) texts = [] for r in rs: try: texts.append(r.choices[0].message.content) except Exception: pass if not texts: return [] joined = "\n\n---\n\n".join(texts) merged = await client.chat.completions.create( model=self.aggregator, messages=[{"role":"system","content": "Merge JSON candidates and return ONLY {'components':[...]}."}, {"role":"user","content": joined}], temperature=0.2, max_tokens=256 ) txt = merged.choices[0].message.content try: obj = json.loads(txt) except Exception: s,e = txt.find("{"), txt.rfind("}") obj = json.loads(txt[s:e+1]) if (s!=-1 and e!=-1) else {"components":[]} comps = [c.strip() for c in obj.get("components", []) if isinstance(c, str) and c.strip()] return comps[:6] finally: await _maybe_close_async_together(client) async def make_tags(self, user_prompt: str, clip77: bool = True) -> Tuple[str, str]: client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"]) seed = TAG_PRESETS.get(self.preset, TAG_PRESETS["hq_preference"]) seed_pos = _dedup_keep_order(seed["seed_pos"]) seed_neg = seed["seed_neg"] try: tasks = [client.chat.completions.create( model=m, messages=[{"role":"system","content": _TXT_SYS}, {"role":"user","content": f"Short idea: {user_prompt}\nSeed: {', '.join(seed_pos)}\nOutput: a single comma-separated tag list."}], temperature=0.7, max_tokens=220 ) for m in LLM_MULTI_CANDIDATES] rs = await asyncio.gather(*tasks, return_exceptions=True) props = [] for r in rs: try: props.append(r.choices[0].message.content) except Exception: pass if not props: pos = ", ".join([user_prompt.strip()] + seed_pos) else: joined = "\n---\n".join(props) merged = await client.chat.completions.create( model=self.aggregator, messages=[{"role":"system","content": "Merge candidate tag lists into ONE comma list (16–26 tags). Subject first; at most TWO style tags; keep high detailed/sharp focus/8k/UHD."}, {"role":"user","content": joined}], temperature=0.2, max_tokens=240 ) raw = merged.choices[0].message.content tags = _dedup_keep_order(_split_tags(raw)) subject = user_prompt.strip().rstrip(",.") if subject and not any(subject.lower() == t.lower() for t in tags): tags = [subject] + tags ordered = _order_tags([tags[0]], tags[1:]) pos = ", ".join(_dedup_keep_order(ordered)) # quality floor for q in ["high detailed","sharp focus","8k","UHD"]: if q.lower() not in {t.lower() for t in _split_tags(pos)}: pos += ", " + q pos = _cleanup_commas(pos) if clip77 and _count_tokens(pos) > 77: pos = clip77_strict(pos, 77) neg = ", ".join(seed_neg) return pos, neg finally: await _maybe_close_async_together(client) async def vlm_refine(self, image: Image.Image, original_prompt: str, components: List[str]) -> Dict[str, object]: client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"]) b64 = pil_to_base64(image, "PNG") def _user_prompt_text() -> str: return ( "You are a precise image-grounded critic.\n" "1) List concrete visual problems and brief corrections.\n" "2) Provide a refined prompt that keeps the original intent.\n\n" f'Original prompt: "{original_prompt}"\n' f"Key components to check: {components}\n" "Output EXACTLY two tags:\n" "...\n..." ) try: tasks = [] for m in VLM_CANDIDATES: msgs = [ {"role":"system","content": "Return ONLY and . No extra text."}, {"role":"user","content": [ {"type":"text","text": _user_prompt_text()}, {"type":"image_url","image_url":{"url": f"data:image/png;base64,{b64}"}} ]} ] tasks.append(client.chat.completions.create(model=m, messages=msgs, temperature=0.2, max_tokens=420)) rs = await asyncio.gather(*tasks, return_exceptions=True) ok = [] for m, r in zip(VLM_CANDIDATES, rs): try: ok.append((m, r.choices[0].message.content)) except Exception: pass if not ok: return {"refined": original_prompt, "issues_merged": ""} refined_items, per_vlm_issues = [], {} for m, raw in ok: issues = _extract_tag(raw, "issues", "") refined = _extract_tag(raw, "refined", original_prompt) if refined.strip(): refined_items.append((m, refined.strip())) if issues.strip(): per_vlm_issues[m] = _summarize_issues_lines(issues, 5) joined_issues = "\n".join(f"[{m}] {t}" for m,t in per_vlm_issues.items()) joined_refined = "\n".join(f"[{m}] {t}" for m,t in refined_items) if refined_items else original_prompt merged = await client.chat.completions.create( model=self.aggregator, messages=[{"role":"system","content": "Merge multiple critics. Output ONLY (≤5 bullets) and (≤70 words)."}, {"role":"user","content": f"{joined_issues}\n\n----\n\n{joined_refined}"}], temperature=0.2, max_tokens=420 ) final_raw = merged.choices[0].message.content final_refined = clip77_strict(_extract_tag(final_raw, "refined", original_prompt), 77) issues_merged = _summarize_issues_lines(_extract_tag(final_raw, "issues", ""), 5) return {"refined": final_refined, "issues_merged": issues_merged} finally: await _maybe_close_async_together(client) @staticmethod def merge_vlm_multi_text(vlm_refined_77: str, tags_77: str) -> str: vlm_tags = _split_tags(vlm_refined_77) moa_tags = _split_tags(tags_77) merged = _dedup_keep_order(_order_tags([vlm_tags[0] if vlm_tags else ""], (vlm_tags[1:] + moa_tags))) merged = [t for t in merged if t] text = _cleanup_commas(", ".join(merged)) if _count_tokens(text) > 77: text = clip77_strict(text, 77) return text # ========================= # 5) SpecFusion (latent FFT gate) # ========================= @torch.no_grad() def frequency_fusion( x_hi_latent: torch.Tensor, x_lo_latent: torch.Tensor, base_c: float = 0.5, rho_t: float = 0.85, device=None, ) -> torch.Tensor: if device is None: device = x_hi_latent.device B, C, H, W = x_hi_latent.shape x_h = x_hi_latent.to(torch.float32).to(device) x_l = x_lo_latent.to(torch.float32).to(device) Xh = torch.fft.fftshift(torch.fft.fftn(x_h, dim=(-2, -1)), dim=(-2, -1)) Xl = torch.fft.fftshift(torch.fft.fftn(x_l, dim=(-2, -1)), dim=(-2, -1)) tau_h = int(H * base_c * (1 - rho_t)) tau_w = int(W * base_c * (1 - rho_t)) mask = torch.ones((B, C, H, W), device=device, dtype=torch.float32) cy, cx = H // 2, W // 2 if tau_h > 0 and tau_w > 0: mask[..., cy - tau_h : cy + tau_h, cx - tau_w : cx + tau_w] = rho_t Xf = Xh * mask + Xl * (1 - mask) x = torch.fft.ifftn(torch.fft.ifftshift(Xf, dim=(-2, -1)), dim=(-2, -1)).real x = x + torch.randn_like(x) * 0.001 return x.to(dtype=x_hi_latent.dtype) def _decode_to_pil(latents): out = decode_image_sdxl(latents, SDXL_i2i) if isinstance(out, Image.Image): return out if hasattr(out, "images"): return out.images[0] return out def _guidance_for_k(k: int) -> float: if k >= 20: return 12.0 if k >= 10: return 7.5 return 5.2 # ========================= # 6) ONE-variant generator (because UI enforces single selection) # ========================= async def generate_one_variant( user_prompt: str, seed: int, H: int, W: int, total_steps_refine: int, last_k: int, guidance: float, preset: str, variant_key: str, out_dir: Optional[Path] = None, ) -> Tuple[Image.Image, str, Dict[str, object]]: """ Returns: img, display_name, meta_dict """ meta: Dict[str, object] = { "user_prompt": user_prompt, "variant_key": variant_key, } def _save(im: Image.Image, display_name: str): if out_dir is None: return out_dir.mkdir(parents=True, exist_ok=True) safe = re.sub(r"[^a-zA-Z0-9_\\-]+", "_", display_name)[:120] im.save(out_dir / f"{safe}.png") # ---------------------------------------------------------- # Variant 1: Base (Original Prompt) [NO Together needed] # ---------------------------------------------------------- if variant_key == "base_original": z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG) meta.update({"note": "SDXL base generation from original prompt."}) _save(base_og, VARIANT_LABELS[variant_key]) return base_og, VARIANT_LABELS[variant_key], meta # The rest need Together if not TOGETHER_API_KEY: raise RuntimeError("TOGETHER_API_KEY not set, but selected variant requires Together.") critic = CritiCore(preset=preset) # Common refine params lk = int(last_k) strength = float(strength_for_last_k(lk, total_steps_refine)) use_guidance = float(guidance) if float(guidance) > 0 else float(_guidance_for_k(lk)) steps = int(total_steps_refine) meta.update({"strength": strength, "guidance": use_guidance, "steps": steps, "last_k": lk}) # ---------------------------------------------------------- # Variant 2: Base (MoA Tags) # ---------------------------------------------------------- if variant_key == "base_multi_llm": pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True) z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags) meta.update({ "pos_tags_77": pos_tags_77, "neg_tags": neg_tags, "note": "SDXL base generation from MoA-generated tags." }) _save(base_enh, VARIANT_LABELS[variant_key]) return base_enh, VARIANT_LABELS[variant_key], meta # ---------------------------------------------------------- # Variant 3: CritiFusion (MoA+VLM+SpecFusion) # ---------------------------------------------------------- if variant_key == "CritiFusion": pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True) comps = await critic.decompose_components(user_prompt) z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags) vlm_out = await critic.vlm_refine(base_enh, pos_tags_77, comps or []) vlm_agg_77 = vlm_out.get("refined") or pos_tags_77 refined_on_enh = CritiCore.merge_vlm_multi_text(vlm_agg_77, pos_tags_77) z_ref = img2img_latent( refined_on_enh, z0_enh, strength=strength, guidance=use_guidance, steps=steps, seed=seed + 2100 + lk, neg=DEFAULT_NEG ) fused_lat = frequency_fusion(z_ref, z0_enh, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE) img_sf = _decode_to_pil(fused_lat) meta.update({ "pos_tags_77": pos_tags_77, "neg_tags": neg_tags, "components": comps, "vlm_refined_77": vlm_agg_77, "enhanced_prompt_77": refined_on_enh, "vlm_issues": vlm_out.get("issues_merged", ""), "note": "MoA tags + VLM critique prompt + img2img + SpecFusion." }) _save(img_sf, VARIANT_LABELS[variant_key]) return img_sf, VARIANT_LABELS[variant_key], meta # ---------------------------------------------------------- # Variant 4: CritiFusion (Original+VLM+SpecFusion) # ---------------------------------------------------------- if variant_key == "criticore_on_original__specfusion": pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True) comps = await critic.decompose_components(user_prompt) z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG) vlm_on_og = await critic.vlm_refine(base_og, user_prompt, comps or []) refined_og_77 = clip77_strict(vlm_on_og.get("refined") or user_prompt, 77) refined_merge = CritiCore.merge_vlm_multi_text(refined_og_77, pos_tags_77) z_ref = img2img_latent( refined_merge, z0_og, strength=strength, guidance=use_guidance, steps=steps, seed=seed + 2400 + lk, neg=DEFAULT_NEG ) fused_lat = frequency_fusion(z_ref, z0_og, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE) img_sf = _decode_to_pil(fused_lat) meta.update({ "pos_tags_77": pos_tags_77, "neg_tags": neg_tags, "components": comps, "vlm_refined_77": refined_og_77, "enhanced_prompt_77": refined_merge, "vlm_issues": vlm_on_og.get("issues_merged", ""), "note": "Original prompt + VLM critique prompt + img2img + SpecFusion." }) _save(img_sf, VARIANT_LABELS[variant_key]) return img_sf, VARIANT_LABELS[variant_key], meta raise ValueError(f"Unknown variant_key: {variant_key}") # ========================= # 7) UI callbacks # ========================= def ui_run_once( user_prompt: str, seed: int, H: int, W: int, preset: str, total_steps_refine: int, last_k: int, guidance: float, enabled_variants_display: List[str], save_outputs: bool, out_dir: str, ): t0 = time.time() try: if not user_prompt or not user_prompt.strip(): return [], "Empty prompt." # display -> internal display_to_internal = {v: k for k, v in VARIANT_LABELS.items()} chosen_display = (enabled_variants_display or [])[-1:] # enforce single here too if not chosen_display: return [], "Please select ONE variant." chosen_display = chosen_display[0] variant_key = display_to_internal.get(chosen_display) if variant_key is None: return [], f"Unknown selected variant: {chosen_display}" out_path = Path(out_dir) if (save_outputs and out_dir) else None img, disp_name, meta = _run_async(generate_one_variant( user_prompt=user_prompt.strip(), seed=int(seed), H=int(H), W=int(W), total_steps_refine=int(total_steps_refine), last_k=int(last_k), guidance=float(guidance), preset=preset, variant_key=variant_key, out_dir=out_path, )) meta["ui"] = { "seed": int(seed), "H": int(H), "W": int(W), "preset": preset, "total_steps_refine": int(total_steps_refine), "last_k": int(last_k), "guidance": float(guidance), "selected_variant": chosen_display, "save_outputs": bool(save_outputs), "out_dir": out_dir if save_outputs else None, } meta["elapsed_sec"] = round(time.time() - t0, 3) gallery = [(img, disp_name)] return gallery, json.dumps(meta, ensure_ascii=False, indent=2) except Exception: return [], traceback.format_exc() @spaces.GPU def ui_run_once_gpu(*args, **kwargs): return ui_run_once(*args, **kwargs) # ========================= # 8) Single-select enforcement for CheckboxGroup # ========================= def enforce_single_variant(new_list: List[str], prev_list: List[str]): new_list = new_list or [] prev_list = prev_list or [] new_set = set(new_list) prev_set = set(prev_list) added = list(new_set - prev_set) if added: # keep the newly added one chosen = added[-1] out = [chosen] else: # no added; maybe removed or same; if multi exists, keep last item out = new_list[-1:] if len(new_list) > 1 else new_list return out, out # update checkbox value + state # ========================= # 9) Gradio UI # ========================= with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo: gr.Markdown( "## CritiFusion Demo (SDXL)\n" "- Keep **Enabled Variants** pills UI, but **only one** can be selected.\n" f"- Device: **{DEVICE_STR}**, DType: **{DTYPE}**\n" f"- Together API: {'✅ set' if TOGETHER_API_KEY else '❌ missing (set TOGETHER_API_KEY)'}" ) with gr.Row(): with gr.Column(scale=7): user_prompt = gr.Textbox( label="Prompt", value="A fluffy orange cat lying on a window ledge, front-facing, stylized 3D, soft indoor lighting", lines=3, ) with gr.Row(): seed = gr.Number(label="Seed", value=2026, precision=0) preset = gr.Dropdown(label="Preset", choices=["hq_preference"], value="hq_preference") with gr.Row(): H = gr.Number(label="H", value=1024, precision=0) W = gr.Number(label="W", value=1024, precision=0) with gr.Row(): total_steps_refine = gr.Slider(label="total_steps_refine", minimum=10, maximum=80, step=1, value=50) last_k = gr.Slider(label="last_k", minimum=1, maximum=50, step=1, value=37) guidance = gr.Slider( label="Guidance (0 => fallback rule)", minimum=0.0, maximum=15.0, step=0.1, value=0.0 ) # --- pills UI, but single-select enforced --- selected_state = gr.State([VARIANT_LABELS["base_original"]]) enabled_variants = gr.CheckboxGroup( label="Enabled Variants (select ONE)", choices=[VARIANT_LABELS[k] for k in VARIANT_LABELS.keys()], value=[VARIANT_LABELS["base_original"]], ) # enforce single selection on change enabled_variants.change( fn=enforce_single_variant, inputs=[enabled_variants, selected_state], outputs=[enabled_variants, selected_state], ) with gr.Row(): save_outputs = gr.Checkbox(label="Save output to disk", value=False) out_dir = gr.Textbox(label="Output dir (only if save enabled)", value="./variants_demo_gradio") run_btn = gr.Button("Run", variant="primary") with gr.Column(scale=8): gallery = gr.Gallery(label="Result", columns=1, height=600) meta_json = gr.Code(label="Meta / Debug (JSON)", language="json") run_btn.click( fn=ui_run_once_gpu, inputs=[user_prompt, seed, H, W, preset, total_steps_refine, last_k, guidance, enabled_variants, save_outputs, out_dir], outputs=[gallery, meta_json], api_name=False, # gradio-safe (avoid schema issues) ) demo.queue().launch( debug=True, share=True, # optional; helps if you run outside Spaces )