sight_chat / preprocess.py
fmegahed's picture
version 2.0.0
ef821d9 verified
raw
history blame
6.93 kB
"""
Refactored preprocessing pipeline for all RAG methods.
Uses utils.py functions and supports multiple retrieval methods.
Directory Layout:
/data/ # Original PDFs, HTML
/embeddings/ # FAISS, Chroma, DPR vector stores
/graph/ # Graph database files
/metadata/ # Image metadata (SQLite or MongoDB)
"""
import logging
from pathlib import Path
from config import *
from utils import (
DocumentLoader, TextPreprocessor, VectorStoreManager,
ImageProcessor, ImageData
)
logger = logging.getLogger(__name__)
# Ensure all directories exist
ensure_directories()
def preprocess_for_method(method: str, documents: list):
"""Preprocess documents for a specific retrieval method."""
print(f"\n{'='*50}")
print(f"Preprocessing for method: {method}")
print(f"{'='*50}")
try:
# Initialize processors
text_processor = TextPreprocessor()
vector_manager = VectorStoreManager()
# Preprocess text chunks for this method
chunks = text_processor.preprocess_for_method(documents, method)
if method == 'vanilla':
# Build FAISS index with OpenAI embeddings
index, metadata = vector_manager.build_faiss_index(chunks, method="vanilla")
vector_manager.save_index(index, metadata, method)
elif method == 'dpr':
# Build FAISS index with sentence transformer embeddings
index, metadata = vector_manager.build_faiss_index(chunks, method="dpr")
vector_manager.save_index(index, metadata, method)
elif method == 'bm25':
# Build BM25 index
bm25_index = vector_manager.build_bm25_index(chunks)
vector_manager.save_index(bm25_index, chunks, method)
elif method == 'graph':
# Build NetworkX graph
graph = vector_manager.build_graph_index(chunks)
vector_manager.save_index(graph, None, method)
elif method == 'context_stuffing':
# Save full documents for context stuffing
vector_manager.save_index(None, chunks, method)
else:
raise ValueError(f"Unknown method: {method}")
print(f"Successfully preprocessed for method '{method}'")
except Exception as e:
logger.error(f"Error preprocessing for {method}: {e}")
raise
def extract_and_process_images(documents: list):
"""Extract images from documents and process them."""
print("\n" + "="*50)
print("Extracting and processing images...")
print("="*50)
image_processor = ImageProcessor()
processed_count = 0
filtered_count = 0
filter_reasons = {}
for doc in documents:
if 'images' in doc and doc['images']:
for image_info in doc['images']:
try:
# Check if image should be filtered out
should_filter, reason = image_processor.should_filter_image(image_info['image_path'])
if should_filter:
filtered_count += 1
filter_reasons[reason] = filter_reasons.get(reason, 0) + 1
print(f" Filtered: {image_info['image_id']} - {reason}")
# Optionally delete the filtered image file
try:
import os
os.remove(image_info['image_path'])
print(f" Deleted: {image_info['image_path']}")
except Exception as e:
logger.warning(f"Could not delete filtered image {image_info['image_path']}: {e}")
continue
# Classify image
classification = image_processor.classify_image(image_info['image_path'])
# Generate embedding (placeholder for now)
# embedding = embed_image_clip([image_info['image_path']])[0]
# Create ImageData object
image_data = ImageData(
image_path=image_info['image_path'],
image_id=image_info['image_id'],
classification=classification,
metadata={
'source': doc['source'],
'page': image_info.get('page'),
'extracted_from': doc['path']
}
)
# Store in database
image_processor.store_image_metadata(image_data)
processed_count += 1
except Exception as e:
logger.error(f"Error processing image {image_info['image_id']}: {e}")
continue
# Print filtering summary
if filtered_count > 0:
print(f"\nImage Filtering Summary:")
print(f" Total filtered: {filtered_count}")
for reason, count in filter_reasons.items():
print(f" {reason}: {count}")
print()
if processed_count > 0:
print(f"Processed and stored metadata for {processed_count} images")
else:
print("No images found in documents")
def main():
"""Main preprocessing pipeline."""
# Validate configuration
try:
validate_api_key()
except ValueError as e:
print(f"Error: {e}")
return
# Print configuration
print_config()
print("\nStarting preprocessing pipeline...")
# Load documents using utils
print("\nLoading documents...")
loader = DocumentLoader()
documents = loader.load_text_documents()
print(f"Loaded {len(documents)} documents")
# Define methods to preprocess
methods = ['vanilla', 'dpr', 'bm25', 'graph', 'context_stuffing']
# Preprocess for each method
for method in methods:
try:
preprocess_for_method(method, documents)
except Exception as e:
print(f"Error preprocessing for {method}: {e}")
import traceback
traceback.print_exc()
# Extract and process images
try:
extract_and_process_images(documents)
except Exception as e:
print(f"Error processing images: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*50)
print("Preprocessing complete!")
print("="*50)
if __name__ == "__main__":
main()