Spaces:
Build error
Build error
| """Gradio interface for DeepSeek-OCR on Hugging Face Spaces. | |
| This application loads the `deepseek-ai/DeepSeek-OCR` vision-language model | |
| and exposes a simple interface capable of processing both image and PDF | |
| documents. The implementation targets the Hugging Face free T4 GPU runtime and | |
| optimizes throughput with bfloat16 precision, flash-attention, and optional | |
| vLLM acceleration when available. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import dataclasses | |
| import logging | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModel, AutoTokenizer | |
| try: # Optional dependency for faster batching | |
| from vllm import LLM, SamplingParams # type: ignore | |
| _HAS_VLLM = True | |
| except Exception: # pragma: no cover - optional path | |
| LLM = None # type: ignore | |
| SamplingParams = None # type: ignore | |
| _HAS_VLLM = False | |
| try: | |
| import fitz # type: ignore[attr-defined] | |
| except Exception as exc: # pragma: no cover - ensures import error is visible | |
| raise RuntimeError( | |
| "PyMuPDF (fitz) is required for PDF processing. Install pymupdf." | |
| ) from exc | |
| logging.basicConfig(level=logging.INFO) | |
| LOGGER = logging.getLogger("deepseek_ocr_app") | |
| MODEL_NAME = "deepseek-ai/DeepSeek-OCR" | |
| DEFAULT_PROMPT = "<image>\n<|grounding|>Convert the document to markdown." | |
| GUNDAM_CONFIG = { | |
| "base_size": 1024, | |
| "image_size": 640, | |
| "crop_mode": True, | |
| "test_compress": True, | |
| } | |
| class PageResult: | |
| """Result for a single page processed by DeepSeek-OCR.""" | |
| index: int | |
| text: str | |
| class DocumentResult: | |
| """Aggregate OCR result for an input document.""" | |
| filename: str | |
| page_results: List[PageResult] | |
| def to_markdown(self) -> str: | |
| sections = [] | |
| for page in self.page_results: | |
| heading = f"### Page {page.index}" | |
| sections.append(f"{heading}\n\n{page.text.strip()}".strip()) | |
| return "\n\n".join(sections).strip() | |
| def has_cuda() -> bool: | |
| return torch.cuda.is_available() | |
| class DeepSeekOCREngine: | |
| """Wrapper around the DeepSeek-OCR model for document processing.""" | |
| def __init__( | |
| self, | |
| model_name: str = MODEL_NAME, | |
| prompt: str = DEFAULT_PROMPT, | |
| config: Optional[dict] = None, | |
| enable_vllm: bool = True, | |
| ) -> None: | |
| self.model_name = model_name | |
| self.prompt_template = prompt | |
| self.config = {**GUNDAM_CONFIG, **(config or {})} | |
| self.enable_vllm = enable_vllm and _HAS_VLLM | |
| self.device = torch.device("cuda" if has_cuda() else "cpu") | |
| self._model = None | |
| self._tokenizer = None | |
| self._vllm_engine = None | |
| self._vllm_sampling_params = None | |
| self._output_root = Path(tempfile.mkdtemp(prefix="deepseek_ocr_out_")) | |
| self._load_model() | |
| def tokenizer(self): | |
| if self._tokenizer is None: | |
| raise RuntimeError("Tokenizer not initialized") | |
| return self._tokenizer | |
| def model(self): | |
| if self._model is None: | |
| raise RuntimeError("Model not initialized") | |
| return self._model | |
| def _load_model(self) -> None: | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # Try to install flash-attn if not available and we're on CUDA | |
| if self.device.type == "cuda": | |
| self._ensure_flash_attention() | |
| if self.enable_vllm: | |
| try: | |
| LOGGER.info("Initializing DeepSeek-OCR with vLLM backend") | |
| self._vllm_engine = LLM( | |
| model=self.model_name, | |
| dtype="bfloat16" if has_cuda() else "float32", | |
| tokenizer=self.model_name, | |
| trust_remote_code=True, | |
| ) | |
| self._vllm_sampling_params = SamplingParams( | |
| temperature=0.0, | |
| top_p=0.9, | |
| max_tokens=4096, | |
| ) | |
| except Exception as vllm_error: | |
| LOGGER.warning( | |
| "vLLM initialization failed (%s). Falling back to HF AutoModel.", | |
| vllm_error, | |
| ) | |
| self.enable_vllm = False | |
| if not self.enable_vllm: | |
| LOGGER.info("Loading DeepSeek-OCR with transformers backend") | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, trust_remote_code=True | |
| ) | |
| torch_dtype = torch.bfloat16 if self.device.type == "cuda" else torch.float32 | |
| # Try with flash attention first, fall back to standard attention | |
| attn_implementation = "flash_attention_2" if self._has_flash_attention() else "eager" | |
| if attn_implementation == "flash_attention_2": | |
| LOGGER.info("Using flash attention for faster inference") | |
| else: | |
| LOGGER.info("Using standard attention (flash attention not available)") | |
| self._model = AutoModel.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| _attn_implementation=attn_implementation, | |
| torch_dtype=torch_dtype, | |
| ) | |
| self._model = self._model.eval().to(self.device) | |
| def _ensure_flash_attention(self) -> None: | |
| """Ensure flash-attn is installed for CUDA devices.""" | |
| if not self._has_flash_attention(): | |
| LOGGER.info("Installing flash-attn for optimized attention") | |
| try: | |
| # Try installing flash-attn with no-build-isolation to avoid torch import issues | |
| result = subprocess.run([ | |
| sys.executable, "-m", "pip", "install", | |
| "flash-attn==2.7.3", "--no-build-isolation" | |
| ], capture_output=True, text=True, timeout=300) | |
| if result.returncode == 0: | |
| LOGGER.info("Successfully installed flash-attn") | |
| else: | |
| LOGGER.warning("Failed to install flash-attn: %s", result.stderr) | |
| except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: | |
| LOGGER.warning("Failed to install flash-attn: %s", e) | |
| def _has_flash_attention(self) -> bool: | |
| """Check if flash attention is available.""" | |
| try: | |
| import flash_attn # noqa: F401 | |
| return True | |
| except ImportError: | |
| return False | |
| def cleanup(self) -> None: | |
| if self._output_root.exists(): | |
| shutil.rmtree(self._output_root, ignore_errors=True) | |
| def _infer_transformers(self, image_path: Path, prompt: str) -> str: | |
| result = self.model.infer( | |
| self.tokenizer, | |
| prompt=prompt, | |
| image_file=str(image_path), | |
| output_path=str(self._output_root), | |
| base_size=self.config["base_size"], | |
| image_size=self.config["image_size"], | |
| crop_mode=self.config["crop_mode"], | |
| save_results=False, | |
| test_compress=self.config.get("test_compress", True), | |
| ) | |
| if isinstance(result, dict): | |
| for key in ("text", "markdown", "raw_text", "result"): | |
| if key in result and isinstance(result[key], str): | |
| return result[key] | |
| return "\n".join(str(v) for v in result.values()) | |
| if isinstance(result, (list, tuple)): | |
| return "\n".join(str(item) for item in result) | |
| return str(result) | |
| def _infer_vllm(self, image_path: Path, prompt: str) -> str: | |
| if not self.enable_vllm or self._vllm_engine is None: | |
| raise RuntimeError("vLLM backend is not initialized") | |
| formatted_prompt = f"<image>{prompt.replace('<image>', '').strip()}" | |
| outputs = self._vllm_engine.generate( | |
| prompts=[formatted_prompt], | |
| image_data=[[Image.open(image_path)]], | |
| sampling_params=self._vllm_sampling_params, | |
| ) | |
| return outputs[0].outputs[0].text if outputs else "" | |
| def _infer(self, image_path: Path, prompt: str) -> str: | |
| if self.enable_vllm: | |
| try: | |
| return self._infer_vllm(image_path, prompt) | |
| except Exception as error: | |
| LOGGER.warning( | |
| "Falling back to transformers backend after vLLM error: %s", | |
| error, | |
| ) | |
| self.enable_vllm = False | |
| return self._infer_transformers(image_path, prompt) | |
| def _convert_pdf_to_images( | |
| self, pdf_path: Path, output_dir: Path, dpi: int = 192 | |
| ) -> List[Path]: | |
| document = fitz.open(pdf_path) | |
| image_paths: List[Path] = [] | |
| zoom = dpi / 72 # Default PDF DPI is 72 | |
| matrix = fitz.Matrix(zoom, zoom) | |
| for page_index in range(len(document)): | |
| page = document.load_page(page_index) | |
| pixmap = page.get_pixmap(matrix=matrix, alpha=False) | |
| page_path = output_dir / f"page-{page_index + 1:04d}.png" | |
| pixmap.save(page_path) | |
| image_paths.append(page_path) | |
| document.close() | |
| return image_paths | |
| def process_document( | |
| self, | |
| file_path: Path, | |
| prompt: Optional[str] = None, | |
| progress: Optional[gr.Progress] = None, | |
| ) -> DocumentResult: | |
| prompt_to_use = prompt.strip() if prompt and prompt.strip() else self.prompt_template | |
| suffix = file_path.suffix.lower() | |
| with tempfile.TemporaryDirectory(prefix="deepseek_ocr_tmp_") as tmp_dir: | |
| tmp_dir_path = Path(tmp_dir) | |
| if suffix in {".png", ".jpg", ".jpeg", ".bmp", ".webp", ".tif", ".tiff"}: | |
| image_paths = [self._ensure_rgb_image(file_path, tmp_dir_path)] | |
| elif suffix == ".pdf": | |
| if progress: | |
| progress(0.0, desc="Converting PDF pages") | |
| image_paths = self._convert_pdf_to_images(file_path, tmp_dir_path) | |
| else: | |
| raise ValueError("Unsupported file format. Please upload an image or PDF.") | |
| total_pages = len(image_paths) | |
| page_results: List[PageResult] = [] | |
| for idx, image_path in enumerate(image_paths, start=1): | |
| if progress: | |
| progress( | |
| (idx - 1) / max(total_pages, 1), | |
| desc=f"Processing page {idx}/{total_pages}" | |
| ) | |
| text = self._infer(image_path, prompt_to_use) | |
| page_results.append(PageResult(index=idx, text=text)) | |
| if progress: | |
| progress(1.0, desc="Completed") | |
| return DocumentResult(filename=file_path.name, page_results=page_results) | |
| def _ensure_rgb_image(self, image_path: Path, output_dir: Path) -> Path: | |
| """Ensure the provided image is saved as RGB PNG for the model.""" | |
| image = Image.open(image_path) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| output_path = output_dir / f"image-{image_path.stem}.png" | |
| image.save(output_path, format="PNG", optimize=True) | |
| return output_path | |
| def progress_tracker(progress: Optional[gr.Progress]): | |
| yield progress if progress else None | |
| ENGINE: Optional[DeepSeekOCREngine] = None | |
| def get_engine() -> DeepSeekOCREngine: | |
| global ENGINE | |
| if ENGINE is None: | |
| use_vllm_env = os.getenv("USE_VLLM", "1").strip().lower() | |
| enable_vllm = use_vllm_env not in {"0", "false", "no"} | |
| LOGGER.info("Instantiating DeepSeek-OCR engine (vLLM=%s)", enable_vllm) | |
| ENGINE = DeepSeekOCREngine(enable_vllm=enable_vllm) | |
| return ENGINE | |
| def handle_upload( | |
| file: gr.File | None, | |
| prompt: str, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), | |
| ) -> str: | |
| if file is None: | |
| raise gr.Error("Please upload an image or PDF file to start OCR.") | |
| uploaded_path = Path(file.name) | |
| fd, tmp_path_str = tempfile.mkstemp( | |
| prefix="deepseek_upload_", | |
| suffix=uploaded_path.suffix, | |
| ) | |
| os.close(fd) | |
| tmp_copy = Path(tmp_path_str) | |
| shutil.copy(uploaded_path, tmp_copy) | |
| engine = get_engine() | |
| try: | |
| with progress_tracker(progress) as tracker: | |
| result = engine.process_document(tmp_copy, prompt=prompt, progress=tracker) | |
| finally: | |
| tmp_copy.unlink(missing_ok=True) | |
| return result.to_markdown() | |
| def build_interface() -> gr.Blocks: | |
| description = ( | |
| "Upload an image or PDF and DeepSeek-OCR will transcribe it into Markdown. " | |
| "Optimized for Hugging Face free T4 GPU Spaces with flash-attention and " | |
| "optional vLLM acceleration." | |
| ) | |
| with gr.Blocks(title="DeepSeek-OCR", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# DeepSeek-OCR PDF & Image Reader") | |
| gr.Markdown(description) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload document", | |
| file_count="single", | |
| type="file", | |
| file_types=[".png", ".jpg", ".jpeg", ".pdf", ".bmp", ".webp", ".tiff", ".tif"], | |
| ) | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| value=DEFAULT_PROMPT, | |
| lines=3, | |
| show_label=True, | |
| placeholder="Enter the grounding instruction for OCR", | |
| ) | |
| submit_btn = gr.Button("Run OCR", variant="primary") | |
| with gr.Column(scale=1): | |
| result_output = gr.Markdown(label="OCR Markdown Output") | |
| submit_btn.click( | |
| fn=handle_upload, | |
| inputs=[file_input, prompt_box], | |
| outputs=[result_output], | |
| ) | |
| return demo | |
| demo = build_interface() | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=2, status_tracker=False).launch() | |