TextEraser / src /pipeline.py
lxzcpro's picture
code clean up
9de67ae
import numpy as np
import cv2
import torch
import gc
from .segmenter import YOLOWorldDetector, SAM2Predictor
from .matcher import CLIPMatcher
from .painter import SDXLInpainter
class ObjectRemovalPipeline:
def __init__(self):
pass
def _clear_ram(self):
"""Helper to force clear RAM & VRAM"""
gc.collect()
torch.cuda.empty_cache()
def get_candidates(self, image, text_query):
candidates = []
box_candidates = []
detector = YOLOWorldDetector()
try:
box_candidates = detector.detect(image, text_query)
finally:
del detector
self._clear_ram()
if not box_candidates:
return [], "No objects detected."
segmenter = SAM2Predictor()
segments_to_score = []
try:
segmenter.set_image(image)
for cand in box_candidates[:3]:
bbox = cand['bbox']
mask_variations = segmenter.predict_from_box(bbox)
for i, (mask, sam_score) in enumerate(mask_variations):
segments_to_score.append({
'mask': mask,
'bbox': bbox,
'area': mask.sum(),
'label': f"{cand['label']} (Var {i+1})"
})
finally:
segmenter.clear_memory()
del segmenter
self._clear_ram()
matcher = CLIPMatcher()
ranked_candidates = []
try:
ranked_candidates = matcher.get_top_k_segments(
image,
segments_to_score,
text_query,
k=len(segments_to_score)
)
finally:
del matcher
self._clear_ram()
return ranked_candidates, f"Found {len(ranked_candidates)} options."
def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):
if shadow_expansion > 0:
kernel_h = int(shadow_expansion * 1.5)
kernel_w = int(shadow_expansion * 0.5)
kernel = np.ones((kernel_h, kernel_w), np.uint8)
selected_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
kernel = np.ones((10, 10), np.uint8)
final_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
result = None
inpainter = SDXLInpainter()
try:
result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
finally:
del inpainter
self._clear_ram()
return result