Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from typing import AsyncGenerator, List, Optional, Dict | |
| from pydantic_settings import BaseSettings | |
| from pydantic import PostgresDsn | |
| import pg8000 | |
| from pg8000 import Connection | |
| from pg8000.exceptions import DatabaseError as Pg8000DatabaseError | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| from threading import Lock | |
| from urllib.parse import urlparse | |
| # Set up structured logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class DatabaseSettings(BaseSettings): | |
| db_url: PostgresDsn | |
| pool_size: int = 5 # Default pool size is 5 | |
| class Config: | |
| env_file = ".env" | |
| # Custom database errors | |
| class DatabaseError(Exception): | |
| """Base exception for database errors.""" | |
| pass | |
| class ConnectionError(DatabaseError): | |
| """Exception raised when a database connection fails.""" | |
| pass | |
| class PoolExhaustedError(DatabaseError): | |
| """Exception raised when the connection pool is exhausted.""" | |
| pass | |
| class QueryExecutionError(DatabaseError): | |
| """Exception raised when a query execution fails.""" | |
| pass | |
| class HealthCheckError(DatabaseError): | |
| """Exception raised when a health check fails.""" | |
| pass | |
| class Database: | |
| def __init__(self, db_url: PostgresDsn, pool_size: int): | |
| self.db_url = db_url | |
| self.pool_size = pool_size | |
| self.pool: List[Connection] = [] | |
| self.lock = Lock() | |
| async def connect(self) -> None: | |
| """Create a connection pool.""" | |
| try: | |
| # Convert PostgresDsn to a string | |
| db_url_str = str(self.db_url) | |
| result = urlparse(db_url_str) | |
| for _ in range(self.pool_size): | |
| conn = pg8000.connect( | |
| user=result.username, | |
| password=result.password, | |
| host=result.hostname, | |
| port=result.port or 5432, | |
| database=result.path.lstrip("/"), | |
| ) | |
| self.pool.append(conn) | |
| logger.info( | |
| f"Database connection pool created with {self.pool_size} connections." | |
| ) | |
| except Pg8000DatabaseError as e: | |
| logger.error(f"Failed to create database connection pool: {e}") | |
| raise ConnectionError("Failed to create database connection pool.") from e | |
| async def disconnect(self) -> None: | |
| """Close all connections in the pool.""" | |
| with self.lock: | |
| for conn in self.pool: | |
| conn.close() | |
| self.pool.clear() | |
| logger.info("Database connection pool closed.") | |
| async def get_connection(self) -> AsyncGenerator[Connection, None]: | |
| """Acquire a connection from the pool.""" | |
| with self.lock: | |
| if not self.pool: | |
| logger.error("Connection pool is exhausted.") | |
| raise PoolExhaustedError("No available connections in the pool.") | |
| conn = self.pool.pop() | |
| try: | |
| yield conn | |
| except Pg8000DatabaseError as e: | |
| logger.error(f"Connection error: {e}") | |
| raise ConnectionError("Failed to use database connection.") from e | |
| finally: | |
| with self.lock: | |
| self.pool.append(conn) | |
| async def fetch(self, query: str, *args) -> Dict[str, List]: | |
| """ | |
| Execute a SELECT query and return the results as a dictionary of lists. | |
| Args: | |
| query (str): The SQL query to execute. | |
| *args: Query parameters. | |
| Returns: | |
| Dict[str, List]: A dictionary where keys are column names and values are lists of column values. | |
| Raises: | |
| QueryExecutionError: If the query execution fails. | |
| """ | |
| try: | |
| async with self.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(query, args) | |
| rows = cursor.fetchall() | |
| columns = [desc[0] for desc in cursor.description] | |
| # Convert the list of dictionaries into a dictionary of lists | |
| data_dict = {column: [] for column in columns} | |
| for row in rows: | |
| for i, value in enumerate(row): | |
| data_dict[columns[i]].append(value) | |
| return data_dict | |
| except Pg8000DatabaseError as e: | |
| logger.error(f"Query execution failed: {e}") | |
| raise QueryExecutionError(f"Failed to execute query: {query}") from e | |
| async def execute(self, query: str, *args) -> None: | |
| """ | |
| Execute an INSERT, UPDATE, or DELETE query. | |
| Args: | |
| query (str): The SQL query to execute. | |
| *args: Query parameters. | |
| Raises: | |
| QueryExecutionError: If the query execution fails. | |
| """ | |
| try: | |
| async with self.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(query, args) | |
| conn.commit() | |
| except Pg8000DatabaseError as e: | |
| logger.error(f"Query execution failed: {e}") | |
| raise QueryExecutionError(f"Failed to execute query: {query}") from e | |
| async def health_check(self) -> bool: | |
| """ | |
| Perform a health check by executing a simple query (e.g., SELECT 1). | |
| Returns: | |
| bool: True if the database is healthy, False otherwise. | |
| Raises: | |
| HealthCheckError: If the health check fails. | |
| """ | |
| try: | |
| async with self.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT 1") | |
| result = cursor.fetchone() | |
| cursor.close() | |
| # Check if the result is as expected | |
| if result and result[0] == 1: | |
| logger.info("Database health check succeeded.") | |
| return True | |
| else: | |
| logger.error("Database health check failed: Unexpected result.") | |
| raise HealthCheckError("Unexpected result from health check query.") | |
| except Pg8000DatabaseError as e: | |
| logger.error(f"Health check failed: {e}") | |
| raise HealthCheckError("Failed to perform health check.") from e | |
| async def get_db_from_url() -> AsyncGenerator[Database, None]: | |
| db = Database(db_url=os.getenv("DB_URL"), pool_size=5) | |
| await db.connect() | |
| try: | |
| yield db | |
| finally: | |
| await db.disconnect() | |
| # Dependency to get the database instance | |
| async def get_db() -> AsyncGenerator[Database, None]: | |
| settings = DatabaseSettings() | |
| db = Database(db_url=settings.db_url, pool_size=settings.pool_size) | |
| await db.connect() | |
| try: | |
| yield db | |
| finally: | |
| await db.disconnect() | |
| # Example usage | |
| if __name__ == "__main__": | |
| async def main(): | |
| settings = DatabaseSettings() | |
| db = Database(db_url=settings.db_url, pool_size=settings.pool_size) | |
| await db.connect() | |
| try: | |
| # Perform a health check | |
| is_healthy = await db.health_check() | |
| print(f"Database health check: {'Success' if is_healthy else 'Failure'}") | |
| except HealthCheckError as e: | |
| print(f"Health check failed: {e}") | |
| finally: | |
| await db.disconnect() | |
| asyncio.run(main()) | |