import os import sys import re import json import random import logging import warnings from dataclasses import dataclass import gradio as gr import torch from PIL import Image, ImageDraw, ImageFont import spaces from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from transformers import AutoModelForCausalLM, AutoTokenizer # ------------------------- 可选依赖:Prompt Enhancer 模板 ------------------------- # 你的原工程里如果有 pe.py,会自动使用;没有也不会报错(enhance 默认关闭) try: sys.path.append(os.path.dirname(os.path.abspath(__file__))) from pe import prompt_template # type: ignore except Exception: prompt_template = ( "You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. " "Return JSON with key revised_prompt." ) # ------------------------- Z-Image 相关(依赖你环境中 diffusers 的实现) ------------------------- from diffusers import ZImagePipeline from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # ==================== Environment Variables ================================== MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") # ============================================================================= os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 RES_CHOICES = { "1024": [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", ], "1280": [ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", "1024x1536 ( 2:3 )", "1536x864 ( 16:9 )", "864x1536 ( 9:16 )", "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", ], "1536": [ "1536x1536 ( 1:1 )", "1728x1344 ( 9:7 )", "1344x1728 ( 7:9 )", "1728x1296 ( 4:3 )", "1296x1728 ( 3:4 )", "1872x1248 ( 3:2 )", "1248x1872 ( 2:3 )", "2048x1152 ( 16:9 )", "1152x2048 ( 9:16 )", "2016x864 ( 21:9 )", "864x2016 ( 9:21 )", ], } RESOLUTION_SET = [] for _k, v in RES_CHOICES.items(): RESOLUTION_SET.extend(v) EXAMPLE_PROMPTS = [ ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"], ] # ------------------------- HF token 兼容参数 ------------------------- def _hf_token_kwargs(token: str | None): """ transformers / diffusers 的 from_pretrained 近年来从 use_auth_token 迁移到 token。 这里做一个兼容:优先传 token,不支持则回退 use_auth_token。 """ if not token: return {} return {"token": token, "use_auth_token": token} def get_resolution(resolution: str): match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) if match: return int(match.group(1)), int(match.group(2)) return 1024, 1024 def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker"): img = Image.new("RGB", (width, height), (20, 20, 20)) draw = ImageDraw.Draw(img) try: font = ImageFont.load_default() except Exception: font = None draw.rectangle([0, 0, width, 90], fill=(160, 0, 0)) draw.text((20, 30), text, fill=(255, 255, 255), font=font) return img def _load_nsfw_placeholder(width=1024, height=1024): """ 命中 NSFW 时优先加载工作目录的 nsfw.png; 不存在就生成一张占位图,避免文件缺失导致再次报错。 """ if os.path.exists("nsfw.png"): try: return Image.open("nsfw.png").convert("RGB") except Exception: pass return _make_blocked_image(width, height, "NSFW blocked") def load_models(model_path: str, enable_compile=False, attention_backend="native"): print(f"[Init] Loading models from: {model_path}") print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}, ENABLE_COMPILE={enable_compile}, ATTENTION_BACKEND={attention_backend}") # 远端 repo-id(不存在的本地路径) vs 本地目录 is_local_dir = os.path.exists(model_path) token_kwargs = _hf_token_kwargs(HF_TOKEN) if not is_local_dir else {} # 1) VAE if not is_local_dir: vae = AutoencoderKL.from_pretrained( model_path, subfolder="vae", torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, **token_kwargs, ) else: vae = AutoencoderKL.from_pretrained( os.path.join(model_path, "vae"), torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, ) # 2) Text Encoder + Tokenizer if not is_local_dir: text_encoder = AutoModelForCausalLM.from_pretrained( model_path, subfolder="text_encoder", torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, **token_kwargs, ).eval() tokenizer = AutoTokenizer.from_pretrained( model_path, subfolder="tokenizer", **token_kwargs, ) else: text_encoder = AutoModelForCausalLM.from_pretrained( os.path.join(model_path, "text_encoder"), torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, ).eval() tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) tokenizer.padding_side = "left" # compile 优化(仅 CUDA 才建议打开) if enable_compile and DEVICE == "cuda": print("[Init] Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) # 3) Transformer if not is_local_dir: transformer = ZImageTransformer2DModel.from_pretrained( model_path, subfolder="transformer", **token_kwargs, ) else: transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")) transformer = transformer.to(DEVICE, DTYPE) pipe.transformer = transformer # attention backend 可能在不同环境不支持,做容错 try: pipe.transformer.set_attention_backend(attention_backend) except Exception as e: print(f"[Init] set_attention_backend('{attention_backend}') failed, fallback to 'native'. Error: {e}") try: pipe.transformer.set_attention_backend("native") except Exception as e2: print(f"[Init] fallback set_attention_backend('native') failed: {e2}") if enable_compile and DEVICE == "cuda": try: print("[Init] Compiling transformer...") pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) except Exception as e: print(f"[Init] torch.compile failed, continue without compile. Error: {e}") pipe = pipe.to(DEVICE, DTYPE) # 4) Safety Checker(用于生成后过滤) try: from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker try: from transformers import CLIPImageProcessor as _CLIPProcessor except Exception: # 老版本兼容 from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = _CLIPProcessor.from_pretrained(safety_model_id, **_hf_token_kwargs(HF_TOKEN)) safety_checker = StableDiffusionSafetyChecker.from_pretrained( safety_model_id, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, **_hf_token_kwargs(HF_TOKEN), ).to(DEVICE) pipe.safety_feature_extractor = safety_feature_extractor pipe.safety_checker = safety_checker print("[Init] Safety checker loaded.") except Exception as e: print(f"[Init] Safety checker init failed. NSFW filtering will be skipped. Error: {e}") pipe.safety_feature_extractor = None pipe.safety_checker = None return pipe def generate_image( pipe, prompt: str, resolution="1024x1024", seed=42, guidance_scale=5.0, num_inference_steps=50, shift=3.0, max_sequence_length=512, progress=gr.Progress(track_tqdm=True), ): width, height = get_resolution(resolution) if DEVICE == "cuda": generator = torch.Generator(device="cuda").manual_seed(int(seed)) else: generator = torch.Generator().manual_seed(int(seed)) scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift)) pipe.scheduler = scheduler out = pipe( prompt=prompt, height=int(height), width=int(width), guidance_scale=float(guidance_scale), num_inference_steps=int(num_inference_steps), generator=generator, max_sequence_length=int(max_sequence_length), ) image = out.images[0] return image def warmup_model(pipe, resolutions): print("[Warmup] Starting warmup phase...") dummy_prompt = "warmup" for res_str in resolutions: print(f"[Warmup] Resolution: {res_str}") try: for i in range(2): generate_image( pipe, prompt=dummy_prompt, resolution=res_str.split(" ")[0], num_inference_steps=6, guidance_scale=0.0, seed=42 + i, ) except Exception as e: print(f"[Warmup] Failed for {res_str}: {e}") print("[Warmup] Completed.") # ==================== Prompt Expander(保留但默认不启用) ==================== @dataclass class PromptOutput: status: bool prompt: str seed: int system_prompt: str message: str class PromptExpander: def __init__(self, backend="api", **kwargs): self.backend = backend def decide_system_prompt(self, template_name=None): return prompt_template class APIPromptExpander(PromptExpander): def __init__(self, api_config=None, **kwargs): super().__init__(backend="api", **kwargs) self.api_config = api_config or {} self.client = self._init_api_client() def _init_api_client(self): try: from openai import OpenAI api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") if not api_key: print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.") return None return OpenAI(api_key=api_key, base_url=base_url) except ImportError: print("[PE] Please install openai: pip install openai") return None except Exception as e: print(f"[PE] Failed to initialize API client: {e}") return None def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): return self.extend(prompt, system_prompt, seed, **kwargs) def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): if self.client is None: return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized") if system_prompt is None: system_prompt = self.decide_system_prompt() if "{prompt}" in system_prompt: system_prompt = system_prompt.format(prompt=prompt) prompt = " " try: model = self.api_config.get("model", "qwen3-max-preview") response = self.client.chat.completions.create( model=model, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], temperature=0.7, top_p=0.8, ) content = response.choices[0].message.content or "" # 尝试从 ```json 块中解析 revised_prompt expanded_prompt = content json_start = content.find("```json") if json_start != -1: json_end = content.find("```", json_start + 7) if json_end != -1: json_str = content[json_start + 7 : json_end].strip() try: data = json.loads(json_str) expanded_prompt = data.get("revised_prompt", content) except Exception: expanded_prompt = content return PromptOutput(True, expanded_prompt, seed, system_prompt, content) except Exception as e: return PromptOutput(False, "", seed, system_prompt, str(e)) def create_prompt_expander(backend="api", **kwargs): if backend == "api": return APIPromptExpander(**kwargs) raise ValueError("Only 'api' backend is supported.") pipe = None prompt_expander = None def init_app(): global pipe, prompt_expander try: pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) print("[Init] Model loaded.") if ENABLE_WARMUP and pipe is not None: all_res = [] for cat in RES_CHOICES.values(): all_res.extend(cat) warmup_model(pipe, all_res) except Exception as e: print(f"[Init] Error loading model: {e}") pipe = None try: prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"}) print("[Init] Prompt expander ready (disabled by default).") except Exception as e: print(f"[Init] Error initializing prompt expander: {e}") prompt_expander = None def prompt_enhance(prompt, enable_enhance: bool): if not enable_enhance or not prompt_expander: return prompt, "Enhancement disabled or unavailable." if not prompt.strip(): return "", "Please enter a prompt." try: result = prompt_expander(prompt) if result.status: return result.prompt, result.message return prompt, f"Enhancement failed: {result.message}" except Exception as e: return prompt, f"Error: {str(e)}" def try_enable_aoti(pipe): """ AoTI(ZeroGPU 加速)可用则启用;不可用则跳过,不影响主流程。 """ if pipe is None: return try: # 优先按你原代码的结构尝试:pipe.transformer.layers if hasattr(pipe, "transformer") and pipe.transformer is not None: target = None if hasattr(pipe.transformer, "layers"): target = pipe.transformer.layers if hasattr(target, "_repeated_blocks"): target._repeated_blocks = ["ZImageTransformerBlock"] else: # 兜底:直接对 transformer 设置 target = pipe.transformer if hasattr(target, "_repeated_blocks"): target._repeated_blocks = ["ZImageTransformerBlock"] if target is not None: spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3") print("[Init] AoTI blocks loaded.") except Exception as e: print(f"[Init] AoTI not enabled (safe to ignore). Error: {e}") @spaces.GPU def generate( prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=9, shift=3.0, random_seed=True, gallery_images=None, enhance=False, # 默认不启用 progress=gr.Progress(track_tqdm=True), ): if random_seed: new_seed = random.randint(1, 1000000) else: new_seed = int(seed) if int(seed) != -1 else random.randint(1, 1000000) if pipe is None: raise gr.Error("Model not loaded. Please check logs.") final_prompt = prompt or "" if enhance: # 你原注释说 DISABLED,这里仍保留能力但默认关闭 final_prompt, _msg = prompt_enhance(final_prompt, True) print(f"[PE] Enhanced prompt: {final_prompt}") # 解析 "1024x1024 ( 1:1 )" -> "1024x1024" try: resolution_str = str(resolution).split(" ")[0] except Exception: resolution_str = "1024x1024" width, height = get_resolution(resolution_str) # 生成 image = generate_image( pipe=pipe, prompt=final_prompt, resolution=resolution_str, seed=new_seed, guidance_scale=0.0, num_inference_steps=int(steps) + 1, shift=float(shift), ) # 生成后 NSFW 安全检查(已去掉 prompt_check) try: if getattr(pipe, "safety_feature_extractor", None) is not None and getattr(pipe, "safety_checker", None) is not None: # CLIP 输入 clip_inputs = pipe.safety_feature_extractor([image], return_tensors="pt") clip_input = clip_inputs.pixel_values.to(DEVICE) # SafetyChecker 需要 numpy 格式图片(batch, H, W, C),float32 0-1 import numpy as np img_np = np.array(image).astype("float32") / 255.0 img_np = img_np[None, ...] checked_images, has_nsfw = pipe.safety_checker(images=img_np, clip_input=clip_input) # has_nsfw 一般是 list[bool] if isinstance(has_nsfw, (list, tuple)) and len(has_nsfw) > 0 and bool(has_nsfw[0]): image = _load_nsfw_placeholder(width, height) except Exception as e: # Safety checker 失败不应阻塞主流程 print(f"[Safety] Check failed (ignored): {e}") if gallery_images is None: gallery_images = [] gallery_images = [image] + list(gallery_images) return gallery_images, str(new_seed), int(new_seed) # ------------------------- 启动初始化 ------------------------- init_app() try_enable_aoti(pipe) # ==================== Gradio UI ==================== with gr.Blocks(title="Z-Image Demo") as demo: gr.Markdown( """