Spaces:
Running
Running
| import numpy as np | |
| import cv2 | |
| import torch | |
| import gc | |
| from .segmenter import YOLOWorldDetector, SAM2Predictor | |
| from .matcher import CLIPMatcher | |
| from .painter import SDXLInpainter | |
| class ObjectRemovalPipeline: | |
| def __init__(self): | |
| pass | |
| def _clear_ram(self): | |
| """Helper to force clear RAM & VRAM""" | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def get_candidates(self, image, text_query): | |
| candidates = [] | |
| box_candidates = [] | |
| detector = YOLOWorldDetector() | |
| try: | |
| box_candidates = detector.detect(image, text_query) | |
| finally: | |
| del detector | |
| self._clear_ram() | |
| if not box_candidates: | |
| return [], "No objects detected." | |
| segmenter = SAM2Predictor() | |
| segments_to_score = [] | |
| try: | |
| segmenter.set_image(image) | |
| for cand in box_candidates[:3]: | |
| bbox = cand['bbox'] | |
| mask_variations = segmenter.predict_from_box(bbox) | |
| for i, (mask, sam_score) in enumerate(mask_variations): | |
| segments_to_score.append({ | |
| 'mask': mask, | |
| 'bbox': bbox, | |
| 'area': mask.sum(), | |
| 'label': f"{cand['label']} (Var {i+1})" | |
| }) | |
| finally: | |
| segmenter.clear_memory() | |
| del segmenter | |
| self._clear_ram() | |
| matcher = CLIPMatcher() | |
| ranked_candidates = [] | |
| try: | |
| ranked_candidates = matcher.get_top_k_segments( | |
| image, | |
| segments_to_score, | |
| text_query, | |
| k=len(segments_to_score) | |
| ) | |
| finally: | |
| del matcher | |
| self._clear_ram() | |
| return ranked_candidates, f"Found {len(ranked_candidates)} options." | |
| def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0): | |
| if shadow_expansion > 0: | |
| kernel_h = int(shadow_expansion * 1.5) | |
| kernel_w = int(shadow_expansion * 0.5) | |
| kernel = np.ones((kernel_h, kernel_w), np.uint8) | |
| selected_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1) | |
| kernel = np.ones((10, 10), np.uint8) | |
| final_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1) | |
| result = None | |
| inpainter = SDXLInpainter() | |
| try: | |
| result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt) | |
| finally: | |
| del inpainter | |
| self._clear_ram() | |
| return result |