tchung1970 commited on
Commit
add1856
·
1 Parent(s): 491977c

Revert "Use local text encoders instead of remote service"

Browse files

This reverts commit b07ac298ee4fd266049b29c24fa576fba90e4d0f.

Files changed (1) hide show
  1. app.py +48 -13
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import io
3
  import re
4
  import gradio as gr
@@ -6,11 +8,16 @@ import numpy as np
6
  import random
7
  import spaces
8
  import torch
9
- from diffusers import Flux2Pipeline
 
 
10
  from PIL import Image
 
11
  import base64
12
  from huggingface_hub import InferenceClient
13
 
 
 
14
  dtype = torch.bfloat16
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
@@ -44,11 +51,32 @@ Rules:
44
 
45
  Output only the final instruction in plain text and nothing else."""
46
 
47
- # Load model with local text encoders
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  repo_id = "black-forest-labs/FLUX.2-dev"
49
 
 
 
 
 
 
 
50
  pipe = Flux2Pipeline.from_pretrained(
51
  repo_id,
 
 
52
  torch_dtype=torch.bfloat16
53
  )
54
  pipe.to(device)
@@ -131,17 +159,20 @@ def update_dimensions_from_image(image_list):
131
  return new_width, new_height
132
 
133
  # Updated duration function to match generate_image arguments (including progress)
134
- def get_duration(prompt, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
135
  num_images = 0 if image_list is None else len(image_list)
136
  step_duration = 1 + 0.8 * num_images
137
- return max(90, num_inference_steps * step_duration + 30) # Increased for text encoding
138
 
139
  @spaces.GPU(duration=get_duration)
140
- def generate_image(prompt, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
 
 
 
141
  generator = torch.Generator(device=device).manual_seed(seed)
142
-
143
  pipe_kwargs = {
144
- "prompt": prompt,
145
  "image": image_list,
146
  "num_inference_steps": num_inference_steps,
147
  "guidance_scale": guidance_scale,
@@ -149,11 +180,11 @@ def generate_image(prompt, image_list, width, height, num_inference_steps, guida
149
  "width": width,
150
  "height": height,
151
  }
152
-
153
  # Progress bar for the actual generation steps
154
  if progress:
155
- progress(0, desc="Generating image...")
156
-
157
  image = pipe(**pipe_kwargs).images[0]
158
  return image
159
 
@@ -180,10 +211,14 @@ def infer(prompt, aspect_ratio="1:1 (1024x1024)", progress=gr.Progress(track_tqd
180
  num_inference_steps = 30
181
  guidance_scale = 4.0
182
 
183
- # Image Generation (GPU bound - includes text encoding)
184
- progress(0.1, desc="Generating image...")
 
 
 
 
185
  image = generate_image(
186
- prompt,
187
  None, # No input images
188
  width,
189
  height,
 
1
  import os
2
+ import subprocess
3
+ import sys
4
  import io
5
  import re
6
  import gradio as gr
 
8
  import random
9
  import spaces
10
  import torch
11
+ from diffusers import Flux2Pipeline, Flux2Transformer2DModel
12
+ from diffusers import BitsAndBytesConfig as DiffBitsAndBytesConfig
13
+ import requests
14
  from PIL import Image
15
+ import json
16
  import base64
17
  from huggingface_hub import InferenceClient
18
 
19
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "spaces==0.43.0"])
20
+
21
  dtype = torch.bfloat16
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
 
51
 
52
  Output only the final instruction in plain text and nothing else."""
53
 
54
+ def remote_text_encoder(prompts):
55
+ from gradio_client import Client
56
+
57
+ client = Client("multimodalart/mistral-text-encoder")
58
+ result = client.predict(
59
+ prompt=prompts,
60
+ api_name="/encode_text"
61
+ )
62
+
63
+ # Load returns a tensor, usually on CPU by default
64
+ prompt_embeds = torch.load(result[0])
65
+ return prompt_embeds
66
+
67
+ # Load model
68
  repo_id = "black-forest-labs/FLUX.2-dev"
69
 
70
+ dit = Flux2Transformer2DModel.from_pretrained(
71
+ repo_id,
72
+ subfolder="transformer",
73
+ torch_dtype=torch.bfloat16
74
+ )
75
+
76
  pipe = Flux2Pipeline.from_pretrained(
77
  repo_id,
78
+ text_encoder=None,
79
+ transformer=dit,
80
  torch_dtype=torch.bfloat16
81
  )
82
  pipe.to(device)
 
159
  return new_width, new_height
160
 
161
  # Updated duration function to match generate_image arguments (including progress)
162
+ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
163
  num_images = 0 if image_list is None else len(image_list)
164
  step_duration = 1 + 0.8 * num_images
165
+ return max(65, num_inference_steps * step_duration + 10)
166
 
167
  @spaces.GPU(duration=get_duration)
168
+ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
169
+ # Move embeddings to GPU only when inside the GPU decorated function
170
+ prompt_embeds = prompt_embeds.to(device)
171
+
172
  generator = torch.Generator(device=device).manual_seed(seed)
173
+
174
  pipe_kwargs = {
175
+ "prompt_embeds": prompt_embeds,
176
  "image": image_list,
177
  "num_inference_steps": num_inference_steps,
178
  "guidance_scale": guidance_scale,
 
180
  "width": width,
181
  "height": height,
182
  }
183
+
184
  # Progress bar for the actual generation steps
185
  if progress:
186
+ progress(0, desc="Starting generation...")
187
+
188
  image = pipe(**pipe_kwargs).images[0]
189
  return image
190
 
 
211
  num_inference_steps = 30
212
  guidance_scale = 4.0
213
 
214
+ # Text Encoding (Network bound - No GPU needed)
215
+ progress(0.1, desc="Encoding prompt...")
216
+ prompt_embeds = remote_text_encoder(prompt)
217
+
218
+ # Image Generation (GPU bound)
219
+ progress(0.3, desc="Generating image...")
220
  image = generate_image(
221
+ prompt_embeds,
222
  None, # No input images
223
  width,
224
  height,