Spaces:
Running
on
Zero
Running
on
Zero
| # ========================= | |
| # ONE-CELL: SDXL + CritiCore + SpecFusion + Gradio UI | |
| # - Keep original "Enabled Variants" pills UI (CheckboxGroup) | |
| # - Enforce: ONLY ONE can be selected at a time (auto-fix on change) | |
| # - 4 variants (but names are clearer) | |
| # - No Radio.format_fn (older gradio safe) | |
| # ========================= | |
| import os, re, io, json, time, base64, asyncio, inspect, traceback | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Tuple | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| import gradio as gr | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| StableDiffusionXLImg2ImgPipeline, | |
| DPMSolverMultistepScheduler, | |
| ) | |
| os.environ["TOGETHER_NO_BANNER"] = "1" | |
| # ========================= | |
| # 0) Variants (MUST be BEFORE Blocks) | |
| # ========================= | |
| # internal_key -> UI display label | |
| VARIANT_LABELS = { | |
| "base_original": "Base (Original Prompt)", | |
| "base_multi_llm": "Base (MoA Tags)", | |
| "CritiFusion": "CritiFusion (MoA+VLM+SpecFusion)", | |
| "criticore_on_original__specfusion": "CritiFusion (Original+VLM+SpecFusion)", | |
| } | |
| # order for gallery display | |
| VARIANT_ORDER = [ | |
| VARIANT_LABELS["base_original"], | |
| VARIANT_LABELS["base_multi_llm"], | |
| VARIANT_LABELS["CritiFusion"], | |
| VARIANT_LABELS["criticore_on_original__specfusion"], | |
| ] | |
| RHO_T_DEFAULT = 0.85 # fixed | |
| # ---- SAFETY: do NOT hardcode API keys ---- | |
| TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY", "").strip() | |
| if not TOGETHER_API_KEY: | |
| print("[Warn] TOGETHER_API_KEY is not set. Together-based variants will error if selected.") | |
| # ========================= | |
| # 1) SDXL init | |
| # ========================= | |
| DEVICE_STR = "cuda" if torch.cuda.is_available() else "cpu" | |
| DEVICE = torch.device(DEVICE_STR) | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| SDXL_ID = os.environ.get("SDXL_ID", "stabilityai/stable-diffusion-xl-base-1.0") | |
| print(f"[Init] DEVICE={DEVICE_STR} DTYPE={DTYPE} SDXL_ID={SDXL_ID}") | |
| SDXL_base = StableDiffusionXLPipeline.from_pretrained(SDXL_ID, torch_dtype=DTYPE).to(DEVICE) | |
| SDXL_i2i = StableDiffusionXLImg2ImgPipeline.from_pretrained(SDXL_ID, torch_dtype=DTYPE).to(DEVICE) | |
| for p in (SDXL_base, SDXL_i2i): | |
| try: | |
| p.enable_vae_slicing() | |
| p.enable_attention_slicing() | |
| except Exception: | |
| pass | |
| p.scheduler = DPMSolverMultistepScheduler.from_config(p.scheduler.config, use_karras_sigmas=True) | |
| DEFAULT_NEG = ( | |
| "blurry, low quality, artifacts, watermark, extra fingers, missing limbs, " | |
| "over-sharpened, harsh lighting, oversaturated" | |
| ) | |
| def decode_image_sdxl(latents: torch.Tensor, pipe: StableDiffusionXLImg2ImgPipeline, output_type="pil"): | |
| vae = pipe.vae | |
| needs_upcast = (vae.dtype in (torch.float16, torch.bfloat16)) and bool(getattr(vae.config, "force_upcast", False)) | |
| if needs_upcast: | |
| try: | |
| pipe.upcast_vae() | |
| except Exception: | |
| pipe.vae = pipe.vae.to(torch.float32) | |
| vae = pipe.vae | |
| lat = latents.to(device=vae.device, dtype=(next(vae.post_quant_conv.parameters()).dtype)) | |
| lat = lat / vae.config.scaling_factor | |
| out = vae.decode(lat) | |
| x = out[0] if isinstance(out, (list, tuple)) else (out.sample if hasattr(out, "sample") else out) | |
| if getattr(pipe, "watermark", None) is not None: | |
| x = pipe.watermark.apply_watermark(x) | |
| img = pipe.image_processor.postprocess(x.detach(), output_type=output_type)[0] | |
| return img | |
| def base_sample_latent(prompt: str, seed: int, H: int, W: int, neg: str): | |
| g = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| out = SDXL_base( | |
| prompt=prompt, | |
| negative_prompt=neg, | |
| height=int(H), width=int(W), | |
| guidance_scale=4.5, | |
| num_inference_steps=50, | |
| generator=g, | |
| output_type="latent" | |
| ) | |
| z0 = out.images | |
| x0 = decode_image_sdxl(z0, SDXL_i2i) | |
| return z0, x0 | |
| def img2img_latent(prompt: str, image_or_latent, strength: float, guidance: float, steps: int, seed: int, neg: str): | |
| g = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| out = SDXL_i2i( | |
| prompt=prompt, | |
| image=image_or_latent, | |
| strength=float(strength), | |
| guidance_scale=float(guidance), | |
| num_inference_steps=int(steps), | |
| generator=g, | |
| output_type="latent", | |
| negative_prompt=neg | |
| ) | |
| return out.images | |
| def strength_for_last_k(k: int, total_steps: int) -> float: | |
| k = max(1, int(k)) | |
| return min(0.95, max(0.01, float(k) / float(max(1, total_steps)))) | |
| # ========================= | |
| # 2) CLIP-77 + text utils | |
| # ========================= | |
| try: | |
| from transformers import CLIPTokenizerFast | |
| _clip_tok = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14") | |
| def _count_tokens(txt: str) -> int: | |
| return len(_clip_tok(txt, add_special_tokens=True, truncation=False)["input_ids"]) | |
| except Exception: | |
| _clip_tok = None | |
| def _count_tokens(txt: str) -> int: | |
| return int(len(re.findall(r"\w+", txt)) * 1.3) | |
| def _cleanup_commas(s: str) -> str: | |
| s = re.sub(r"\s*,\s*", ", ", (s or "").strip()) | |
| s = re.sub(r"(,\s*){2,}", ", ", s) | |
| return s.strip(" ,") | |
| def clip77_strict(text: str, max_tok: int = 77) -> str: | |
| text = (text or "").strip() | |
| if _count_tokens(text) <= max_tok: | |
| return text | |
| words = text.split() | |
| lo, hi, best = 0, len(words), "" | |
| while lo <= hi: | |
| mid = (lo + hi) // 2 | |
| cand = " ".join(words[:mid]) if mid > 0 else "" | |
| if _count_tokens(cand) <= max_tok: | |
| best = cand; lo = mid + 1 | |
| else: | |
| hi = mid - 1 | |
| return best.strip() | |
| def _split_tags(s: str) -> List[str]: | |
| return [p.strip() for p in re.split(r",|\n", (s or "").strip()) if p.strip()] | |
| def _dedup_keep_order(items: List[str]) -> List[str]: | |
| seen, out = set(), [] | |
| for t in items: | |
| key = re.sub(r"\s+", " ", t.lower()).strip() | |
| if key and key not in seen: | |
| seen.add(key); out.append(t.strip()) | |
| return out | |
| def _order_tags(subject_first: List[str], rest: List[str]) -> List[str]: | |
| buckets = {"subject": [], "style": [], "composition": [], "lighting": [], "color": [], "detail": [], "other": []} | |
| style_kw = ("style","painterly","illustration","photorealistic","neon","poster","matte painting","watercolor","cyberpunk") | |
| comp_kw = ("composition","rule of thirds","centered","symmetry","balanced composition") | |
| light_kw = ("lighting","light","glow","glowing","rim","sunset","sunrise","golden hour","global illumination","cinematic") | |
| color_kw = ("color","palette","vibrant","muted","monochrome","pastel","warm","cool","balanced contrast") | |
| detail_kw= ("detailed","hyperdetailed","texture","intricate","high detail","highly detailed","sharp focus","uhd","8k") | |
| for t in subject_first: | |
| if t: buckets["subject"].append(t) | |
| for t in rest: | |
| lt = t.lower() | |
| if any(k in lt for k in style_kw): buckets["style"].append(t) | |
| elif any(k in lt for k in comp_kw): buckets["composition"].append(t) | |
| elif any(k in lt for k in light_kw): buckets["lighting"].append(t) | |
| elif any(k in lt for k in color_kw): buckets["color"].append(t) | |
| elif any(k in lt for k in detail_kw): buckets["detail"].append(t) | |
| else: buckets["other"].append(t) | |
| return buckets["subject"] + buckets["style"] + buckets["composition"] + buckets["lighting"] + buckets["color"] + buckets["detail"] + buckets["other"] | |
| def pil_to_base64(img: Image.Image, fmt: str = "PNG") -> str: | |
| buf = io.BytesIO() | |
| img.save(buf, format=fmt) | |
| return base64.b64encode(buf.getvalue()).decode("ascii") | |
| async def _maybe_close_async_together(client) -> None: | |
| try: | |
| if hasattr(client, "aclose") and inspect.iscoroutinefunction(client.aclose): | |
| await client.aclose() | |
| elif hasattr(client, "close"): | |
| fn = client.close | |
| if inspect.iscoroutinefunction(fn): | |
| await fn() | |
| else: | |
| try: fn() | |
| except Exception: pass | |
| except Exception: | |
| pass | |
| # ========================= | |
| # 3) Async runner | |
| # ========================= | |
| def _run_async(coro): | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| return loop.run_until_complete(coro) # nest_asyncio enabled | |
| return loop.run_until_complete(coro) | |
| except RuntimeError: | |
| return asyncio.run(coro) | |
| # ========================= | |
| # 4) CritiCore (Together) | |
| # ========================= | |
| from together import AsyncTogether | |
| AGGREGATOR_MODEL = os.environ.get("AGGREGATOR_MODEL", "Qwen/Qwen2.5-72B-Instruct-Turbo") | |
| LLM_MULTI_CANDIDATES = [ | |
| "meta-llama/Llama-3.3-70B-Instruct-Turbo", | |
| "Qwen/Qwen2.5-72B-Instruct-Turbo", | |
| "Qwen/Qwen2.5-Coder-32B-Instruct", | |
| "deepseek-ai/DeepSeek-V3", | |
| "nvidia/NVIDIA-Nemotron-Nano-9B-v2", | |
| ] | |
| _env_list = [s.strip() for s in os.environ.get("VLM_MOA_CANDIDATES","").split(",") if s.strip()] | |
| VLM_CANDIDATES = _env_list or ["meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"] | |
| TAG_PRESETS = { | |
| "hq_preference": { | |
| "seed_pos": [ | |
| "balanced composition", | |
| "natural color palette","vibrant colors","balanced contrast", | |
| "high detail","highly detailed","hyperdetailed","sharp focus", | |
| "UHD","8k" | |
| ], | |
| "seed_neg": [ | |
| "low quality","blurry","watermark","jpeg artifacts","overexposed","underexposed", | |
| "color banding","extra fingers","missing limbs","disfigured","mutated hands" | |
| ] | |
| } | |
| } | |
| _DECOMP_SYS = ( | |
| "Decompose the user's visual instruction into 3-6 concrete, checkable visual components " | |
| "(entities + interactions + spatial relations). Return ONLY JSON: " | |
| '{"components":["..."]}' | |
| ) | |
| _TXT_SYS = ( | |
| "Expand a VERY SHORT visual idea into a COMMA-SEPARATED TAG LIST for SDXL.\n" | |
| "Constraints:\n" | |
| "- Start with the subject phrase first.\n" | |
| "- Prioritize composition, lighting, color, and detail over style.\n" | |
| "- Use at most TWO style tags if any.\n" | |
| "- 16β26 concise tags total. Commas only, no sentences, no 'and'. No trailing period.\n" | |
| "- Prefer human-preference aesthetics; keep 'high detailed', 'sharp focus', '8k', 'UHD'." | |
| ) | |
| def _TAG_RE(tag: str): | |
| return re.compile(rf"<\s*{tag}\s*>(.*?)</\s*{tag}\s*>", re.S|re.I) | |
| def _extract_tag(text: str, tag: str, fallback: str = "") -> str: | |
| s = (text or "").strip() | |
| r = _TAG_RE(tag); m = r.search(s) | |
| if m: return m.group(1).strip() | |
| s2 = s.replace("<","<").replace(">",">") | |
| m2 = r.search(s2) | |
| return m2.group(1).strip() if m2 else fallback.strip() | |
| def _summarize_issues_lines(text: str, max_lines: int = 5) -> str: | |
| if not text: | |
| return "" | |
| parts = [p.strip(" -β’\t") for p in re.split(r"[\n;]+", text) if p.strip()] | |
| parts = parts[:max_lines] | |
| return "\n".join(f"- {p}" for p in parts) | |
| class CritiCore: | |
| def __init__(self, preset: str = "hq_preference", aggregator_model: str = AGGREGATOR_MODEL): | |
| if not os.environ.get("TOGETHER_API_KEY"): | |
| raise RuntimeError("Missing TOGETHER_API_KEY in environment.") | |
| self.preset = preset | |
| self.aggregator = aggregator_model | |
| async def decompose_components(self, user_prompt: str) -> List[str]: | |
| client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"]) | |
| try: | |
| tasks = [client.chat.completions.create( | |
| model=m, | |
| messages=[{"role":"system","content": _DECOMP_SYS}, | |
| {"role":"user","content": user_prompt}], | |
| temperature=0.4, max_tokens=256 | |
| ) for m in LLM_MULTI_CANDIDATES] | |
| rs = await asyncio.gather(*tasks, return_exceptions=True) | |
| texts = [] | |
| for r in rs: | |
| try: texts.append(r.choices[0].message.content) | |
| except Exception: pass | |
| if not texts: | |
| return [] | |
| joined = "\n\n---\n\n".join(texts) | |
| merged = await client.chat.completions.create( | |
| model=self.aggregator, | |
| messages=[{"role":"system","content": "Merge JSON candidates and return ONLY {'components':[...]}."}, | |
| {"role":"user","content": joined}], | |
| temperature=0.2, max_tokens=256 | |
| ) | |
| txt = merged.choices[0].message.content | |
| try: | |
| obj = json.loads(txt) | |
| except Exception: | |
| s,e = txt.find("{"), txt.rfind("}") | |
| obj = json.loads(txt[s:e+1]) if (s!=-1 and e!=-1) else {"components":[]} | |
| comps = [c.strip() for c in obj.get("components", []) if isinstance(c, str) and c.strip()] | |
| return comps[:6] | |
| finally: | |
| await _maybe_close_async_together(client) | |
| async def make_tags(self, user_prompt: str, clip77: bool = True) -> Tuple[str, str]: | |
| client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"]) | |
| seed = TAG_PRESETS.get(self.preset, TAG_PRESETS["hq_preference"]) | |
| seed_pos = _dedup_keep_order(seed["seed_pos"]) | |
| seed_neg = seed["seed_neg"] | |
| try: | |
| tasks = [client.chat.completions.create( | |
| model=m, | |
| messages=[{"role":"system","content": _TXT_SYS}, | |
| {"role":"user","content": | |
| f"Short idea: {user_prompt}\nSeed: {', '.join(seed_pos)}\nOutput: a single comma-separated tag list."}], | |
| temperature=0.7, max_tokens=220 | |
| ) for m in LLM_MULTI_CANDIDATES] | |
| rs = await asyncio.gather(*tasks, return_exceptions=True) | |
| props = [] | |
| for r in rs: | |
| try: props.append(r.choices[0].message.content) | |
| except Exception: pass | |
| if not props: | |
| pos = ", ".join([user_prompt.strip()] + seed_pos) | |
| else: | |
| joined = "\n---\n".join(props) | |
| merged = await client.chat.completions.create( | |
| model=self.aggregator, | |
| messages=[{"role":"system","content": | |
| "Merge candidate tag lists into ONE comma list (16β26 tags). Subject first; at most TWO style tags; keep high detailed/sharp focus/8k/UHD."}, | |
| {"role":"user","content": joined}], | |
| temperature=0.2, max_tokens=240 | |
| ) | |
| raw = merged.choices[0].message.content | |
| tags = _dedup_keep_order(_split_tags(raw)) | |
| subject = user_prompt.strip().rstrip(",.") | |
| if subject and not any(subject.lower() == t.lower() for t in tags): | |
| tags = [subject] + tags | |
| ordered = _order_tags([tags[0]], tags[1:]) | |
| pos = ", ".join(_dedup_keep_order(ordered)) | |
| # quality floor | |
| for q in ["high detailed","sharp focus","8k","UHD"]: | |
| if q.lower() not in {t.lower() for t in _split_tags(pos)}: | |
| pos += ", " + q | |
| pos = _cleanup_commas(pos) | |
| if clip77 and _count_tokens(pos) > 77: | |
| pos = clip77_strict(pos, 77) | |
| neg = ", ".join(seed_neg) | |
| return pos, neg | |
| finally: | |
| await _maybe_close_async_together(client) | |
| async def vlm_refine(self, image: Image.Image, original_prompt: str, components: List[str]) -> Dict[str, object]: | |
| client = AsyncTogether(api_key=os.environ["TOGETHER_API_KEY"]) | |
| b64 = pil_to_base64(image, "PNG") | |
| def _user_prompt_text() -> str: | |
| return ( | |
| "You are a precise image-grounded critic.\n" | |
| "1) List concrete visual problems and brief corrections.\n" | |
| "2) Provide a refined prompt that keeps the original intent.\n\n" | |
| f'Original prompt: "{original_prompt}"\n' | |
| f"Key components to check: {components}\n" | |
| "Output EXACTLY two tags:\n" | |
| "<issues>...</issues>\n<refined>...</refined>" | |
| ) | |
| try: | |
| tasks = [] | |
| for m in VLM_CANDIDATES: | |
| msgs = [ | |
| {"role":"system","content": "Return ONLY <issues> and <refined>. No extra text."}, | |
| {"role":"user","content": [ | |
| {"type":"text","text": _user_prompt_text()}, | |
| {"type":"image_url","image_url":{"url": f"data:image/png;base64,{b64}"}} | |
| ]} | |
| ] | |
| tasks.append(client.chat.completions.create(model=m, messages=msgs, temperature=0.2, max_tokens=420)) | |
| rs = await asyncio.gather(*tasks, return_exceptions=True) | |
| ok = [] | |
| for m, r in zip(VLM_CANDIDATES, rs): | |
| try: ok.append((m, r.choices[0].message.content)) | |
| except Exception: pass | |
| if not ok: | |
| return {"refined": original_prompt, "issues_merged": ""} | |
| refined_items, per_vlm_issues = [], {} | |
| for m, raw in ok: | |
| issues = _extract_tag(raw, "issues", "") | |
| refined = _extract_tag(raw, "refined", original_prompt) | |
| if refined.strip(): refined_items.append((m, refined.strip())) | |
| if issues.strip(): per_vlm_issues[m] = _summarize_issues_lines(issues, 5) | |
| joined_issues = "\n".join(f"[{m}] {t}" for m,t in per_vlm_issues.items()) | |
| joined_refined = "\n".join(f"[{m}] {t}" for m,t in refined_items) if refined_items else original_prompt | |
| merged = await client.chat.completions.create( | |
| model=self.aggregator, | |
| messages=[{"role":"system","content": | |
| "Merge multiple critics. Output ONLY <issues> (β€5 bullets) and <refined> (β€70 words)."}, | |
| {"role":"user","content": f"{joined_issues}\n\n----\n\n{joined_refined}"}], | |
| temperature=0.2, max_tokens=420 | |
| ) | |
| final_raw = merged.choices[0].message.content | |
| final_refined = clip77_strict(_extract_tag(final_raw, "refined", original_prompt), 77) | |
| issues_merged = _summarize_issues_lines(_extract_tag(final_raw, "issues", ""), 5) | |
| return {"refined": final_refined, "issues_merged": issues_merged} | |
| finally: | |
| await _maybe_close_async_together(client) | |
| def merge_vlm_multi_text(vlm_refined_77: str, tags_77: str) -> str: | |
| vlm_tags = _split_tags(vlm_refined_77) | |
| moa_tags = _split_tags(tags_77) | |
| merged = _dedup_keep_order(_order_tags([vlm_tags[0] if vlm_tags else ""], (vlm_tags[1:] + moa_tags))) | |
| merged = [t for t in merged if t] | |
| text = _cleanup_commas(", ".join(merged)) | |
| if _count_tokens(text) > 77: | |
| text = clip77_strict(text, 77) | |
| return text | |
| # ========================= | |
| # 5) SpecFusion (latent FFT gate) | |
| # ========================= | |
| def frequency_fusion( | |
| x_hi_latent: torch.Tensor, | |
| x_lo_latent: torch.Tensor, | |
| base_c: float = 0.5, | |
| rho_t: float = 0.85, | |
| device=None, | |
| ) -> torch.Tensor: | |
| if device is None: | |
| device = x_hi_latent.device | |
| B, C, H, W = x_hi_latent.shape | |
| x_h = x_hi_latent.to(torch.float32).to(device) | |
| x_l = x_lo_latent.to(torch.float32).to(device) | |
| Xh = torch.fft.fftshift(torch.fft.fftn(x_h, dim=(-2, -1)), dim=(-2, -1)) | |
| Xl = torch.fft.fftshift(torch.fft.fftn(x_l, dim=(-2, -1)), dim=(-2, -1)) | |
| tau_h = int(H * base_c * (1 - rho_t)) | |
| tau_w = int(W * base_c * (1 - rho_t)) | |
| mask = torch.ones((B, C, H, W), device=device, dtype=torch.float32) | |
| cy, cx = H // 2, W // 2 | |
| if tau_h > 0 and tau_w > 0: | |
| mask[..., cy - tau_h : cy + tau_h, cx - tau_w : cx + tau_w] = rho_t | |
| Xf = Xh * mask + Xl * (1 - mask) | |
| x = torch.fft.ifftn(torch.fft.ifftshift(Xf, dim=(-2, -1)), dim=(-2, -1)).real | |
| x = x + torch.randn_like(x) * 0.001 | |
| return x.to(dtype=x_hi_latent.dtype) | |
| def _decode_to_pil(latents): | |
| out = decode_image_sdxl(latents, SDXL_i2i) | |
| if isinstance(out, Image.Image): | |
| return out | |
| if hasattr(out, "images"): | |
| return out.images[0] | |
| return out | |
| def _guidance_for_k(k: int) -> float: | |
| if k >= 20: return 12.0 | |
| if k >= 10: return 7.5 | |
| return 5.2 | |
| # ========================= | |
| # 6) ONE-variant generator (because UI enforces single selection) | |
| # ========================= | |
| async def generate_one_variant( | |
| user_prompt: str, | |
| seed: int, | |
| H: int, | |
| W: int, | |
| total_steps_refine: int, | |
| last_k: int, | |
| guidance: float, | |
| preset: str, | |
| variant_key: str, | |
| out_dir: Optional[Path] = None, | |
| ) -> Tuple[Image.Image, str, Dict[str, object]]: | |
| """ | |
| Returns: | |
| img, display_name, meta_dict | |
| """ | |
| meta: Dict[str, object] = { | |
| "user_prompt": user_prompt, | |
| "variant_key": variant_key, | |
| } | |
| def _save(im: Image.Image, display_name: str): | |
| if out_dir is None: | |
| return | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| safe = re.sub(r"[^a-zA-Z0-9_\\-]+", "_", display_name)[:120] | |
| im.save(out_dir / f"{safe}.png") | |
| # ---------------------------------------------------------- | |
| # Variant 1: Base (Original Prompt) [NO Together needed] | |
| # ---------------------------------------------------------- | |
| if variant_key == "base_original": | |
| z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG) | |
| meta.update({"note": "SDXL base generation from original prompt."}) | |
| _save(base_og, VARIANT_LABELS[variant_key]) | |
| return base_og, VARIANT_LABELS[variant_key], meta | |
| # The rest need Together | |
| if not TOGETHER_API_KEY: | |
| raise RuntimeError("TOGETHER_API_KEY not set, but selected variant requires Together.") | |
| critic = CritiCore(preset=preset) | |
| # Common refine params | |
| lk = int(last_k) | |
| strength = float(strength_for_last_k(lk, total_steps_refine)) | |
| use_guidance = float(guidance) if float(guidance) > 0 else float(_guidance_for_k(lk)) | |
| steps = int(total_steps_refine) | |
| meta.update({"strength": strength, "guidance": use_guidance, "steps": steps, "last_k": lk}) | |
| # ---------------------------------------------------------- | |
| # Variant 2: Base (MoA Tags) | |
| # ---------------------------------------------------------- | |
| if variant_key == "base_multi_llm": | |
| pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True) | |
| z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags) | |
| meta.update({ | |
| "pos_tags_77": pos_tags_77, | |
| "neg_tags": neg_tags, | |
| "note": "SDXL base generation from MoA-generated tags." | |
| }) | |
| _save(base_enh, VARIANT_LABELS[variant_key]) | |
| return base_enh, VARIANT_LABELS[variant_key], meta | |
| # ---------------------------------------------------------- | |
| # Variant 3: CritiFusion (MoA+VLM+SpecFusion) | |
| # ---------------------------------------------------------- | |
| if variant_key == "CritiFusion": | |
| pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True) | |
| comps = await critic.decompose_components(user_prompt) | |
| z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags) | |
| vlm_out = await critic.vlm_refine(base_enh, pos_tags_77, comps or []) | |
| vlm_agg_77 = vlm_out.get("refined") or pos_tags_77 | |
| refined_on_enh = CritiCore.merge_vlm_multi_text(vlm_agg_77, pos_tags_77) | |
| z_ref = img2img_latent( | |
| refined_on_enh, z0_enh, | |
| strength=strength, guidance=use_guidance, steps=steps, | |
| seed=seed + 2100 + lk, | |
| neg=DEFAULT_NEG | |
| ) | |
| fused_lat = frequency_fusion(z_ref, z0_enh, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE) | |
| img_sf = _decode_to_pil(fused_lat) | |
| meta.update({ | |
| "pos_tags_77": pos_tags_77, | |
| "neg_tags": neg_tags, | |
| "components": comps, | |
| "vlm_refined_77": vlm_agg_77, | |
| "enhanced_prompt_77": refined_on_enh, | |
| "vlm_issues": vlm_out.get("issues_merged", ""), | |
| "note": "MoA tags + VLM critique prompt + img2img + SpecFusion." | |
| }) | |
| _save(img_sf, VARIANT_LABELS[variant_key]) | |
| return img_sf, VARIANT_LABELS[variant_key], meta | |
| # ---------------------------------------------------------- | |
| # Variant 4: CritiFusion (Original+VLM+SpecFusion) | |
| # ---------------------------------------------------------- | |
| if variant_key == "criticore_on_original__specfusion": | |
| pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True) | |
| comps = await critic.decompose_components(user_prompt) | |
| z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG) | |
| vlm_on_og = await critic.vlm_refine(base_og, user_prompt, comps or []) | |
| refined_og_77 = clip77_strict(vlm_on_og.get("refined") or user_prompt, 77) | |
| refined_merge = CritiCore.merge_vlm_multi_text(refined_og_77, pos_tags_77) | |
| z_ref = img2img_latent( | |
| refined_merge, z0_og, | |
| strength=strength, guidance=use_guidance, steps=steps, | |
| seed=seed + 2400 + lk, | |
| neg=DEFAULT_NEG | |
| ) | |
| fused_lat = frequency_fusion(z_ref, z0_og, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE) | |
| img_sf = _decode_to_pil(fused_lat) | |
| meta.update({ | |
| "pos_tags_77": pos_tags_77, | |
| "neg_tags": neg_tags, | |
| "components": comps, | |
| "vlm_refined_77": refined_og_77, | |
| "enhanced_prompt_77": refined_merge, | |
| "vlm_issues": vlm_on_og.get("issues_merged", ""), | |
| "note": "Original prompt + VLM critique prompt + img2img + SpecFusion." | |
| }) | |
| _save(img_sf, VARIANT_LABELS[variant_key]) | |
| return img_sf, VARIANT_LABELS[variant_key], meta | |
| raise ValueError(f"Unknown variant_key: {variant_key}") | |
| # ========================= | |
| # 7) UI callbacks | |
| # ========================= | |
| def ui_run_once( | |
| user_prompt: str, | |
| seed: int, | |
| H: int, | |
| W: int, | |
| preset: str, | |
| total_steps_refine: int, | |
| last_k: int, | |
| guidance: float, | |
| enabled_variants_display: List[str], | |
| save_outputs: bool, | |
| out_dir: str, | |
| ): | |
| t0 = time.time() | |
| try: | |
| if not user_prompt or not user_prompt.strip(): | |
| return [], "Empty prompt." | |
| # display -> internal | |
| display_to_internal = {v: k for k, v in VARIANT_LABELS.items()} | |
| chosen_display = (enabled_variants_display or [])[-1:] # enforce single here too | |
| if not chosen_display: | |
| return [], "Please select ONE variant." | |
| chosen_display = chosen_display[0] | |
| variant_key = display_to_internal.get(chosen_display) | |
| if variant_key is None: | |
| return [], f"Unknown selected variant: {chosen_display}" | |
| out_path = Path(out_dir) if (save_outputs and out_dir) else None | |
| img, disp_name, meta = _run_async(generate_one_variant( | |
| user_prompt=user_prompt.strip(), | |
| seed=int(seed), | |
| H=int(H), W=int(W), | |
| total_steps_refine=int(total_steps_refine), | |
| last_k=int(last_k), | |
| guidance=float(guidance), | |
| preset=preset, | |
| variant_key=variant_key, | |
| out_dir=out_path, | |
| )) | |
| meta["ui"] = { | |
| "seed": int(seed), | |
| "H": int(H), | |
| "W": int(W), | |
| "preset": preset, | |
| "total_steps_refine": int(total_steps_refine), | |
| "last_k": int(last_k), | |
| "guidance": float(guidance), | |
| "selected_variant": chosen_display, | |
| "save_outputs": bool(save_outputs), | |
| "out_dir": out_dir if save_outputs else None, | |
| } | |
| meta["elapsed_sec"] = round(time.time() - t0, 3) | |
| gallery = [(img, disp_name)] | |
| return gallery, json.dumps(meta, ensure_ascii=False, indent=2) | |
| except Exception: | |
| return [], traceback.format_exc() | |
| def ui_run_once_gpu(*args, **kwargs): | |
| return ui_run_once(*args, **kwargs) | |
| # ========================= | |
| # 8) Single-select enforcement for CheckboxGroup | |
| # ========================= | |
| def enforce_single_variant(new_list: List[str], prev_list: List[str]): | |
| new_list = new_list or [] | |
| prev_list = prev_list or [] | |
| new_set = set(new_list) | |
| prev_set = set(prev_list) | |
| added = list(new_set - prev_set) | |
| if added: | |
| # keep the newly added one | |
| chosen = added[-1] | |
| out = [chosen] | |
| else: | |
| # no added; maybe removed or same; if multi exists, keep last item | |
| out = new_list[-1:] if len(new_list) > 1 else new_list | |
| return out, out # update checkbox value + state | |
| # ========================= | |
| # 9) Gradio UI | |
| # ========================= | |
| with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo: | |
| gr.Markdown( | |
| "## CritiFusion Demo (SDXL)\n" | |
| "- Keep **Enabled Variants** pills UI, but **only one** can be selected.\n" | |
| f"- Device: **{DEVICE_STR}**, DType: **{DTYPE}**\n" | |
| f"- Together API: {'β set' if TOGETHER_API_KEY else 'β missing (set TOGETHER_API_KEY)'}" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| user_prompt = gr.Textbox( | |
| label="Prompt", | |
| value="A fluffy orange cat lying on a window ledge, front-facing, stylized 3D, soft indoor lighting", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=2026, precision=0) | |
| preset = gr.Dropdown(label="Preset", choices=["hq_preference"], value="hq_preference") | |
| with gr.Row(): | |
| H = gr.Number(label="H", value=1024, precision=0) | |
| W = gr.Number(label="W", value=1024, precision=0) | |
| with gr.Row(): | |
| total_steps_refine = gr.Slider(label="total_steps_refine", minimum=10, maximum=80, step=1, value=50) | |
| last_k = gr.Slider(label="last_k", minimum=1, maximum=50, step=1, value=37) | |
| guidance = gr.Slider( | |
| label="Guidance (0 => fallback rule)", | |
| minimum=0.0, maximum=15.0, step=0.1, value=0.0 | |
| ) | |
| # --- pills UI, but single-select enforced --- | |
| selected_state = gr.State([VARIANT_LABELS["base_original"]]) | |
| enabled_variants = gr.CheckboxGroup( | |
| label="Enabled Variants (select ONE)", | |
| choices=[VARIANT_LABELS[k] for k in VARIANT_LABELS.keys()], | |
| value=[VARIANT_LABELS["base_original"]], | |
| ) | |
| # enforce single selection on change | |
| enabled_variants.change( | |
| fn=enforce_single_variant, | |
| inputs=[enabled_variants, selected_state], | |
| outputs=[enabled_variants, selected_state], | |
| ) | |
| with gr.Row(): | |
| save_outputs = gr.Checkbox(label="Save output to disk", value=False) | |
| out_dir = gr.Textbox(label="Output dir (only if save enabled)", value="./variants_demo_gradio") | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=8): | |
| gallery = gr.Gallery(label="Result", columns=1, height=600) | |
| meta_json = gr.Code(label="Meta / Debug (JSON)", language="json") | |
| run_btn.click( | |
| fn=ui_run_once_gpu, | |
| inputs=[user_prompt, seed, H, W, preset, total_steps_refine, last_k, guidance, enabled_variants, save_outputs, out_dir], | |
| outputs=[gallery, meta_json], | |
| api_name=False, # gradio-safe (avoid schema issues) | |
| ) | |
| demo.queue().launch( | |
| debug=True, | |
| share=True, # optional; helps if you run outside Spaces | |
| ) | |