Spaces:
Paused
Paused
| """ | |
| Utility functions for the Multi-Method RAG System. | |
| Directory Layout: | |
| /data/ # Original PDFs, HTML | |
| /embeddings/ # FAISS, Chroma, DPR vector stores | |
| /graph/ # Graph database files | |
| /metadata/ # Image metadata (SQLite or MongoDB) | |
| """ | |
| import os | |
| import json | |
| import pickle | |
| import sqlite3 | |
| import base64 | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional, Any, Union | |
| from dataclasses import dataclass | |
| import logging | |
| import pymupdf4llm | |
| import pymupdf | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import requests | |
| from bs4 import BeautifulSoup | |
| # Vector stores and search | |
| import faiss | |
| import chromadb | |
| from rank_bm25 import BM25Okapi | |
| import networkx as nx | |
| # ML models | |
| from openai import OpenAI | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import torch | |
| # import clip | |
| # Text processing | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import tiktoken | |
| from config import * | |
| logger = logging.getLogger(__name__) | |
| class DocumentChunk: | |
| """Data structure for document chunks.""" | |
| text: str | |
| metadata: Dict[str, Any] | |
| chunk_id: str | |
| embedding: Optional[np.ndarray] = None | |
| class ImageData: | |
| """Data structure for image metadata.""" | |
| image_path: str | |
| image_id: str | |
| classification: Optional[str] = None | |
| embedding: Optional[np.ndarray] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| class DocumentLoader: | |
| """Load and extract text from various document formats.""" | |
| def __init__(self): | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| validate_api_key() | |
| def load_pdf_documents(self, pdf_paths: List[Union[str, Path]]) -> List[Dict[str, Any]]: | |
| """Load text from PDF files using pymupdf4llm.""" | |
| documents = [] | |
| for pdf_path in pdf_paths: | |
| try: | |
| pdf_path = Path(pdf_path) | |
| logger.info(f"Loading PDF: {pdf_path}") | |
| # Extract text using pymupdf4llm | |
| text = pymupdf4llm.to_markdown(str(pdf_path)) | |
| # Extract images if present | |
| images = self._extract_pdf_images(pdf_path) | |
| doc = { | |
| 'text': text, | |
| 'source': str(pdf_path.name), | |
| 'path': str(pdf_path), | |
| 'type': 'pdf', | |
| 'images': images, | |
| 'metadata': { | |
| 'file_size': pdf_path.stat().st_size, | |
| 'modified': pdf_path.stat().st_mtime | |
| } | |
| } | |
| documents.append(doc) | |
| except Exception as e: | |
| logger.error(f"Error loading PDF {pdf_path}: {e}") | |
| continue | |
| return documents | |
| def _extract_pdf_images(self, pdf_path: Path) -> List[Dict[str, Any]]: | |
| """Extract images from PDF using pymupdf.""" | |
| images = [] | |
| try: | |
| doc = pymupdf.open(str(pdf_path)) | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| image_list = page.get_images(full=True) | |
| for img_index, img in enumerate(image_list): | |
| try: | |
| # Extract image | |
| xref = img[0] | |
| pix = pymupdf.Pixmap(doc, xref) | |
| # Skip if pixmap is invalid or has no colorspace | |
| if not pix or pix.colorspace is None: | |
| if pix: | |
| pix = None | |
| continue | |
| # Only process images with valid color channels | |
| if pix.n - pix.alpha < 4: # GRAY or RGB | |
| image_id = f"{pdf_path.stem}_p{page_num}_img{img_index}" | |
| image_path = IMAGES_DIR / f"{image_id}.png" | |
| # Convert to RGB if grayscale or other formats | |
| if pix.n == 1: # Grayscale | |
| rgb_pix = pymupdf.Pixmap(pymupdf.csRGB, pix) | |
| pix = None # Clean up original | |
| pix = rgb_pix | |
| elif pix.n == 4 and pix.alpha == 0: # CMYK | |
| rgb_pix = pymupdf.Pixmap(pymupdf.csRGB, pix) | |
| pix = None # Clean up original | |
| pix = rgb_pix | |
| # Save image | |
| pix.save(str(image_path)) | |
| images.append({ | |
| 'image_id': image_id, | |
| 'image_path': str(image_path), | |
| 'page': page_num, | |
| 'source': str(pdf_path.name) | |
| }) | |
| pix = None | |
| except Exception as e: | |
| logger.warning(f"Error extracting image {img_index} from page {page_num}: {e}") | |
| if 'pix' in locals() and pix: | |
| pix = None | |
| continue | |
| doc.close() | |
| except Exception as e: | |
| logger.error(f"Error extracting images from {pdf_path}: {e}") | |
| return images | |
| def load_html_documents(self, html_sources: List[Dict[str, str]]) -> List[Dict[str, Any]]: | |
| """Load text from HTML sources.""" | |
| documents = [] | |
| for source in html_sources: | |
| try: | |
| logger.info(f"Loading HTML: {source.get('title', source['url'])}") | |
| # Fetch HTML content | |
| response = requests.get(source['url'], timeout=30) | |
| response.raise_for_status() | |
| # Parse with BeautifulSoup | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| # Extract text | |
| text = soup.get_text(separator=' ', strip=True) | |
| doc = { | |
| 'text': text, | |
| 'source': source.get('title', source['url']), | |
| 'path': source['url'], | |
| 'type': 'html', | |
| 'images': [], | |
| 'metadata': { | |
| 'url': source['url'], | |
| 'title': source.get('title', ''), | |
| 'year': source.get('year', ''), | |
| 'category': source.get('category', ''), | |
| 'format': source.get('format', 'HTML') | |
| } | |
| } | |
| documents.append(doc) | |
| except Exception as e: | |
| logger.error(f"Error loading HTML {source['url']}: {e}") | |
| continue | |
| return documents | |
| def load_text_documents(self, data_dir: Path = DATA_DIR) -> List[Dict[str, Any]]: | |
| """Load all supported document types from data directory.""" | |
| documents = [] | |
| # Load PDFs | |
| pdf_files = list(data_dir.glob("*.pdf")) | |
| if pdf_files: | |
| documents.extend(self.load_pdf_documents(pdf_files)) | |
| # Load HTML sources (from config) | |
| if DEFAULT_HTML_SOURCES: | |
| documents.extend(self.load_html_documents(DEFAULT_HTML_SOURCES)) | |
| logger.info(f"Loaded {len(documents)} documents total") | |
| return documents | |
| class TextPreprocessor: | |
| """Preprocess text for different retrieval methods.""" | |
| def __init__(self): | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| def chunk_text_by_tokens(self, text: str, chunk_size: int = CHUNK_SIZE, | |
| overlap: int = CHUNK_OVERLAP) -> List[str]: | |
| """Split text into chunks by token count.""" | |
| tokens = self.encoding.encode(text) | |
| chunks = [] | |
| start = 0 | |
| while start < len(tokens): | |
| end = start + chunk_size | |
| chunk_tokens = tokens[start:end] | |
| chunk_text = self.encoding.decode(chunk_tokens) | |
| chunks.append(chunk_text) | |
| start = end - overlap | |
| return chunks | |
| def chunk_text_by_sections(self, text: str, method: str = "vanilla") -> List[str]: | |
| """Split text by sections based on method requirements.""" | |
| if method in ["vanilla", "dpr"]: | |
| return self.chunk_text_by_tokens(text) | |
| elif method == "bm25": | |
| # BM25 works better with paragraph-level chunks | |
| paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] | |
| return paragraphs | |
| elif method == "graph": | |
| # Graph method uses larger sections | |
| return self.chunk_text_by_tokens(text, chunk_size=CHUNK_SIZE*2) | |
| elif method == "context_stuffing": | |
| # Context stuffing uses full documents | |
| return [text] | |
| else: | |
| return self.chunk_text_by_tokens(text) | |
| def preprocess_for_method(self, documents: List[Dict[str, Any]], | |
| method: str) -> List[DocumentChunk]: | |
| """Preprocess documents for specific retrieval method.""" | |
| chunks = [] | |
| for doc in documents: | |
| text_chunks = self.chunk_text_by_sections(doc['text'], method) | |
| for i, chunk_text in enumerate(text_chunks): | |
| chunk_id = f"{doc['source']}_{method}_chunk_{i}" | |
| chunk = DocumentChunk( | |
| text=chunk_text, | |
| metadata={ | |
| 'source': doc['source'], | |
| 'path': doc['path'], | |
| 'type': doc['type'], | |
| 'chunk_index': i, | |
| 'method': method, | |
| **doc.get('metadata', {}) | |
| }, | |
| chunk_id=chunk_id | |
| ) | |
| chunks.append(chunk) | |
| logger.info(f"Created {len(chunks)} chunks for method '{method}'") | |
| return chunks | |
| class EmbeddingGenerator: | |
| """Generate embeddings using various models.""" | |
| def __init__(self): | |
| self.openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| self.sentence_transformer = None | |
| # self.clip_model = None | |
| # self.clip_preprocess = None | |
| def _get_sentence_transformer(self): | |
| """Lazy loading of sentence transformer.""" | |
| if self.sentence_transformer is None: | |
| self.sentence_transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) | |
| if DEVICE == "cuda": | |
| self.sentence_transformer = self.sentence_transformer.to(DEVICE) | |
| return self.sentence_transformer | |
| # def _get_clip_model(self): | |
| # """Lazy loading of CLIP model.""" | |
| # if self.clip_model is None: | |
| # self.clip_model, self.clip_preprocess = clip.load(CLIP_MODEL, device=DEVICE) | |
| # return self.clip_model, self.clip_preprocess | |
| def embed_text_openai(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings using OpenAI API.""" | |
| embeddings = [] | |
| # Process in batches | |
| for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): | |
| batch = texts[i:i + EMBEDDING_BATCH_SIZE] | |
| try: | |
| response = self.openai_client.embeddings.create( | |
| model=OPENAI_EMBEDDING_MODEL, | |
| input=batch | |
| ) | |
| batch_embeddings = [data.embedding for data in response.data] | |
| embeddings.extend(batch_embeddings) | |
| except Exception as e: | |
| logger.error(f"Error generating OpenAI embeddings: {e}") | |
| raise | |
| return np.array(embeddings) | |
| def embed_text_sentence_transformer(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings using sentence transformers.""" | |
| model = self._get_sentence_transformer() | |
| try: | |
| embeddings = model.encode(texts, convert_to_numpy=True, | |
| show_progress_bar=True, batch_size=32) | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"Error generating sentence transformer embeddings: {e}") | |
| raise | |
| def embed_image_clip(self, image_paths: List[str]) -> np.ndarray: | |
| """Generate image embeddings using CLIP.""" | |
| # model, preprocess = self._get_clip_model() | |
| # embeddings = [] | |
| # for image_path in image_paths: | |
| # try: | |
| # image = preprocess(Image.open(image_path)).unsqueeze(0).to(DEVICE) | |
| # | |
| # with torch.no_grad(): | |
| # image_features = model.encode_image(image) | |
| # image_features /= image_features.norm(dim=-1, keepdim=True) | |
| # | |
| # embeddings.append(image_features.cpu().numpy().flatten()) | |
| # | |
| # except Exception as e: | |
| # logger.error(f"Error embedding image {image_path}: {e}") | |
| # continue | |
| # return np.array(embeddings) if embeddings else np.array([]) | |
| # Placeholder for CLIP embeddings | |
| logger.warning("CLIP embeddings not implemented - returning dummy embeddings") | |
| return np.random.rand(len(image_paths), 512) | |
| class VectorStoreManager: | |
| """Manage vector stores for different methods.""" | |
| def __init__(self): | |
| self.embedding_generator = EmbeddingGenerator() | |
| def build_faiss_index(self, chunks: List[DocumentChunk], method: str = "vanilla") -> Tuple[Any, List[Dict]]: | |
| """Build FAISS index for vanilla or DPR method.""" | |
| # Generate embeddings | |
| texts = [chunk.text for chunk in chunks] | |
| if method == "vanilla": | |
| embeddings = self.embedding_generator.embed_text_openai(texts) | |
| elif method == "dpr": | |
| embeddings = self.embedding_generator.embed_text_sentence_transformer(texts) | |
| else: | |
| raise ValueError(f"Unsupported method for FAISS: {method}") | |
| # Build FAISS index | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity | |
| # Ensure embeddings are float32 and normalize for cosine similarity | |
| embeddings = embeddings.astype(np.float32) | |
| faiss.normalize_L2(embeddings) | |
| index.add(embeddings) | |
| # Store chunk metadata | |
| metadata = [] | |
| for i, chunk in enumerate(chunks): | |
| metadata.append({ | |
| 'chunk_id': chunk.chunk_id, | |
| 'text': chunk.text, | |
| 'metadata': chunk.metadata, | |
| 'embedding': embeddings[i].tolist() | |
| }) | |
| logger.info(f"Built FAISS index with {index.ntotal} vectors for method '{method}'") | |
| return index, metadata | |
| def build_chroma_index(self, chunks: List[DocumentChunk], method: str = "vanilla") -> Any: | |
| """Build Chroma vector database.""" | |
| # Initialize Chroma client | |
| chroma_client = chromadb.PersistentClient(path=str(CHROMA_PATH / method)) | |
| collection = chroma_client.get_or_create_collection( | |
| name=f"{method}_collection", | |
| metadata={"method": method} | |
| ) | |
| # Prepare data for Chroma | |
| texts = [chunk.text for chunk in chunks] | |
| ids = [chunk.chunk_id for chunk in chunks] | |
| metadatas = [chunk.metadata for chunk in chunks] | |
| # Add to collection (Chroma handles embeddings internally) | |
| collection.add( | |
| documents=texts, | |
| ids=ids, | |
| metadatas=metadatas | |
| ) | |
| logger.info(f"Built Chroma collection with {collection.count()} documents for method '{method}'") | |
| return collection | |
| def build_bm25_index(self, chunks: List[DocumentChunk]) -> BM25Okapi: | |
| """Build BM25 index for keyword search.""" | |
| # Tokenize texts | |
| tokenized_corpus = [] | |
| for chunk in chunks: | |
| tokens = chunk.text.lower().split() | |
| tokenized_corpus.append(tokens) | |
| # Build BM25 index | |
| bm25 = BM25Okapi(tokenized_corpus, k1=BM25_K1, b=BM25_B) | |
| logger.info(f"Built BM25 index with {len(tokenized_corpus)} documents") | |
| return bm25 | |
| def build_graph_index(self, chunks: List[DocumentChunk]) -> nx.Graph: | |
| """Build NetworkX graph for graph-based retrieval.""" | |
| # Create graph | |
| G = nx.Graph() | |
| # Generate embeddings for similarity calculation | |
| texts = [chunk.text for chunk in chunks] | |
| embeddings = self.embedding_generator.embed_text_openai(texts) | |
| # Add nodes (convert embeddings to lists for GML serialization) | |
| for i, chunk in enumerate(chunks): | |
| G.add_node(chunk.chunk_id, | |
| text=chunk.text, | |
| metadata=chunk.metadata, | |
| embedding=embeddings[i].tolist()) # Convert to list for serialization | |
| # Add edges based on similarity | |
| threshold = 0.7 # Similarity threshold | |
| for i in range(len(chunks)): | |
| for j in range(i + 1, len(chunks)): | |
| # Calculate cosine similarity | |
| sim = np.dot(embeddings[i], embeddings[j]) / ( | |
| np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j]) | |
| ) | |
| if sim > threshold: | |
| G.add_edge(chunks[i].chunk_id, chunks[j].chunk_id, | |
| weight=float(sim)) | |
| logger.info(f"Built graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") | |
| return G | |
| def save_index(self, index: Any, metadata: Any, method: str): | |
| """Save index and metadata to disk.""" | |
| if method == "vanilla": | |
| faiss.write_index(index, str(VANILLA_FAISS_INDEX)) | |
| with open(VANILLA_METADATA, 'wb') as f: | |
| pickle.dump(metadata, f) | |
| elif method == "dpr": | |
| faiss.write_index(index, str(DPR_FAISS_INDEX)) | |
| with open(DPR_METADATA, 'wb') as f: | |
| pickle.dump(metadata, f) | |
| elif method == "bm25": | |
| with open(BM25_INDEX, 'wb') as f: | |
| pickle.dump({'index': index, 'texts': metadata}, f) | |
| elif method == "context_stuffing": | |
| with open(CONTEXT_DOCS, 'wb') as f: | |
| pickle.dump(metadata, f) | |
| elif method == "graph": | |
| nx.write_gml(index, str(GRAPH_FILE)) | |
| logger.info(f"Saved {method} index to disk") | |
| class ImageProcessor: | |
| """Process and classify images.""" | |
| def __init__(self): | |
| self.embedding_generator = EmbeddingGenerator() | |
| self.openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| self._init_database() | |
| def _init_database(self): | |
| """Initialize SQLite database for image metadata.""" | |
| conn = sqlite3.connect(IMAGES_DB) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS images ( | |
| image_id TEXT PRIMARY KEY, | |
| image_path TEXT NOT NULL, | |
| classification TEXT, | |
| metadata TEXT, | |
| embedding BLOB, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| def classify_image(self, image_path: str) -> str: | |
| """Classify image using GPT-5 Vision.""" | |
| try: | |
| # Convert image to base64 | |
| with open(image_path, "rb") as image_file: | |
| image_b64 = base64.b64encode(image_file.read()).decode() | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Classify this image in 1-2 words (e.g., 'machine guard', 'press brake', 'conveyor belt', 'safety sign')."}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}", "detail": "low"}} | |
| ] | |
| }] | |
| # For GPT-5 vision, temperature must be default (1.0) | |
| response = self.openai_client.chat.completions.create( | |
| model=OPENAI_CHAT_MODEL, | |
| messages=messages, | |
| max_completion_tokens=50 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| logger.error(f"Error classifying image {image_path}: {e}") | |
| return "unknown" | |
| def should_filter_image(self, image_path: str) -> tuple[bool, str]: | |
| """ | |
| Check if image should be filtered out based on height and black image criteria. | |
| Args: | |
| image_path: Path to the image file | |
| Returns: | |
| Tuple of (should_filter: bool, reason: str) | |
| """ | |
| try: | |
| from PIL import Image | |
| import numpy as np | |
| # Open and analyze the image | |
| with Image.open(image_path) as img: | |
| # Convert to RGB if needed | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| width, height = img.size | |
| # Filter 1: Height less than 40 pixels | |
| if height < 40: | |
| return True, f"height too small ({height}px)" | |
| # Filter 2: Check if image is mostly black | |
| img_array = np.array(img) | |
| mean_brightness = np.mean(img_array) | |
| # If mean brightness is very low (mostly black) | |
| if mean_brightness < 10: # Adjust threshold as needed | |
| return True, "mostly black image" | |
| except Exception as e: | |
| logger.warning(f"Error analyzing image {image_path}: {e}") | |
| # If we can't analyze it, don't filter it out | |
| return False, "analysis failed" | |
| return False, "passed all filters" | |
| def store_image_metadata(self, image_data: ImageData): | |
| """Store image metadata in database.""" | |
| conn = sqlite3.connect(IMAGES_DB) | |
| cursor = conn.cursor() | |
| # Serialize metadata and embedding | |
| metadata_json = json.dumps(image_data.metadata) if image_data.metadata else None | |
| embedding_blob = image_data.embedding.tobytes() if image_data.embedding is not None else None | |
| cursor.execute(''' | |
| INSERT OR REPLACE INTO images | |
| (image_id, image_path, classification, metadata, embedding) | |
| VALUES (?, ?, ?, ?, ?) | |
| ''', (image_data.image_id, image_data.image_path, | |
| image_data.classification, metadata_json, embedding_blob)) | |
| conn.commit() | |
| conn.close() | |
| def get_image_metadata(self, image_id: str) -> Optional[ImageData]: | |
| """Retrieve image metadata from database.""" | |
| conn = sqlite3.connect(IMAGES_DB) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT image_id, image_path, classification, metadata, embedding | |
| FROM images WHERE image_id = ? | |
| ''', (image_id,)) | |
| row = cursor.fetchone() | |
| conn.close() | |
| if row: | |
| image_id, image_path, classification, metadata_json, embedding_blob = row | |
| metadata = json.loads(metadata_json) if metadata_json else None | |
| embedding = np.frombuffer(embedding_blob, dtype=np.float32) if embedding_blob else None | |
| return ImageData( | |
| image_path=image_path, | |
| image_id=image_id, | |
| classification=classification, | |
| embedding=embedding, | |
| metadata=metadata | |
| ) | |
| return None | |
| def load_text_documents() -> List[Dict[str, Any]]: | |
| """Convenience function to load all text documents.""" | |
| loader = DocumentLoader() | |
| return loader.load_text_documents() | |
| def embed_image_clip(image_paths: List[str]) -> np.ndarray: | |
| """Convenience function to embed images with CLIP.""" | |
| generator = EmbeddingGenerator() | |
| return generator.embed_image_clip(image_paths) | |
| def store_image_metadata(image_data: ImageData): | |
| """Convenience function to store image metadata.""" | |
| processor = ImageProcessor() | |
| processor.store_image_metadata(image_data) | |
| def classify_image(image_path: str) -> str: | |
| """Convenience function to classify an image.""" | |
| processor = ImageProcessor() | |
| return processor.classify_image(image_path) |