amaye15 commited on
Commit
4b2d0b0
·
1 Parent(s): fab7fd4

Intial Commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chat_template.json filter=lfs diff=lfs merge=lfs -text
37
+ generation_config.json filter=lfs diff=lfs merge=lfs -text
38
+ preprocessor_config.json filter=lfs diff=lfs merge=lfs -text
39
+ tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
40
+ adapter_config.json filter=lfs diff=lfs merge=lfs -text
41
+ added_tokens.json filter=lfs diff=lfs merge=lfs -text
42
+ special_tokens_map.json filter=lfs diff=lfs merge=lfs -text
43
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
44
+ vocab.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.DS*
2
+ *__pycache__*
3
+ *.pdf
4
+ *.ipynb
README.md CHANGED
@@ -1,3 +1,117 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+
6
+ # EndpointHandler
7
+
8
+ `EndpointHandler` is a Python class that processes image and text data to generate embeddings and similarity scores using the ColQwen2 model—a visual retriever based on Qwen2-VL-2B-Instruct with the ColBERT strategy. This handler is optimized for retrieving documents and visual information based on their visual and textual features.
9
+
10
+ ## Overview
11
+
12
+ - **Efficient Document Retrieval**: Uses the ColQwen2 model to produce embeddings for images and text for accurate document retrieval.
13
+ - **Multi-vector Representation**: Generates ColBERT-style multi-vector embeddings for improved similarity search.
14
+ - **Flexible Image Resolution**: Supports dynamic image resolution without altering the aspect ratio, capped at 768 patches for memory efficiency.
15
+ - **Device Compatibility**: Automatically utilizes available CUDA devices or defaults to CPU.
16
+
17
+ ## Model Details
18
+
19
+ The **ColQwen2** model extends Qwen2-VL-2B with a focus on vision-language tasks, making it suitable for content indexing and retrieval. Key features include:
20
+ - **Training**: Pre-trained with a batch size of 256 over 5 epochs, with a modified pad token.
21
+ - **Input Flexibility**: Handles various image resolutions without resizing, ensuring accurate multi-vector representation.
22
+ - **Similarity Scoring**: Utilizes a ColBERT-style scoring approach for efficient retrieval across image and text modalities.
23
+
24
+ This base version is untrained, providing deterministic initialization of the projection layer for further customization.
25
+
26
+ ## How to Use
27
+
28
+ The following example demonstrates how to use `EndpointHandler` for processing PDF documents and text. PDF pages are converted to base64 images, which are then passed as input alongside text data to the handler.
29
+
30
+ ### Example Script
31
+
32
+ ```python
33
+ import torch
34
+ from pdf2image import convert_from_path
35
+ import base64
36
+ from io import BytesIO
37
+ import requests
38
+
39
+ # Function to convert PIL Image to base64 string
40
+ def pil_image_to_base64(image):
41
+ """Converts a PIL Image to a base64 encoded string."""
42
+ buffer = BytesIO()
43
+ image.save(buffer, format="PNG")
44
+ return base64.b64encode(buffer.getvalue()).decode()
45
+
46
+ # Function to convert PDF pages to base64 images
47
+ def convert_pdf_to_base64_images(pdf_path):
48
+ """Converts PDF pages to base64 encoded images."""
49
+ pages = convert_from_path(pdf_path)
50
+ return [pil_image_to_base64(page) for page in pages]
51
+
52
+ # Function to send payload to API and retrieve response
53
+ def query_api(payload, api_url, headers):
54
+ """Sends a POST request to the API and returns the response."""
55
+ response = requests.post(api_url, headers=headers, json=payload)
56
+ return response.json()
57
+
58
+ # Main execution
59
+ if __name__ == "__main__":
60
+ # Convert PDF pages to base64 encoded images
61
+ encoded_images = convert_pdf_to_base64_images('document.pdf')
62
+
63
+ # Prepare payload
64
+ payload = {
65
+ "inputs": [],
66
+ "image": encoded_images,
67
+ "text": ["example query text"]
68
+ }
69
+
70
+ # API configuration
71
+ API_URL = "https://your-api-url"
72
+ headers = {
73
+ "Accept": "application/json",
74
+ "Authorization": "Bearer your_access_token",
75
+ "Content-Type": "application/json"
76
+ }
77
+
78
+ # Query the API and get output
79
+ output = query_api(payload=payload, api_url=API_URL, headers=headers)
80
+ print(output)
81
+ ```
82
+
83
+ ## Inputs and Outputs
84
+
85
+ ### Input Format
86
+ The `EndpointHandler` expects a dictionary containing:
87
+ - **image**: A list of base64-encoded strings for images (e.g., PDF pages converted to images).
88
+ - **text**: A list of text strings representing queries or document contents.
89
+ - **batch_size** (optional): The batch size for processing images and text. Defaults to `4`.
90
+
91
+ Example payload:
92
+ ```json
93
+ {
94
+ "image": ["base64_image_string_1", "base64_image_string_2"],
95
+ "text": ["sample text 1", "sample text 2"],
96
+ "batch_size": 4
97
+ }
98
+ ```
99
+
100
+ ### Output Format
101
+ The handler returns a dictionary with the following keys:
102
+ - **image**: List of embeddings for each image.
103
+ - **text**: List of embeddings for each text entry.
104
+ - **scores**: List of similarity scores between the image and text embeddings.
105
+
106
+ Example output:
107
+ ```json
108
+ {
109
+ "image": [[0.12, 0.34, ...], [0.56, 0.78, ...]],
110
+ "text": [[0.11, 0.22, ...], [0.33, 0.44, ...]],
111
+ "scores": [[0.87, 0.45], [0.23, 0.67]]
112
+ }
113
+ ```
114
+
115
+ ### Error Handling
116
+ If any issues occur during processing (e.g., decoding images or model inference), the handler logs the error and returns an error message in the output dictionary.
117
+
adapter_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c88fb289a155188a09737629830dc32e753bb679d6bddd5f94ddf9daa1921114
3
+ size 727
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc856312174dc99a4c7f88a2c54d9590a3b3f5b5a86e2728d7138c7f4758c6d5
3
+ size 74018232
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fa54985b58718a8fdb4f4d97484c4bd908db114847675e4bf3afe3e1d5d7bd4
3
+ size 392
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94174d7176c52a7192f96fc34eb2cf23c7c2059d63cdbfadca1586ba89731fb7
3
+ size 1049
generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f31bc5c808ee15908986654279dd054f3e6bd65d52f8ca7b18a2a80552e2d35b
3
+ size 215
handler.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Any, List
3
+ from PIL import Image
4
+ import base64
5
+ from io import BytesIO
6
+ import logging
7
+
8
+
9
+ class EndpointHandler:
10
+ """
11
+ A handler class for processing image and text data, generating embeddings using a specified model and processor.
12
+
13
+ Attributes:
14
+ model: The pre-trained model used for generating embeddings.
15
+ processor: The pre-trained processor used to process images and text before model inference.
16
+ device: The device (CPU or CUDA) used to run model inference.
17
+ default_batch_size: The default batch size for processing images and text in batches.
18
+ """
19
+
20
+ def __init__(self, path: str = "", default_batch_size: int = 4):
21
+ """
22
+ Initializes the EndpointHandler with a specified model path and default batch size.
23
+
24
+ Args:
25
+ path (str): Path to the pre-trained model and processor.
26
+ default_batch_size (int): Default batch size for processing images and text data.
27
+ """
28
+ # Initialize logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ self.logger = logging.getLogger(__name__)
31
+
32
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
33
+
34
+ self.logger.info("Initializing model and processor.")
35
+ try:
36
+ self.model = ColQwen2.from_pretrained(
37
+ path,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto",
40
+ ).eval()
41
+ self.processor = ColQwen2Processor.from_pretrained(path)
42
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ self.model.to(self.device)
44
+ self.default_batch_size = default_batch_size
45
+ self.logger.info("Initialization complete.")
46
+ except Exception as e:
47
+ self.logger.error(f"Failed to initialize model or processor: {e}")
48
+ raise
49
+
50
+ def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
51
+ """
52
+ Processes a batch of images and generates embeddings.
53
+
54
+ Args:
55
+ images (List[Image.Image]): List of images to process.
56
+
57
+ Returns:
58
+ List[List[float]]: List of embeddings for each image.
59
+ """
60
+ self.logger.debug(f"Processing batch of {len(images)} images.")
61
+ try:
62
+ batch_images = self.processor.process_images(images).to(self.device)
63
+ with torch.no_grad():
64
+ image_embeddings = self.model(**batch_images)
65
+ self.logger.debug("Image batch processing complete.")
66
+ return image_embeddings.cpu().tolist()
67
+ except Exception as e:
68
+ self.logger.error(f"Error processing image batch: {e}")
69
+ raise
70
+
71
+ def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
72
+ """
73
+ Processes a batch of text queries and generates embeddings.
74
+
75
+ Args:
76
+ texts (List[str]): List of text queries to process.
77
+
78
+ Returns:
79
+ List[List[float]]: List of embeddings for each text query.
80
+ """
81
+ self.logger.debug(f"Processing batch of {len(texts)} text queries.")
82
+ try:
83
+ batch_queries = self.processor.process_queries(texts).to(self.device)
84
+ with torch.no_grad():
85
+ query_embeddings = self.model(**batch_queries)
86
+ self.logger.debug("Text batch processing complete.")
87
+ return query_embeddings.cpu().tolist()
88
+ except Exception as e:
89
+ self.logger.error(f"Error processing text batch: {e}")
90
+ raise
91
+
92
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
93
+ """
94
+ Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
95
+
96
+ Args:
97
+ data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
98
+
99
+ Returns:
100
+ Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
101
+ """
102
+ images_data = data.get("image", [])
103
+ text_data = data.get("text", [])
104
+ batch_size = data.get("batch_size", self.default_batch_size)
105
+
106
+ # Decode and process images
107
+ images = []
108
+ if images_data:
109
+ self.logger.info("Decoding images from base64.")
110
+ for img_data in images_data:
111
+ if isinstance(img_data, str):
112
+ try:
113
+ image_bytes = base64.b64decode(img_data)
114
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
115
+ images.append(image)
116
+ except Exception as e:
117
+ self.logger.error(f"Invalid image data: {e}")
118
+ return {"error": f"Invalid image data: {e}"}
119
+ else:
120
+ self.logger.error("Images should be base64-encoded strings.")
121
+ return {"error": "Images should be base64-encoded strings."}
122
+
123
+ image_embeddings = []
124
+ if images:
125
+ self.logger.info("Processing image embeddings.")
126
+ try:
127
+ for i in range(0, len(images), batch_size):
128
+ batch_images = images[i : i + batch_size]
129
+ batch_embeddings = self._process_image_batch(batch_images)
130
+ image_embeddings.extend(batch_embeddings)
131
+ except Exception as e:
132
+ self.logger.error(f"Error generating image embeddings: {e}")
133
+ return {"error": f"Error generating image embeddings: {e}"}
134
+
135
+ # Process text data
136
+ text_embeddings = []
137
+ if text_data:
138
+ self.logger.info("Processing text embeddings.")
139
+ try:
140
+ for i in range(0, len(text_data), batch_size):
141
+ batch_texts = text_data[i : i + batch_size]
142
+ batch_text_embeddings = self._process_text_batch(batch_texts)
143
+ text_embeddings.extend(batch_text_embeddings)
144
+ except Exception as e:
145
+ self.logger.error(f"Error generating text embeddings: {e}")
146
+ return {"error": f"Error generating text embeddings: {e}"}
147
+
148
+ # Compute similarity scores if both image and text embeddings are available
149
+ scores = []
150
+ if image_embeddings and text_embeddings:
151
+ self.logger.info("Computing similarity scores.")
152
+ try:
153
+ image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
154
+ text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
155
+ with torch.no_grad():
156
+ scores = (
157
+ self.processor.score_multi_vector(
158
+ text_embeddings_tensor, image_embeddings_tensor
159
+ )
160
+ .cpu()
161
+ .tolist()
162
+ )
163
+ self.logger.info("Similarity scoring complete.")
164
+ except Exception as e:
165
+ self.logger.error(f"Error computing similarity scores: {e}")
166
+ return {"error": f"Error computing similarity scores: {e}"}
167
+
168
+ return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dd5968b65af7e090e399f39ae94734e400d9d71a3c82fca2720c5ee514034f3
3
+ size 568
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ colpali-engine==0.3.3
2
+ pdf2image
3
+ GPUtil
4
+ accelerate==0.30.1
special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76862e765266b85aa9459767e33cbaf13970f327a0e88d1c65846c2ddd3a1ecd
3
+ size 613
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:091aa7594dc2fcfbfa06b9e3c22a5f0562ac14f30375c13af7309407a0e67b8a
3
+ size 11420371
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:955409fb4dab09a71b957ce69f8a8185bbbd3416b9ab5a47e01221545be39c6f
3
+ size 4298
vocab.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910
3
+ size 2776833