File size: 2,857 Bytes
13359ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas as pd
import faiss
import torch
from open_clip import create_model_from_pretrained, get_tokenizer

class SearchSigLIP():

    def __init__(self, index_path, metadata_path):

        # 1. Initialise Index
        print(f'Loading index from PATH={index_path}')
        self.index_path = index_path
        self.init_index()
        print('[DONE]')

        # 2. Initialise Metadata
        print(f'Loading metadata from PATH={metadata_path}')
        self.metadata_path = metadata_path
        self.metadata_df = pd.read_parquet(self.metadata_path)
        print('[DONE]')

        # 3. Initialise Text Encoder
        self.init_model()

    def init_index(self):
        self.cpu_index = faiss.read_index(self.index_path)
        res = faiss.StandardGpuResources()
        cloner_options = faiss.GpuClonerOptions()
        cloner_options.useFloat16LookupTables = True 
        self.gpu_index = faiss.index_cpu_to_gpu(res, 0, self.cpu_index, cloner_options)
        self.gpu_index.nprobe = 32 # Higher = more accurate, slower

    def init_model(self):
        self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
        self.model.eval()
        self.tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP')

    def encode_text(self, text, device='cuda'):
        self.model.to(device)
        with torch.no_grad():
            text = self.tokenizer([text], context_length=self.model.context_length)
            return self.model.encode_text(text.to(device))

    def search_with_grid(self, query_vec, k=5):
        # Prepare query
        if isinstance(query_vec, torch.Tensor):
            query_vec = query_vec.cpu().squeeze().numpy()
            
        query_vec = query_vec.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_vec)
        
        # Search
        distances, indices = self.gpu_index.search(query_vec, k)
        
        # Flatten results
        ids = indices[0]
        scores = distances[0]
        
        results = []
        
        # Batch lookup in pandas (Faster than looping)
        # We ignore -1 (which happens if k > total vectors, unlikely here)
        valid_mask = ids != -1
        valid_ids = ids[valid_mask]
        valid_scores = scores[valid_mask]
        
        if len(valid_ids) > 0:
            # MAGIC LINE: Direct lookup by integer index
            matches = self.metadata_df.iloc[valid_ids].copy()
            matches['score'] = valid_scores
            
            # Convert to list of dicts for easy usage
            results = matches.to_dict(orient='records')
            
        return results

    def faiss(self, text, k=1): # k - number of neighbours

        # 1. Compute query
        q = self.encode_text(text)

        # 2. Find Hits
        results = self.search_with_grid(q, k=k)

        return results