Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Union, Optional | |
| import tiktoken | |
| from ..embedding_provider import EmbeddingProvider | |
| import numpy as np | |
| class OpenAIEmbedding(EmbeddingProvider): | |
| def __init__( | |
| self, | |
| api_key: Optional[str] = None, | |
| model: str = "text-embedding-3-small", | |
| max_tokens: int = 8191 | |
| ) -> None: | |
| """Initialize OpenAI embedding provider | |
| Args: | |
| model_name (str, optional): Name of the embedding model. Default to "text-embedding-3-small" | |
| more info: https://platform.openai.com/docs/models#embeddings | |
| api_key: api_key for OpenAI | |
| """ | |
| from openai import OpenAI | |
| self.client = OpenAI(api_key=api_key) | |
| self.model = model | |
| self.max_tokens = max_tokens | |
| self.tokenizer = tiktoken.encoding_for_model(model) | |
| def _trancated_text(self, text: str) -> str: | |
| """Truncate text into maximum token length | |
| Args: | |
| text (str): Input text | |
| Returns: | |
| str: Truncated text | |
| """ | |
| tokens = self.tokenizer.encode(text) | |
| truncated_tokens = tokens[:self.max_tokens] | |
| return self.tokenizer.decode(truncated_tokens) | |
| def embed_documents( | |
| self, | |
| documents: List[str], | |
| batch_size: int = 100 | |
| ) -> np.array: | |
| """Embed a list of documents | |
| Args: | |
| documents (List[str]): List of documents to embed | |
| Returns: | |
| np.array: embeddings of documents | |
| """ | |
| truncated_docs = [self._trancated_text(doc) for doc in documents] | |
| embeddings = [] | |
| for i in range(0, len(truncated_docs), batch_size): | |
| batch = truncated_docs[i: i+batch_size] | |
| response = self.client.embeddings.create( | |
| input=batch, | |
| model=self.model | |
| ) | |
| batch_embeddings = [ | |
| embed.embedding for embed in response.data | |
| ] | |
| embeddings.extend(batch_embeddings) | |
| return np.array(embeddings) | |
| def embed_query(self, query): | |
| truncated_query = self._trancated_text(query) | |
| response = self.client.embeddings.create( | |
| input=[truncated_query], | |
| model=self.model | |
| ) | |
| return np.array(response.data[0].embedding) | |
| def get_embedding_info(self) -> Dict[str, Union[str, int]]: | |
| """ | |
| Get information about the current embedding configuration | |
| Returns: | |
| Dict: Embedding configuration details | |
| """ | |
| return { | |
| "model": self.model, | |
| "max_tokens": self.max_tokens, | |
| "batch_size": 100, # Default batch size | |
| } | |
| def list_available_models(self) -> List[str]: | |
| """ | |
| List available OpenAI embedding models | |
| Returns: | |
| List[str]: Available embedding model names | |
| """ | |
| return [ | |
| "text-embedding-ada-002", # Most common | |
| "text-embedding-3-small", # Newer, more efficient | |
| "text-embedding-3-large" # Highest quality | |
| ] | |
| def estimate_cost(self, num_documents: int) -> float: | |
| """ | |
| Estimate embedding cost | |
| Args: | |
| num_documents (int): Number of documents to embed | |
| Returns: | |
| float: Estimated cost in USD | |
| """ | |
| # Pricing as of 2024 (subject to change) | |
| pricing = { | |
| "text-embedding-ada-002": 0.0001 / 1000, # $0.0001 per 1000 tokens | |
| "text-embedding-3-small": 0.00006 / 1000, | |
| "text-embedding-3-large": 0.00013 / 1000 | |
| } | |
| # Estimate tokens (assuming ~100 tokens per document) | |
| total_tokens = num_documents * 100 | |
| return total_tokens * pricing.get(self.model, pricing["text-embedding-ada-002"]) |