Spaces:
Sleeping
Sleeping
Commit
·
b1341dd
1
Parent(s):
b777c8f
reset embeddings effective
Browse files- src/api/models/embedding_models.py +1 -1
- src/main.py +24 -6
src/api/models/embedding_models.py
CHANGED
|
@@ -68,4 +68,4 @@ class SearchEmbeddingRequest(BaseModel):
|
|
| 68 |
|
| 69 |
|
| 70 |
class ResetEmbeddingsRequest(BaseModel):
|
| 71 |
-
dataset_name: str
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
class ResetEmbeddingsRequest(BaseModel):
|
| 71 |
+
dataset_name: str = "re-mind/product_type_embedding"
|
src/main.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
from fastapi import FastAPI, Depends, HTTPException
|
| 3 |
from fastapi.responses import JSONResponse, RedirectResponse
|
| 4 |
from fastapi.middleware.gzip import GZipMiddleware
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from typing import List, Dict
|
| 7 |
from datasets import Dataset
|
|
@@ -18,6 +19,7 @@ from src.api.models.embedding_models import (
|
|
| 18 |
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
|
| 19 |
from src.api.services.embedding_service import EmbeddingService
|
| 20 |
from src.api.services.huggingface_service import HuggingFaceService
|
|
|
|
| 21 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
| 22 |
|
| 23 |
import logging
|
|
@@ -309,16 +311,32 @@ async def search_embedding(
|
|
| 309 |
@app.post("/reset_embeddings")
|
| 310 |
async def reset_embeddings(
|
| 311 |
request: ResetEmbeddingsRequest,
|
| 312 |
-
db: Database = Depends(get_db_from_url)
|
|
|
|
|
|
|
| 313 |
):
|
| 314 |
"""
|
| 315 |
Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
|
| 316 |
using the actual database
|
| 317 |
"""
|
|
|
|
|
|
|
| 318 |
try:
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 2 |
from fastapi import FastAPI, Depends, HTTPException
|
| 3 |
from fastapi.responses import JSONResponse, RedirectResponse
|
| 4 |
from fastapi.middleware.gzip import GZipMiddleware
|
| 5 |
+
from pg8000 import DatabaseError
|
| 6 |
from pydantic import BaseModel
|
| 7 |
from typing import List, Dict
|
| 8 |
from datasets import Dataset
|
|
|
|
| 19 |
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
|
| 20 |
from src.api.services.embedding_service import EmbeddingService
|
| 21 |
from src.api.services.huggingface_service import HuggingFaceService
|
| 22 |
+
from src.api.services.postgresql_service import PostgresqlService
|
| 23 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
| 24 |
|
| 25 |
import logging
|
|
|
|
| 311 |
@app.post("/reset_embeddings")
|
| 312 |
async def reset_embeddings(
|
| 313 |
request: ResetEmbeddingsRequest,
|
| 314 |
+
db: Database = Depends(get_db_from_url),
|
| 315 |
+
embedding_service: EmbeddingService = Depends(get_embedding_service),
|
| 316 |
+
huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
|
| 317 |
):
|
| 318 |
"""
|
| 319 |
Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
|
| 320 |
using the actual database
|
| 321 |
"""
|
| 322 |
+
postgresql_service = PostgresqlService(db)
|
| 323 |
+
|
| 324 |
try:
|
| 325 |
+
# List of rows from database
|
| 326 |
+
results = await postgresql_service.get_db_rows_from_dataset_name(request.dataset_name)
|
| 327 |
+
|
| 328 |
+
# Generation of embeddings for each row
|
| 329 |
+
dataset = Dataset.from_dict(results)
|
| 330 |
+
target_column = "type" if request.dataset_name == "re-mind/product_type_embedding" else "name"
|
| 331 |
+
dataset_embedded = await embedding_service.create_embeddings(dataset, target_column, "embedding")
|
| 332 |
+
# Embeddings up-to-date with database will overwrite old dataset
|
| 333 |
+
await huggingface_service.push_to_hub(dataset_embedded, request.dataset_name)
|
| 334 |
+
|
| 335 |
+
return {
|
| 336 |
+
"message": "Dataset updated succesfully with up-to-date rows from database",
|
| 337 |
+
"dataset_name": request.dataset_name,
|
| 338 |
+
"num_rows": len(dataset_embedded)
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
except DatabaseError as e:
|
| 342 |
raise HTTPException(status_code=500, detail=str(e))
|