Commit
·
fe0eb36
1
Parent(s):
9d563c0
init
Browse files- .gitignore +2 -0
- Dockerfile +16 -0
- app.py +52 -0
- inference.py +56 -0
- pp/__init__.py +0 -0
- pp/albu.py +8 -0
- requirements.txt +13 -0
- utils/__init__.py +0 -0
- utils/ckpts.py +15 -0
- utils/dataset.py +51 -0
- utils/filterfunc.py +57 -0
- utils/knn.py +49 -0
- utils/predict.py +77 -0
- utils/utilfuncs.py +86 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
env
|
| 2 |
+
*.pyc
|
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# you will also find guides on how best to write your Dockerfile
|
| 3 |
+
|
| 4 |
+
FROM python:3.9
|
| 5 |
+
|
| 6 |
+
RUN useradd -m -u 1000 user
|
| 7 |
+
USER user
|
| 8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY --chown=user . /app
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile, Form
|
| 2 |
+
import uvicorn
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import nltk
|
| 6 |
+
nltk.download("stopwords")
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List
|
| 9 |
+
from inference import inference
|
| 10 |
+
from code_base.utils import CFG
|
| 11 |
+
|
| 12 |
+
TKN_PATH= ["bert-base-uncased"]
|
| 13 |
+
IMG_SIZE = 256
|
| 14 |
+
BATCH_SIZE = 32
|
| 15 |
+
img = True
|
| 16 |
+
|
| 17 |
+
CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
app = FastAPI(title="shopee-test-app")
|
| 20 |
+
|
| 21 |
+
@app.get("/")
|
| 22 |
+
async def root():
|
| 23 |
+
return {"status": "ok", "message": "Space is running"}
|
| 24 |
+
|
| 25 |
+
@app.post("/predict")
|
| 26 |
+
async def predict_image(files: List[UploadFile] = File(...),
|
| 27 |
+
texts: List[str] = Form(...)):
|
| 28 |
+
li, lt= [], []
|
| 29 |
+
for file, text in zip(files, texts):
|
| 30 |
+
contents = await file.read()
|
| 31 |
+
li.append(contents)
|
| 32 |
+
lt.append(text)
|
| 33 |
+
res = inference(li=li,
|
| 34 |
+
lt=lt,
|
| 35 |
+
IMG_SIZE=IMG_SIZE,
|
| 36 |
+
TKN_PATH=TKN_PATH,
|
| 37 |
+
BATCH_SIZE=BATCH_SIZE
|
| 38 |
+
)
|
| 39 |
+
msg = "products matched" if res else "products not matched"
|
| 40 |
+
|
| 41 |
+
return {"message" : f"{msg}"}
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import torch
|
| 3 |
+
from utils.predict import predict
|
| 4 |
+
from utils.filterfunc import filter_match_titles
|
| 5 |
+
from utils.ckpts import img_ckpt, txt_ckpt
|
| 6 |
+
from utils.utilfuncs import gen_data, load_model, return_feas
|
| 7 |
+
|
| 8 |
+
img_backbone = ["timm/eca_nfnet_l1.ra2_in1k"]
|
| 9 |
+
txt_backbone = ["google-bert/bert-base-uncased"]
|
| 10 |
+
|
| 11 |
+
def clean():
|
| 12 |
+
gc.collect()
|
| 13 |
+
|
| 14 |
+
def inference(li, lt, IMG_SIZE,
|
| 15 |
+
TKN_PATH,
|
| 16 |
+
BATCH_SIZE,
|
| 17 |
+
num_workers = 4,
|
| 18 |
+
):
|
| 19 |
+
dataloader_img, dataloader_txt = gen_data(li,
|
| 20 |
+
lt,
|
| 21 |
+
IMG_SIZE,
|
| 22 |
+
BATCH_SIZE,
|
| 23 |
+
TKN_PATH[0],
|
| 24 |
+
num_workers)
|
| 25 |
+
|
| 26 |
+
img_model = [load_model(backbone=img_backbone[i],
|
| 27 |
+
ckpt_path=img_ckpt[i],
|
| 28 |
+
img=True)
|
| 29 |
+
for i in range(len(img_backbone))]
|
| 30 |
+
|
| 31 |
+
img_feas = torch.cat([return_feas(
|
| 32 |
+
img_model[i],
|
| 33 |
+
dataloader_img, img=True)
|
| 34 |
+
for i in range(len(img_backbone))], dim=1)
|
| 35 |
+
|
| 36 |
+
txt_model = [load_model(backbone=TKN_PATH[i], ckpt_path=txt_ckpt[i])
|
| 37 |
+
for i in range(len(txt_backbone))]
|
| 38 |
+
|
| 39 |
+
txt_feas = torch.cat([return_feas(
|
| 40 |
+
txt_model[i],
|
| 41 |
+
dataloader_txt)
|
| 42 |
+
for i in range(len(txt_backbone))], dim=1)
|
| 43 |
+
|
| 44 |
+
match_final = predict(img_feas=img_feas,
|
| 45 |
+
txt_feas=txt_feas)
|
| 46 |
+
|
| 47 |
+
match_final = filter_match_titles(match_final, title_list=lt)
|
| 48 |
+
|
| 49 |
+
assert len(match_final == 2)
|
| 50 |
+
|
| 51 |
+
return set(match_final[0]) == set(match_final[1])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
pp/__init__.py
ADDED
|
File without changes
|
pp/albu.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import albumentations as A
|
| 2 |
+
|
| 3 |
+
def transform(size):
|
| 4 |
+
transforms = A.Compose([
|
| 5 |
+
A.Resize(size, size),
|
| 6 |
+
A.Normalize()
|
| 7 |
+
])
|
| 8 |
+
return transforms
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
fastapi
|
| 3 |
+
uvicorn[standard]
|
| 4 |
+
huggingface_hub
|
| 5 |
+
albumentations
|
| 6 |
+
transformers
|
| 7 |
+
scikit-learn
|
| 8 |
+
unidecode
|
| 9 |
+
nltk
|
| 10 |
+
timm
|
| 11 |
+
faiss-cpu
|
| 12 |
+
hf-xet
|
| 13 |
+
git+https://github.com/Anirban0011/shopee-product-matching.git
|
utils/__init__.py
ADDED
|
File without changes
|
utils/ckpts.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import hf_hub_download
|
| 2 |
+
|
| 3 |
+
REPO_ID = "Anirban0011/multimodal-shopee-finetune"
|
| 4 |
+
|
| 5 |
+
def get_path(filename, repo):
|
| 6 |
+
path = hf_hub_download(repo_id=repo, filename=filename)
|
| 7 |
+
return path
|
| 8 |
+
|
| 9 |
+
img_path = get_path(repo=REPO_ID,
|
| 10 |
+
filename="img_model_eca_nfnet_l1.ra2_in1k.pth")
|
| 11 |
+
txt_path = get_path(repo=REPO_ID,
|
| 12 |
+
filename="txt_model_bert-base-uncased_35.pth")
|
| 13 |
+
|
| 14 |
+
img_ckpt = [img_path]
|
| 15 |
+
txt_ckpt = [txt_path]
|
utils/dataset.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
class ImageDataset(Dataset):
|
| 9 |
+
def __init__(self, li, transform=None):
|
| 10 |
+
self.li = li
|
| 11 |
+
self.transform = transform
|
| 12 |
+
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return len(self.li)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, index):
|
| 17 |
+
img_byte = self.li[index]
|
| 18 |
+
img = Image.open(io.BytesIO(img_byte)).convert("RGB")
|
| 19 |
+
img = np.array(img)
|
| 20 |
+
img = img.copy()
|
| 21 |
+
|
| 22 |
+
if self.transform is not None:
|
| 23 |
+
img = self.transform(image=img)
|
| 24 |
+
img = img["image"]
|
| 25 |
+
img = img.astype(np.float32)
|
| 26 |
+
img = img.transpose(2, 0, 1)
|
| 27 |
+
|
| 28 |
+
return torch.tensor(img).float()
|
| 29 |
+
|
| 30 |
+
class TextDataset(Dataset):
|
| 31 |
+
def __init__(self, li, tokenizer=None):
|
| 32 |
+
self.li = li
|
| 33 |
+
self.to = AutoTokenizer.from_pretrained(tokenizer)
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.li)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, index):
|
| 39 |
+
text = self.li[index]
|
| 40 |
+
text = self.tokenizer(
|
| 41 |
+
text,
|
| 42 |
+
padding="max_length",
|
| 43 |
+
truncation=True,
|
| 44 |
+
max_length=35,
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
)
|
| 47 |
+
input_ids = text["input_ids"][0]
|
| 48 |
+
attention_mask = text["attention_mask"][0]
|
| 49 |
+
return input_ids, attention_mask
|
| 50 |
+
|
| 51 |
+
|
utils/filterfunc.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import regex
|
| 2 |
+
|
| 3 |
+
# adapted from kaggle.com/code/slawekbiel/resnet18-0-772-public-lb/notebook
|
| 4 |
+
measurements = {
|
| 5 |
+
'weight': [('mg',1), ('g', 1000), ('gr', 1000), ('gram', 1000), ('kg', 1000000)],
|
| 6 |
+
'length': [('mm',1), ('cm', 10), ('m',1000), ('meter', 1000)],
|
| 7 |
+
'pieces': [ ('pc',1)],
|
| 8 |
+
'memory': [('gb', 1)],
|
| 9 |
+
'volume': [('ml', 1), ('l', 1000), ('liter',1000)]
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
def to_num(x, mult=1):
|
| 13 |
+
x = x.replace(',','.')
|
| 14 |
+
return int(float(x)*mult)
|
| 15 |
+
|
| 16 |
+
def extract_unit(tit, m):
|
| 17 |
+
pat = rf'\W(\d+(?:[\,\.]\d+)?) ?{m}s?\W'
|
| 18 |
+
matches = regex.findall(pat, tit, overlapped=True)
|
| 19 |
+
return set(matches)
|
| 20 |
+
|
| 21 |
+
def extract(tit):
|
| 22 |
+
res =dict()
|
| 23 |
+
tit = ' '+tit.lower()+' '
|
| 24 |
+
for cat, units in measurements.items():
|
| 25 |
+
cat_values=set()
|
| 26 |
+
for unit_name, mult in units:
|
| 27 |
+
values = extract_unit(tit, unit_name)
|
| 28 |
+
values = {to_num(v, mult) for v in values}
|
| 29 |
+
cat_values = cat_values.union(values)
|
| 30 |
+
if cat_values:
|
| 31 |
+
res[cat] = cat_values
|
| 32 |
+
return res
|
| 33 |
+
|
| 34 |
+
def match_measures(m1, m2):
|
| 35 |
+
k1,k2 = set(m1.keys()), set(m2.keys())
|
| 36 |
+
common = k1.intersection(k2)
|
| 37 |
+
if not common:
|
| 38 |
+
return True
|
| 39 |
+
for key in common:
|
| 40 |
+
s1,s2 = m1[key], m2[key]
|
| 41 |
+
if s1.intersection(s2):
|
| 42 |
+
return True
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
def filter_match_titles(matches : list, title_list): # filter matches override
|
| 46 |
+
for i in range(len(matches)):
|
| 47 |
+
item_title = extract(title_list)
|
| 48 |
+
l=[]
|
| 49 |
+
for match in matches[i]:
|
| 50 |
+
if match == i:
|
| 51 |
+
l.append(i)
|
| 52 |
+
continue
|
| 53 |
+
match_title = extract(title_list)
|
| 54 |
+
if (match_measures(item_title, match_title)):
|
| 55 |
+
l.append(match)
|
| 56 |
+
matches[i] = l
|
| 57 |
+
return matches
|
utils/knn.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from code_base.utils import CFG
|
| 2 |
+
import faiss
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def build_faiss(feas, dim):
|
| 6 |
+
if CFG.device.type == "cpu":
|
| 7 |
+
index = faiss.IndexFlatIP(dim)
|
| 8 |
+
else :
|
| 9 |
+
res = faiss.StandardGpuResources()
|
| 10 |
+
index = faiss.GpuIndexFlatIP(res, dim)
|
| 11 |
+
index.add(feas)
|
| 12 |
+
return index
|
| 13 |
+
|
| 14 |
+
def get_batches(bs, n_batch, feas):
|
| 15 |
+
batches = []
|
| 16 |
+
for i in range(n_batch):
|
| 17 |
+
left = bs * i
|
| 18 |
+
right = bs * (i+1)
|
| 19 |
+
if i == n_batch - 1:
|
| 20 |
+
right = feas.shape[0]
|
| 21 |
+
batches.append(feas[left:right,:])
|
| 22 |
+
return batches
|
| 23 |
+
|
| 24 |
+
def get_matches(bs, n_batch, feas, dim, k=51):
|
| 25 |
+
index = build_faiss(feas, dim)
|
| 26 |
+
m=[]
|
| 27 |
+
s=[]
|
| 28 |
+
for batch in get_batches(bs, n_batch, feas):
|
| 29 |
+
batch = batch.to(CFG.device)
|
| 30 |
+
sims, matches = index.search(batch, k)
|
| 31 |
+
m.append(matches)
|
| 32 |
+
s.append(sims)
|
| 33 |
+
m = torch.cat(m, dim=0).to(torch.int32)
|
| 34 |
+
s = torch.cat(s, dim=0)
|
| 35 |
+
return m,s
|
| 36 |
+
|
| 37 |
+
def th_matches(bs, n_batch, matches, sims, th):
|
| 38 |
+
matches = get_batches(bs, n_batch, matches)
|
| 39 |
+
sims = get_batches(bs, n_batch, sims)
|
| 40 |
+
m = []
|
| 41 |
+
s=[]
|
| 42 |
+
for (batch_m, batch_s) in zip(matches, sims):
|
| 43 |
+
batch_m = batch_m.cpu().numpy()
|
| 44 |
+
batch_s = batch_s.cpu().numpy()
|
| 45 |
+
mask = (batch_s > th)
|
| 46 |
+
for row in range(len(mask)):
|
| 47 |
+
m.append(batch_m[row][mask[row]].tolist())
|
| 48 |
+
s.append(batch_s[row][mask[row]].tolist())
|
| 49 |
+
return m, s
|
utils/predict.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from code_base.utils import CFG
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from functools import reduce
|
| 6 |
+
from utils.knn import get_matches, th_matches
|
| 7 |
+
|
| 8 |
+
K = 51
|
| 9 |
+
di = 1792
|
| 10 |
+
dt = 1024
|
| 11 |
+
dc = di+dt
|
| 12 |
+
n_batch = 10
|
| 13 |
+
|
| 14 |
+
def filter_embeddings(feas, matches, sims):
|
| 15 |
+
feas = feas.detach().cpu()
|
| 16 |
+
new_feas = feas.clone()
|
| 17 |
+
|
| 18 |
+
for i in range(feas.shape[0]):
|
| 19 |
+
cur_feas = feas[matches[i]]
|
| 20 |
+
weights = torch.unsqueeze(torch.Tensor(sims[i]), 1)
|
| 21 |
+
new_feas[i] = weights.T@cur_feas
|
| 22 |
+
new_feas = F.normalize(new_feas)
|
| 23 |
+
return new_feas.to(CFG.device)
|
| 24 |
+
|
| 25 |
+
def filter_matches(matches, sims, th=1.0, k=3, dist=1e-2):
|
| 26 |
+
top_matches = [row[:k] for row in matches]
|
| 27 |
+
top_sims = [row[:k] for row in sims]
|
| 28 |
+
for i in range(len(matches)):
|
| 29 |
+
if len(matches[i]) < k+1:
|
| 30 |
+
continue
|
| 31 |
+
dist_1 = sims[i][k-2] - sims[i][k-1]
|
| 32 |
+
dist_2 = sims[i][k-1] - sims[i][k]
|
| 33 |
+
if dist_2 < dist:
|
| 34 |
+
continue
|
| 35 |
+
if th*dist_1 < dist_2:
|
| 36 |
+
matches[i] = top_matches[i]
|
| 37 |
+
sims[i] = top_sims[i]
|
| 38 |
+
return matches, sims
|
| 39 |
+
|
| 40 |
+
def union_matches(*lists):
|
| 41 |
+
matches = []
|
| 42 |
+
for group in zip(*lists):
|
| 43 |
+
matches.append(reduce(np.union1d, group).tolist())
|
| 44 |
+
return matches
|
| 45 |
+
|
| 46 |
+
def predict(img_feas, txt_feas):
|
| 47 |
+
|
| 48 |
+
img_feas, txt_feas = F.normalize(img_feas).to(CFG.device) , F.normalize(txt_feas).to(CFG.device)
|
| 49 |
+
comb_feas = F.normalize(torch.cat([img_feas, txt_feas], dim=1)).to(CFG.device)
|
| 50 |
+
|
| 51 |
+
bs = len(comb_feas) // n_batch
|
| 52 |
+
|
| 53 |
+
img_matches, img_sims = get_matches(bs, n_batch, img_feas, di, k=K)
|
| 54 |
+
text_matches, text_sims = get_matches(bs, n_batch, txt_feas, dt, k=K)
|
| 55 |
+
comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K)
|
| 56 |
+
|
| 57 |
+
img_final, img_sims = th_matches(bs, n_batch, img_matches, img_sims, 0.704)
|
| 58 |
+
text_final, text_sims = th_matches(bs, n_batch, text_matches, text_sims, 0.764)
|
| 59 |
+
comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.52)
|
| 60 |
+
|
| 61 |
+
comb_feas = filter_embeddings(comb_feas, comb_final, comb_sims)
|
| 62 |
+
comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K)
|
| 63 |
+
comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.9)
|
| 64 |
+
|
| 65 |
+
img_final,_ = filter_matches(img_final, img_sims, 1.1, 4, 2e-2)
|
| 66 |
+
text_final,_ = filter_matches(text_final, text_sims, 1.2, 4, 2e-2)
|
| 67 |
+
comb_final,_ = filter_matches(comb_final, comb_sims, 1.0, 3, 2e-2)
|
| 68 |
+
|
| 69 |
+
match_final = union_matches(img_final, text_final, comb_final)
|
| 70 |
+
|
| 71 |
+
return match_final
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
utils/utilfuncs.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pp.albu import transform
|
| 3 |
+
from code_base.utils import CFG
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from utils.dataset import ImageDataset, TextDataset
|
| 6 |
+
from code_base.pipeline import ImgEncoder, TextEncoder
|
| 7 |
+
|
| 8 |
+
def gen_data(li,
|
| 9 |
+
lt,
|
| 10 |
+
IMG_SIZE,
|
| 11 |
+
BATCH_SIZE,
|
| 12 |
+
TKN_PATH,
|
| 13 |
+
num_workers):
|
| 14 |
+
data_img = ImageDataset(li=li, transform=transform(size=IMG_SIZE))
|
| 15 |
+
data_txt = TextDataset(li=lt, tokenizer=TKN_PATH)
|
| 16 |
+
dataloader_img = DataLoader(data_img, batch_size=BATCH_SIZE, shuffle=False,
|
| 17 |
+
num_workers=num_workers)
|
| 18 |
+
dataloader_txt = DataLoader(data_txt, batch_size=BATCH_SIZE, shuffle= False,
|
| 19 |
+
num_workers=num_workers)
|
| 20 |
+
|
| 21 |
+
return dataloader_img, dataloader_txt
|
| 22 |
+
|
| 23 |
+
def load_model(backbone, ckpt_path, num_classes=11014, img = False):
|
| 24 |
+
if img:
|
| 25 |
+
model = ImgEncoder(num_classes, backbone = backbone, pretrained = False, p=4)
|
| 26 |
+
else:
|
| 27 |
+
model = TextEncoder(num_classes, backbone = backbone, eval_model=True)
|
| 28 |
+
|
| 29 |
+
ckpt = torch.load(ckpt_path, weights_only=True, map_location = CFG.device)
|
| 30 |
+
|
| 31 |
+
new_state_dict = {}
|
| 32 |
+
|
| 33 |
+
for k, v in ckpt.items():
|
| 34 |
+
new_key = k.replace("module.", "") # remove module. prefix
|
| 35 |
+
new_state_dict[new_key] = v
|
| 36 |
+
|
| 37 |
+
model.load_state_dict(new_state_dict)
|
| 38 |
+
model = model.to(CFG.device)
|
| 39 |
+
print(f"model {backbone} loaded successfully")
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
class gen_feas:
|
| 43 |
+
def __init__(self, model, dataloader):
|
| 44 |
+
self.model = model
|
| 45 |
+
self.dataloader = dataloader
|
| 46 |
+
|
| 47 |
+
def gen_img_feas(self):
|
| 48 |
+
|
| 49 |
+
self.model.eval()
|
| 50 |
+
|
| 51 |
+
FEAS = []
|
| 52 |
+
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for batch_idx, (images) in enumerate(self.dataloader):
|
| 55 |
+
images = images.to(CFG.device)
|
| 56 |
+
|
| 57 |
+
logits = self.model(images)
|
| 58 |
+
FEAS += [logits.detach().cpu()]
|
| 59 |
+
|
| 60 |
+
FEAS = torch.cat(FEAS).cpu().numpy()
|
| 61 |
+
return FEAS
|
| 62 |
+
|
| 63 |
+
def gen_txt_feas(self):
|
| 64 |
+
|
| 65 |
+
self.model.eval()
|
| 66 |
+
|
| 67 |
+
FEAS = []
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
for batch_idx, (inp_ids, att_masks) in enumerate(self.dataloader):
|
| 71 |
+
inp_ids, att_masks = inp_ids.to(CFG.device), att_masks.to(CFG.device)
|
| 72 |
+
|
| 73 |
+
logits = self.model(inp_ids, att_masks)
|
| 74 |
+
FEAS += [logits.detach().cpu()]
|
| 75 |
+
|
| 76 |
+
FEAS = torch.cat(FEAS).cpu().numpy()
|
| 77 |
+
return FEAS
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def return_feas(model, dataloader, img=False):
|
| 81 |
+
if img:
|
| 82 |
+
feas = gen_feas(model, dataloader).gen_img_feas()
|
| 83 |
+
else:
|
| 84 |
+
feas = gen_feas(model, dataloader).gen_txt_feas()
|
| 85 |
+
feas = torch.tensor(feas).to(CFG.device)
|
| 86 |
+
return feas
|