Spaces:
Runtime error
Runtime error
| from typing import Dict | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import SamModel, SamProcessor | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL = SamModel.from_pretrained("facebook/sam-vit-large").to(DEVICE) | |
| PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-large") | |
| def inference(masked_image: Dict[str, Image.Image]) -> Image.Image: | |
| image = masked_image['image'] | |
| mask = masked_image['mask'].resize((256, 256), Image.Resampling.LANCZOS) | |
| return image | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| image_mode='RGB', type='pil', tool="sketch", interactive=True, | |
| brush_radius=20.0, brush_color="#FFFFFF", height=500) | |
| submit_button = gr.Button("Submit") | |
| output_image = gr.Image(image_mode='RGB', type='pil') | |
| submit_button.click( | |
| inference, | |
| inputs=[input_image], | |
| outputs=output_image) | |
| demo.launch(debug=False, show_error=True) | |