Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| from ade20k_colors import colors | |
| from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation | |
| beit_models = ['microsoft/beit-base-finetuned-ade-640-640', | |
| 'microsoft/beit-large-finetuned-ade-640-640'] | |
| models = [BeitForSemanticSegmentation.from_pretrained(m) for m in beit_models] | |
| extractors = [BeitFeatureExtractor.from_pretrained(m) for m in beit_models] | |
| def apply_colors(img): | |
| ret = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) | |
| for y in range(img.shape[0]): | |
| for x in range(img.shape[1]): | |
| ret[y,x] = colors[np.argmax(img[y,x])] | |
| return ret | |
| def inference(image, chosen_model): | |
| feature_extractor = extractors[chosen_model] | |
| model = models[chosen_model] | |
| inputs = feature_extractor(images=image, return_tensors='pt') | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| output = torch.sigmoid(logits).detach().numpy()[0] | |
| output = np.transpose(output, (1,2,0)) | |
| output = apply_colors(output) | |
| return cv2.resize(output, image.shape[1::-1]) | |
| inputs = [gr.inputs.Image(label='Input Image'), | |
| gr.inputs.Radio(['Base', 'Large'], label='BEiT Model', type='index')] | |
| gr.Interface( | |
| inference, | |
| inputs, | |
| gr.outputs.Image(label='Output'), | |
| title='BEiT - Semantic Segmentation', | |
| description='BEIT: BERT Pre-Training of Image Transformers', | |
| examples=[['images/armchair.jpg', 'Base'], | |
| ['images/cat.jpg', 'Base'], | |
| ['images/plant.jpg', 'Large']] | |
| ).launch() |