Ravindi Gunarathna
Initial commit
0efe8b2
import gradio as gr
from PIL import Image
import pdf2image
from transformers import (
LayoutLMv3FeatureExtractor,
LayoutLMv3TokenizerFast,
LayoutLMv3Processor,
LayoutLMv3ForSequenceClassification
)
import torch
import logging
import traceback
import os
import tempfile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DocumentClassifier:
"""Class to classify document types based on layout and text content"""
def __init__(self):
try:
self.feature_extractor = LayoutLMv3FeatureExtractor()
self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained(
"microsoft/layoutlmv3-base"
)
self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer)
self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
"RavindiG/layoutlmv3-document-classification-v2"
)
self.model.eval()
self.id2label = {
0: 'Financial Report',
1: 'Invoice or Receipt',
2: 'Legal Document',
3: 'Medical Record',
4: 'Research Paper'
}
logger.info("DocumentClassifier initialized successfully.")
except Exception as e:
logger.error(f"Error during initialization: {e}")
logger.error(traceback.format_exc())
raise
def predict_document_class(self, image):
try:
logger.info(f"Image size: {image.size}, mode: {image.mode}")
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
logger.info("Processing image...")
encoded_inputs = self.processor(image, max_length=512, return_tensors="pt")
device = next(self.model.parameters()).device
for k, v in encoded_inputs.items():
encoded_inputs[k] = v.to(device)
logger.info("Running model inference...")
with torch.no_grad():
outputs = self.model(**encoded_inputs)
logits = outputs.logits
logger.info(f"Logits: {logits}")
threshold = getattr(self, "confidence_threshold", 1.0)
max_logit = logits.max().item()
# Get probabilities for all classes
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
if max_logit < threshold:
logger.info("Max logit below threshold, returning 'Other'")
predicted_label = "Other"
confidence_dict = {"Other": 1.0}
else:
predicted_class_idx = logits.argmax(-1).item()
predicted_label = self.id2label.get(predicted_class_idx, "Other")
# Create confidence dictionary for all classes
confidence_dict = {
self.id2label[i]: probabilities[i].item()
for i in range(len(self.id2label))
}
logger.info(f"Predicted label: {predicted_label}")
return predicted_label, confidence_dict
except Exception as e:
logger.error(f"Error in predict_document_class: {e}")
logger.error(traceback.format_exc())
raise
def classify_document(self, file_path):
"""Classify document type based on layout and extracted text"""
try:
logger.info(f"Processing file: {file_path}")
if file_path.lower().endswith('.pdf'):
logger.info("Converting PDF to image...")
images = pdf2image.convert_from_path(
file_path,
dpi=120,
fmt="RGB",
first_page=1,
last_page=1
)
if not images:
raise ValueError("No images extracted from PDF")
image = images[0]
logger.info(f"PDF converted successfully, image size: {image.size}")
else:
logger.info("Opening image file...")
image = Image.open(file_path)
logger.info(f"Image opened successfully, size: {image.size}")
doc_type, confidence = self.predict_document_class(image)
logger.info(f"Classification successful: {doc_type}")
return doc_type, confidence, image
except Exception as e:
logger.error(f"Error classifying document {file_path}: {e}")
logger.error(traceback.format_exc())
return "Error", {"Error": 1.0}, None
# Initialize the classifier
logger.info("Initializing DocumentClassifier...")
classifier = DocumentClassifier()
logger.info("Classifier ready!")
def classify_upload(file):
"""Gradio interface function for file upload"""
if file is None:
return "No file uploaded", {}, None
try:
# Get the file path
file_path = file.name
# Classify the document
doc_type, confidence, image = classifier.classify_document(file_path)
return doc_type, confidence, image
except Exception as e:
logger.error(f"Error in classify_upload: {e}")
logger.error(traceback.format_exc())
return f"Error: {str(e)}", {}, None
# Create Gradio interface
with gr.Blocks(title="Document Classifier") as demo:
gr.Markdown(
"""
# πŸ“„ Document Type Classifier
Upload a document image or PDF to classify it into one of the following categories:
- Financial Report
- Invoice or Receipt
- Legal Document
- Medical Record
- Research Paper
**Note:** Only the first page of PDFs will be analyzed.
"""
)
with gr.Row():
with gr.Column():
file_input = gr.File(
label="Upload Document (Image or PDF)",
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
)
classify_btn = gr.Button("Classify Document", variant="primary")
with gr.Column():
label_output = gr.Textbox(label="Document Type", lines=2)
confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5)
image_output = gr.Image(label="Document Preview", type="pil")
# Examples
gr.Markdown("### Examples")
gr.Markdown("Upload your document using the file uploader above.")
# Connect the button to the function
classify_btn.click(
fn=classify_upload,
inputs=[file_input],
outputs=[label_output, confidence_output, image_output]
)
# Also trigger on file upload
file_input.change(
fn=classify_upload,
inputs=[file_input],
outputs=[label_output, confidence_output, image_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()