Spaces:
Runtime error
Runtime error
| import uuid | |
| from typing import Tuple, List | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from transformers import pipeline, CLIPModel, CLIPProcessor | |
| MARKDOWN = """ | |
| # Auto β‘ ProPainter π§βπ¨ | |
| This is a demo for automatic removal of objects from videos using | |
| [Segment Anything Model](https://github.com/facebookresearch/segment-anything), | |
| [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), and | |
| [ProPainter](https://github.com/sczhou/ProPainter) combo. | |
| - [x] Automated object masking using SAM + MetaCLIP | |
| - [ ] Automated inpainting using ProPainter | |
| - [ ] Automated β‘ object masking using FastSAM + MetaCLIP | |
| """ | |
| START_FRAME = 0 | |
| END_FRAME = 10 | |
| TOTAL = END_FRAME - START_FRAME | |
| MINIMUM_AREA = 0.01 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SAM_GENERATOR = pipeline( | |
| task="mask-generation", | |
| model="facebook/sam-vit-large", | |
| device=DEVICE) | |
| CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE) | |
| CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") | |
| def run_sam(frame: np.ndarray) -> sv.Detections: | |
| # convert from Numpy BGR to PIL RGB | |
| image = Image.fromarray(frame[:, :, ::-1]) | |
| outputs = SAM_GENERATOR(image) | |
| mask = np.array(outputs['masks']) | |
| return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
| def run_clip(frame: np.ndarray, text: List[str]) -> np.ndarray: | |
| # convert from Numpy BGR to PIL RGB | |
| image = Image.fromarray(frame[:, :, ::-1]) | |
| inputs = CLIP_PROCESSOR(text=text, images=image, return_tensors="pt").to(DEVICE) | |
| outputs = CLIP_MODEL(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1) | |
| return probs.detach().cpu().numpy() | |
| def gray_background(image: np.ndarray, mask: np.ndarray, gray_value=128): | |
| gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8) | |
| return np.where(mask[..., None], image, gray_color) | |
| def filter_detections_by_area(frame: np.ndarray, detections: sv.Detections, minimum_area: float) -> sv.Detections: | |
| frame_width, frame_height = frame.shape[1], frame.shape[0] | |
| frame_area = frame_width * frame_height | |
| return detections[detections.area > minimum_area * frame_area] | |
| def filter_detections_by_prompt(frame: np.ndarray, detections: sv.Detections, prompt: str, confidence: float) -> sv.Detections: | |
| text = [f"a picture of {prompt}", "a picture of background"] | |
| filtering_mask = [] | |
| for xyxy, mask in zip(detections.xyxy, detections.mask): | |
| crop = gray_background( | |
| image=sv.crop_image(image=frame, xyxy=xyxy), | |
| mask=sv.crop_image(image=mask, xyxy=xyxy)) | |
| probs = run_clip(frame=crop, text=text) | |
| filtering_mask.append(probs[0][0] > confidence) | |
| return detections[np.array(filtering_mask)] | |
| def mask_frame(frame: np.ndarray, prompt: str, confidence: float) -> np.ndarray: | |
| detections = run_sam(frame) | |
| detections = filter_detections_by_area( | |
| frame=frame, detections=detections, minimum_area=MINIMUM_AREA) | |
| detections = filter_detections_by_prompt( | |
| frame=frame, detections=detections, prompt=prompt, confidence=confidence) | |
| # converting set of masks to a single mask | |
| mask = np.any(detections.mask, axis=0).astype(np.uint8) * 255 | |
| # converting single channel mask to 3 channel mask | |
| return np.repeat(mask[:, :, np.newaxis], 3, axis=2) | |
| def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str: | |
| video_info = sv.VideoInfo.from_video_path(source_video) | |
| frame_iterator = iter(sv.get_video_frames_generator( | |
| source_path=source_video, start=START_FRAME, end=END_FRAME)) | |
| with sv.ImageSink(name, image_name_pattern="{:05d}.png") as image_sink: | |
| with sv.VideoSink(f"{name}.mp4", video_info=video_info) as video_sink: | |
| for _ in tqdm(range(TOTAL), desc="Masking frames"): | |
| frame = next(frame_iterator) | |
| annotated_frame = mask_frame(frame, prompt, confidence) | |
| video_sink.write_frame(annotated_frame) | |
| image_sink.save_image(annotated_frame) | |
| return f"{name}.mp4" | |
| def process( | |
| source_video: str, | |
| prompt: str, | |
| confidence: float, | |
| progress=gr.Progress(track_tqdm=True) | |
| ) -> Tuple[str, str]: | |
| name = str(uuid.uuid4()) | |
| masked_video = mask_video(source_video, prompt, confidence, name) | |
| return masked_video, masked_video | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_video_player = gr.Video( | |
| label="Source video", source="upload", format="mp4") | |
| prompt_text = gr.Textbox( | |
| label="Prompt", value="person") | |
| confidence_slider = gr.Slider( | |
| label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6) | |
| submit_button = gr.Button("Submit") | |
| with gr.Column(): | |
| masked_video_player = gr.Video(label="Masked video") | |
| painted_video_player = gr.Video(label="Painted video") | |
| submit_button.click( | |
| process, | |
| inputs=[source_video_player, prompt_text, confidence_slider], | |
| outputs=[masked_video_player, painted_video_player]) | |
| demo.queue().launch(debug=False, show_error=True) | |