# app/main.py import os import asyncio from pathlib import Path import zipfile import io import requests import uuid # Add this import from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks # Add BackgroundTasks from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Dict, Any # Импортируем утилиты from inference_utils import ( load_models_and_config, process_uploaded_zip, find_matches_for_item, process_shared_dataset_directory, cache_local_dataset, ) # Импортируем нашу новую функцию для скачивания from download_utils import download_yandex_file # --- Инициализация --- app = FastAPI() # Разрешаем CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Глобальные кэши --- SHARED_DATASET_FULL_DATA = {} SHARED_DATASET_ID = "shared_dataset_1" PROCESSING_STATUS = {} # NEW: For tracking progress # --- Helper Functions --- def download_and_unzip_yandex_archive(public_url: str, destination_dir: Path, description: str): # This function is unchanged print(f"--- 📥 Checking for {description} ---") if destination_dir.exists() and any(destination_dir.iterdir()): print(f"✅ {description} already exists in '{destination_dir}'. Skipping download.") return True print(f"⏳ {description} not found. Starting download from Yandex.Disk...") destination_dir.mkdir(parents=True, exist_ok=True) if "YOUR_" in public_url or "ВАША_" in public_url: print(f"🔥 WARNING: Placeholder URL detected for {description}. Download skipped.") return False try: api_url = "https://cloud-api.yandex.net/v1/disk/public/resources/download" params = {'public_key': public_url} response = requests.get(api_url, params=params) response.raise_for_status() download_url = response.json().get('href') if not download_url: raise RuntimeError(f"Could not retrieve download URL for {description} from Yandex.Disk API.") print(f" 🔗 Got download link. Fetching ZIP archive for {description}...") zip_response = requests.get(download_url, stream=True) zip_response.raise_for_status() zip_in_memory = io.BytesIO(zip_response.content) print(f" 🗂️ Unzipping archive for {description}...") with zipfile.ZipFile(zip_in_memory, 'r') as zip_ref: zip_ref.extractall(destination_dir) print(f"🎉 {description} successfully downloaded and extracted to '{destination_dir}'.") return True except Exception as e: print(f"🔥 CRITICAL ERROR downloading or unzipping {description}: {e}") return False # --- NEW: Background Processing Wrapper --- def background_process_zip(zip_bytes: bytes, original_filename: str, job_id: str): """Wrapper function to run processing and update status.""" def update_status(stage: str, progress: int): """Callback to update the global status dictionary.""" print(f"Job {job_id}: {stage} - {progress}%") PROCESSING_STATUS[job_id] = {"stage": stage, "progress": progress, "status": "processing"} try: processed_data = process_uploaded_zip( zip_bytes, original_filename, update_status ) PROCESSING_STATUS[job_id] = { "status": "complete", "result": processed_data } except Exception as e: import traceback traceback.print_exc() PROCESSING_STATUS[job_id] = { "status": "error", "message": f"An error occurred during processing: {e}" } class SingleMatchRequest(BaseModel): modality: str content: str dataset_id: str # --- MODIFIED: process-dataset endpoint --- class ProcessDatasetResponse(BaseModel): job_id: str class DataItemModel(BaseModel): id: str name: str content: str | None = None # Frontend sends content as string (base64 or text) contentUrl: str | None = None class DatasetDataModel(BaseModel): images: List[DataItemModel] texts: List[DataItemModel] meshes: List[DataItemModel] class LocalDatasetModel(BaseModel): id: str name: str data: DatasetDataModel # We only need the core data for re-hydration, other fields are optional # Use 'Any' for complex fields we don't need to strictly validate here fullComparison: Dict[str, Any] | None = None @app.post("/api/cache-local-dataset") async def cache_local_dataset_endpoint(dataset: LocalDatasetModel): """ Receives a local dataset from the frontend to re-hydrate the server's in-memory cache. """ try: # Pydantic's .dict() is deprecated, use .model_dump() dataset_dict = dataset.model_dump() await asyncio.to_thread(cache_local_dataset, dataset_dict) return {"status": "cached", "id": dataset.id} except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Failed to cache dataset: {e}") # --- Startup Event --- @app.on_event("startup") def startup_event(): # This function is unchanged SHARED_DATASET_DIR = Path("static/shared_dataset") SHARED_EMBEDDINGS_DIR = Path("static/shared_embeddings") SHARED_DATASET_ZIP_URL = "https://disk.yandex.ru/d/G9C3_FGGzSLAXw" SHARED_EMBEDDINGS_ZIP_URL = "https://disk.yandex.ru/d/aVTX6n2pc0hrCw" dataset_ready = download_and_unzip_yandex_archive(SHARED_DATASET_ZIP_URL, SHARED_DATASET_DIR, "shared dataset files") embeddings_ready = download_and_unzip_yandex_archive(SHARED_EMBEDDINGS_ZIP_URL, SHARED_EMBEDDINGS_DIR, "pre-computed embeddings") DATA_DIR = Path("data/") MODEL_URLS = { "text_proj.pth": "https://disk.yandex.ru/d/uMH1ls0nYM4txw", "text_encoder.pth": "https://disk.yandex.ru/d/R0BBLPXj828OhA", "moe.pth": "https://disk.yandex.ru/d/vDfuIPziuO45wg", "pc_encoder.pth": "https://disk.yandex.ru/d/03Ps2TMcWAKkww", } print("--- 📥 Checking and loading models ---") DATA_DIR.mkdir(parents=True, exist_ok=True) all_models_present = True for filename, url in MODEL_URLS.items(): destination_file = DATA_DIR / filename if not destination_file.exists(): print(f"⏳ Модель '{filename}' не найдена. Начинаю загрузку...") if not "ВАША_ССЫЛКА" in url: success = download_yandex_file(public_file_url=url, destination_path=str(DATA_DIR), filename=filename) if not success: all_models_present = False print(f"🔥 Критическая ошибка: не удалось скачать модель '{filename}'.") else: all_models_present = False print(f"🔥 ВНИМАНИЕ: Пропущена загрузка '{filename}', т.к. ссылка является плейсхолдером.") else: print(f"✅ Модель '{filename}' уже существует. Пропускаю загрузку.") if not all_models_present: raise RuntimeError("Не удалось загрузить все необходимые модели. Приложение не может запуститься.") print("--- ✅ Все модели готовы к использованию ---") model_paths = {"text_proj": str(DATA_DIR / "text_proj.pth"), "text_encoder": str(DATA_DIR / "text_encoder.pth"), "moe": str(DATA_DIR / "moe.pth"), "pc_encoder": str(DATA_DIR / "pc_encoder.pth")} config_path = "cad_retrieval_utils/config/config.py" try: load_models_and_config(config_path=config_path, model_paths=model_paths) print("✅ Все модели успешно загружены в память.") except Exception as e: print(f"🔥 Ошибка при загрузке моделей: {e}") import traceback traceback.print_exc() raise RuntimeError(f"Ошибка загрузки моделей, приложение не может запуститься.") from e if dataset_ready and embeddings_ready: print("--- 🧠 Loading pre-computed embeddings for shared dataset ---") try: full_data = process_shared_dataset_directory(directory_path=SHARED_DATASET_DIR, embeddings_path=SHARED_EMBEDDINGS_DIR, dataset_id=SHARED_DATASET_ID, dataset_name="Cloud Multi-Modal Dataset") if full_data: SHARED_DATASET_FULL_DATA[SHARED_DATASET_ID] = full_data print("--- ✅ Shared dataset processed and cached successfully. ---") else: print("--- ⚠️ Shared dataset processing returned no data. Caching skipped. ---") except Exception as e: print(f"🔥 CRITICAL ERROR processing shared dataset: {e}") import traceback traceback.print_exc() else: print("--- ⚠️ Shared dataset or embeddings not available. Processing skipped. ---") # --- API Endpoints --- @app.get("/api/shared-dataset-metadata") async def get_shared_dataset_metadata(): # This function is unchanged metadata_list = [] for dataset_id, full_data in SHARED_DATASET_FULL_DATA.items(): metadata = {"id": full_data["id"], "name": full_data["name"], "uploadDate": full_data["uploadDate"], "processingState": full_data["processingState"], "itemCounts": {"images": len(full_data["data"]["images"]), "texts": len(full_data["data"]["texts"]), "meshes": len(full_data["data"]["meshes"])}, "isShared": True} metadata_list.append(metadata) return metadata_list @app.get("/api/shared-dataset") async def get_shared_dataset(id: str): # This function is unchanged dataset = SHARED_DATASET_FULL_DATA.get(id) if not dataset: raise HTTPException(status_code=404, detail=f"Shared dataset with id '{id}' not found.") return dataset @app.post("/api/process-dataset", response_model=ProcessDatasetResponse) async def process_dataset_endpoint( background_tasks: BackgroundTasks, file: UploadFile = File(...) ): if not file.filename or not file.filename.endswith('.zip'): raise HTTPException(status_code=400, detail="A ZIP archive is required.") zip_bytes = await file.read() job_id = str(uuid.uuid4()) PROCESSING_STATUS[job_id] = {"status": "starting", "stage": "Queued", "progress": 0} background_tasks.add_task( background_process_zip, zip_bytes, file.filename, job_id ) return {"job_id": job_id} # --- NEW: processing-status endpoint --- class StatusResponse(BaseModel): status: str stage: str | None = None progress: int | None = None message: str | None = None result: dict | None = None @app.get("/api/processing-status/{job_id}", response_model=StatusResponse) async def get_processing_status(job_id: str): """Poll this endpoint to get the status of a processing job.""" status = PROCESSING_STATUS.get(job_id) if not status: raise HTTPException(status_code=404, detail="Job ID not found.") return status @app.post("/api/find-matches") async def find_matches_endpoint(request: SingleMatchRequest): # This function is unchanged try: match_results = await asyncio.to_thread( find_matches_for_item, request.modality, request.content, request.dataset_id ) source_item_data = {"id": "source_item", "name": "Source Item", "content": request.content} final_response = {"sourceItem": source_item_data, "sourceModality": request.modality, **match_results} return final_response except ValueError as ve: raise HTTPException(status_code=404, detail=str(ve)) except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Ошибка при поиске совпадений: {e}") app.mount("/", StaticFiles(directory="static", html=True), name="static")