Roman
Fix flash-attn installation: remove from requirements.txt, install at runtime
8f1979d
"""Gradio interface for DeepSeek-OCR on Hugging Face Spaces.
This application loads the `deepseek-ai/DeepSeek-OCR` vision-language model
and exposes a simple interface capable of processing both image and PDF
documents. The implementation targets the Hugging Face free T4 GPU runtime and
optimizes throughput with bfloat16 precision, flash-attention, and optional
vLLM acceleration when available.
"""
from __future__ import annotations
import contextlib
import dataclasses
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import List, Optional
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
try: # Optional dependency for faster batching
from vllm import LLM, SamplingParams # type: ignore
_HAS_VLLM = True
except Exception: # pragma: no cover - optional path
LLM = None # type: ignore
SamplingParams = None # type: ignore
_HAS_VLLM = False
try:
import fitz # type: ignore[attr-defined]
except Exception as exc: # pragma: no cover - ensures import error is visible
raise RuntimeError(
"PyMuPDF (fitz) is required for PDF processing. Install pymupdf."
) from exc
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger("deepseek_ocr_app")
MODEL_NAME = "deepseek-ai/DeepSeek-OCR"
DEFAULT_PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
GUNDAM_CONFIG = {
"base_size": 1024,
"image_size": 640,
"crop_mode": True,
"test_compress": True,
}
@dataclasses.dataclass
class PageResult:
"""Result for a single page processed by DeepSeek-OCR."""
index: int
text: str
@dataclasses.dataclass
class DocumentResult:
"""Aggregate OCR result for an input document."""
filename: str
page_results: List[PageResult]
def to_markdown(self) -> str:
sections = []
for page in self.page_results:
heading = f"### Page {page.index}"
sections.append(f"{heading}\n\n{page.text.strip()}".strip())
return "\n\n".join(sections).strip()
def has_cuda() -> bool:
return torch.cuda.is_available()
class DeepSeekOCREngine:
"""Wrapper around the DeepSeek-OCR model for document processing."""
def __init__(
self,
model_name: str = MODEL_NAME,
prompt: str = DEFAULT_PROMPT,
config: Optional[dict] = None,
enable_vllm: bool = True,
) -> None:
self.model_name = model_name
self.prompt_template = prompt
self.config = {**GUNDAM_CONFIG, **(config or {})}
self.enable_vllm = enable_vllm and _HAS_VLLM
self.device = torch.device("cuda" if has_cuda() else "cpu")
self._model = None
self._tokenizer = None
self._vllm_engine = None
self._vllm_sampling_params = None
self._output_root = Path(tempfile.mkdtemp(prefix="deepseek_ocr_out_"))
self._load_model()
@property
def tokenizer(self):
if self._tokenizer is None:
raise RuntimeError("Tokenizer not initialized")
return self._tokenizer
@property
def model(self):
if self._model is None:
raise RuntimeError("Model not initialized")
return self._model
def _load_model(self) -> None:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# Try to install flash-attn if not available and we're on CUDA
if self.device.type == "cuda":
self._ensure_flash_attention()
if self.enable_vllm:
try:
LOGGER.info("Initializing DeepSeek-OCR with vLLM backend")
self._vllm_engine = LLM(
model=self.model_name,
dtype="bfloat16" if has_cuda() else "float32",
tokenizer=self.model_name,
trust_remote_code=True,
)
self._vllm_sampling_params = SamplingParams(
temperature=0.0,
top_p=0.9,
max_tokens=4096,
)
except Exception as vllm_error:
LOGGER.warning(
"vLLM initialization failed (%s). Falling back to HF AutoModel.",
vllm_error,
)
self.enable_vllm = False
if not self.enable_vllm:
LOGGER.info("Loading DeepSeek-OCR with transformers backend")
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, trust_remote_code=True
)
torch_dtype = torch.bfloat16 if self.device.type == "cuda" else torch.float32
# Try with flash attention first, fall back to standard attention
attn_implementation = "flash_attention_2" if self._has_flash_attention() else "eager"
if attn_implementation == "flash_attention_2":
LOGGER.info("Using flash attention for faster inference")
else:
LOGGER.info("Using standard attention (flash attention not available)")
self._model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
use_safetensors=True,
_attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
self._model = self._model.eval().to(self.device)
def _ensure_flash_attention(self) -> None:
"""Ensure flash-attn is installed for CUDA devices."""
if not self._has_flash_attention():
LOGGER.info("Installing flash-attn for optimized attention")
try:
# Try installing flash-attn with no-build-isolation to avoid torch import issues
result = subprocess.run([
sys.executable, "-m", "pip", "install",
"flash-attn==2.7.3", "--no-build-isolation"
], capture_output=True, text=True, timeout=300)
if result.returncode == 0:
LOGGER.info("Successfully installed flash-attn")
else:
LOGGER.warning("Failed to install flash-attn: %s", result.stderr)
except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e:
LOGGER.warning("Failed to install flash-attn: %s", e)
def _has_flash_attention(self) -> bool:
"""Check if flash attention is available."""
try:
import flash_attn # noqa: F401
return True
except ImportError:
return False
def cleanup(self) -> None:
if self._output_root.exists():
shutil.rmtree(self._output_root, ignore_errors=True)
def _infer_transformers(self, image_path: Path, prompt: str) -> str:
result = self.model.infer(
self.tokenizer,
prompt=prompt,
image_file=str(image_path),
output_path=str(self._output_root),
base_size=self.config["base_size"],
image_size=self.config["image_size"],
crop_mode=self.config["crop_mode"],
save_results=False,
test_compress=self.config.get("test_compress", True),
)
if isinstance(result, dict):
for key in ("text", "markdown", "raw_text", "result"):
if key in result and isinstance(result[key], str):
return result[key]
return "\n".join(str(v) for v in result.values())
if isinstance(result, (list, tuple)):
return "\n".join(str(item) for item in result)
return str(result)
def _infer_vllm(self, image_path: Path, prompt: str) -> str:
if not self.enable_vllm or self._vllm_engine is None:
raise RuntimeError("vLLM backend is not initialized")
formatted_prompt = f"<image>{prompt.replace('<image>', '').strip()}"
outputs = self._vllm_engine.generate(
prompts=[formatted_prompt],
image_data=[[Image.open(image_path)]],
sampling_params=self._vllm_sampling_params,
)
return outputs[0].outputs[0].text if outputs else ""
def _infer(self, image_path: Path, prompt: str) -> str:
if self.enable_vllm:
try:
return self._infer_vllm(image_path, prompt)
except Exception as error:
LOGGER.warning(
"Falling back to transformers backend after vLLM error: %s",
error,
)
self.enable_vllm = False
return self._infer_transformers(image_path, prompt)
def _convert_pdf_to_images(
self, pdf_path: Path, output_dir: Path, dpi: int = 192
) -> List[Path]:
document = fitz.open(pdf_path)
image_paths: List[Path] = []
zoom = dpi / 72 # Default PDF DPI is 72
matrix = fitz.Matrix(zoom, zoom)
for page_index in range(len(document)):
page = document.load_page(page_index)
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
page_path = output_dir / f"page-{page_index + 1:04d}.png"
pixmap.save(page_path)
image_paths.append(page_path)
document.close()
return image_paths
def process_document(
self,
file_path: Path,
prompt: Optional[str] = None,
progress: Optional[gr.Progress] = None,
) -> DocumentResult:
prompt_to_use = prompt.strip() if prompt and prompt.strip() else self.prompt_template
suffix = file_path.suffix.lower()
with tempfile.TemporaryDirectory(prefix="deepseek_ocr_tmp_") as tmp_dir:
tmp_dir_path = Path(tmp_dir)
if suffix in {".png", ".jpg", ".jpeg", ".bmp", ".webp", ".tif", ".tiff"}:
image_paths = [self._ensure_rgb_image(file_path, tmp_dir_path)]
elif suffix == ".pdf":
if progress:
progress(0.0, desc="Converting PDF pages")
image_paths = self._convert_pdf_to_images(file_path, tmp_dir_path)
else:
raise ValueError("Unsupported file format. Please upload an image or PDF.")
total_pages = len(image_paths)
page_results: List[PageResult] = []
for idx, image_path in enumerate(image_paths, start=1):
if progress:
progress(
(idx - 1) / max(total_pages, 1),
desc=f"Processing page {idx}/{total_pages}"
)
text = self._infer(image_path, prompt_to_use)
page_results.append(PageResult(index=idx, text=text))
if progress:
progress(1.0, desc="Completed")
return DocumentResult(filename=file_path.name, page_results=page_results)
def _ensure_rgb_image(self, image_path: Path, output_dir: Path) -> Path:
"""Ensure the provided image is saved as RGB PNG for the model."""
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
output_path = output_dir / f"image-{image_path.stem}.png"
image.save(output_path, format="PNG", optimize=True)
return output_path
@contextlib.contextmanager
def progress_tracker(progress: Optional[gr.Progress]):
yield progress if progress else None
ENGINE: Optional[DeepSeekOCREngine] = None
def get_engine() -> DeepSeekOCREngine:
global ENGINE
if ENGINE is None:
use_vllm_env = os.getenv("USE_VLLM", "1").strip().lower()
enable_vllm = use_vllm_env not in {"0", "false", "no"}
LOGGER.info("Instantiating DeepSeek-OCR engine (vLLM=%s)", enable_vllm)
ENGINE = DeepSeekOCREngine(enable_vllm=enable_vllm)
return ENGINE
def handle_upload(
file: gr.File | None,
prompt: str,
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> str:
if file is None:
raise gr.Error("Please upload an image or PDF file to start OCR.")
uploaded_path = Path(file.name)
fd, tmp_path_str = tempfile.mkstemp(
prefix="deepseek_upload_",
suffix=uploaded_path.suffix,
)
os.close(fd)
tmp_copy = Path(tmp_path_str)
shutil.copy(uploaded_path, tmp_copy)
engine = get_engine()
try:
with progress_tracker(progress) as tracker:
result = engine.process_document(tmp_copy, prompt=prompt, progress=tracker)
finally:
tmp_copy.unlink(missing_ok=True)
return result.to_markdown()
def build_interface() -> gr.Blocks:
description = (
"Upload an image or PDF and DeepSeek-OCR will transcribe it into Markdown. "
"Optimized for Hugging Face free T4 GPU Spaces with flash-attention and "
"optional vLLM acceleration."
)
with gr.Blocks(title="DeepSeek-OCR", theme=gr.themes.Soft()) as demo:
gr.Markdown("# DeepSeek-OCR PDF & Image Reader")
gr.Markdown(description)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
file_input = gr.File(
label="Upload document",
file_count="single",
type="file",
file_types=[".png", ".jpg", ".jpeg", ".pdf", ".bmp", ".webp", ".tiff", ".tif"],
)
prompt_box = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
lines=3,
show_label=True,
placeholder="Enter the grounding instruction for OCR",
)
submit_btn = gr.Button("Run OCR", variant="primary")
with gr.Column(scale=1):
result_output = gr.Markdown(label="OCR Markdown Output")
submit_btn.click(
fn=handle_upload,
inputs=[file_input, prompt_box],
outputs=[result_output],
)
return demo
demo = build_interface()
if __name__ == "__main__":
demo.queue(concurrency_count=2, status_tracker=False).launch()