Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, Depends, HTTPException | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from fastapi.middleware.gzip import GZipMiddleware | |
| from pg8000 import DatabaseError | |
| from pydantic import BaseModel | |
| from typing import List, Dict | |
| from datasets import Dataset | |
| from src.api.models.embedding_models import ( | |
| CreateEmbeddingRequest, | |
| DeleteByColumnRequest, | |
| ReadEmbeddingRequest, | |
| UpdateEmbeddingRequest, | |
| DeleteEmbeddingRequest, | |
| EmbedRequest, | |
| SearchEmbeddingRequest, | |
| ResetEmbeddingsRequest, | |
| ) | |
| from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url | |
| from src.api.services.embedding_service import EmbeddingService | |
| from src.api.services.huggingface_service import HuggingFaceService | |
| from src.api.services.postgresql_service import PostgresqlService | |
| from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError | |
| import logging | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Set up structured logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| description = """A FastAPI application for similarity search with PostgreSQL and OpenAI embeddings. | |
| Direct/API URL: | |
| https://re-mind-similarity-search.hf.space | |
| """ | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Similarity Search API", | |
| description=description, | |
| version="1.0.0", | |
| ) | |
| app.add_middleware(GZipMiddleware, minimum_size=1000) | |
| # def get_database_service() -> Database: | |
| # return Database(db_url=os.getenv("DB_URL")) | |
| # Dependency to get EmbeddingService | |
| def get_embedding_service() -> EmbeddingService: | |
| return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY")) | |
| # def get_db_from_env(): | |
| # return get_db_from_url(os.getenv("DB_URL")) | |
| # Dependency to get HuggingFaceService | |
| def get_huggingface_service() -> HuggingFaceService: | |
| return HuggingFaceService() | |
| # Root endpoint redirects to /docs | |
| async def root(): | |
| return RedirectResponse(url="/docs") | |
| # Health check endpoint | |
| async def health_check(db: Database = Depends(get_db)): | |
| try: | |
| is_healthy = await db.health_check() | |
| if not is_healthy: | |
| raise HTTPException(status_code=500, detail="Database is unhealthy") | |
| return {"status": "healthy"} | |
| except HealthCheckError as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Endpoint to generate embeddings for a list of strings | |
| async def embed( | |
| request: EmbedRequest, | |
| embedding_service: EmbeddingService = Depends(get_embedding_service), | |
| ): | |
| """ | |
| Generate embeddings for a list of strings and return them in the response. | |
| """ | |
| try: | |
| # Step 1: Generate embeddings | |
| logger.info("Generating embeddings for list of texts...") | |
| embeddings = await embedding_service.create_embeddings(request.texts) | |
| return JSONResponse( | |
| content={ | |
| "message": "Embeddings generated successfully.", | |
| "embeddings": embeddings, | |
| "num_texts": len(request.texts), | |
| } | |
| ) | |
| except OpenAIError as e: | |
| logger.error(f"OpenAI API error: {e}") | |
| raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| # Endpoint to create embeddings from a database query | |
| async def create_embedding( | |
| request: CreateEmbeddingRequest, | |
| db: Database = Depends(get_db), | |
| embedding_service: EmbeddingService = Depends(get_embedding_service), | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
| ): | |
| """ | |
| Create embeddings for the target column in the dataset. | |
| """ | |
| try: | |
| embedding_service.model = request.model | |
| embedding_service.batch_size = request.batch_size | |
| # embedding_service.max_concurrent_requests = request.max_concurrent_requests | |
| # Step 1: Query the database | |
| logger.info("Fetching data from the database...") | |
| result = await db.fetch(request.query) | |
| # logger.info(f"{result}") | |
| dataset = Dataset.from_dict(result) | |
| # Step 2: Generate embeddings | |
| dataset = await embedding_service.create_embeddings( | |
| dataset, request.target_column, request.output_column | |
| ) | |
| # Step 3: Push to Hugging Face Hub | |
| await huggingface_service.push_to_hub(dataset, request.dataset_name) | |
| return JSONResponse( | |
| content={ | |
| "message": "Embeddings created and pushed to Hugging Face Hub.", | |
| "dataset_name": request.dataset_name, | |
| "num_rows": len(dataset), | |
| } | |
| ) | |
| except QueryExecutionError as e: | |
| logger.error(f"Database query failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Database query failed: {e}") | |
| except OpenAIError as e: | |
| logger.error(f"OpenAI API error: {e}") | |
| raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}") | |
| except DatasetPushError as e: | |
| logger.error(f"Failed to push dataset: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to push dataset: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| # Endpoint to read embeddings | |
| async def read_embeddings( | |
| request: ReadEmbeddingRequest, | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
| ): | |
| """ | |
| Read embeddings from a Hugging Face dataset. | |
| """ | |
| try: | |
| dataset = await huggingface_service.read_dataset(request.dataset_name) | |
| return dataset.to_dict() | |
| except DatasetNotFoundError as e: | |
| logger.error(f"Dataset not found: {e}") | |
| raise HTTPException(status_code=404, detail=f"Dataset not found: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| # Endpoint to update embeddings | |
| async def update_embeddings( | |
| request: UpdateEmbeddingRequest, | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
| ): | |
| """ | |
| Update embeddings in a Hugging Face dataset by generating embeddings for new data and concatenating it with the existing dataset. | |
| """ | |
| try: | |
| # Call the update_dataset method to generate embeddings, concatenate, and push the updated dataset | |
| updated_dataset = await huggingface_service.update_dataset( | |
| request.dataset_name, | |
| request.updates, | |
| request.target_column, | |
| request.output_column, | |
| ) | |
| return { | |
| "message": "Embeddings updated successfully.", | |
| "dataset_name": request.dataset_name, | |
| "num_rows": len(updated_dataset), | |
| } | |
| except DatasetPushError as e: | |
| logger.error(f"Failed to update dataset: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| # Endpoint to delete embeddings | |
| async def delete_embeddings( | |
| request: DeleteEmbeddingRequest, | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
| ): | |
| """ | |
| Delete embeddings from a Hugging Face dataset. | |
| """ | |
| try: | |
| await huggingface_service.delete_dataset(request.dataset_name) | |
| return { | |
| "message": "Embeddings deleted successfully.", | |
| "dataset_name": request.dataset_name, | |
| } | |
| except DatasetPushError as e: | |
| logger.error(f"Failed to delete columns: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to delete columns: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| async def delete_rows_by_key( | |
| request: DeleteByColumnRequest, | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service) | |
| ): | |
| """ | |
| Deletes specific rows from a Hugging Face dataset based on a key column and values.""" | |
| try: | |
| await huggingface_service.delete_rows_from_dataset( | |
| request.dataset_name, request.key_column, request.keys_to_delete | |
| ) | |
| return { | |
| "message": "Rows deleted succesfully from dataset.", | |
| "dataset_name": request.dataset_name, | |
| "key_column": request.key_column, | |
| "deleted_keys": request.keys_to_delete, | |
| } | |
| except DatasetNotFoundError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"An error occured while deleting rows: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def search_embedding( | |
| request: SearchEmbeddingRequest, | |
| embedding_service: EmbeddingService = Depends(get_embedding_service), | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
| ): | |
| """ | |
| Search for similar texts in a dataset using embeddings. | |
| """ | |
| try: | |
| # Step 1: Generate embeddings for the query texts | |
| logger.info("Generating embeddings for query texts...") | |
| query_embeddings = await embedding_service.create_embeddings(request.texts) | |
| # Step 2: Load the dataset from Hugging Face Hub | |
| logger.info(f"Loading dataset from Hugging Face Hub: {request.dataset_name}...") | |
| dataset = await huggingface_service.read_dataset(request.dataset_name) | |
| # Step 3: Perform cosine similarity search | |
| logger.info("Performing cosine similarity search...") | |
| results = await embedding_service.search_embeddings( | |
| query_embeddings, | |
| dataset, | |
| request.embedding_column, | |
| request.target_column, | |
| request.num_results, | |
| request.additional_columns, | |
| ) | |
| return JSONResponse( | |
| content={ | |
| "message": "Search completed successfully.", | |
| "results": results, | |
| } | |
| ) | |
| except DatasetNotFoundError as e: | |
| logger.error(f"Dataset not found: {e}") | |
| raise HTTPException(status_code=404, detail=f"Dataset not found: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| async def reset_embeddings( | |
| request: ResetEmbeddingsRequest, | |
| db: Database = Depends(get_db_from_url), | |
| embedding_service: EmbeddingService = Depends(get_embedding_service), | |
| huggingface_service: HuggingFaceService = Depends(get_huggingface_service), | |
| ): | |
| """ | |
| Reset embeddings from a Hugging Face dataset by deleting them, then reloading them | |
| using the actual database | |
| """ | |
| postgresql_service = PostgresqlService(db) | |
| try: | |
| # List of rows from database | |
| results = await postgresql_service.get_db_rows_from_dataset_name(request.dataset_name) | |
| # Generation of embeddings for each row | |
| dataset = Dataset.from_dict(results) | |
| dataset_embedded = await embedding_service.create_embeddings(dataset, request.target_column, "embedding") | |
| # Embeddings up-to-date with database will overwrite old dataset | |
| await huggingface_service.push_to_hub(dataset_embedded, request.dataset_name) | |
| return { | |
| "message": "Dataset updated succesfully with up-to-date rows from database", | |
| "dataset_name": request.dataset_name, | |
| "num_rows": len(dataset_embedded) | |
| } | |
| except DatabaseError as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |