lxzcpro commited on
Commit
dec259d
·
1 Parent(s): c3f0641
Files changed (2) hide show
  1. app.py +4 -18
  2. requirements.txt +3 -1
app.py CHANGED
@@ -5,8 +5,6 @@ import os
5
  from src.pipeline import ObjectRemovalPipeline
6
  from src.utils import visualize_mask
7
 
8
- # --- ZeroGPU Compatibility Shim ---
9
- # Allows code to run on local CPU/GPU without crashing on 'import spaces'
10
  try:
11
  import spaces
12
  except ImportError:
@@ -17,11 +15,10 @@ except ImportError:
17
  return func
18
  return decorator
19
 
20
- # Initialize pipeline (Models use lazy-loading to save memory)
21
  pipeline = ObjectRemovalPipeline()
22
 
23
  def ensure_uint8(image):
24
- """Normalize image to uint8 (0-255)"""
25
  if image is None: return None
26
  image = np.array(image)
27
  if image.dtype != np.uint8:
@@ -31,11 +28,9 @@ def ensure_uint8(image):
31
 
32
  @spaces.GPU(duration=120)
33
  def step1_detect(image, text_query):
34
- """Detect objects and return candidates for user selection"""
35
  if image is None or not text_query:
36
  return [], [], "Please upload image and enter text."
37
 
38
- # 1. Detect & Rank candidates via Pipeline
39
  candidates, msg = pipeline.get_candidates(image, text_query)
40
 
41
  if not candidates:
@@ -43,7 +38,6 @@ def step1_detect(image, text_query):
43
 
44
  masks = [c['mask'] for c in candidates]
45
 
46
- # 2. Visualize masks for Gallery
47
  gallery_imgs = []
48
  for i, mask in enumerate(masks):
49
  viz = visualize_mask(image, mask)
@@ -54,30 +48,25 @@ def step1_detect(image, text_query):
54
  return masks, gallery_imgs, "Select the best match below."
55
 
56
  def on_select(evt: gr.SelectData):
57
- """Capture user selection from Gallery"""
58
  return evt.index
59
 
60
  @spaces.GPU(duration=120)
61
  def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
62
- """Inpaint the selected mask"""
63
  if not masks or selected_idx is None:
64
  return None, "Please select an object first."
65
 
66
  target_mask = masks[selected_idx]
67
 
68
- # 3. Inpaint with Shadow Fix logic
69
  result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
70
 
71
  return ensure_uint8(result), "Success!"
72
 
73
- # CSS for better layout and full image visibility in Gallery
74
  css = """
75
  .gradio-container {min-height: 0px !important}
76
  button.gallery-item {object-fit: contain !important}
77
  """
78
 
79
- with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
80
- # State to hold masks between steps
81
  mask_state = gr.State([])
82
  idx_state = gr.State(0)
83
 
@@ -90,7 +79,6 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
90
  btn_detect = gr.Button("1. Detect Objects", variant="primary")
91
 
92
  with gr.Column(scale=1):
93
- # Interactive Gallery (Adaptable size)
94
  gallery = gr.Gallery(
95
  label="Candidates (Select One)",
96
  columns=2,
@@ -109,7 +97,6 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
109
  with gr.Column(scale=1):
110
  output_image = gr.Image(label="Final Result", height=400)
111
 
112
- # Event Wiring
113
  btn_detect.click(
114
  fn=step1_detect,
115
  inputs=[input_image, text_query],
@@ -126,8 +113,7 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
126
 
127
  if __name__ == "__main__":
128
  parser = argparse.ArgumentParser()
129
- parser.add_argument("--share", action="store_true", help="Create a public link (Colab)")
130
  args = parser.parse_args()
131
 
132
- # queue() is required for ZeroGPU
133
- demo.queue().launch(share=args.share)
 
5
  from src.pipeline import ObjectRemovalPipeline
6
  from src.utils import visualize_mask
7
 
 
 
8
  try:
9
  import spaces
10
  except ImportError:
 
15
  return func
16
  return decorator
17
 
18
+ # Initialize pipeline
19
  pipeline = ObjectRemovalPipeline()
20
 
21
  def ensure_uint8(image):
 
22
  if image is None: return None
23
  image = np.array(image)
24
  if image.dtype != np.uint8:
 
28
 
29
  @spaces.GPU(duration=120)
30
  def step1_detect(image, text_query):
 
31
  if image is None or not text_query:
32
  return [], [], "Please upload image and enter text."
33
 
 
34
  candidates, msg = pipeline.get_candidates(image, text_query)
35
 
36
  if not candidates:
 
38
 
39
  masks = [c['mask'] for c in candidates]
40
 
 
41
  gallery_imgs = []
42
  for i, mask in enumerate(masks):
43
  viz = visualize_mask(image, mask)
 
48
  return masks, gallery_imgs, "Select the best match below."
49
 
50
  def on_select(evt: gr.SelectData):
 
51
  return evt.index
52
 
53
  @spaces.GPU(duration=120)
54
  def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
 
55
  if not masks or selected_idx is None:
56
  return None, "Please select an object first."
57
 
58
  target_mask = masks[selected_idx]
59
 
 
60
  result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
61
 
62
  return ensure_uint8(result), "Success!"
63
 
 
64
  css = """
65
  .gradio-container {min-height: 0px !important}
66
  button.gallery-item {object-fit: contain !important}
67
  """
68
 
69
+ with gr.Blocks(title="TextEraser") as demo:
 
70
  mask_state = gr.State([])
71
  idx_state = gr.State(0)
72
 
 
79
  btn_detect = gr.Button("1. Detect Objects", variant="primary")
80
 
81
  with gr.Column(scale=1):
 
82
  gallery = gr.Gallery(
83
  label="Candidates (Select One)",
84
  columns=2,
 
97
  with gr.Column(scale=1):
98
  output_image = gr.Image(label="Final Result", height=400)
99
 
 
100
  btn_detect.click(
101
  fn=step1_detect,
102
  inputs=[input_image, text_query],
 
113
 
114
  if __name__ == "__main__":
115
  parser = argparse.ArgumentParser()
116
+ parser.add_argument("--share", action="store_true")
117
  args = parser.parse_args()
118
 
119
+ demo.queue().launch(share=args.share, css=css)
 
requirements.txt CHANGED
@@ -19,4 +19,6 @@ Pillow
19
  sniffio
20
  spaces
21
  clip
22
- git+https://github.com/facebookresearch/sam2.git
 
 
 
19
  sniffio
20
  spaces
21
  clip
22
+ git+https://github.com/facebookresearch/sam2.git
23
+ accelerate
24
+ ftfy