lxzcpro commited on
Commit
9de67ae
·
1 Parent(s): 03bafc0

code clean up

Browse files
Files changed (5) hide show
  1. app.py +6 -9
  2. src/matcher.py +7 -11
  3. src/painter.py +17 -17
  4. src/pipeline.py +11 -24
  5. src/segmenter.py +5 -5
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  from src.pipeline import ObjectRemovalPipeline
4
  from src.utils import visualize_mask
5
 
6
- # Initialize pipeline once
7
  pipeline = ObjectRemovalPipeline()
8
 
9
  def ensure_uint8(image):
@@ -18,7 +18,7 @@ def step1_detect(image, text_query):
18
  if image is None or not text_query:
19
  return [], [], "Please upload image and enter text."
20
 
21
- # Calls the new method in pipeline.py
22
  candidates, msg = pipeline.get_candidates(image, text_query)
23
 
24
  if not candidates:
@@ -26,11 +26,11 @@ def step1_detect(image, text_query):
26
 
27
  masks = [c['mask'] for c in candidates]
28
 
29
- # Generate visualization for gallery
30
  gallery_imgs = []
31
  for i, mask in enumerate(masks):
32
  viz = visualize_mask(image, mask)
33
- # Label with rank and score if available
34
  label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
35
  gallery_imgs.append((ensure_uint8(viz), label))
36
 
@@ -45,15 +45,13 @@ def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
45
 
46
  target_mask = masks[selected_idx]
47
 
48
- # Calls the pipeline method
49
  result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
50
 
51
  return ensure_uint8(result), "Success!"
52
 
53
- # CSS for cleaner UI
54
  css = """
55
  .gradio-container {min-height: 0px !important}
56
- /* Ensure images in gallery don't get cropped strictly */
57
  button.gallery-item {object-fit: contain !important}
58
  """
59
 
@@ -70,8 +68,7 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
70
  btn_detect = gr.Button("1. Detect Objects", variant="primary")
71
 
72
  with gr.Column(scale=1):
73
- # FIXED: object_fit="contain" prevents cropping
74
- # allow_preview=True lets you click to zoom
75
  gallery = gr.Gallery(
76
  label="Candidates (Select One)",
77
  columns=2,
 
3
  from src.pipeline import ObjectRemovalPipeline
4
  from src.utils import visualize_mask
5
 
6
+
7
  pipeline = ObjectRemovalPipeline()
8
 
9
  def ensure_uint8(image):
 
18
  if image is None or not text_query:
19
  return [], [], "Please upload image and enter text."
20
 
21
+
22
  candidates, msg = pipeline.get_candidates(image, text_query)
23
 
24
  if not candidates:
 
26
 
27
  masks = [c['mask'] for c in candidates]
28
 
29
+
30
  gallery_imgs = []
31
  for i, mask in enumerate(masks):
32
  viz = visualize_mask(image, mask)
33
+
34
  label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
35
  gallery_imgs.append((ensure_uint8(viz), label))
36
 
 
45
 
46
  target_mask = masks[selected_idx]
47
 
48
+
49
  result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
50
 
51
  return ensure_uint8(result), "Success!"
52
 
 
53
  css = """
54
  .gradio-container {min-height: 0px !important}
 
55
  button.gallery-item {object-fit: contain !important}
56
  """
57
 
 
68
  btn_detect = gr.Button("1. Detect Objects", variant="primary")
69
 
70
  with gr.Column(scale=1):
71
+
 
72
  gallery = gr.Gallery(
73
  label="Candidates (Select One)",
74
  columns=2,
src/matcher.py CHANGED
@@ -7,19 +7,19 @@ from transformers import CLIPProcessor, CLIPModel
7
  class CLIPMatcher:
8
  def __init__(self, model_name='openai/clip-vit-large-patch14'):
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
- # Load directly to CPU first
11
  self.model = CLIPModel.from_pretrained(model_name).to("cpu")
12
  self.processor = CLIPProcessor.from_pretrained(model_name)
13
 
14
  def get_top_k_segments(self, image, segments, text_query, k=5):
15
  if not segments: return []
16
 
17
- # 1. Clean Text
18
  ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
19
  words = [w for w in text_query.lower().split() if w not in ignore]
20
  clean_text = " ".join(words) if words else text_query
21
 
22
- # 2. Crop (CPU)
23
  pil_image = Image.fromarray(image)
24
  crops = []
25
  valid_segments = []
@@ -30,11 +30,11 @@ class CLIPMatcher:
30
  for seg in segments:
31
  if 'bbox' not in seg: continue
32
 
33
- # Safe numpy cast
34
  bbox = np.array(seg['bbox']).astype(int)
35
  x1, y1, x2, y2 = bbox
36
 
37
- # Adaptive Context Padding (30%)
38
  w_box, h_box = x2 - x1, y2 - y1
39
  pad_x = int(w_box * 0.3)
40
  pad_y = int(h_box * 0.3)
@@ -49,7 +49,7 @@ class CLIPMatcher:
49
 
50
  if not crops: return []
51
 
52
- # 3. Inference (Brief GPU usage)
53
  try:
54
  self.model.to(self.device)
55
  inputs = self.processor(
@@ -58,17 +58,14 @@ class CLIPMatcher:
58
 
59
  with torch.no_grad():
60
  outputs = self.model(**inputs)
61
- # FIX: Use raw logits for meaningful scores.
62
- # (Softmax forces sum=1, concealing bad matches)
63
  probs = outputs.logits_per_image.cpu().numpy().flatten()
64
  except Exception as e:
65
  print(f"CLIP Error: {e}")
66
  return []
67
  finally:
68
- # Move back to CPU immediately
69
  self.model.to("cpu")
70
 
71
- # 4. Score & Sort
72
  final_results = []
73
  for i, score in enumerate(probs):
74
  seg = valid_segments[i]
@@ -78,7 +75,6 @@ class CLIPMatcher:
78
  w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
79
  area_ratio = (w*h) / total_img_area
80
 
81
- # Logits are roughly 15-30 range. Add small boost for area.
82
  weighted_score = float(score) + (area_ratio * 2.0)
83
 
84
  final_results.append({
 
7
  class CLIPMatcher:
8
  def __init__(self, model_name='openai/clip-vit-large-patch14'):
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
  self.model = CLIPModel.from_pretrained(model_name).to("cpu")
12
  self.processor = CLIPProcessor.from_pretrained(model_name)
13
 
14
  def get_top_k_segments(self, image, segments, text_query, k=5):
15
  if not segments: return []
16
 
17
+
18
  ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
19
  words = [w for w in text_query.lower().split() if w not in ignore]
20
  clean_text = " ".join(words) if words else text_query
21
 
22
+
23
  pil_image = Image.fromarray(image)
24
  crops = []
25
  valid_segments = []
 
30
  for seg in segments:
31
  if 'bbox' not in seg: continue
32
 
33
+
34
  bbox = np.array(seg['bbox']).astype(int)
35
  x1, y1, x2, y2 = bbox
36
 
37
+
38
  w_box, h_box = x2 - x1, y2 - y1
39
  pad_x = int(w_box * 0.3)
40
  pad_y = int(h_box * 0.3)
 
49
 
50
  if not crops: return []
51
 
52
+
53
  try:
54
  self.model.to(self.device)
55
  inputs = self.processor(
 
58
 
59
  with torch.no_grad():
60
  outputs = self.model(**inputs)
 
 
61
  probs = outputs.logits_per_image.cpu().numpy().flatten()
62
  except Exception as e:
63
  print(f"CLIP Error: {e}")
64
  return []
65
  finally:
 
66
  self.model.to("cpu")
67
 
68
+
69
  final_results = []
70
  for i, score in enumerate(probs):
71
  seg = valid_segments[i]
 
75
  w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
76
  area_ratio = (w*h) / total_img_area
77
 
 
78
  weighted_score = float(score) + (area_ratio * 2.0)
79
 
80
  final_results.append({
src/painter.py CHANGED
@@ -6,7 +6,7 @@ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPi
6
  class SDInpainter:
7
  def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
8
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
- # Use float16 for GPU to save VRAM and speed up inference
10
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
11
  model_id,
12
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
@@ -18,36 +18,36 @@ class SDInpainter:
18
  def inpaint(self, image, mask, prompt="background"):
19
  pil_image = Image.fromarray(image).convert('RGB')
20
 
21
- # Dilate mask to ensure the object edge is covered
22
  mask = self._dilate_mask(mask)
23
  pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
24
 
25
- # 1. Keep aspect ratio, resize ensuring dimensions are multiples of 8
26
  w, h = pil_image.size
27
- factor = 512 / max(w, h) # Scale based on the longest side
28
  new_w = int(w * factor) - (int(w * factor) % 8)
29
  new_h = int(h * factor) - (int(h * factor) % 8)
30
 
31
  resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
32
  resized_mask = pil_mask.resize((new_w, new_h), Image.NEAREST)
33
 
34
- # 2. Inpaint
35
  output = self.pipe(
36
  prompt=prompt,
37
- negative_prompt="artifacts, low quality, distortion, object", # Add negative prompt for better quality
38
  image=resized_image,
39
  mask_image=resized_mask,
40
  num_inference_steps=30,
41
  guidance_scale=7.5,
42
  ).images[0]
43
 
44
- # 3. Resize back to original resolution
45
  result = output.resize((w, h), Image.LANCZOS)
46
 
47
  return np.array(result)
48
 
49
  def _dilate_mask(self, mask, kernel_size=9):
50
- # Increased kernel size slightly for better blending
51
  import cv2
52
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
53
  return cv2.dilate(mask, kernel, iterations=1)
@@ -56,24 +56,24 @@ class SDInpainter:
56
  class SDXLInpainter:
57
  def __init__(self, model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"):
58
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
59
- # Use float16
60
  self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
61
  model_id,
62
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
63
- variant="fp16", # Add variant for faster loading if available
64
  use_safetensors=True
65
  ).to(self.device)
66
 
67
  if self.device == "cuda":
68
- self.pipe.enable_model_cpu_offload() # Saves VRAM effectively
69
 
70
- def inpaint(self, image, mask, prompt=""): # Default prompt changed to empty
71
  pil_image = Image.fromarray(image).convert('RGB')
72
 
73
- # Increase kernel size to 15 or 20 to ensure no edge artifacts remain
74
  mask = self._dilate_mask(mask, kernel_size=15)
75
 
76
- # Blur the mask slightly to make the transition smoother
77
  import cv2
78
  mask = cv2.GaussianBlur(mask, (21, 21), 0)
79
 
@@ -90,7 +90,7 @@ class SDXLInpainter:
90
 
91
  if not prompt or prompt == "background":
92
  final_prompt = "clean background, empty space, seamless texture, high quality"
93
- # Lower guidance scale for background filling to rely more on image context
94
  guidance_scale = 4.5
95
  else:
96
  final_prompt = prompt
@@ -108,8 +108,8 @@ class SDXLInpainter:
108
  image=resized_image,
109
  mask_image=resized_mask,
110
  num_inference_steps=40,
111
- guidance_scale=guidance_scale, # Dynamic guidance
112
- strength=0.99, # High strength to ensure removal
113
  ).images[0]
114
 
115
  result = output.resize((w, h), Image.LANCZOS)
 
6
  class SDInpainter:
7
  def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
8
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
  self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
11
  model_id,
12
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
 
18
  def inpaint(self, image, mask, prompt="background"):
19
  pil_image = Image.fromarray(image).convert('RGB')
20
 
21
+
22
  mask = self._dilate_mask(mask)
23
  pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
24
 
25
+
26
  w, h = pil_image.size
27
+ factor = 512 / max(w, h)
28
  new_w = int(w * factor) - (int(w * factor) % 8)
29
  new_h = int(h * factor) - (int(h * factor) % 8)
30
 
31
  resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
32
  resized_mask = pil_mask.resize((new_w, new_h), Image.NEAREST)
33
 
34
+
35
  output = self.pipe(
36
  prompt=prompt,
37
+ negative_prompt="artifacts, low quality, distortion, object",
38
  image=resized_image,
39
  mask_image=resized_mask,
40
  num_inference_steps=30,
41
  guidance_scale=7.5,
42
  ).images[0]
43
 
44
+
45
  result = output.resize((w, h), Image.LANCZOS)
46
 
47
  return np.array(result)
48
 
49
  def _dilate_mask(self, mask, kernel_size=9):
50
+
51
  import cv2
52
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
53
  return cv2.dilate(mask, kernel, iterations=1)
 
56
  class SDXLInpainter:
57
  def __init__(self, model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"):
58
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
59
+
60
  self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
61
  model_id,
62
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
63
+ variant="fp16",
64
  use_safetensors=True
65
  ).to(self.device)
66
 
67
  if self.device == "cuda":
68
+ self.pipe.enable_model_cpu_offload()
69
 
70
+ def inpaint(self, image, mask, prompt=""):
71
  pil_image = Image.fromarray(image).convert('RGB')
72
 
73
+
74
  mask = self._dilate_mask(mask, kernel_size=15)
75
 
76
+
77
  import cv2
78
  mask = cv2.GaussianBlur(mask, (21, 21), 0)
79
 
 
90
 
91
  if not prompt or prompt == "background":
92
  final_prompt = "clean background, empty space, seamless texture, high quality"
93
+
94
  guidance_scale = 4.5
95
  else:
96
  final_prompt = prompt
 
108
  image=resized_image,
109
  mask_image=resized_mask,
110
  num_inference_steps=40,
111
+ guidance_scale=guidance_scale,
112
+ strength=0.99,
113
  ).images[0]
114
 
115
  result = output.resize((w, h), Image.LANCZOS)
src/pipeline.py CHANGED
@@ -2,15 +2,13 @@ import numpy as np
2
  import cv2
3
  import torch
4
  import gc
5
- # Note: We import classes but DO NOT instantiate them globally
6
  from .segmenter import YOLOWorldDetector, SAM2Predictor
7
  from .matcher import CLIPMatcher
8
  from .painter import SDXLInpainter
9
 
10
  class ObjectRemovalPipeline:
11
  def __init__(self):
12
- print("Initializing Pipeline in LOW MEMORY mode...")
13
- # No models loaded at startup!
14
  pass
15
 
16
  def _clear_ram(self):
@@ -19,32 +17,26 @@ class ObjectRemovalPipeline:
19
  torch.cuda.empty_cache()
20
 
21
  def get_candidates(self, image, text_query):
22
- """
23
- Step 1: Detect & Segment & Rank
24
- Strategy: Load one model at a time, use it, then delete it.
25
- """
26
  candidates = []
27
  box_candidates = []
28
 
29
- # --- PHASE 1: YOLO (Detect) ---
30
- print("Loading YOLO...")
31
  detector = YOLOWorldDetector()
32
  try:
33
  box_candidates = detector.detect(image, text_query)
34
  finally:
35
- del detector # Delete model immediately
36
  self._clear_ram()
37
 
38
  if not box_candidates:
39
  return [], "No objects detected."
40
 
41
- # --- PHASE 2: SAM2 (Segment) ---
42
- print("Loading SAM2...")
43
  segmenter = SAM2Predictor()
44
  segments_to_score = []
45
  try:
46
  segmenter.set_image(image)
47
- # Process top 3 boxes -> up to 9 masks
48
  for cand in box_candidates[:3]:
49
  bbox = cand['bbox']
50
  mask_variations = segmenter.predict_from_box(bbox)
@@ -56,14 +48,12 @@ class ObjectRemovalPipeline:
56
  'label': f"{cand['label']} (Var {i+1})"
57
  })
58
  finally:
59
- # Critical cleanup for SAM2
60
- if hasattr(segmenter, 'clear_memory'):
61
- segmenter.clear_memory()
62
  del segmenter
63
  self._clear_ram()
64
 
65
- # --- PHASE 3: CLIP (Rank) ---
66
- print("Loading CLIP...")
67
  matcher = CLIPMatcher()
68
  ranked_candidates = []
69
  try:
@@ -80,10 +70,8 @@ class ObjectRemovalPipeline:
80
  return ranked_candidates, f"Found {len(ranked_candidates)} options."
81
 
82
  def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
83
- """
84
- Step 2: Inpaint
85
- """
86
- # Shadow / Edge Logic (CPU ops)
87
  if shadow_expansion > 0:
88
  kernel_h = int(shadow_expansion * 1.5)
89
  kernel_w = int(shadow_expansion * 0.5)
@@ -95,8 +83,7 @@ class ObjectRemovalPipeline:
95
 
96
  result = None
97
 
98
- # --- PHASE 4: SDXL (Inpaint) ---
99
- print("Loading SDXL...")
100
  inpainter = SDXLInpainter()
101
  try:
102
  result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
 
2
  import cv2
3
  import torch
4
  import gc
5
+
6
  from .segmenter import YOLOWorldDetector, SAM2Predictor
7
  from .matcher import CLIPMatcher
8
  from .painter import SDXLInpainter
9
 
10
  class ObjectRemovalPipeline:
11
  def __init__(self):
 
 
12
  pass
13
 
14
  def _clear_ram(self):
 
17
  torch.cuda.empty_cache()
18
 
19
  def get_candidates(self, image, text_query):
20
+
 
 
 
21
  candidates = []
22
  box_candidates = []
23
 
24
+
 
25
  detector = YOLOWorldDetector()
26
  try:
27
  box_candidates = detector.detect(image, text_query)
28
  finally:
29
+ del detector
30
  self._clear_ram()
31
 
32
  if not box_candidates:
33
  return [], "No objects detected."
34
 
35
+
 
36
  segmenter = SAM2Predictor()
37
  segments_to_score = []
38
  try:
39
  segmenter.set_image(image)
 
40
  for cand in box_candidates[:3]:
41
  bbox = cand['bbox']
42
  mask_variations = segmenter.predict_from_box(bbox)
 
48
  'label': f"{cand['label']} (Var {i+1})"
49
  })
50
  finally:
51
+
52
+ segmenter.clear_memory()
 
53
  del segmenter
54
  self._clear_ram()
55
 
56
+
 
57
  matcher = CLIPMatcher()
58
  ranked_candidates = []
59
  try:
 
70
  return ranked_candidates, f"Found {len(ranked_candidates)} options."
71
 
72
  def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
73
+
74
+
 
 
75
  if shadow_expansion > 0:
76
  kernel_h = int(shadow_expansion * 1.5)
77
  kernel_w = int(shadow_expansion * 0.5)
 
83
 
84
  result = None
85
 
86
+
 
87
  inpainter = SDXLInpainter()
88
  try:
89
  result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
src/segmenter.py CHANGED
@@ -6,7 +6,7 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
6
 
7
  class YOLOWorldDetector:
8
  def __init__(self, model_name='yolov8s-worldv2.pt'):
9
- # Initialize, but manage device carefully
10
  self.model = YOLO(model_name)
11
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
 
@@ -16,7 +16,7 @@ class YOLOWorldDetector:
16
 
17
  boxes = []
18
  try:
19
- # FIX: Force CPU for text encoding to prevent RuntimeError
20
  self.model.to('cpu')
21
  self.model.set_classes([clean_text])
22
 
@@ -37,7 +37,7 @@ class YOLOWorldDetector:
37
  except Exception as e:
38
  print(f"YOLO Error: {e}")
39
  finally:
40
- # Always offload after use
41
  self.model.to('cpu')
42
 
43
  boxes.sort(key=lambda x: x['score'], reverse=True)
@@ -57,7 +57,7 @@ class SAM2Predictor:
57
 
58
  def predict_from_box(self, bbox):
59
  box_input = np.array(bbox)[None, :]
60
- # Multimask = True for variety
61
  masks, scores, logits = self.predictor.predict(
62
  point_coords=None,
63
  point_labels=None,
@@ -68,7 +68,7 @@ class SAM2Predictor:
68
  return [(m.astype(np.uint8), s) for m, s in sorted_results]
69
 
70
  def clear_memory(self):
71
- # Critical for preventing memory leaks
72
  self.predictor.reset_predictor()
73
  self.predictor.model.to('cpu')
74
  del self.predictor
 
6
 
7
  class YOLOWorldDetector:
8
  def __init__(self, model_name='yolov8s-worldv2.pt'):
9
+
10
  self.model = YOLO(model_name)
11
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
 
 
16
 
17
  boxes = []
18
  try:
19
+
20
  self.model.to('cpu')
21
  self.model.set_classes([clean_text])
22
 
 
37
  except Exception as e:
38
  print(f"YOLO Error: {e}")
39
  finally:
40
+
41
  self.model.to('cpu')
42
 
43
  boxes.sort(key=lambda x: x['score'], reverse=True)
 
57
 
58
  def predict_from_box(self, bbox):
59
  box_input = np.array(bbox)[None, :]
60
+
61
  masks, scores, logits = self.predictor.predict(
62
  point_coords=None,
63
  point_labels=None,
 
68
  return [(m.astype(np.uint8), s) for m, s in sorted_results]
69
 
70
  def clear_memory(self):
71
+
72
  self.predictor.reset_predictor()
73
  self.predictor.model.to('cpu')
74
  del self.predictor