Spaces:
Running
on
L4
Running
on
L4
| import pandas as pd | |
| import faiss | |
| import torch | |
| from open_clip import create_model_from_pretrained, get_tokenizer | |
| class SearchSigLIP(): | |
| def __init__(self, index_path, metadata_path): | |
| # 1. Initialise Index | |
| print(f'Loading index from PATH={index_path}') | |
| self.index_path = index_path | |
| self.init_index() | |
| print('[DONE]') | |
| # 2. Initialise Metadata | |
| print(f'Loading metadata from PATH={metadata_path}') | |
| self.metadata_path = metadata_path | |
| self.metadata_df = pd.read_parquet(self.metadata_path) | |
| print('[DONE]') | |
| # 3. Initialise Text Encoder | |
| self.init_model() | |
| def init_index(self): | |
| self.cpu_index = faiss.read_index(self.index_path) | |
| res = faiss.StandardGpuResources() | |
| cloner_options = faiss.GpuClonerOptions() | |
| cloner_options.useFloat16LookupTables = True | |
| self.gpu_index = faiss.index_cpu_to_gpu(res, 0, self.cpu_index, cloner_options) | |
| self.gpu_index.nprobe = 32 # Higher = more accurate, slower | |
| def init_model(self): | |
| self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384') | |
| self.model.eval() | |
| self.tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP') | |
| def encode_text(self, text, device='cuda'): | |
| self.model.to(device) | |
| with torch.no_grad(): | |
| text = self.tokenizer([text], context_length=self.model.context_length) | |
| return self.model.encode_text(text.to(device)) | |
| def search_with_grid(self, query_vec, k=5): | |
| # Prepare query | |
| if isinstance(query_vec, torch.Tensor): | |
| query_vec = query_vec.cpu().squeeze().numpy() | |
| query_vec = query_vec.reshape(1, -1).astype('float32') | |
| faiss.normalize_L2(query_vec) | |
| # Search | |
| distances, indices = self.gpu_index.search(query_vec, k) | |
| # Flatten results | |
| ids = indices[0] | |
| scores = distances[0] | |
| results = [] | |
| # Batch lookup in pandas (Faster than looping) | |
| # We ignore -1 (which happens if k > total vectors, unlikely here) | |
| valid_mask = ids != -1 | |
| valid_ids = ids[valid_mask] | |
| valid_scores = scores[valid_mask] | |
| if len(valid_ids) > 0: | |
| # MAGIC LINE: Direct lookup by integer index | |
| matches = self.metadata_df.iloc[valid_ids].copy() | |
| matches['score'] = valid_scores | |
| # Convert to list of dicts for easy usage | |
| results = matches.to_dict(orient='records') | |
| return results | |
| def faiss(self, text, k=1): # k - number of neighbours | |
| # 1. Compute query | |
| q = self.encode_text(text) | |
| # 2. Find Hits | |
| results = self.search_with_grid(q, k=k) | |
| return results |