Spaces:
Running
on
L4
Running
on
L4
Upload TextSearch.py
Browse files- helpers/TextSearch.py +85 -0
helpers/TextSearch.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import faiss
|
| 3 |
+
import torch
|
| 4 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 5 |
+
|
| 6 |
+
class SearchSigLIP():
|
| 7 |
+
|
| 8 |
+
def __init__(self, index_path, metadata_path):
|
| 9 |
+
|
| 10 |
+
# 1. Initialise Index
|
| 11 |
+
print(f'Loading index from PATH={index_path}')
|
| 12 |
+
self.index_path = index_path
|
| 13 |
+
self.init_index()
|
| 14 |
+
print('[DONE]')
|
| 15 |
+
|
| 16 |
+
# 2. Initialise Metadata
|
| 17 |
+
print(f'Loading metadata from PATH={metadata_path}')
|
| 18 |
+
self.metadata_path = metadata_path
|
| 19 |
+
self.metadata_df = pd.read_parquet(self.metadata_path)
|
| 20 |
+
print('[DONE]')
|
| 21 |
+
|
| 22 |
+
# 3. Initialise Text Encoder
|
| 23 |
+
self.init_model()
|
| 24 |
+
|
| 25 |
+
def init_index(self):
|
| 26 |
+
self.cpu_index = faiss.read_index(self.index_path)
|
| 27 |
+
res = faiss.StandardGpuResources()
|
| 28 |
+
cloner_options = faiss.GpuClonerOptions()
|
| 29 |
+
cloner_options.useFloat16LookupTables = True
|
| 30 |
+
self.gpu_index = faiss.index_cpu_to_gpu(res, 0, self.cpu_index, cloner_options)
|
| 31 |
+
self.gpu_index.nprobe = 32 # Higher = more accurate, slower
|
| 32 |
+
|
| 33 |
+
def init_model(self):
|
| 34 |
+
self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
|
| 35 |
+
self.model.eval()
|
| 36 |
+
self.tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP')
|
| 37 |
+
|
| 38 |
+
def encode_text(self, text, device='cuda'):
|
| 39 |
+
self.model.to(device)
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
text = self.tokenizer([text], context_length=self.model.context_length)
|
| 42 |
+
return self.model.encode_text(text.to(device))
|
| 43 |
+
|
| 44 |
+
def search_with_grid(self, query_vec, k=5):
|
| 45 |
+
# Prepare query
|
| 46 |
+
if isinstance(query_vec, torch.Tensor):
|
| 47 |
+
query_vec = query_vec.cpu().squeeze().numpy()
|
| 48 |
+
|
| 49 |
+
query_vec = query_vec.reshape(1, -1).astype('float32')
|
| 50 |
+
faiss.normalize_L2(query_vec)
|
| 51 |
+
|
| 52 |
+
# Search
|
| 53 |
+
distances, indices = self.gpu_index.search(query_vec, k)
|
| 54 |
+
|
| 55 |
+
# Flatten results
|
| 56 |
+
ids = indices[0]
|
| 57 |
+
scores = distances[0]
|
| 58 |
+
|
| 59 |
+
results = []
|
| 60 |
+
|
| 61 |
+
# Batch lookup in pandas (Faster than looping)
|
| 62 |
+
# We ignore -1 (which happens if k > total vectors, unlikely here)
|
| 63 |
+
valid_mask = ids != -1
|
| 64 |
+
valid_ids = ids[valid_mask]
|
| 65 |
+
valid_scores = scores[valid_mask]
|
| 66 |
+
|
| 67 |
+
if len(valid_ids) > 0:
|
| 68 |
+
# MAGIC LINE: Direct lookup by integer index
|
| 69 |
+
matches = self.metadata_df.iloc[valid_ids].copy()
|
| 70 |
+
matches['score'] = valid_scores
|
| 71 |
+
|
| 72 |
+
# Convert to list of dicts for easy usage
|
| 73 |
+
results = matches.to_dict(orient='records')
|
| 74 |
+
|
| 75 |
+
return results
|
| 76 |
+
|
| 77 |
+
def faiss(self, text, k=1): # k - number of neighbours
|
| 78 |
+
|
| 79 |
+
# 1. Compute query
|
| 80 |
+
q = self.encode_text(text)
|
| 81 |
+
|
| 82 |
+
# 2. Find Hits
|
| 83 |
+
results = self.search_with_grid(q, k=k)
|
| 84 |
+
|
| 85 |
+
return results
|