Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import pdf2image | |
| from transformers import ( | |
| LayoutLMv3FeatureExtractor, | |
| LayoutLMv3TokenizerFast, | |
| LayoutLMv3Processor, | |
| LayoutLMv3ForSequenceClassification | |
| ) | |
| import torch | |
| import logging | |
| import traceback | |
| import os | |
| import tempfile | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DocumentClassifier: | |
| """Class to classify document types based on layout and text content""" | |
| def __init__(self): | |
| try: | |
| self.feature_extractor = LayoutLMv3FeatureExtractor() | |
| self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained( | |
| "microsoft/layoutlmv3-base" | |
| ) | |
| self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer) | |
| self.model = LayoutLMv3ForSequenceClassification.from_pretrained( | |
| "RavindiG/layoutlmv3-document-classification-v2" | |
| ) | |
| self.model.eval() | |
| self.id2label = { | |
| 0: 'Financial Report', | |
| 1: 'Invoice or Receipt', | |
| 2: 'Legal Document', | |
| 3: 'Medical Record', | |
| 4: 'Research Paper' | |
| } | |
| logger.info("DocumentClassifier initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Error during initialization: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| def predict_document_class(self, image): | |
| try: | |
| logger.info(f"Image size: {image.size}, mode: {image.mode}") | |
| if image.mode != 'RGB': | |
| logger.info(f"Converting image from {image.mode} to RGB") | |
| image = image.convert('RGB') | |
| logger.info("Processing image...") | |
| encoded_inputs = self.processor(image, max_length=512, return_tensors="pt") | |
| device = next(self.model.parameters()).device | |
| for k, v in encoded_inputs.items(): | |
| encoded_inputs[k] = v.to(device) | |
| logger.info("Running model inference...") | |
| with torch.no_grad(): | |
| outputs = self.model(**encoded_inputs) | |
| logits = outputs.logits | |
| logger.info(f"Logits: {logits}") | |
| threshold = getattr(self, "confidence_threshold", 1.0) | |
| max_logit = logits.max().item() | |
| # Get probabilities for all classes | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| if max_logit < threshold: | |
| logger.info("Max logit below threshold, returning 'Other'") | |
| predicted_label = "Other" | |
| confidence_dict = {"Other": 1.0} | |
| else: | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_label = self.id2label.get(predicted_class_idx, "Other") | |
| # Create confidence dictionary for all classes | |
| confidence_dict = { | |
| self.id2label[i]: probabilities[i].item() | |
| for i in range(len(self.id2label)) | |
| } | |
| logger.info(f"Predicted label: {predicted_label}") | |
| return predicted_label, confidence_dict | |
| except Exception as e: | |
| logger.error(f"Error in predict_document_class: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| def classify_document(self, file_path): | |
| """Classify document type based on layout and extracted text""" | |
| try: | |
| logger.info(f"Processing file: {file_path}") | |
| if file_path.lower().endswith('.pdf'): | |
| logger.info("Converting PDF to image...") | |
| images = pdf2image.convert_from_path( | |
| file_path, | |
| dpi=120, | |
| fmt="RGB", | |
| first_page=1, | |
| last_page=1 | |
| ) | |
| if not images: | |
| raise ValueError("No images extracted from PDF") | |
| image = images[0] | |
| logger.info(f"PDF converted successfully, image size: {image.size}") | |
| else: | |
| logger.info("Opening image file...") | |
| image = Image.open(file_path) | |
| logger.info(f"Image opened successfully, size: {image.size}") | |
| doc_type, confidence = self.predict_document_class(image) | |
| logger.info(f"Classification successful: {doc_type}") | |
| return doc_type, confidence, image | |
| except Exception as e: | |
| logger.error(f"Error classifying document {file_path}: {e}") | |
| logger.error(traceback.format_exc()) | |
| return "Error", {"Error": 1.0}, None | |
| # Initialize the classifier | |
| logger.info("Initializing DocumentClassifier...") | |
| classifier = DocumentClassifier() | |
| logger.info("Classifier ready!") | |
| def classify_upload(file): | |
| """Gradio interface function for file upload""" | |
| if file is None: | |
| return "No file uploaded", {}, None | |
| try: | |
| # Get the file path | |
| file_path = file.name | |
| # Classify the document | |
| doc_type, confidence, image = classifier.classify_document(file_path) | |
| return doc_type, confidence, image | |
| except Exception as e: | |
| logger.error(f"Error in classify_upload: {e}") | |
| logger.error(traceback.format_exc()) | |
| return f"Error: {str(e)}", {}, None | |
| # Create Gradio interface | |
| with gr.Blocks(title="Document Classifier") as demo: | |
| gr.Markdown( | |
| """ | |
| # π Document Type Classifier | |
| Upload a document image or PDF to classify it into one of the following categories: | |
| - Financial Report | |
| - Invoice or Receipt | |
| - Legal Document | |
| - Medical Record | |
| - Research Paper | |
| **Note:** Only the first page of PDFs will be analyzed. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| label="Upload Document (Image or PDF)", | |
| file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"] | |
| ) | |
| classify_btn = gr.Button("Classify Document", variant="primary") | |
| with gr.Column(): | |
| label_output = gr.Textbox(label="Document Type", lines=2) | |
| confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5) | |
| image_output = gr.Image(label="Document Preview", type="pil") | |
| # Examples | |
| gr.Markdown("### Examples") | |
| gr.Markdown("Upload your document using the file uploader above.") | |
| # Connect the button to the function | |
| classify_btn.click( | |
| fn=classify_upload, | |
| inputs=[file_input], | |
| outputs=[label_output, confidence_output, image_output] | |
| ) | |
| # Also trigger on file upload | |
| file_input.change( | |
| fn=classify_upload, | |
| inputs=[file_input], | |
| outputs=[label_output, confidence_output, image_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |