"""Search engine manager for coordinating multi-platform searches.""" import asyncio import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Set from .base import BaseSearchEngine, ImageResult, SearchPlatform, SearchResult from .instagram import InstagramSearchEngine from .pinterest import PinterestSearchEngine from .reddit import RedditSearchEngine class SearchEngineManager: """Manages and coordinates searches across multiple platforms.""" def __init__(self, max_workers: int = 5): self.max_workers = max_workers self.engines: Dict[SearchPlatform, BaseSearchEngine] = { SearchPlatform.PINTEREST: PinterestSearchEngine(), SearchPlatform.REDDIT: RedditSearchEngine(), SearchPlatform.INSTAGRAM: InstagramSearchEngine(), } def search_all_platforms( self, query: str, max_results_per_platform: int = 20, platforms: Optional[Set[SearchPlatform]] = None ) -> SearchResult: """Search across multiple platforms concurrently.""" start_time = time.time() if platforms is None: platforms = set(self.engines.keys()) all_results = [] platforms_used = set() with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Submit search tasks for each platform future_to_platform = { executor.submit( self._search_single_platform, platform, query, max_results_per_platform ): platform for platform in platforms if platform in self.engines } # Collect results as they complete for future in as_completed(future_to_platform): platform = future_to_platform[future] try: platform_results = future.result(timeout=30) # 30s timeout per platform if platform_results: all_results.extend(platform_results) platforms_used.add(platform) except Exception as e: print(f"Platform {platform.value} search failed: {e}") # Remove duplicates and sort by quality unique_results = self._deduplicate_results(all_results) sorted_results = sorted(unique_results, key=lambda x: x.quality_score, reverse=True) search_duration = time.time() - start_time return SearchResult( images=sorted_results, total_found=len(sorted_results), platforms_used=platforms_used, search_duration=search_duration ) def search_with_fallback( self, query: str, max_results: int = 50, min_results_threshold: int = 10 ) -> SearchResult: """Search with intelligent fallback strategies.""" # Try primary platforms first primary_platforms = {SearchPlatform.PINTEREST, SearchPlatform.REDDIT} result = self.search_all_platforms( query, max_results_per_platform=max_results // 2, platforms=primary_platforms ) # If we don't have enough results, try additional platforms if len(result.images) < min_results_threshold: additional_platforms = {SearchPlatform.INSTAGRAM} additional_result = self.search_all_platforms( query, max_results_per_platform=max_results // 2, platforms=additional_platforms ) # Merge results all_images = result.images + additional_result.images unique_images = self._deduplicate_results(all_images) sorted_images = sorted(unique_images, key=lambda x: x.quality_score, reverse=True) result = SearchResult( images=sorted_images, total_found=len(sorted_images), platforms_used=result.platforms_used | additional_result.platforms_used, search_duration=result.search_duration + additional_result.search_duration ) # If still not enough, try simplified queries if len(result.images) < min_results_threshold: simplified_query = self._simplify_query(query) if simplified_query != query: fallback_result = self.search_all_platforms( simplified_query, max_results_per_platform=max_results // 3 ) # Merge with existing results all_images = result.images + fallback_result.images unique_images = self._deduplicate_results(all_images) sorted_images = sorted(unique_images, key=lambda x: x.quality_score, reverse=True) result = SearchResult( images=sorted_images, total_found=len(sorted_images), platforms_used=result.platforms_used | fallback_result.platforms_used, search_duration=result.search_duration + fallback_result.search_duration ) return result def _search_single_platform( self, platform: SearchPlatform, query: str, max_results: int ) -> List[ImageResult]: """Search a single platform (thread-safe).""" engine = self.engines.get(platform) if not engine: return [] try: return engine.search(query, max_results) except Exception as e: print(f"Error searching {platform.value}: {e}") return [] def _deduplicate_results(self, results: List[ImageResult]) -> List[ImageResult]: """Remove duplicate URLs while preserving the highest quality version.""" seen_urls = {} for result in results: if result.url in seen_urls: # Keep the result with higher quality score if result.quality_score > seen_urls[result.url].quality_score: seen_urls[result.url] = result else: seen_urls[result.url] = result return list(seen_urls.values()) def _simplify_query(self, query: str) -> str: """Simplify query by removing complex terms and keeping core concepts.""" # Remove adjectives and keep main nouns words = query.split() # Common tattoo-related keywords to keep core_keywords = { 'tattoo', 'design', 'art', 'ink', 'traditional', 'realistic', 'geometric', 'tribal', 'watercolor', 'minimalist', 'blackwork', 'dotwork', 'dragon', 'flower', 'skull', 'rose', 'bird', 'lion', 'butterfly' } # Keep important words and first few words simplified_words = [] for i, word in enumerate(words): if i < 3 or word.lower() in core_keywords: simplified_words.append(word) simplified = ' '.join(simplified_words) return simplified if simplified else 'tattoo art' def get_platform_stats(self) -> Dict[str, Dict]: """Get statistics about available platforms.""" stats = {} for platform, engine in self.engines.items(): stats[platform.value] = { 'available': True, 'class': engine.__class__.__name__ } return stats