File size: 2,458 Bytes
48ae4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Database Initialization and Session Management
"""

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from contextlib import contextmanager
from config import config
from database.models import Base, Provider, ProviderStatusEnum
import logging

logger = logging.getLogger(__name__)

# Create engine
engine = create_engine(
    config.DATABASE_URL,
    connect_args={"check_same_thread": False} if "sqlite" in config.DATABASE_URL else {}
)

# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def init_database():
    """Initialize database and populate providers"""
    try:
        # Create all tables
        Base.metadata.create_all(bind=engine)
        logger.info("Database tables created successfully")

        # Populate providers from config
        db = SessionLocal()
        try:
            for provider_config in config.PROVIDERS:
                existing = db.query(Provider).filter(Provider.name == provider_config.name).first()
                if not existing:
                    provider = Provider(
                        name=provider_config.name,
                        category=provider_config.category,
                        endpoint_url=provider_config.endpoint_url,
                        requires_key=provider_config.requires_key,
                        api_key_masked=mask_api_key(provider_config.api_key) if provider_config.api_key else None,
                        rate_limit_type=provider_config.rate_limit_type,
                        rate_limit_value=provider_config.rate_limit_value,
                        timeout_ms=provider_config.timeout_ms,
                        priority_tier=provider_config.priority_tier,
                        status=ProviderStatusEnum.UNKNOWN
                    )
                    db.add(provider)

            db.commit()
            logger.info(f"Initialized {len(config.PROVIDERS)} providers")
        finally:
            db.close()

    except Exception as e:
        logger.error(f"Database initialization failed: {e}")
        raise


@contextmanager
def get_db() -> Session:
    """Get database session"""
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


def mask_api_key(key: str) -> str:
    """Mask API key showing only first 4 and last 4 characters"""
    if not key or len(key) < 8:
        return "****"
    return f"{key[:4]}...{key[-4:]}"