Spaces:
Running
Running
File size: 3,759 Bytes
144afae c3f0641 144afae c3f0641 dec259d 144afae 03bafc0 144afae 03bafc0 144afae c3f0641 03bafc0 144afae 03bafc0 c3f0641 03bafc0 144afae 03bafc0 144afae 03bafc0 144afae c3f0641 03bafc0 144afae c3f0641 144afae dec259d 03bafc0 144afae 03bafc0 144afae 03bafc0 144afae 03bafc0 144afae 03bafc0 144afae 03bafc0 144afae c3f0641 dec259d c3f0641 dec259d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import numpy as np
import argparse
import os
from src.pipeline import ObjectRemovalPipeline
from src.utils import visualize_mask
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=120):
def decorator(func):
return func
return decorator
# Initialize pipeline
pipeline = ObjectRemovalPipeline()
def ensure_uint8(image):
if image is None: return None
image = np.array(image)
if image.dtype != np.uint8:
if image.max() <= 1.0: image = image * 255.0
image = np.clip(image, 0, 255).astype(np.uint8)
return image
@spaces.GPU(duration=120)
def step1_detect(image, text_query):
if image is None or not text_query:
return [], [], "Please upload image and enter text."
candidates, msg = pipeline.get_candidates(image, text_query)
if not candidates:
return [], [], f"Error: {msg}"
masks = [c['mask'] for c in candidates]
gallery_imgs = []
for i, mask in enumerate(masks):
viz = visualize_mask(image, mask)
score = candidates[i].get('weighted_score', 0)
label = f"Option {i+1} (Score: {score:.2f})"
gallery_imgs.append((ensure_uint8(viz), label))
return masks, gallery_imgs, "Select the best match below."
def on_select(evt: gr.SelectData):
return evt.index
@spaces.GPU(duration=120)
def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
if not masks or selected_idx is None:
return None, "Please select an object first."
target_mask = masks[selected_idx]
result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
return ensure_uint8(result), "Success!"
css = """
.gradio-container {min-height: 0px !important}
button.gallery-item {object-fit: contain !important}
"""
with gr.Blocks(title="TextEraser") as demo:
mask_state = gr.State([])
idx_state = gr.State(0)
gr.Markdown("## TextEraser: Interactive Object Removal")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="numpy", height=400)
text_query = gr.Textbox(label="What to remove?", placeholder="e.g. 'bottle', 'shadow'")
btn_detect = gr.Button("1. Detect Objects", variant="primary")
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Candidates (Select One)",
columns=2,
height=400,
allow_preview=True,
object_fit="contain"
)
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
with gr.Column(scale=1):
shadow_slider = gr.Slider(0, 40, value=10, label="Shadow Fix (Expand Mask Downwards)")
inpaint_prompt = gr.Textbox(label="Background Description", value="background")
btn_remove = gr.Button("2. Remove Selected", variant="stop")
with gr.Column(scale=1):
output_image = gr.Image(label="Final Result", height=400)
btn_detect.click(
fn=step1_detect,
inputs=[input_image, text_query],
outputs=[mask_state, gallery, status]
)
gallery.select(fn=on_select, inputs=None, outputs=idx_state)
btn_remove.click(
fn=step2_remove,
inputs=[input_image, mask_state, idx_state, inpaint_prompt, shadow_slider],
outputs=[output_image, status]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
demo.queue().launch(share=args.share, css=css) |