import io, os, re, json from typing import List, Tuple, Dict import numpy as np import pandas as pd from PIL import Image, ImageOps, ImageFilter import streamlit as st import torch import torchvision.transforms as T # --- word detector (Tesseract) --- import pytesseract from pytesseract import Output # --- PDF -> images --- from pdf2image import convert_from_bytes # ---- import the repo's models ---- # Install via requirements.txt (git+https URL) OR copy repo files into root. # The repo defines model classes: Swin_CTC, VED import models as pdrt_models # from dparres/Pretrained-Document-Recognition-Transformers st.set_page_config(page_title="Invoice OCR (ViT recognizer + Tesseract detector)", layout="wide") # ========================= UI SIDEBAR ========================= st.sidebar.header("Model") arch = st.sidebar.selectbox("Architecture", ["Swin_CTC", "VED"], index=0) ckpt_path = st.sidebar.text_input("Checkpoint path (inside Space)", value="checkpoints/pdrt_weights.pth") alphabet = st.sidebar.text_input("Alphabet (ordered classes, exclude CTC blank)", value="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_/.,:;()[]{}#+*&%$@!?\"' ") img_h = st.sidebar.number_input("Recognizer input height", 64, 256, 128, 8) img_w = st.sidebar.number_input("Recognizer input width", 128, 2048, 512, 16) det_lang = st.sidebar.text_input("Tesseract lang(s) for detection only", value="eng") show_boxes = st.sidebar.checkbox("Show word boxes", value=False) device = "cuda" if torch.cuda.is_available() else "cpu" # ========================= UTILITIES ========================= def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]: name = (name or "").lower() if name.endswith(".pdf"): return convert_from_bytes(file_bytes, dpi=300) return [Image.open(io.BytesIO(file_bytes)).convert("RGB")] def preprocess_for_detection(img: Image.Image) -> Image.Image: g = ImageOps.grayscale(img) g = ImageOps.autocontrast(g) g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3)) return g @st.cache_resource def load_pdrt(arch_name: str, ckpt: str, num_classes: int): if arch_name == "Swin_CTC": model = pdrt_models.Swin_CTC(num_classes=num_classes) elif arch_name == "VED": model = pdrt_models.VED(num_classes=num_classes) else: raise ValueError("Unknown model") state = torch.load(ckpt, map_location="cpu") model.load_state_dict(state, strict=False) model.eval().to(device) return model def build_transform(img_h: int, img_w: int): return T.Compose([ T.Grayscale(num_output_channels=3), # keep 3ch if encoder expects RGB T.Resize((img_h, img_w)), T.ToTensor(), T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), ]) def greedy_ctc_decode(logits: torch.Tensor, alphabet: str) -> str: """ logits: (B, T, C) or (T, B, C). We map argmax to chars, collapse repeats, remove blank. We assume blank_id = len(alphabet). """ if logits.dim() == 3 and logits.shape[0] != 1 and logits.shape[1] == 1: # rare shape, just permute if needed pass if logits.shape[0] == 1: logits = logits.squeeze(0) # (T, C) elif logits.shape[1] == 1: logits = logits[:,0,:] # (T, C) probs = logits.softmax(-1) ids = probs.argmax(-1).tolist() blank_id = len(alphabet) out = [] prev = None for i in ids: if i != prev and i != blank_id: out.append(alphabet[i] if i < len(alphabet) else "") prev = i return "".join(out) def recognize_word_crops(model, crops: List[Image.Image], tfm, arch_name: str, alphabet: str) -> List[str]: texts = [] with torch.no_grad(): for im in crops: x = tfm(im).unsqueeze(0).to(device) y = model(x) if arch_name == "Swin_CTC": # expect CTC logits [B, T, C] or [T, B, C] if y.dim() == 3 and y.shape[0] == 1: # [1, T, C] logits = y[0] # [T, C] elif y.dim() == 3 and y.shape[1] == 1: # [T, 1, C] logits = y[:,0,:] else: logits = y txt = greedy_ctc_decode(logits, alphabet) else: # VED: if returns token ids/logits, plug your repo's decoding here. # Fallback: argmax over last dim per step and map ids to alphabet (no blank). if y.dim() == 3 and y.shape[0] == 1: y = y[0] ids = y.argmax(-1).tolist() txt = "".join(alphabet[i] if i < len(alphabet) else "" for i in ids).strip() texts.append(txt) return texts def detect_words(img: Image.Image, lang="eng") -> pd.DataFrame: df = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME) df = df.dropna(subset=["text"]).reset_index(drop=True) df["x2"] = df["left"] + df["width"] df["y2"] = df["top"] + df["height"] return df[df["conf"] > -1] def crop_words(img: Image.Image, df: pd.DataFrame) -> List[Tuple[Image.Image, Dict]]: crops, metas = [], [] for _, r in df.iterrows(): if str(r["text"]).strip() == "": continue box = (int(r["left"]), int(r["top"]), int(r["x2"]), int(r["y2"])) c = img.crop(box) crops.append(c) metas.append({"box": box}) return crops, metas # ---------------- key fields & table (same logic as earlier Tesseract app) ---------------- CURRENCY = r"(?PUSD|CAD|EUR|GBP|\$|C\$|€|£)?" MONEY = rf"{CURRENCY}\s?(?P\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)" DATE = r"(?P(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))" INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P[A-Z0-9\-_/]{4,}))" PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P[A-Z0-9\-_/]{3,}))" TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})" SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})" TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})" def parse_fields(fulltext: str): t = re.sub(r"[ \t]+", " ", fulltext) t = re.sub(r"\n{2,}", "\n", t) out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None} m = re.search(INV_PAT, t, re.I); out["invoice_number"] = m.group("inv") if m else None m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I) out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None)) m = re.search(SUBTOTAL_PAT, t, re.I|re.S); if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] m = re.search(TAX_PAT, t, re.I|re.S); if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] m = re.search(TOTAL_PAT, t, re.I|re.S); if m: out["total"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] if out["currency"] in ["$", "C$", "€", "£"]: out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}[out["currency"]] return out HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"] def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame: # Group into lines df = df.copy() df["cx"] = df["left"] + 0.5*df["width"] df["cy"] = df["top"] + 0.5*df["height"] lines = [] for (b,p,l), g in df.groupby(["block_num","par_num","line_num"]): text = " ".join([t for t in g["text"].astype(str) if t.strip()]) if text.strip(): lines.append({ "block_num":b,"par_num":p,"line_num":l, "text": text.lower(), "top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(), "left": g["left"].min(), "right": (g["left"]+g["width"]).max(), "words": g.sort_values("cx")[["cx","left","top","width","height"]].values.tolist() }) L = pd.DataFrame(lines) if L.empty: return pd.DataFrame() L["score"] = L["text"].apply(lambda s: sum(1 for h in HEAD_CANDIDATES if h in s)) headers = L[L["score"]>=2].sort_values(["score","top"], ascending=[False,True]) if headers.empty: return pd.DataFrame() H = headers.iloc[0] header_y = H["bottom"] + 4 # choose column centers from header words positions # we reuse df within header band header_band = df[(df["top"]>=H["top"]-5) & ((df["top"]+df["height"])<=H["bottom"]+5)] header_band = header_band.sort_values("left") col_x = header_band["left"].tolist() if len(col_x)<2: return pd.DataFrame() # region below header until totals below = df[df["top"]>header_y].copy() totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False) if totals_mask.any(): stop_y = below.loc[totals_mask,"top"].min() below = below[below["top"] 1: page_idx = st.number_input("Page", 1, len(pages), 1) - 1 img = pages[page_idx] col1, col2 = st.columns([1.1,1.3], gap="large") with col1: st.subheader("Preview") st.image(img, use_column_width=True) det_img = preprocess_for_detection(img) with st.expander("Detection view"): st.image(det_img, use_column_width=True) with col2: st.subheader("OCR & Extraction") # 1) detect words (boxes only) det_df = detect_words(det_img, lang=det_lang) # 2) crop & recognize each word via ViT recognizer crops, metas = crop_words(det_img, det_df) texts = recognize_word_crops(model, crops, tfm, arch, alphabet) # 3) stitch line-by-line using tesseract line indices det_df = det_df.reset_index(drop=True) det_df["pred"] = texts grouped = det_df.groupby(["block_num","par_num","line_num"]) lines = [] for _, g in grouped: g = g.sort_values("left") line = " ".join([t for t in g["pred"].tolist() if t]) lines.append(line) full_text = "\n".join([ln for ln in lines if ln.strip()]) if show_boxes: st.caption("First 15 predicted words") st.write(det_df[["left","top","width","height","text","pred"]].head(15)) # 4) key fields key_fields = parse_fields(full_text) k1,k2,k3 = st.columns(3) with k1: st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}") st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or '—'}") with k2: st.write(f"**PO #:** {key_fields.get('po_number') or '—'}") st.write(f"**Subtotal:** {key_fields.get('subtotal') or '—'}") with k3: st.write(f"**Tax:** {key_fields.get('tax') or '—'}") tot = key_fields.get('total') or '—' cur = key_fields.get('currency') or '' st.write(f"**Total:** {tot} {cur}".strip()) # 5) line items (geometry heuristic) items = items_from_wordgrid(det_df.assign(text=det_df["pred"])) st.markdown("**Line Items**") if items.empty: st.caption("No line items confidently detected.") else: st.dataframe(items, use_container_width=True) # 6) downloads result = { "file": up.name, "page": page_idx+1, "key_fields": key_fields, "items": items.to_dict(orient="records") if not items.empty else [], "full_text": full_text } st.download_button("Download JSON", data=json.dumps(result, indent=2), file_name="invoice_extraction.json", mime="application/json") if not items.empty: st.download_button("Download Items CSV", data=items.to_csv(index=False), file_name="invoice_items.csv", mime="text/csv")