Spaces:
Running
Running
code clean up
Browse files- app.py +6 -9
- src/matcher.py +7 -11
- src/painter.py +17 -17
- src/pipeline.py +11 -24
- src/segmenter.py +5 -5
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
| 3 |
from src.pipeline import ObjectRemovalPipeline
|
| 4 |
from src.utils import visualize_mask
|
| 5 |
|
| 6 |
-
|
| 7 |
pipeline = ObjectRemovalPipeline()
|
| 8 |
|
| 9 |
def ensure_uint8(image):
|
|
@@ -18,7 +18,7 @@ def step1_detect(image, text_query):
|
|
| 18 |
if image is None or not text_query:
|
| 19 |
return [], [], "Please upload image and enter text."
|
| 20 |
|
| 21 |
-
|
| 22 |
candidates, msg = pipeline.get_candidates(image, text_query)
|
| 23 |
|
| 24 |
if not candidates:
|
|
@@ -26,11 +26,11 @@ def step1_detect(image, text_query):
|
|
| 26 |
|
| 27 |
masks = [c['mask'] for c in candidates]
|
| 28 |
|
| 29 |
-
|
| 30 |
gallery_imgs = []
|
| 31 |
for i, mask in enumerate(masks):
|
| 32 |
viz = visualize_mask(image, mask)
|
| 33 |
-
|
| 34 |
label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
|
| 35 |
gallery_imgs.append((ensure_uint8(viz), label))
|
| 36 |
|
|
@@ -45,15 +45,13 @@ def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
|
|
| 45 |
|
| 46 |
target_mask = masks[selected_idx]
|
| 47 |
|
| 48 |
-
|
| 49 |
result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
|
| 50 |
|
| 51 |
return ensure_uint8(result), "Success!"
|
| 52 |
|
| 53 |
-
# CSS for cleaner UI
|
| 54 |
css = """
|
| 55 |
.gradio-container {min-height: 0px !important}
|
| 56 |
-
/* Ensure images in gallery don't get cropped strictly */
|
| 57 |
button.gallery-item {object-fit: contain !important}
|
| 58 |
"""
|
| 59 |
|
|
@@ -70,8 +68,7 @@ with gr.Blocks(title="TextEraser", css=css, theme=gr.themes.Soft()) as demo:
|
|
| 70 |
btn_detect = gr.Button("1. Detect Objects", variant="primary")
|
| 71 |
|
| 72 |
with gr.Column(scale=1):
|
| 73 |
-
|
| 74 |
-
# allow_preview=True lets you click to zoom
|
| 75 |
gallery = gr.Gallery(
|
| 76 |
label="Candidates (Select One)",
|
| 77 |
columns=2,
|
|
|
|
| 3 |
from src.pipeline import ObjectRemovalPipeline
|
| 4 |
from src.utils import visualize_mask
|
| 5 |
|
| 6 |
+
|
| 7 |
pipeline = ObjectRemovalPipeline()
|
| 8 |
|
| 9 |
def ensure_uint8(image):
|
|
|
|
| 18 |
if image is None or not text_query:
|
| 19 |
return [], [], "Please upload image and enter text."
|
| 20 |
|
| 21 |
+
|
| 22 |
candidates, msg = pipeline.get_candidates(image, text_query)
|
| 23 |
|
| 24 |
if not candidates:
|
|
|
|
| 26 |
|
| 27 |
masks = [c['mask'] for c in candidates]
|
| 28 |
|
| 29 |
+
|
| 30 |
gallery_imgs = []
|
| 31 |
for i, mask in enumerate(masks):
|
| 32 |
viz = visualize_mask(image, mask)
|
| 33 |
+
|
| 34 |
label = f"Option {i+1} (Score: {candidates[i].get('weighted_score', 0):.2f})"
|
| 35 |
gallery_imgs.append((ensure_uint8(viz), label))
|
| 36 |
|
|
|
|
| 45 |
|
| 46 |
target_mask = masks[selected_idx]
|
| 47 |
|
| 48 |
+
|
| 49 |
result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
|
| 50 |
|
| 51 |
return ensure_uint8(result), "Success!"
|
| 52 |
|
|
|
|
| 53 |
css = """
|
| 54 |
.gradio-container {min-height: 0px !important}
|
|
|
|
| 55 |
button.gallery-item {object-fit: contain !important}
|
| 56 |
"""
|
| 57 |
|
|
|
|
| 68 |
btn_detect = gr.Button("1. Detect Objects", variant="primary")
|
| 69 |
|
| 70 |
with gr.Column(scale=1):
|
| 71 |
+
|
|
|
|
| 72 |
gallery = gr.Gallery(
|
| 73 |
label="Candidates (Select One)",
|
| 74 |
columns=2,
|
src/matcher.py
CHANGED
|
@@ -7,19 +7,19 @@ from transformers import CLIPProcessor, CLIPModel
|
|
| 7 |
class CLIPMatcher:
|
| 8 |
def __init__(self, model_name='openai/clip-vit-large-patch14'):
|
| 9 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
-
|
| 11 |
self.model = CLIPModel.from_pretrained(model_name).to("cpu")
|
| 12 |
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 13 |
|
| 14 |
def get_top_k_segments(self, image, segments, text_query, k=5):
|
| 15 |
if not segments: return []
|
| 16 |
|
| 17 |
-
|
| 18 |
ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
|
| 19 |
words = [w for w in text_query.lower().split() if w not in ignore]
|
| 20 |
clean_text = " ".join(words) if words else text_query
|
| 21 |
|
| 22 |
-
|
| 23 |
pil_image = Image.fromarray(image)
|
| 24 |
crops = []
|
| 25 |
valid_segments = []
|
|
@@ -30,11 +30,11 @@ class CLIPMatcher:
|
|
| 30 |
for seg in segments:
|
| 31 |
if 'bbox' not in seg: continue
|
| 32 |
|
| 33 |
-
|
| 34 |
bbox = np.array(seg['bbox']).astype(int)
|
| 35 |
x1, y1, x2, y2 = bbox
|
| 36 |
|
| 37 |
-
|
| 38 |
w_box, h_box = x2 - x1, y2 - y1
|
| 39 |
pad_x = int(w_box * 0.3)
|
| 40 |
pad_y = int(h_box * 0.3)
|
|
@@ -49,7 +49,7 @@ class CLIPMatcher:
|
|
| 49 |
|
| 50 |
if not crops: return []
|
| 51 |
|
| 52 |
-
|
| 53 |
try:
|
| 54 |
self.model.to(self.device)
|
| 55 |
inputs = self.processor(
|
|
@@ -58,17 +58,14 @@ class CLIPMatcher:
|
|
| 58 |
|
| 59 |
with torch.no_grad():
|
| 60 |
outputs = self.model(**inputs)
|
| 61 |
-
# FIX: Use raw logits for meaningful scores.
|
| 62 |
-
# (Softmax forces sum=1, concealing bad matches)
|
| 63 |
probs = outputs.logits_per_image.cpu().numpy().flatten()
|
| 64 |
except Exception as e:
|
| 65 |
print(f"CLIP Error: {e}")
|
| 66 |
return []
|
| 67 |
finally:
|
| 68 |
-
# Move back to CPU immediately
|
| 69 |
self.model.to("cpu")
|
| 70 |
|
| 71 |
-
|
| 72 |
final_results = []
|
| 73 |
for i, score in enumerate(probs):
|
| 74 |
seg = valid_segments[i]
|
|
@@ -78,7 +75,6 @@ class CLIPMatcher:
|
|
| 78 |
w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
|
| 79 |
area_ratio = (w*h) / total_img_area
|
| 80 |
|
| 81 |
-
# Logits are roughly 15-30 range. Add small boost for area.
|
| 82 |
weighted_score = float(score) + (area_ratio * 2.0)
|
| 83 |
|
| 84 |
final_results.append({
|
|
|
|
| 7 |
class CLIPMatcher:
|
| 8 |
def __init__(self, model_name='openai/clip-vit-large-patch14'):
|
| 9 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
|
| 11 |
self.model = CLIPModel.from_pretrained(model_name).to("cpu")
|
| 12 |
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 13 |
|
| 14 |
def get_top_k_segments(self, image, segments, text_query, k=5):
|
| 15 |
if not segments: return []
|
| 16 |
|
| 17 |
+
|
| 18 |
ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
|
| 19 |
words = [w for w in text_query.lower().split() if w not in ignore]
|
| 20 |
clean_text = " ".join(words) if words else text_query
|
| 21 |
|
| 22 |
+
|
| 23 |
pil_image = Image.fromarray(image)
|
| 24 |
crops = []
|
| 25 |
valid_segments = []
|
|
|
|
| 30 |
for seg in segments:
|
| 31 |
if 'bbox' not in seg: continue
|
| 32 |
|
| 33 |
+
|
| 34 |
bbox = np.array(seg['bbox']).astype(int)
|
| 35 |
x1, y1, x2, y2 = bbox
|
| 36 |
|
| 37 |
+
|
| 38 |
w_box, h_box = x2 - x1, y2 - y1
|
| 39 |
pad_x = int(w_box * 0.3)
|
| 40 |
pad_y = int(h_box * 0.3)
|
|
|
|
| 49 |
|
| 50 |
if not crops: return []
|
| 51 |
|
| 52 |
+
|
| 53 |
try:
|
| 54 |
self.model.to(self.device)
|
| 55 |
inputs = self.processor(
|
|
|
|
| 58 |
|
| 59 |
with torch.no_grad():
|
| 60 |
outputs = self.model(**inputs)
|
|
|
|
|
|
|
| 61 |
probs = outputs.logits_per_image.cpu().numpy().flatten()
|
| 62 |
except Exception as e:
|
| 63 |
print(f"CLIP Error: {e}")
|
| 64 |
return []
|
| 65 |
finally:
|
|
|
|
| 66 |
self.model.to("cpu")
|
| 67 |
|
| 68 |
+
|
| 69 |
final_results = []
|
| 70 |
for i, score in enumerate(probs):
|
| 71 |
seg = valid_segments[i]
|
|
|
|
| 75 |
w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
|
| 76 |
area_ratio = (w*h) / total_img_area
|
| 77 |
|
|
|
|
| 78 |
weighted_score = float(score) + (area_ratio * 2.0)
|
| 79 |
|
| 80 |
final_results.append({
|
src/painter.py
CHANGED
|
@@ -6,7 +6,7 @@ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPi
|
|
| 6 |
class SDInpainter:
|
| 7 |
def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
|
| 8 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
-
|
| 10 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 11 |
model_id,
|
| 12 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
@@ -18,36 +18,36 @@ class SDInpainter:
|
|
| 18 |
def inpaint(self, image, mask, prompt="background"):
|
| 19 |
pil_image = Image.fromarray(image).convert('RGB')
|
| 20 |
|
| 21 |
-
|
| 22 |
mask = self._dilate_mask(mask)
|
| 23 |
pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
|
| 24 |
|
| 25 |
-
|
| 26 |
w, h = pil_image.size
|
| 27 |
-
factor = 512 / max(w, h)
|
| 28 |
new_w = int(w * factor) - (int(w * factor) % 8)
|
| 29 |
new_h = int(h * factor) - (int(h * factor) % 8)
|
| 30 |
|
| 31 |
resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
|
| 32 |
resized_mask = pil_mask.resize((new_w, new_h), Image.NEAREST)
|
| 33 |
|
| 34 |
-
|
| 35 |
output = self.pipe(
|
| 36 |
prompt=prompt,
|
| 37 |
-
negative_prompt="artifacts, low quality, distortion, object",
|
| 38 |
image=resized_image,
|
| 39 |
mask_image=resized_mask,
|
| 40 |
num_inference_steps=30,
|
| 41 |
guidance_scale=7.5,
|
| 42 |
).images[0]
|
| 43 |
|
| 44 |
-
|
| 45 |
result = output.resize((w, h), Image.LANCZOS)
|
| 46 |
|
| 47 |
return np.array(result)
|
| 48 |
|
| 49 |
def _dilate_mask(self, mask, kernel_size=9):
|
| 50 |
-
|
| 51 |
import cv2
|
| 52 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 53 |
return cv2.dilate(mask, kernel, iterations=1)
|
|
@@ -56,24 +56,24 @@ class SDInpainter:
|
|
| 56 |
class SDXLInpainter:
|
| 57 |
def __init__(self, model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"):
|
| 58 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
-
|
| 60 |
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 61 |
model_id,
|
| 62 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 63 |
-
variant="fp16",
|
| 64 |
use_safetensors=True
|
| 65 |
).to(self.device)
|
| 66 |
|
| 67 |
if self.device == "cuda":
|
| 68 |
-
self.pipe.enable_model_cpu_offload()
|
| 69 |
|
| 70 |
-
def inpaint(self, image, mask, prompt=""):
|
| 71 |
pil_image = Image.fromarray(image).convert('RGB')
|
| 72 |
|
| 73 |
-
|
| 74 |
mask = self._dilate_mask(mask, kernel_size=15)
|
| 75 |
|
| 76 |
-
|
| 77 |
import cv2
|
| 78 |
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
| 79 |
|
|
@@ -90,7 +90,7 @@ class SDXLInpainter:
|
|
| 90 |
|
| 91 |
if not prompt or prompt == "background":
|
| 92 |
final_prompt = "clean background, empty space, seamless texture, high quality"
|
| 93 |
-
|
| 94 |
guidance_scale = 4.5
|
| 95 |
else:
|
| 96 |
final_prompt = prompt
|
|
@@ -108,8 +108,8 @@ class SDXLInpainter:
|
|
| 108 |
image=resized_image,
|
| 109 |
mask_image=resized_mask,
|
| 110 |
num_inference_steps=40,
|
| 111 |
-
guidance_scale=guidance_scale,
|
| 112 |
-
strength=0.99,
|
| 113 |
).images[0]
|
| 114 |
|
| 115 |
result = output.resize((w, h), Image.LANCZOS)
|
|
|
|
| 6 |
class SDInpainter:
|
| 7 |
def __init__(self, model_id="runwayml/stable-diffusion-inpainting"):
|
| 8 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
+
|
| 10 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 11 |
model_id,
|
| 12 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
|
|
| 18 |
def inpaint(self, image, mask, prompt="background"):
|
| 19 |
pil_image = Image.fromarray(image).convert('RGB')
|
| 20 |
|
| 21 |
+
|
| 22 |
mask = self._dilate_mask(mask)
|
| 23 |
pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
|
| 24 |
|
| 25 |
+
|
| 26 |
w, h = pil_image.size
|
| 27 |
+
factor = 512 / max(w, h)
|
| 28 |
new_w = int(w * factor) - (int(w * factor) % 8)
|
| 29 |
new_h = int(h * factor) - (int(h * factor) % 8)
|
| 30 |
|
| 31 |
resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
|
| 32 |
resized_mask = pil_mask.resize((new_w, new_h), Image.NEAREST)
|
| 33 |
|
| 34 |
+
|
| 35 |
output = self.pipe(
|
| 36 |
prompt=prompt,
|
| 37 |
+
negative_prompt="artifacts, low quality, distortion, object",
|
| 38 |
image=resized_image,
|
| 39 |
mask_image=resized_mask,
|
| 40 |
num_inference_steps=30,
|
| 41 |
guidance_scale=7.5,
|
| 42 |
).images[0]
|
| 43 |
|
| 44 |
+
|
| 45 |
result = output.resize((w, h), Image.LANCZOS)
|
| 46 |
|
| 47 |
return np.array(result)
|
| 48 |
|
| 49 |
def _dilate_mask(self, mask, kernel_size=9):
|
| 50 |
+
|
| 51 |
import cv2
|
| 52 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 53 |
return cv2.dilate(mask, kernel, iterations=1)
|
|
|
|
| 56 |
class SDXLInpainter:
|
| 57 |
def __init__(self, model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"):
|
| 58 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
+
|
| 60 |
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 61 |
model_id,
|
| 62 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 63 |
+
variant="fp16",
|
| 64 |
use_safetensors=True
|
| 65 |
).to(self.device)
|
| 66 |
|
| 67 |
if self.device == "cuda":
|
| 68 |
+
self.pipe.enable_model_cpu_offload()
|
| 69 |
|
| 70 |
+
def inpaint(self, image, mask, prompt=""):
|
| 71 |
pil_image = Image.fromarray(image).convert('RGB')
|
| 72 |
|
| 73 |
+
|
| 74 |
mask = self._dilate_mask(mask, kernel_size=15)
|
| 75 |
|
| 76 |
+
|
| 77 |
import cv2
|
| 78 |
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
| 79 |
|
|
|
|
| 90 |
|
| 91 |
if not prompt or prompt == "background":
|
| 92 |
final_prompt = "clean background, empty space, seamless texture, high quality"
|
| 93 |
+
|
| 94 |
guidance_scale = 4.5
|
| 95 |
else:
|
| 96 |
final_prompt = prompt
|
|
|
|
| 108 |
image=resized_image,
|
| 109 |
mask_image=resized_mask,
|
| 110 |
num_inference_steps=40,
|
| 111 |
+
guidance_scale=guidance_scale,
|
| 112 |
+
strength=0.99,
|
| 113 |
).images[0]
|
| 114 |
|
| 115 |
result = output.resize((w, h), Image.LANCZOS)
|
src/pipeline.py
CHANGED
|
@@ -2,15 +2,13 @@ import numpy as np
|
|
| 2 |
import cv2
|
| 3 |
import torch
|
| 4 |
import gc
|
| 5 |
-
|
| 6 |
from .segmenter import YOLOWorldDetector, SAM2Predictor
|
| 7 |
from .matcher import CLIPMatcher
|
| 8 |
from .painter import SDXLInpainter
|
| 9 |
|
| 10 |
class ObjectRemovalPipeline:
|
| 11 |
def __init__(self):
|
| 12 |
-
print("Initializing Pipeline in LOW MEMORY mode...")
|
| 13 |
-
# No models loaded at startup!
|
| 14 |
pass
|
| 15 |
|
| 16 |
def _clear_ram(self):
|
|
@@ -19,32 +17,26 @@ class ObjectRemovalPipeline:
|
|
| 19 |
torch.cuda.empty_cache()
|
| 20 |
|
| 21 |
def get_candidates(self, image, text_query):
|
| 22 |
-
|
| 23 |
-
Step 1: Detect & Segment & Rank
|
| 24 |
-
Strategy: Load one model at a time, use it, then delete it.
|
| 25 |
-
"""
|
| 26 |
candidates = []
|
| 27 |
box_candidates = []
|
| 28 |
|
| 29 |
-
|
| 30 |
-
print("Loading YOLO...")
|
| 31 |
detector = YOLOWorldDetector()
|
| 32 |
try:
|
| 33 |
box_candidates = detector.detect(image, text_query)
|
| 34 |
finally:
|
| 35 |
-
del detector
|
| 36 |
self._clear_ram()
|
| 37 |
|
| 38 |
if not box_candidates:
|
| 39 |
return [], "No objects detected."
|
| 40 |
|
| 41 |
-
|
| 42 |
-
print("Loading SAM2...")
|
| 43 |
segmenter = SAM2Predictor()
|
| 44 |
segments_to_score = []
|
| 45 |
try:
|
| 46 |
segmenter.set_image(image)
|
| 47 |
-
# Process top 3 boxes -> up to 9 masks
|
| 48 |
for cand in box_candidates[:3]:
|
| 49 |
bbox = cand['bbox']
|
| 50 |
mask_variations = segmenter.predict_from_box(bbox)
|
|
@@ -56,14 +48,12 @@ class ObjectRemovalPipeline:
|
|
| 56 |
'label': f"{cand['label']} (Var {i+1})"
|
| 57 |
})
|
| 58 |
finally:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
segmenter.clear_memory()
|
| 62 |
del segmenter
|
| 63 |
self._clear_ram()
|
| 64 |
|
| 65 |
-
|
| 66 |
-
print("Loading CLIP...")
|
| 67 |
matcher = CLIPMatcher()
|
| 68 |
ranked_candidates = []
|
| 69 |
try:
|
|
@@ -80,10 +70,8 @@ class ObjectRemovalPipeline:
|
|
| 80 |
return ranked_candidates, f"Found {len(ranked_candidates)} options."
|
| 81 |
|
| 82 |
def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
"""
|
| 86 |
-
# Shadow / Edge Logic (CPU ops)
|
| 87 |
if shadow_expansion > 0:
|
| 88 |
kernel_h = int(shadow_expansion * 1.5)
|
| 89 |
kernel_w = int(shadow_expansion * 0.5)
|
|
@@ -95,8 +83,7 @@ class ObjectRemovalPipeline:
|
|
| 95 |
|
| 96 |
result = None
|
| 97 |
|
| 98 |
-
|
| 99 |
-
print("Loading SDXL...")
|
| 100 |
inpainter = SDXLInpainter()
|
| 101 |
try:
|
| 102 |
result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
|
|
|
|
| 2 |
import cv2
|
| 3 |
import torch
|
| 4 |
import gc
|
| 5 |
+
|
| 6 |
from .segmenter import YOLOWorldDetector, SAM2Predictor
|
| 7 |
from .matcher import CLIPMatcher
|
| 8 |
from .painter import SDXLInpainter
|
| 9 |
|
| 10 |
class ObjectRemovalPipeline:
|
| 11 |
def __init__(self):
|
|
|
|
|
|
|
| 12 |
pass
|
| 13 |
|
| 14 |
def _clear_ram(self):
|
|
|
|
| 17 |
torch.cuda.empty_cache()
|
| 18 |
|
| 19 |
def get_candidates(self, image, text_query):
|
| 20 |
+
|
|
|
|
|
|
|
|
|
|
| 21 |
candidates = []
|
| 22 |
box_candidates = []
|
| 23 |
|
| 24 |
+
|
|
|
|
| 25 |
detector = YOLOWorldDetector()
|
| 26 |
try:
|
| 27 |
box_candidates = detector.detect(image, text_query)
|
| 28 |
finally:
|
| 29 |
+
del detector
|
| 30 |
self._clear_ram()
|
| 31 |
|
| 32 |
if not box_candidates:
|
| 33 |
return [], "No objects detected."
|
| 34 |
|
| 35 |
+
|
|
|
|
| 36 |
segmenter = SAM2Predictor()
|
| 37 |
segments_to_score = []
|
| 38 |
try:
|
| 39 |
segmenter.set_image(image)
|
|
|
|
| 40 |
for cand in box_candidates[:3]:
|
| 41 |
bbox = cand['bbox']
|
| 42 |
mask_variations = segmenter.predict_from_box(bbox)
|
|
|
|
| 48 |
'label': f"{cand['label']} (Var {i+1})"
|
| 49 |
})
|
| 50 |
finally:
|
| 51 |
+
|
| 52 |
+
segmenter.clear_memory()
|
|
|
|
| 53 |
del segmenter
|
| 54 |
self._clear_ram()
|
| 55 |
|
| 56 |
+
|
|
|
|
| 57 |
matcher = CLIPMatcher()
|
| 58 |
ranked_candidates = []
|
| 59 |
try:
|
|
|
|
| 70 |
return ranked_candidates, f"Found {len(ranked_candidates)} options."
|
| 71 |
|
| 72 |
def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
|
| 73 |
+
|
| 74 |
+
|
|
|
|
|
|
|
| 75 |
if shadow_expansion > 0:
|
| 76 |
kernel_h = int(shadow_expansion * 1.5)
|
| 77 |
kernel_w = int(shadow_expansion * 0.5)
|
|
|
|
| 83 |
|
| 84 |
result = None
|
| 85 |
|
| 86 |
+
|
|
|
|
| 87 |
inpainter = SDXLInpainter()
|
| 88 |
try:
|
| 89 |
result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
|
src/segmenter.py
CHANGED
|
@@ -6,7 +6,7 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
| 6 |
|
| 7 |
class YOLOWorldDetector:
|
| 8 |
def __init__(self, model_name='yolov8s-worldv2.pt'):
|
| 9 |
-
|
| 10 |
self.model = YOLO(model_name)
|
| 11 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 12 |
|
|
@@ -16,7 +16,7 @@ class YOLOWorldDetector:
|
|
| 16 |
|
| 17 |
boxes = []
|
| 18 |
try:
|
| 19 |
-
|
| 20 |
self.model.to('cpu')
|
| 21 |
self.model.set_classes([clean_text])
|
| 22 |
|
|
@@ -37,7 +37,7 @@ class YOLOWorldDetector:
|
|
| 37 |
except Exception as e:
|
| 38 |
print(f"YOLO Error: {e}")
|
| 39 |
finally:
|
| 40 |
-
|
| 41 |
self.model.to('cpu')
|
| 42 |
|
| 43 |
boxes.sort(key=lambda x: x['score'], reverse=True)
|
|
@@ -57,7 +57,7 @@ class SAM2Predictor:
|
|
| 57 |
|
| 58 |
def predict_from_box(self, bbox):
|
| 59 |
box_input = np.array(bbox)[None, :]
|
| 60 |
-
|
| 61 |
masks, scores, logits = self.predictor.predict(
|
| 62 |
point_coords=None,
|
| 63 |
point_labels=None,
|
|
@@ -68,7 +68,7 @@ class SAM2Predictor:
|
|
| 68 |
return [(m.astype(np.uint8), s) for m, s in sorted_results]
|
| 69 |
|
| 70 |
def clear_memory(self):
|
| 71 |
-
|
| 72 |
self.predictor.reset_predictor()
|
| 73 |
self.predictor.model.to('cpu')
|
| 74 |
del self.predictor
|
|
|
|
| 6 |
|
| 7 |
class YOLOWorldDetector:
|
| 8 |
def __init__(self, model_name='yolov8s-worldv2.pt'):
|
| 9 |
+
|
| 10 |
self.model = YOLO(model_name)
|
| 11 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 12 |
|
|
|
|
| 16 |
|
| 17 |
boxes = []
|
| 18 |
try:
|
| 19 |
+
|
| 20 |
self.model.to('cpu')
|
| 21 |
self.model.set_classes([clean_text])
|
| 22 |
|
|
|
|
| 37 |
except Exception as e:
|
| 38 |
print(f"YOLO Error: {e}")
|
| 39 |
finally:
|
| 40 |
+
|
| 41 |
self.model.to('cpu')
|
| 42 |
|
| 43 |
boxes.sort(key=lambda x: x['score'], reverse=True)
|
|
|
|
| 57 |
|
| 58 |
def predict_from_box(self, bbox):
|
| 59 |
box_input = np.array(bbox)[None, :]
|
| 60 |
+
|
| 61 |
masks, scores, logits = self.predictor.predict(
|
| 62 |
point_coords=None,
|
| 63 |
point_labels=None,
|
|
|
|
| 68 |
return [(m.astype(np.uint8), s) for m, s in sorted_results]
|
| 69 |
|
| 70 |
def clear_memory(self):
|
| 71 |
+
|
| 72 |
self.predictor.reset_predictor()
|
| 73 |
self.predictor.model.to('cpu')
|
| 74 |
del self.predictor
|