Spaces:
Build error
Build error
temp state
Browse files
README.md
CHANGED
|
@@ -1 +1,11 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SuperFeatures
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
# Learning Super-Features for Image Retrieval
|
| 11 |
+
A demo for the ICLR 22 paper "Learning Super-Features for Image Retrieval". [[Paper](https://openreview.net/pdf?id=wogsFPHwftY)] [[Official Github Repo](https://github.com/naver/fire)]
|
app.py
CHANGED
|
@@ -3,6 +3,40 @@ import gradio as gr
|
|
| 3 |
def greet(name):
|
| 4 |
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 7 |
-
iface.launch()
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
def greet(name):
|
| 4 |
return "Hello " + name + "!!"
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
# Model to use
|
| 8 |
+
net_path = 'fire.pth'
|
| 9 |
+
|
| 10 |
+
# CPU / GPU
|
| 11 |
+
device = 'cpu'
|
| 12 |
+
|
| 13 |
+
# Images will be downscaled to this size prior processing with the network
|
| 14 |
+
image_size = 1024
|
| 15 |
+
|
| 16 |
+
# Wrapper
|
| 17 |
+
def generate_matching_superfeatures(im1, im2, scale=6):
|
| 18 |
+
|
| 19 |
+
# Possible Scales for multiscale inference
|
| 20 |
+
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# GRADIO APP
|
| 24 |
+
title = "Visualizing Super-features"
|
| 25 |
+
description = "TBD"
|
| 26 |
+
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
iface = gr.Interface(
|
| 30 |
+
fn=generate_matching_superfeatures,
|
| 31 |
+
inputs=[
|
| 32 |
+
gr.inputs.Image(shape=(240, 240), type="pil"),
|
| 33 |
+
gr.inputs.Image(shape=(240, 240), type="pil"),
|
| 34 |
+
gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale")],
|
| 35 |
+
outputs="plot",
|
| 36 |
+
enable_queue=True,
|
| 37 |
+
title=title,
|
| 38 |
+
description=description,
|
| 39 |
+
article=article,
|
| 40 |
+
examples=[["chateau_1.png", "chateau_2.png", 6]],
|
| 41 |
+
)
|
| 42 |
+
iface.launch()
|
fire_network.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torchvision
|
| 8 |
+
|
| 9 |
+
from cirtorch.networks import imageretrievalnet
|
| 10 |
+
|
| 11 |
+
from how import layers
|
| 12 |
+
from how.layers import functional as HF
|
| 13 |
+
|
| 14 |
+
from lit import LocalfeatureIntegrationTransformer
|
| 15 |
+
|
| 16 |
+
from how.networks.how_net import HOWNet, CORERCF_SIZE
|
| 17 |
+
|
| 18 |
+
class FIReNet(HOWNet):
|
| 19 |
+
|
| 20 |
+
def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
|
| 21 |
+
super().__init__(features, attention, None, dim_reduction, meta, runtime)
|
| 22 |
+
self.lit = lit
|
| 23 |
+
self.return_global = False
|
| 24 |
+
|
| 25 |
+
def copy_excluding_dim_reduction(self):
|
| 26 |
+
"""Return a copy of this network without the dim_reduction layer"""
|
| 27 |
+
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
|
| 28 |
+
return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)
|
| 29 |
+
|
| 30 |
+
def copy_with_runtime(self, runtime):
|
| 31 |
+
"""Return a copy of this network with a different runtime dict"""
|
| 32 |
+
return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)
|
| 33 |
+
|
| 34 |
+
def parameter_groups(self):
|
| 35 |
+
"""Return torch parameter groups"""
|
| 36 |
+
layers = [self.features, self.attention, self.smoothing, self.lit]
|
| 37 |
+
parameters = [{'params': x.parameters()} for x in layers if x is not None]
|
| 38 |
+
if self.dim_reduction:
|
| 39 |
+
# Do not update dimensionality reduction layer
|
| 40 |
+
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
|
| 41 |
+
return parameters
|
| 42 |
+
|
| 43 |
+
def get_superfeatures(self, x, *, scales):
|
| 44 |
+
"""
|
| 45 |
+
return a list of tuple (features, attentionmpas) where each is a list containing requested scales
|
| 46 |
+
features is a tensor BxDxNx1
|
| 47 |
+
attentionmaps is a tensor BxNxHxW
|
| 48 |
+
"""
|
| 49 |
+
feats = []
|
| 50 |
+
attns = []
|
| 51 |
+
strengths = []
|
| 52 |
+
for s in scales:
|
| 53 |
+
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
|
| 54 |
+
o = self.features(xs)
|
| 55 |
+
o, attn = self.lit(o)
|
| 56 |
+
strength = self.attention(o)
|
| 57 |
+
if self.smoothing:
|
| 58 |
+
o = self.smoothing(o)
|
| 59 |
+
if self.dim_reduction:
|
| 60 |
+
o = self.dim_reduction(o)
|
| 61 |
+
feats.append(o)
|
| 62 |
+
attns.append(attn)
|
| 63 |
+
strengths.append(strength)
|
| 64 |
+
return feats, attns, strengths
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
if self.return_global:
|
| 68 |
+
return self.forward_global(x, scales=self.runtime['training_scales'])
|
| 69 |
+
return self.get_superfeatures(x, scales=self.runtime['training_scales'])
|
| 70 |
+
|
| 71 |
+
def forward_global(self, x, *, scales):
|
| 72 |
+
"""Return global descriptor"""
|
| 73 |
+
feats, _, strengths = self.get_superfeatures(x, scales=scales)
|
| 74 |
+
return HF.weighted_spoc(feats, strengths)
|
| 75 |
+
|
| 76 |
+
def forward_local(self, x, *, features_num, scales):
|
| 77 |
+
"""Return selected super features"""
|
| 78 |
+
feats, _, strengths = self.get_superfeatures(x, scales=scales)
|
| 79 |
+
return HF.how_select_local(feats, strengths, scales=scales, features_num=features_num)
|
| 80 |
+
|
| 81 |
+
def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
|
| 82 |
+
"""Initialize FIRe network
|
| 83 |
+
:param str architecture: Network backbone architecture (e.g. resnet18)
|
| 84 |
+
:param str pretrained: url of the pretrained model (None for using random initialization)
|
| 85 |
+
:param int skip_layer: How many layers of blocks should be skipped (from the end)
|
| 86 |
+
:param dict dim_reduction: Options for the dimensionality reduction layer
|
| 87 |
+
:param dict lit: Options for the lit layer
|
| 88 |
+
:param dict runtime: Runtime options to be stored in the network
|
| 89 |
+
:return FIRe: Initialized network
|
| 90 |
+
"""
|
| 91 |
+
# Take convolutional layers as features, always ends with ReLU to make last activations non-negative
|
| 92 |
+
net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead
|
| 93 |
+
if architecture.startswith('alexnet') or architecture.startswith('vgg'):
|
| 94 |
+
features = list(net_in.features.children())[:-1]
|
| 95 |
+
elif architecture.startswith('resnet'):
|
| 96 |
+
features = list(net_in.children())[:-2]
|
| 97 |
+
elif architecture.startswith('densenet'):
|
| 98 |
+
features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
|
| 99 |
+
elif architecture.startswith('squeezenet'):
|
| 100 |
+
features = list(net_in.features.children())
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
|
| 103 |
+
|
| 104 |
+
if skip_layer > 0:
|
| 105 |
+
features = features[:-skip_layer]
|
| 106 |
+
backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)
|
| 107 |
+
|
| 108 |
+
att_layer = layers.attention.L2Attention()
|
| 109 |
+
|
| 110 |
+
lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)
|
| 111 |
+
|
| 112 |
+
reduction_layer = None
|
| 113 |
+
if dim_reduction:
|
| 114 |
+
reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])
|
| 115 |
+
|
| 116 |
+
meta = {
|
| 117 |
+
"architecture": architecture,
|
| 118 |
+
"backbone_dim": lit['dim'],
|
| 119 |
+
"outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
|
| 120 |
+
"corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
|
| 121 |
+
}
|
| 122 |
+
net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
|
| 123 |
+
|
| 124 |
+
if pretrained is not None:
|
| 125 |
+
assert os.path.isfile(pretrained), pretrained
|
| 126 |
+
ckpt = torch.load(pretrained, map_location='cpu')
|
| 127 |
+
missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
|
| 128 |
+
assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
|
| 129 |
+
assert all(['fc' in a for a in unexpected]), "Loading did not go well"
|
| 130 |
+
return net
|
lit.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
class LocalfeatureIntegrationTransformer(nn.Module):
|
| 8 |
+
"""Map a set of local features to a fixed number of SuperFeatures """
|
| 9 |
+
|
| 10 |
+
def __init__(self, T, N, input_dim, dim):
|
| 11 |
+
"""
|
| 12 |
+
T: number of iterations
|
| 13 |
+
N: number of SuperFeatures
|
| 14 |
+
input_dim: dimension of input local features
|
| 15 |
+
dim: dimension of SuperFeatures
|
| 16 |
+
"""
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.T = T
|
| 19 |
+
self.N = N
|
| 20 |
+
self.input_dim = input_dim
|
| 21 |
+
self.dim = dim
|
| 22 |
+
# learnable initialization
|
| 23 |
+
self.templates_init = nn.Parameter(torch.randn(1,self.N,dim))
|
| 24 |
+
# qkv
|
| 25 |
+
self.project_q = nn.Linear(dim, dim, bias=False)
|
| 26 |
+
self.project_k = nn.Linear(input_dim, dim, bias=False)
|
| 27 |
+
self.project_v = nn.Linear(input_dim, dim, bias=False)
|
| 28 |
+
# layer norms
|
| 29 |
+
self.norm_inputs = nn.LayerNorm(input_dim)
|
| 30 |
+
self.norm_templates = nn.LayerNorm(dim)
|
| 31 |
+
# for the normalization
|
| 32 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 33 |
+
self.scale = dim ** -0.5
|
| 34 |
+
# mlp
|
| 35 |
+
self.norm_mlp = nn.LayerNorm(dim)
|
| 36 |
+
mlp_dim = dim//2
|
| 37 |
+
self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim) )
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
"""
|
| 42 |
+
input:
|
| 43 |
+
x has shape BxCxHxW
|
| 44 |
+
output:
|
| 45 |
+
template (output SuperFeatures): tensor of shape BxCxNx1
|
| 46 |
+
attn (attention over local features at the last iteration): tensor of shape BxNxHxW
|
| 47 |
+
"""
|
| 48 |
+
# reshape inputs from BxCxHxW to Bx(H*W)xC
|
| 49 |
+
B,C,H,W = x.size()
|
| 50 |
+
x = x.reshape(B,C,H*W).permute(0,2,1)
|
| 51 |
+
|
| 52 |
+
# k and v projection
|
| 53 |
+
x = self.norm_inputs(x)
|
| 54 |
+
k = self.project_k(x)
|
| 55 |
+
v = self.project_v(x)
|
| 56 |
+
|
| 57 |
+
# template initialization
|
| 58 |
+
templates = torch.repeat_interleave(self.templates_init, B, dim=0)
|
| 59 |
+
attn = None
|
| 60 |
+
|
| 61 |
+
# main iteration loop
|
| 62 |
+
for _ in range(self.T):
|
| 63 |
+
templates_prev = templates
|
| 64 |
+
|
| 65 |
+
# q projection
|
| 66 |
+
templates = self.norm_templates(templates)
|
| 67 |
+
q = self.project_q(templates)
|
| 68 |
+
|
| 69 |
+
# attention
|
| 70 |
+
q = q * self.scale # Normalization.
|
| 71 |
+
attn_logits = torch.einsum('bnd,bld->bln', q, k)
|
| 72 |
+
attn = self.softmax(attn_logits)
|
| 73 |
+
attn = attn + 1e-8 # to avoid zero when with the L1 norm below
|
| 74 |
+
attn = attn / attn.sum(dim=-2, keepdim=True)
|
| 75 |
+
|
| 76 |
+
# update template
|
| 77 |
+
templates = templates_prev + torch.einsum('bld,bln->bnd', v, attn)
|
| 78 |
+
|
| 79 |
+
# mlp
|
| 80 |
+
templates = templates + self.mlp(self.norm_mlp(templates))
|
| 81 |
+
|
| 82 |
+
# reshape templates to BxDxNx1
|
| 83 |
+
templates = templates.permute(0,2,1)[:,:,:,None]
|
| 84 |
+
attn = attn.permute(0,2,1).view(B,self.N,H,W)
|
| 85 |
+
|
| 86 |
+
return templates, attn
|
| 87 |
+
|
| 88 |
+
def __repr__(self):
|
| 89 |
+
s = str(self.__class__.__name__)
|
| 90 |
+
for k in ["T","N","input_dim","dim"]:
|
| 91 |
+
s += "\n {:s}: {:d}".format(k, getattr(self,k))
|
| 92 |
+
return s
|