Spaces:
Runtime error
Runtime error
| import requests | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from timm.data import create_transform | |
| from focalnet import FocalNet, build_transforms, build_transforms4display | |
| # Download human-readable labels for ImageNet. | |
| response = requests.get("https://git.io/JJkYN") | |
| labels = response.text.split("\n") | |
| ''' | |
| build model | |
| ''' | |
| model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], use_layerscale=True, use_postln=True) | |
| # url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth' | |
| # checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) | |
| checkpoint = torch.load("./focalnet_base_iso_16.pth", map_location="cpu") | |
| model.load_state_dict(checkpoint["model"]) | |
| model.eval() | |
| ''' | |
| build data transform | |
| ''' | |
| eval_transforms = build_transforms(224, center_crop=True) | |
| display_transforms = build_transforms4display(224, center_crop=True) | |
| ''' | |
| build upsampler | |
| ''' | |
| # upsampler = nn.Upsample(scale_factor=16, mode='bilinear') | |
| ''' | |
| borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
| ''' | |
| def show_cam_on_image(img: np.ndarray, | |
| mask: np.ndarray, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
| """ This function overlays the cam mask on the image as an heatmap. | |
| By default the heatmap is in BGR format. | |
| :param img: The base image in RGB or BGR format. | |
| :param mask: The cam mask. | |
| :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
| :param colormap: The OpenCV colormap to be used. | |
| :returns: The default image with the cam overlay. | |
| """ | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception( | |
| "The input image should np.float32 in the range [0, 1]") | |
| cam = 0.5*heatmap + 0.5*img | |
| # cam = heatmap | |
| # cam = cam / np.max(cam) | |
| return np.uint8(255 * cam) | |
| def classify_image(inp): | |
| img_t = eval_transforms(inp) | |
| img_d = display_transforms(inp).permute(1, 2, 0).numpy() | |
| print(img_d.min(), img_d.max()) | |
| prediction = model(img_t.unsqueeze(0)).softmax(-1).flatten() | |
| modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True) | |
| modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
| modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
| modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
| cam0 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
| modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True) | |
| modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
| modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
| modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
| cam1 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
| modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True) | |
| modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
| modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
| modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
| cam2 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
| modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True) | |
| modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
| modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
| modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
| cam3 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
| return {labels[i]: float(prediction[i]) for i in range(1000)}, Image.fromarray(cam0), Image.fromarray(cam1), Image.fromarray(cam2), Image.fromarray(cam3), Image.fromarray(np.uint8(255 * img_d)) | |
| image = gr.inputs.Image() | |
| label = gr.outputs.Label(num_top_classes=3) | |
| gr.Interface( | |
| description="Image classification and visualizations with FocalNet (https://github.com/microsoft/FocalNet)", | |
| fn=classify_image, | |
| inputs=image, | |
| outputs=[ | |
| label, | |
| gr.outputs.Image( | |
| type="pil", | |
| label="Modulator at layer 12"), | |
| gr.outputs.Image( | |
| type="pil", | |
| label="Modulator at layer 9"), | |
| gr.outputs.Image( | |
| type="pil", | |
| label="Modulator at layer 6"), | |
| gr.outputs.Image( | |
| type="pil", | |
| label="Modulator at layer 3"), | |
| gr.outputs.Image( | |
| type="pil", | |
| label="Cropped Input"), | |
| ], | |
| examples=[["./donut.png"], ["./horses.png"], ["./pencil.png"]], | |
| ).launch() | |