Spaces:
Runtime error
Runtime error
| from pymilvus import MilvusClient | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Any, Optional, Union | |
| import logging | |
| from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class Database: | |
| """数据库操作类,处理与Milvus的交互""" | |
| def __init__(self): | |
| self.client = MilvusClient( | |
| uri = MILVUS_DB_URL, | |
| token= MILVUS_DB_TOKEN) | |
| self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) | |
| print('初始化模型完成',self.model) | |
| self.collection_name = "stickers" | |
| def init_collection(self) -> bool: | |
| """初始化 Milvus 数据库""" | |
| try: | |
| print('初始化 Milvus 数据库', self.client.list_collections()) | |
| if not len(self.client.list_collections()) > 0: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| dimension=768, | |
| primary_field="id", | |
| auto_id=True | |
| ) | |
| self.client.create_index( | |
| collection_name=self.collection_name, | |
| index_type="IVF_SQ8", | |
| metric_type="COSINE", | |
| params={"nlist": 128}, | |
| index_params={} | |
| ) | |
| logger.info(f"Collection initialized: {self.collection_name}") | |
| print('初始化 Milvus 数据库成功', self.client.list_collections()) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Collection initialization failed: {str(e)}") | |
| return False | |
| def encode_text(self, text: str) -> List[float]: | |
| """将文本编码为向量""" | |
| return self.model.encode(text).tolist() | |
| def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool: | |
| """存储贴纸数据到Milvus""" | |
| try: | |
| vector = self.encode_text(description) | |
| # 处理标签格式 | |
| if isinstance(tags, str): | |
| tags = tags.split(",") | |
| logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}") | |
| self.client.insert( | |
| collection_name=self.collection_name, | |
| data=[{ | |
| "vector": vector, | |
| "title": title, | |
| "description": description, | |
| "tags": tags, | |
| "file_name": file_path, | |
| "image_hash": image_hash | |
| }] | |
| ) | |
| logger.info("Storing to Milvus Success ✅") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to store sticker: {str(e)}") | |
| return False | |
| def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]: | |
| """搜索贴纸""" | |
| if not description: | |
| return [] | |
| try: | |
| text_vector = self.encode_text(description) | |
| logger.info(f"Searching Milvus - query: {description}, limit: {limit}") | |
| results = self.client.search( | |
| collection_name=self.collection_name, | |
| data=[text_vector], | |
| limit=limit, | |
| search_params={ | |
| "metric_type": "COSINE", | |
| }, | |
| output_fields=["title", "description", "tags", "file_name"], | |
| ) | |
| logger.info(f"Search Result: {results}") | |
| return results[0] | |
| except Exception as e: | |
| logger.error(f"Search failed: {str(e)}") | |
| return [] | |
| def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]: | |
| """获取所有贴纸""" | |
| try: | |
| results = self.client.query( | |
| collection_name=self.collection_name, | |
| filter="", | |
| limit=limit, | |
| output_fields=["title", "description", "tags", "file_name", "image_hash"] | |
| ) | |
| logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to get all stickers: {str(e)}") | |
| return [] | |
| def check_image_exists(self, image_hash: str) -> bool: | |
| """检查文件名是否已存在""" | |
| try: | |
| results = self.client.query( | |
| collection_name=self.collection_name, | |
| filter=f"image_hash == '{image_hash}'", | |
| limit=1, | |
| output_fields=["file_name", "image_hash"] | |
| ) | |
| exists = len(results) > 0 | |
| logger.info(f"Check file exists - hash: {image_hash}, exists: {exists}, results: {results}") | |
| return exists | |
| except Exception as e: | |
| logger.error(f"Failed to check file exists: {str(e)}") | |
| return False | |
| def delete_sticker(self, sticker_id: int) -> str: | |
| """删除贴纸""" | |
| try: | |
| logger.info(f"Deleting sticker - id: {sticker_id}") | |
| res = self.client.delete( | |
| collection_name=self.collection_name, | |
| ids=[sticker_id] | |
| ) | |
| logger.info(f"Deleted sticker - id: {sticker_id}") | |
| print(res) | |
| return f"Sticker with ID {sticker_id} deleted successfully" | |
| except Exception as e: | |
| logger.error(f"Failed to delete sticker: {str(e)}") | |
| return f"Failed to delete sticker: {str(e)}" | |
| # 初始化 Milvus 数据库 | |
| # 创建数据库实例 | |
| db = Database() |