Anirban0011 commited on
Commit
fe0eb36
·
1 Parent(s): 9d563c0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ env
2
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ import uvicorn
3
+ import torch
4
+ import os
5
+ import nltk
6
+ nltk.download("stopwords")
7
+ import numpy as np
8
+ from typing import List
9
+ from inference import inference
10
+ from code_base.utils import CFG
11
+
12
+ TKN_PATH= ["bert-base-uncased"]
13
+ IMG_SIZE = 256
14
+ BATCH_SIZE = 32
15
+ img = True
16
+
17
+ CFG.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ app = FastAPI(title="shopee-test-app")
20
+
21
+ @app.get("/")
22
+ async def root():
23
+ return {"status": "ok", "message": "Space is running"}
24
+
25
+ @app.post("/predict")
26
+ async def predict_image(files: List[UploadFile] = File(...),
27
+ texts: List[str] = Form(...)):
28
+ li, lt= [], []
29
+ for file, text in zip(files, texts):
30
+ contents = await file.read()
31
+ li.append(contents)
32
+ lt.append(text)
33
+ res = inference(li=li,
34
+ lt=lt,
35
+ IMG_SIZE=IMG_SIZE,
36
+ TKN_PATH=TKN_PATH,
37
+ BATCH_SIZE=BATCH_SIZE
38
+ )
39
+ msg = "products matched" if res else "products not matched"
40
+
41
+ return {"message" : f"{msg}"}
42
+
43
+ if __name__ == "__main__":
44
+ uvicorn.run("app:app", host="0.0.0.0", port=8000)
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
inference.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ from utils.predict import predict
4
+ from utils.filterfunc import filter_match_titles
5
+ from utils.ckpts import img_ckpt, txt_ckpt
6
+ from utils.utilfuncs import gen_data, load_model, return_feas
7
+
8
+ img_backbone = ["timm/eca_nfnet_l1.ra2_in1k"]
9
+ txt_backbone = ["google-bert/bert-base-uncased"]
10
+
11
+ def clean():
12
+ gc.collect()
13
+
14
+ def inference(li, lt, IMG_SIZE,
15
+ TKN_PATH,
16
+ BATCH_SIZE,
17
+ num_workers = 4,
18
+ ):
19
+ dataloader_img, dataloader_txt = gen_data(li,
20
+ lt,
21
+ IMG_SIZE,
22
+ BATCH_SIZE,
23
+ TKN_PATH[0],
24
+ num_workers)
25
+
26
+ img_model = [load_model(backbone=img_backbone[i],
27
+ ckpt_path=img_ckpt[i],
28
+ img=True)
29
+ for i in range(len(img_backbone))]
30
+
31
+ img_feas = torch.cat([return_feas(
32
+ img_model[i],
33
+ dataloader_img, img=True)
34
+ for i in range(len(img_backbone))], dim=1)
35
+
36
+ txt_model = [load_model(backbone=TKN_PATH[i], ckpt_path=txt_ckpt[i])
37
+ for i in range(len(txt_backbone))]
38
+
39
+ txt_feas = torch.cat([return_feas(
40
+ txt_model[i],
41
+ dataloader_txt)
42
+ for i in range(len(txt_backbone))], dim=1)
43
+
44
+ match_final = predict(img_feas=img_feas,
45
+ txt_feas=txt_feas)
46
+
47
+ match_final = filter_match_titles(match_final, title_list=lt)
48
+
49
+ assert len(match_final == 2)
50
+
51
+ return set(match_final[0]) == set(match_final[1])
52
+
53
+
54
+
55
+
56
+
pp/__init__.py ADDED
File without changes
pp/albu.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+
3
+ def transform(size):
4
+ transforms = A.Compose([
5
+ A.Resize(size, size),
6
+ A.Normalize()
7
+ ])
8
+ return transforms
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ fastapi
3
+ uvicorn[standard]
4
+ huggingface_hub
5
+ albumentations
6
+ transformers
7
+ scikit-learn
8
+ unidecode
9
+ nltk
10
+ timm
11
+ faiss-cpu
12
+ hf-xet
13
+ git+https://github.com/Anirban0011/shopee-product-matching.git
utils/__init__.py ADDED
File without changes
utils/ckpts.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ REPO_ID = "Anirban0011/multimodal-shopee-finetune"
4
+
5
+ def get_path(filename, repo):
6
+ path = hf_hub_download(repo_id=repo, filename=filename)
7
+ return path
8
+
9
+ img_path = get_path(repo=REPO_ID,
10
+ filename="img_model_eca_nfnet_l1.ra2_in1k.pth")
11
+ txt_path = get_path(repo=REPO_ID,
12
+ filename="txt_model_bert-base-uncased_35.pth")
13
+
14
+ img_ckpt = [img_path]
15
+ txt_ckpt = [txt_path]
utils/dataset.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from transformers import AutoTokenizer
6
+ from torch.utils.data import Dataset
7
+
8
+ class ImageDataset(Dataset):
9
+ def __init__(self, li, transform=None):
10
+ self.li = li
11
+ self.transform = transform
12
+
13
+ def __len__(self):
14
+ return len(self.li)
15
+
16
+ def __getitem__(self, index):
17
+ img_byte = self.li[index]
18
+ img = Image.open(io.BytesIO(img_byte)).convert("RGB")
19
+ img = np.array(img)
20
+ img = img.copy()
21
+
22
+ if self.transform is not None:
23
+ img = self.transform(image=img)
24
+ img = img["image"]
25
+ img = img.astype(np.float32)
26
+ img = img.transpose(2, 0, 1)
27
+
28
+ return torch.tensor(img).float()
29
+
30
+ class TextDataset(Dataset):
31
+ def __init__(self, li, tokenizer=None):
32
+ self.li = li
33
+ self.to = AutoTokenizer.from_pretrained(tokenizer)
34
+
35
+ def __len__(self):
36
+ return len(self.li)
37
+
38
+ def __getitem__(self, index):
39
+ text = self.li[index]
40
+ text = self.tokenizer(
41
+ text,
42
+ padding="max_length",
43
+ truncation=True,
44
+ max_length=35,
45
+ return_tensors="pt",
46
+ )
47
+ input_ids = text["input_ids"][0]
48
+ attention_mask = text["attention_mask"][0]
49
+ return input_ids, attention_mask
50
+
51
+
utils/filterfunc.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex
2
+
3
+ # adapted from kaggle.com/code/slawekbiel/resnet18-0-772-public-lb/notebook
4
+ measurements = {
5
+ 'weight': [('mg',1), ('g', 1000), ('gr', 1000), ('gram', 1000), ('kg', 1000000)],
6
+ 'length': [('mm',1), ('cm', 10), ('m',1000), ('meter', 1000)],
7
+ 'pieces': [ ('pc',1)],
8
+ 'memory': [('gb', 1)],
9
+ 'volume': [('ml', 1), ('l', 1000), ('liter',1000)]
10
+ }
11
+
12
+ def to_num(x, mult=1):
13
+ x = x.replace(',','.')
14
+ return int(float(x)*mult)
15
+
16
+ def extract_unit(tit, m):
17
+ pat = rf'\W(\d+(?:[\,\.]\d+)?) ?{m}s?\W'
18
+ matches = regex.findall(pat, tit, overlapped=True)
19
+ return set(matches)
20
+
21
+ def extract(tit):
22
+ res =dict()
23
+ tit = ' '+tit.lower()+' '
24
+ for cat, units in measurements.items():
25
+ cat_values=set()
26
+ for unit_name, mult in units:
27
+ values = extract_unit(tit, unit_name)
28
+ values = {to_num(v, mult) for v in values}
29
+ cat_values = cat_values.union(values)
30
+ if cat_values:
31
+ res[cat] = cat_values
32
+ return res
33
+
34
+ def match_measures(m1, m2):
35
+ k1,k2 = set(m1.keys()), set(m2.keys())
36
+ common = k1.intersection(k2)
37
+ if not common:
38
+ return True
39
+ for key in common:
40
+ s1,s2 = m1[key], m2[key]
41
+ if s1.intersection(s2):
42
+ return True
43
+ return False
44
+
45
+ def filter_match_titles(matches : list, title_list): # filter matches override
46
+ for i in range(len(matches)):
47
+ item_title = extract(title_list)
48
+ l=[]
49
+ for match in matches[i]:
50
+ if match == i:
51
+ l.append(i)
52
+ continue
53
+ match_title = extract(title_list)
54
+ if (match_measures(item_title, match_title)):
55
+ l.append(match)
56
+ matches[i] = l
57
+ return matches
utils/knn.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from code_base.utils import CFG
2
+ import faiss
3
+ import torch
4
+
5
+ def build_faiss(feas, dim):
6
+ if CFG.device.type == "cpu":
7
+ index = faiss.IndexFlatIP(dim)
8
+ else :
9
+ res = faiss.StandardGpuResources()
10
+ index = faiss.GpuIndexFlatIP(res, dim)
11
+ index.add(feas)
12
+ return index
13
+
14
+ def get_batches(bs, n_batch, feas):
15
+ batches = []
16
+ for i in range(n_batch):
17
+ left = bs * i
18
+ right = bs * (i+1)
19
+ if i == n_batch - 1:
20
+ right = feas.shape[0]
21
+ batches.append(feas[left:right,:])
22
+ return batches
23
+
24
+ def get_matches(bs, n_batch, feas, dim, k=51):
25
+ index = build_faiss(feas, dim)
26
+ m=[]
27
+ s=[]
28
+ for batch in get_batches(bs, n_batch, feas):
29
+ batch = batch.to(CFG.device)
30
+ sims, matches = index.search(batch, k)
31
+ m.append(matches)
32
+ s.append(sims)
33
+ m = torch.cat(m, dim=0).to(torch.int32)
34
+ s = torch.cat(s, dim=0)
35
+ return m,s
36
+
37
+ def th_matches(bs, n_batch, matches, sims, th):
38
+ matches = get_batches(bs, n_batch, matches)
39
+ sims = get_batches(bs, n_batch, sims)
40
+ m = []
41
+ s=[]
42
+ for (batch_m, batch_s) in zip(matches, sims):
43
+ batch_m = batch_m.cpu().numpy()
44
+ batch_s = batch_s.cpu().numpy()
45
+ mask = (batch_s > th)
46
+ for row in range(len(mask)):
47
+ m.append(batch_m[row][mask[row]].tolist())
48
+ s.append(batch_s[row][mask[row]].tolist())
49
+ return m, s
utils/predict.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from code_base.utils import CFG
4
+ import torch.nn.functional as F
5
+ from functools import reduce
6
+ from utils.knn import get_matches, th_matches
7
+
8
+ K = 51
9
+ di = 1792
10
+ dt = 1024
11
+ dc = di+dt
12
+ n_batch = 10
13
+
14
+ def filter_embeddings(feas, matches, sims):
15
+ feas = feas.detach().cpu()
16
+ new_feas = feas.clone()
17
+
18
+ for i in range(feas.shape[0]):
19
+ cur_feas = feas[matches[i]]
20
+ weights = torch.unsqueeze(torch.Tensor(sims[i]), 1)
21
+ new_feas[i] = weights.T@cur_feas
22
+ new_feas = F.normalize(new_feas)
23
+ return new_feas.to(CFG.device)
24
+
25
+ def filter_matches(matches, sims, th=1.0, k=3, dist=1e-2):
26
+ top_matches = [row[:k] for row in matches]
27
+ top_sims = [row[:k] for row in sims]
28
+ for i in range(len(matches)):
29
+ if len(matches[i]) < k+1:
30
+ continue
31
+ dist_1 = sims[i][k-2] - sims[i][k-1]
32
+ dist_2 = sims[i][k-1] - sims[i][k]
33
+ if dist_2 < dist:
34
+ continue
35
+ if th*dist_1 < dist_2:
36
+ matches[i] = top_matches[i]
37
+ sims[i] = top_sims[i]
38
+ return matches, sims
39
+
40
+ def union_matches(*lists):
41
+ matches = []
42
+ for group in zip(*lists):
43
+ matches.append(reduce(np.union1d, group).tolist())
44
+ return matches
45
+
46
+ def predict(img_feas, txt_feas):
47
+
48
+ img_feas, txt_feas = F.normalize(img_feas).to(CFG.device) , F.normalize(txt_feas).to(CFG.device)
49
+ comb_feas = F.normalize(torch.cat([img_feas, txt_feas], dim=1)).to(CFG.device)
50
+
51
+ bs = len(comb_feas) // n_batch
52
+
53
+ img_matches, img_sims = get_matches(bs, n_batch, img_feas, di, k=K)
54
+ text_matches, text_sims = get_matches(bs, n_batch, txt_feas, dt, k=K)
55
+ comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K)
56
+
57
+ img_final, img_sims = th_matches(bs, n_batch, img_matches, img_sims, 0.704)
58
+ text_final, text_sims = th_matches(bs, n_batch, text_matches, text_sims, 0.764)
59
+ comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.52)
60
+
61
+ comb_feas = filter_embeddings(comb_feas, comb_final, comb_sims)
62
+ comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K)
63
+ comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.9)
64
+
65
+ img_final,_ = filter_matches(img_final, img_sims, 1.1, 4, 2e-2)
66
+ text_final,_ = filter_matches(text_final, text_sims, 1.2, 4, 2e-2)
67
+ comb_final,_ = filter_matches(comb_final, comb_sims, 1.0, 3, 2e-2)
68
+
69
+ match_final = union_matches(img_final, text_final, comb_final)
70
+
71
+ return match_final
72
+
73
+
74
+
75
+
76
+
77
+
utils/utilfuncs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pp.albu import transform
3
+ from code_base.utils import CFG
4
+ from torch.utils.data import DataLoader
5
+ from utils.dataset import ImageDataset, TextDataset
6
+ from code_base.pipeline import ImgEncoder, TextEncoder
7
+
8
+ def gen_data(li,
9
+ lt,
10
+ IMG_SIZE,
11
+ BATCH_SIZE,
12
+ TKN_PATH,
13
+ num_workers):
14
+ data_img = ImageDataset(li=li, transform=transform(size=IMG_SIZE))
15
+ data_txt = TextDataset(li=lt, tokenizer=TKN_PATH)
16
+ dataloader_img = DataLoader(data_img, batch_size=BATCH_SIZE, shuffle=False,
17
+ num_workers=num_workers)
18
+ dataloader_txt = DataLoader(data_txt, batch_size=BATCH_SIZE, shuffle= False,
19
+ num_workers=num_workers)
20
+
21
+ return dataloader_img, dataloader_txt
22
+
23
+ def load_model(backbone, ckpt_path, num_classes=11014, img = False):
24
+ if img:
25
+ model = ImgEncoder(num_classes, backbone = backbone, pretrained = False, p=4)
26
+ else:
27
+ model = TextEncoder(num_classes, backbone = backbone, eval_model=True)
28
+
29
+ ckpt = torch.load(ckpt_path, weights_only=True, map_location = CFG.device)
30
+
31
+ new_state_dict = {}
32
+
33
+ for k, v in ckpt.items():
34
+ new_key = k.replace("module.", "") # remove module. prefix
35
+ new_state_dict[new_key] = v
36
+
37
+ model.load_state_dict(new_state_dict)
38
+ model = model.to(CFG.device)
39
+ print(f"model {backbone} loaded successfully")
40
+ return model
41
+
42
+ class gen_feas:
43
+ def __init__(self, model, dataloader):
44
+ self.model = model
45
+ self.dataloader = dataloader
46
+
47
+ def gen_img_feas(self):
48
+
49
+ self.model.eval()
50
+
51
+ FEAS = []
52
+
53
+ with torch.no_grad():
54
+ for batch_idx, (images) in enumerate(self.dataloader):
55
+ images = images.to(CFG.device)
56
+
57
+ logits = self.model(images)
58
+ FEAS += [logits.detach().cpu()]
59
+
60
+ FEAS = torch.cat(FEAS).cpu().numpy()
61
+ return FEAS
62
+
63
+ def gen_txt_feas(self):
64
+
65
+ self.model.eval()
66
+
67
+ FEAS = []
68
+
69
+ with torch.no_grad():
70
+ for batch_idx, (inp_ids, att_masks) in enumerate(self.dataloader):
71
+ inp_ids, att_masks = inp_ids.to(CFG.device), att_masks.to(CFG.device)
72
+
73
+ logits = self.model(inp_ids, att_masks)
74
+ FEAS += [logits.detach().cpu()]
75
+
76
+ FEAS = torch.cat(FEAS).cpu().numpy()
77
+ return FEAS
78
+
79
+
80
+ def return_feas(model, dataloader, img=False):
81
+ if img:
82
+ feas = gen_feas(model, dataloader).gen_img_feas()
83
+ else:
84
+ feas = gen_feas(model, dataloader).gen_txt_feas()
85
+ feas = torch.tensor(feas).to(CFG.device)
86
+ return feas