import io import json import logging import os import random import re import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional import requests import torch import torch.nn.functional as F from dotenv import load_dotenv from fastapi import FastAPI, File, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import InferenceClient from PIL import Image from embeddings import EmbeddingModel, EmbeddingModelFactory, get_default_model_configs from patch_attention import PatchAttentionAnalyzer from search_engines import SearchEngineManager from utils import SearchCache, URLValidator # Load environment variables from .env file load_dotenv() # Configuration HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is required") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Tattoo Search Engine", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class TattooSearchEngine: def __init__(self, embedding_model_type: str = "clip"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Initialize HuggingFace InferenceClient for VLM captioning logger.info("Initializing HuggingFace InferenceClient...") self.client = InferenceClient( provider="novita", api_key=HF_TOKEN, ) self.vlm_model = "zai-org/GLM-4.5V" logger.info(f"Using VLM model: {self.vlm_model}") # Load embedding model logger.info(f"Loading embedding model: {embedding_model_type}") self.embedding_model = EmbeddingModelFactory.create_model( embedding_model_type, self.device ) logger.info(f"Using embedding model: {self.embedding_model.get_model_name()}") # Initialize new search system logger.info("Initializing search system...") self.search_manager = SearchEngineManager(max_workers=5) self.url_validator = URLValidator(max_workers=10, timeout=10) self.search_cache = SearchCache(default_ttl=3600, max_size=1000) # Setup enhanced web scraping self.user_agents = [ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", ] logger.info("Search system initialized successfully!") def generate_caption(self, image: Image.Image) -> str: """Generate tattoo caption using HuggingFace InferenceClient.""" try: # Convert PIL image to base64 URL format img_buffer = io.BytesIO() image.save(img_buffer, format="JPEG", quality=95) img_buffer.seek(0) # Create image URL for the API import base64 image_b64 = base64.b64encode(img_buffer.getvalue()).decode() image_url = f"data:image/jpeg;base64,{image_b64}" completion = self.client.chat.completions.create( model=self.vlm_model, messages=[ { "role": "user", "content": [ { "type": "text", "text": "Generate a one search engine query to find the most similar tattoos to this image. Response in json format", }, { "type": "image_url", "image_url": {"url": image_url}, }, ], } ], ) caption = completion.choices[0].message.content if caption: match = re.search(r"\{.*\}", caption) if match: data = json.loads(match.group()) search_query = data["search_query"] return search_query else: logger.warning("No caption generated from VLM") return "tattoo artwork" except Exception as e: logger.error(f"Failed to generate caption: {e}") return "tattoo artwork" def search_images(self, query: str, max_results: int = 50) -> List[str]: """Search for tattoo images across multiple platforms with caching and validation.""" # Check cache first cache_key = SearchCache.create_cache_key(query, max_results) cached_result = self.search_cache.get(cache_key) if cached_result: logger.info(f"Cache hit for query: {query}") return cached_result logger.info(f"Searching for images: {query}") # Use new search system with fallback search_result = self.search_manager.search_with_fallback( query=query, max_results=max_results, min_results_threshold=10 ) # Extract URLs from search results urls = [image.url for image in search_result.images] if not urls: logger.warning(f"No URLs found for query: {query}") return [] # Validate URLs logger.info(f"Validating {len(urls)} URLs...") valid_urls = self.url_validator.validate_urls(urls) if not valid_urls: logger.warning(f"No valid URLs found for query: {query}") return [] # Cache the result self.search_cache.set(cache_key, valid_urls, ttl=3600) logger.info( f"Search completed: {len(valid_urls)} valid URLs from " f"{len(search_result.platforms_used)} platforms in " f"{search_result.search_duration:.2f}s" ) return valid_urls[:max_results] def download_image(self, url: str, max_retries: int = 3) -> Image.Image: for attempt in range(max_retries): try: # Instagram-optimized headers headers = { "User-Agent": random.choice(self.user_agents), "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", "Accept-Language": "en-US,en;q=0.9", "Accept-Encoding": "gzip, deflate, br", "DNT": "1", "Connection": "keep-alive", "Upgrade-Insecure-Requests": "1", "Sec-Fetch-Dest": "image", "Sec-Fetch-Mode": "no-cors", "Sec-Fetch-Site": "cross-site", "Cache-Control": "no-cache", "Pragma": "no-cache", } # Pinterest-specific headers if "pinterest" in url.lower() or "pinimg" in url.lower(): headers.update( { "Referer": "https://www.pinterest.com/", "Origin": "https://www.pinterest.com", "X-Requested-With": "XMLHttpRequest", "Sec-Fetch-User": "?1", "X-Pinterest-Source": "web", "X-APP-VERSION": "web", } ) else: headers["Referer"] = "https://www.google.com/" response = requests.get( url, headers=headers, timeout=15, allow_redirects=True, stream=True ) response.raise_for_status() # Validate content type content_type = response.headers.get("content-type", "").lower() if not content_type.startswith("image/"): logger.warning(f"Invalid content type for {url}: {content_type}") return None # Check file size (avoid downloading huge files) content_length = response.headers.get("content-length") if ( content_length and int(content_length) > 10 * 1024 * 1024 ): # 10MB limit logger.warning(f"Image too large: {url} ({content_length} bytes)") return None # Download and process image image_data = response.content if len(image_data) < 1024: # Skip very small images (likely broken) logger.warning(f"Image too small: {url} ({len(image_data)} bytes)") return None image = Image.open(io.BytesIO(image_data)).convert("RGB") # Validate image dimensions if image.size[0] < 50 or image.size[1] < 50: logger.warning(f"Image dimensions too small: {url} {image.size}") return None return image except requests.exceptions.RequestException as e: if attempt < max_retries - 1: wait_time = (2**attempt) + random.uniform(0, 1) logger.info(f"Retry {attempt + 1} for {url} in {wait_time:.1f}s") time.sleep(wait_time) else: logger.warning( f"Failed to download image {url} after {max_retries} attempts: {e}" ) except Exception as e: logger.warning(f"Failed to process image {url}: {e}") break return None def download_and_process_image( self, url: str, query_features: torch.Tensor, query_image: Image.Image = None, include_patch_attention: bool = False, ) -> Dict[str, Any]: """Download and compute similarity for a single image""" candidate_image = self.download_image(url) if candidate_image is None: return None try: candidate_features = self.embedding_model.encode_image(candidate_image) similarity = self.embedding_model.compute_similarity( query_features, candidate_features ) result = {"score": float(similarity), "url": url} # Add patch attention analysis if requested if include_patch_attention and query_image is not None: try: analyzer = PatchAttentionAnalyzer(self.embedding_model) patch_data = analyzer.compute_patch_similarities( query_image, candidate_image ) result["patch_attention"] = { "overall_similarity": patch_data["overall_similarity"], "query_grid_size": patch_data["query_grid_size"], "candidate_grid_size": patch_data["candidate_grid_size"], "attention_summary": analyzer.get_similarity_summary( patch_data ), } except Exception as e: logger.warning(f"Failed to compute patch attention for {url}: {e}") result["patch_attention"] = None return result except Exception as e: logger.warning(f"Error processing candidate image {url}: {e}") return None def compute_similarity( self, query_image: Image.Image, candidate_urls: List[str], include_patch_attention: bool = False, ) -> List[Dict[str, Any]]: # Encode query image using the selected embedding model query_features = self.embedding_model.encode_image(query_image) results = [] # Use ThreadPoolExecutor for concurrent downloading and processing max_workers = min(10, len(candidate_urls)) # Limit concurrent downloads with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all download tasks future_to_url = { executor.submit( self.download_and_process_image, url, query_features, query_image, include_patch_attention, ): url for url in candidate_urls } # Process completed downloads with rate limiting for future in as_completed(future_to_url): url = future_to_url[future] try: result = future.result() if result is not None: results.append(result) # Stop early if we have enough good results (unless patch attention is needed) target_count = 5 if include_patch_attention else 20 if len(results) >= target_count: # Cancel remaining futures for remaining_future in future_to_url: remaining_future.cancel() break except Exception as e: logger.warning(f"Error in concurrent processing for {url}: {e}") # Small delay to be respectful to servers time.sleep(0.1) # Sort by similarity score (highest first) results.sort(key=lambda x: x["score"], reverse=True) final_count = 3 if include_patch_attention else 15 return results[:final_count] # Global variable to store search engine instance search_engine = None def get_search_engine(embedding_model: str = "clip") -> TattooSearchEngine: """Get or create search engine instance with specified embedding model.""" global search_engine if ( search_engine is None or search_engine.embedding_model.get_model_name().lower() != embedding_model ): search_engine = TattooSearchEngine(embedding_model) return search_engine @app.post("/search") async def search_tattoos( file: UploadFile = File(...), embedding_model: str = Query( default="clip", description="Embedding model to use (clip, dinov2, siglip)" ), include_patch_attention: bool = Query( default=False, description="Include patch-level attention analysis" ), ): if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") try: # Validate embedding model available_models = EmbeddingModelFactory.get_available_models() if embedding_model not in available_models: raise HTTPException( status_code=400, detail=f"Invalid embedding model. Available: {available_models}", ) # Get search engine with specified embedding model engine = get_search_engine(embedding_model) # Read and process the uploaded image image_data = await file.read() query_image = Image.open(io.BytesIO(image_data)).convert("RGB") # Generate caption logger.info("Generating caption...") caption = engine.generate_caption(query_image) logger.info(f"Generated caption: {caption}") # Search for candidate images logger.info("Searching for candidate images...") candidate_urls = engine.search_images(caption, max_results=100) if not candidate_urls: return { "caption": caption, "results": [], "embedding_model": engine.embedding_model.get_model_name(), } # Compute similarities and rank logger.info("Computing similarities...") results = engine.compute_similarity( query_image, candidate_urls, include_patch_attention ) return { "caption": caption, "results": results, "embedding_model": engine.embedding_model.get_model_name(), "patch_attention_enabled": include_patch_attention, } except Exception as e: logger.error(f"Error processing request: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/analyze-attention") async def analyze_patch_attention( query_file: UploadFile = File(...), candidate_url: str = Query( ..., description="URL of the candidate image to compare" ), embedding_model: str = Query( default="clip", description="Embedding model to use (clip, dinov2, siglip)" ), include_visualizations: bool = Query( default=True, description="Include attention visualizations" ), ): """Analyze patch-level attention between query image and a specific candidate image.""" if not query_file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Query file must be an image") try: # Validate embedding model available_models = EmbeddingModelFactory.get_available_models() if embedding_model not in available_models: raise HTTPException( status_code=400, detail=f"Invalid embedding model. Available: {available_models}", ) # Get search engine with specified embedding model engine = get_search_engine(embedding_model) # Read query image query_image_data = await query_file.read() query_image = Image.open(io.BytesIO(query_image_data)).convert("RGB") # Download candidate image candidate_image = engine.download_image(candidate_url) if candidate_image is None: raise HTTPException( status_code=400, detail="Failed to download candidate image" ) # Analyze patch attention analyzer = PatchAttentionAnalyzer(engine.embedding_model) similarity_data = analyzer.compute_patch_similarities( query_image, candidate_image ) result = { "query_image_size": query_image.size, "candidate_image_size": candidate_image.size, "candidate_url": candidate_url, "embedding_model": engine.embedding_model.get_model_name(), "similarity_analysis": analyzer.get_similarity_summary(similarity_data), "attention_matrix_shape": similarity_data["attention_matrix"].shape, "top_correspondences": similarity_data["top_correspondences"][ :10 ], # Top 10 } # Add visualizations if requested if include_visualizations: try: attention_heatmap = analyzer.visualize_attention_heatmap( query_image, candidate_image, similarity_data ) top_correspondences_viz = analyzer.visualize_top_correspondences( query_image, candidate_image, similarity_data ) result["visualizations"] = { "attention_heatmap": f"data:image/png;base64,{attention_heatmap}", "top_correspondences": f"data:image/png;base64,{top_correspondences_viz}", } except Exception as e: logger.warning(f"Failed to generate visualizations: {e}") result["visualizations"] = None return result except Exception as e: logger.error(f"Error analyzing patch attention: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/models") async def get_available_models(): """Get list of available embedding models and their configurations.""" models = EmbeddingModelFactory.get_available_models() configs = get_default_model_configs() return {"available_models": models, "model_configs": configs} @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)