Spaces:
Running
on
Zero
Running
on
Zero
bug fixes and improvement
Browse files- app.py +96 -17
- requirements.txt +4 -2
- utils.py +7 -9
app.py
CHANGED
|
@@ -8,11 +8,14 @@ import ocrmypdf
|
|
| 8 |
import os
|
| 9 |
import pandas as pd
|
| 10 |
import pymupdf
|
|
|
|
| 11 |
import spaces
|
| 12 |
import torch
|
| 13 |
from PIL import Image
|
| 14 |
from chromadb.utils import embedding_functions
|
| 15 |
from chromadb.utils.data_loaders import ImageLoader
|
|
|
|
|
|
|
| 16 |
from gradio.themes.utils import sizes
|
| 17 |
from langchain import PromptTemplate
|
| 18 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
@@ -22,6 +25,29 @@ from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
|
|
| 22 |
from utils import *
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if torch.cuda.is_available():
|
| 26 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
| 27 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
@@ -75,7 +101,6 @@ def get_vectordb(text, images):
|
|
| 75 |
metadata={"hnsw:space": "cosine"},
|
| 76 |
)
|
| 77 |
descs = []
|
| 78 |
-
print(descs)
|
| 79 |
for image in images:
|
| 80 |
try:
|
| 81 |
descs.append(get_image_description(image)[0])
|
|
@@ -97,7 +122,9 @@ def get_vectordb(text, images):
|
|
| 97 |
chunk_overlap=10,
|
| 98 |
)
|
| 99 |
|
| 100 |
-
if len(text)
|
|
|
|
|
|
|
| 101 |
docs = splitter.create_documents([text])
|
| 102 |
doc_texts = [i.page_content for i in docs]
|
| 103 |
text_collection.add(
|
|
@@ -106,7 +133,16 @@ def get_vectordb(text, images):
|
|
| 106 |
return client
|
| 107 |
|
| 108 |
|
| 109 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if len(docs) == 0:
|
| 111 |
raise gr.Error("No documents to process")
|
| 112 |
progress(0, "Extracting Images")
|
|
@@ -115,18 +151,20 @@ def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()
|
|
| 115 |
|
| 116 |
progress(0.25, "Extracting Text")
|
| 117 |
|
| 118 |
-
strategy = "hi_res"
|
| 119 |
-
model_name = "yolox"
|
| 120 |
-
all_elements = []
|
| 121 |
all_text = ""
|
| 122 |
|
| 123 |
images = []
|
| 124 |
for doc in docs:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
if include_images == "Include Images":
|
| 129 |
-
images.extend(extract_images([
|
| 130 |
|
| 131 |
progress(
|
| 132 |
0.6, "Generating image descriptions and inserting everything into vectorDB"
|
|
@@ -153,20 +191,28 @@ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFuncti
|
|
| 153 |
|
| 154 |
|
| 155 |
def conversation(
|
| 156 |
-
vectordb_client,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
):
|
| 158 |
if hf_token.strip() != "" and model_path.strip() != "":
|
| 159 |
llm = HuggingFaceEndpoint(
|
| 160 |
repo_id=model_path,
|
| 161 |
-
temperature=
|
| 162 |
-
max_new_tokens=
|
| 163 |
huggingfacehub_api_token=hf_token,
|
| 164 |
)
|
| 165 |
else:
|
| 166 |
llm = HuggingFaceEndpoint(
|
| 167 |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
| 168 |
-
temperature=
|
| 169 |
-
max_new_tokens=
|
| 170 |
huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
|
| 171 |
)
|
| 172 |
|
|
@@ -273,6 +319,12 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
| 273 |
label="Include/ Exclude Images",
|
| 274 |
interactive=True,
|
| 275 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
with gr.Row(equal_height=True, variant="panel") as row:
|
| 278 |
selected = gr.Dataframe(
|
|
@@ -327,6 +379,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
| 327 |
interactive=True,
|
| 328 |
value=2,
|
| 329 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
with gr.Row():
|
| 331 |
with gr.Column():
|
| 332 |
ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
|
|
@@ -361,7 +430,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
| 361 |
)
|
| 362 |
embed.click(
|
| 363 |
extract_data_from_pdfs,
|
| 364 |
-
inputs=[doc_collection, session_states, include_images],
|
| 365 |
outputs=[
|
| 366 |
vectordb,
|
| 367 |
session_states,
|
|
@@ -374,7 +443,17 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
|
|
| 374 |
|
| 375 |
submit_btn.click(
|
| 376 |
conversation,
|
| 377 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
[chatbot, references, ret_images],
|
| 379 |
)
|
| 380 |
|
|
|
|
| 8 |
import os
|
| 9 |
import pandas as pd
|
| 10 |
import pymupdf
|
| 11 |
+
from pypdf import PdfReader
|
| 12 |
import spaces
|
| 13 |
import torch
|
| 14 |
from PIL import Image
|
| 15 |
from chromadb.utils import embedding_functions
|
| 16 |
from chromadb.utils.data_loaders import ImageLoader
|
| 17 |
+
from doctr.io import DocumentFile
|
| 18 |
+
from doctr.models import ocr_predictor
|
| 19 |
from gradio.themes.utils import sizes
|
| 20 |
from langchain import PromptTemplate
|
| 21 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
| 25 |
from utils import *
|
| 26 |
|
| 27 |
|
| 28 |
+
def result_to_text(result, as_text=False) -> str or list:
|
| 29 |
+
full_doc = []
|
| 30 |
+
for _, page in enumerate(result.pages, start=1):
|
| 31 |
+
text = ""
|
| 32 |
+
for block in page.blocks:
|
| 33 |
+
text += "\n\t"
|
| 34 |
+
for line in block.lines:
|
| 35 |
+
for word in line.words:
|
| 36 |
+
text += word.value + " "
|
| 37 |
+
|
| 38 |
+
full_doc.append(clean_text(text) + "\n\n")
|
| 39 |
+
|
| 40 |
+
return "\n".join(full_doc) if as_text else full_doc
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
ocr_model = ocr_predictor(
|
| 44 |
+
"db_resnet50",
|
| 45 |
+
"crnn_mobilenet_v3_large",
|
| 46 |
+
pretrained=True,
|
| 47 |
+
assume_straight_pages=True,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
if torch.cuda.is_available():
|
| 52 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
| 53 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
|
|
| 101 |
metadata={"hnsw:space": "cosine"},
|
| 102 |
)
|
| 103 |
descs = []
|
|
|
|
| 104 |
for image in images:
|
| 105 |
try:
|
| 106 |
descs.append(get_image_description(image)[0])
|
|
|
|
| 122 |
chunk_overlap=10,
|
| 123 |
)
|
| 124 |
|
| 125 |
+
if len(text.replace(" ", "").replace("\n", "")) == 0:
|
| 126 |
+
gr.Error("No text found in documents")
|
| 127 |
+
else:
|
| 128 |
docs = splitter.create_documents([text])
|
| 129 |
doc_texts = [i.page_content for i in docs]
|
| 130 |
text_collection.add(
|
|
|
|
| 133 |
return client
|
| 134 |
|
| 135 |
|
| 136 |
+
def extract_only_text(reader):
|
| 137 |
+
text = ""
|
| 138 |
+
for _, page in enumerate(reader.pages):
|
| 139 |
+
text = page.extract_text()
|
| 140 |
+
return text.strip()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def extract_data_from_pdfs(
|
| 144 |
+
docs, session, include_images, do_ocr, progress=gr.Progress()
|
| 145 |
+
):
|
| 146 |
if len(docs) == 0:
|
| 147 |
raise gr.Error("No documents to process")
|
| 148 |
progress(0, "Extracting Images")
|
|
|
|
| 151 |
|
| 152 |
progress(0.25, "Extracting Text")
|
| 153 |
|
|
|
|
|
|
|
|
|
|
| 154 |
all_text = ""
|
| 155 |
|
| 156 |
images = []
|
| 157 |
for doc in docs:
|
| 158 |
+
if do_ocr == "Get Text With OCR":
|
| 159 |
+
pdf_doc = DocumentFile.from_pdf(doc)
|
| 160 |
+
result = ocr_model(pdf_doc)
|
| 161 |
+
all_text += result_to_text(result, as_text=True) + "\n\n"
|
| 162 |
+
else:
|
| 163 |
+
reader = PdfReader(doc)
|
| 164 |
+
all_text += extract_only_text(reader) + "\n\n"
|
| 165 |
+
|
| 166 |
if include_images == "Include Images":
|
| 167 |
+
images.extend(extract_images([doc]))
|
| 168 |
|
| 169 |
progress(
|
| 170 |
0.6, "Generating image descriptions and inserting everything into vectorDB"
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
def conversation(
|
| 194 |
+
vectordb_client,
|
| 195 |
+
msg,
|
| 196 |
+
num_context,
|
| 197 |
+
img_context,
|
| 198 |
+
history,
|
| 199 |
+
temperature,
|
| 200 |
+
max_new_tokens,
|
| 201 |
+
hf_token,
|
| 202 |
+
model_path,
|
| 203 |
):
|
| 204 |
if hf_token.strip() != "" and model_path.strip() != "":
|
| 205 |
llm = HuggingFaceEndpoint(
|
| 206 |
repo_id=model_path,
|
| 207 |
+
temperature=temperature,
|
| 208 |
+
max_new_tokens=max_new_tokens,
|
| 209 |
huggingfacehub_api_token=hf_token,
|
| 210 |
)
|
| 211 |
else:
|
| 212 |
llm = HuggingFaceEndpoint(
|
| 213 |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
| 214 |
+
temperature=temperature,
|
| 215 |
+
max_new_tokens=max_new_tokens,
|
| 216 |
huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
|
| 217 |
)
|
| 218 |
|
|
|
|
| 319 |
label="Include/ Exclude Images",
|
| 320 |
interactive=True,
|
| 321 |
)
|
| 322 |
+
do_ocr = gr.Radio(
|
| 323 |
+
["Get Text With OCR", "Get Available Text Only"],
|
| 324 |
+
value="Get Text With OCR",
|
| 325 |
+
label="OCR/ No OCR",
|
| 326 |
+
interactive=True,
|
| 327 |
+
)
|
| 328 |
|
| 329 |
with gr.Row(equal_height=True, variant="panel") as row:
|
| 330 |
selected = gr.Dataframe(
|
|
|
|
| 379 |
interactive=True,
|
| 380 |
value=2,
|
| 381 |
)
|
| 382 |
+
with gr.Row(variant="panel", equal_height=True):
|
| 383 |
+
temp = gr.Slider(
|
| 384 |
+
label="Temperature",
|
| 385 |
+
minimum=0.1,
|
| 386 |
+
maximum=1,
|
| 387 |
+
step=0.1,
|
| 388 |
+
interactive=True,
|
| 389 |
+
value=0.4,
|
| 390 |
+
)
|
| 391 |
+
max_tokens = gr.Slider(
|
| 392 |
+
label="Max Tokens",
|
| 393 |
+
minimum=10,
|
| 394 |
+
maximum=2000,
|
| 395 |
+
step=10,
|
| 396 |
+
interactive=True,
|
| 397 |
+
value=500,
|
| 398 |
+
)
|
| 399 |
with gr.Row():
|
| 400 |
with gr.Column():
|
| 401 |
ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
|
|
|
|
| 430 |
)
|
| 431 |
embed.click(
|
| 432 |
extract_data_from_pdfs,
|
| 433 |
+
inputs=[doc_collection, session_states, include_images, do_ocr],
|
| 434 |
outputs=[
|
| 435 |
vectordb,
|
| 436 |
session_states,
|
|
|
|
| 443 |
|
| 444 |
submit_btn.click(
|
| 445 |
conversation,
|
| 446 |
+
[
|
| 447 |
+
vectordb,
|
| 448 |
+
msg,
|
| 449 |
+
num_context,
|
| 450 |
+
img_context,
|
| 451 |
+
chatbot,
|
| 452 |
+
temp,
|
| 453 |
+
max_tokens,
|
| 454 |
+
hf_token,
|
| 455 |
+
model_path,
|
| 456 |
+
],
|
| 457 |
[chatbot, references, ret_images],
|
| 458 |
)
|
| 459 |
|
requirements.txt
CHANGED
|
@@ -7,8 +7,10 @@ pandas==2.2.2
|
|
| 7 |
Pillow==10.3.0
|
| 8 |
pymupdf==1.24.5
|
| 9 |
sentence_transformers==3.0.1
|
| 10 |
-
unstructured[all-docs]
|
| 11 |
accelerate
|
| 12 |
bitsandbytes
|
| 13 |
easyocr
|
| 14 |
-
ocrmypdf
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
Pillow==10.3.0
|
| 8 |
pymupdf==1.24.5
|
| 9 |
sentence_transformers==3.0.1
|
|
|
|
| 10 |
accelerate
|
| 11 |
bitsandbytes
|
| 12 |
easyocr
|
| 13 |
+
ocrmypdf
|
| 14 |
+
tf2onnx
|
| 15 |
+
clean-text[gpl]
|
| 16 |
+
python-doctr[torch]
|
utils.py
CHANGED
|
@@ -27,19 +27,17 @@ def extract_pdfs(docs, doc_collection):
|
|
| 27 |
def extract_images(docs):
|
| 28 |
images = []
|
| 29 |
for doc_path in docs:
|
| 30 |
-
doc = pymupdf.open(doc_path)
|
| 31 |
|
| 32 |
-
for page_index in range(len(doc)):
|
| 33 |
-
page = doc[page_index]
|
| 34 |
image_list = page.get_images()
|
| 35 |
|
| 36 |
-
for
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
xref = img[0] # get the XREF of the image
|
| 40 |
-
pix = pymupdf.Pixmap(doc, xref) # create a Pixmap
|
| 41 |
|
| 42 |
-
if pix.n - pix.alpha > 3:
|
| 43 |
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
| 44 |
|
| 45 |
images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
|
|
|
|
| 27 |
def extract_images(docs):
|
| 28 |
images = []
|
| 29 |
for doc_path in docs:
|
| 30 |
+
doc = pymupdf.open(doc_path)
|
| 31 |
|
| 32 |
+
for page_index in range(len(doc)):
|
| 33 |
+
page = doc[page_index]
|
| 34 |
image_list = page.get_images()
|
| 35 |
|
| 36 |
+
for _, img in enumerate(image_list, start=1):
|
| 37 |
+
xref = img[0]
|
| 38 |
+
pix = pymupdf.Pixmap(doc, xref)
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
if pix.n - pix.alpha > 3:
|
| 41 |
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
| 42 |
|
| 43 |
images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
|