|
|
""" |
|
|
WebSocket Support Module |
|
|
Provides real-time updates via WebSocket connections with connection management |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
from datetime import datetime |
|
|
from typing import Set, Dict, Any, Optional, List |
|
|
from fastapi import WebSocket, WebSocketDisconnect, APIRouter |
|
|
from starlette.websockets import WebSocketState |
|
|
from utils.logger import setup_logger |
|
|
from database.db_manager import db_manager |
|
|
from monitoring.rate_limiter import rate_limiter |
|
|
from config import config |
|
|
|
|
|
|
|
|
logger = setup_logger("websocket", level="INFO") |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
""" |
|
|
Manages WebSocket connections and broadcasts messages to all connected clients |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize connection manager""" |
|
|
self.active_connections: Set[WebSocket] = set() |
|
|
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {} |
|
|
self._broadcast_task: Optional[asyncio.Task] = None |
|
|
self._heartbeat_task: Optional[asyncio.Task] = None |
|
|
self._is_running = False |
|
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str = None): |
|
|
""" |
|
|
Accept and register a new WebSocket connection |
|
|
|
|
|
Args: |
|
|
websocket: WebSocket connection |
|
|
client_id: Optional client identifier |
|
|
""" |
|
|
await websocket.accept() |
|
|
self.active_connections.add(websocket) |
|
|
|
|
|
|
|
|
self.connection_metadata[websocket] = { |
|
|
'client_id': client_id or f"client_{id(websocket)}", |
|
|
'connected_at': datetime.utcnow().isoformat(), |
|
|
'last_ping': datetime.utcnow().isoformat() |
|
|
} |
|
|
|
|
|
logger.info( |
|
|
f"WebSocket connected: {self.connection_metadata[websocket]['client_id']} " |
|
|
f"(Total connections: {len(self.active_connections)})" |
|
|
) |
|
|
|
|
|
|
|
|
await self.send_personal_message( |
|
|
{ |
|
|
'type': 'connection_established', |
|
|
'client_id': self.connection_metadata[websocket]['client_id'], |
|
|
'timestamp': datetime.utcnow().isoformat(), |
|
|
'message': 'Connected to Crypto API Monitor WebSocket' |
|
|
}, |
|
|
websocket |
|
|
) |
|
|
|
|
|
def disconnect(self, websocket: WebSocket): |
|
|
""" |
|
|
Unregister and close a WebSocket connection |
|
|
|
|
|
Args: |
|
|
websocket: WebSocket connection to disconnect |
|
|
""" |
|
|
if websocket in self.active_connections: |
|
|
client_id = self.connection_metadata.get(websocket, {}).get('client_id', 'unknown') |
|
|
self.active_connections.remove(websocket) |
|
|
|
|
|
if websocket in self.connection_metadata: |
|
|
del self.connection_metadata[websocket] |
|
|
|
|
|
logger.info( |
|
|
f"WebSocket disconnected: {client_id} " |
|
|
f"(Remaining connections: {len(self.active_connections)})" |
|
|
) |
|
|
|
|
|
async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket): |
|
|
""" |
|
|
Send a message to a specific WebSocket connection |
|
|
|
|
|
Args: |
|
|
message: Message dictionary to send |
|
|
websocket: Target WebSocket connection |
|
|
""" |
|
|
try: |
|
|
if websocket.client_state == WebSocketState.CONNECTED: |
|
|
await websocket.send_json(message) |
|
|
except Exception as e: |
|
|
logger.error(f"Error sending personal message: {e}") |
|
|
self.disconnect(websocket) |
|
|
|
|
|
async def broadcast(self, message: Dict[str, Any]): |
|
|
""" |
|
|
Broadcast a message to all connected clients |
|
|
|
|
|
Args: |
|
|
message: Message dictionary to broadcast |
|
|
""" |
|
|
disconnected = [] |
|
|
|
|
|
for connection in self.active_connections.copy(): |
|
|
try: |
|
|
if connection.client_state == WebSocketState.CONNECTED: |
|
|
await connection.send_json(message) |
|
|
else: |
|
|
disconnected.append(connection) |
|
|
except Exception as e: |
|
|
logger.error(f"Error broadcasting to client: {e}") |
|
|
disconnected.append(connection) |
|
|
|
|
|
|
|
|
for connection in disconnected: |
|
|
self.disconnect(connection) |
|
|
|
|
|
async def broadcast_status_update(self): |
|
|
""" |
|
|
Broadcast system status update to all connected clients |
|
|
""" |
|
|
try: |
|
|
|
|
|
latest_metrics = db_manager.get_latest_system_metrics() |
|
|
|
|
|
|
|
|
providers = config.get_all_providers() |
|
|
|
|
|
|
|
|
rate_limit_statuses = rate_limiter.get_all_statuses() |
|
|
|
|
|
|
|
|
alerts = db_manager.get_alerts(acknowledged=False, hours=1) |
|
|
|
|
|
|
|
|
message = { |
|
|
'type': 'status_update', |
|
|
'timestamp': datetime.utcnow().isoformat(), |
|
|
'system_metrics': { |
|
|
'total_providers': latest_metrics.total_providers if latest_metrics else len(providers), |
|
|
'online_count': latest_metrics.online_count if latest_metrics else 0, |
|
|
'degraded_count': latest_metrics.degraded_count if latest_metrics else 0, |
|
|
'offline_count': latest_metrics.offline_count if latest_metrics else 0, |
|
|
'avg_response_time_ms': latest_metrics.avg_response_time_ms if latest_metrics else 0, |
|
|
'total_requests_hour': latest_metrics.total_requests_hour if latest_metrics else 0, |
|
|
'total_failures_hour': latest_metrics.total_failures_hour if latest_metrics else 0, |
|
|
'system_health': latest_metrics.system_health if latest_metrics else 'unknown' |
|
|
}, |
|
|
'alert_count': len(alerts), |
|
|
'active_websocket_clients': len(self.active_connections) |
|
|
} |
|
|
|
|
|
await self.broadcast(message) |
|
|
logger.debug(f"Broadcasted status update to {len(self.active_connections)} clients") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error broadcasting status update: {e}", exc_info=True) |
|
|
|
|
|
async def broadcast_new_log_entry(self, log_type: str, log_data: Dict[str, Any]): |
|
|
""" |
|
|
Broadcast a new log entry |
|
|
|
|
|
Args: |
|
|
log_type: Type of log (connection, failure, collection, rate_limit) |
|
|
log_data: Log data dictionary |
|
|
""" |
|
|
try: |
|
|
message = { |
|
|
'type': 'new_log_entry', |
|
|
'timestamp': datetime.utcnow().isoformat(), |
|
|
'log_type': log_type, |
|
|
'data': log_data |
|
|
} |
|
|
|
|
|
await self.broadcast(message) |
|
|
logger.debug(f"Broadcasted new {log_type} log entry") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error broadcasting log entry: {e}", exc_info=True) |
|
|
|
|
|
async def broadcast_rate_limit_alert(self, provider_name: str, percentage: float): |
|
|
""" |
|
|
Broadcast rate limit alert |
|
|
|
|
|
Args: |
|
|
provider_name: Provider name |
|
|
percentage: Current usage percentage |
|
|
""" |
|
|
try: |
|
|
message = { |
|
|
'type': 'rate_limit_alert', |
|
|
'timestamp': datetime.utcnow().isoformat(), |
|
|
'provider': provider_name, |
|
|
'percentage': percentage, |
|
|
'severity': 'critical' if percentage >= 95 else 'warning' |
|
|
} |
|
|
|
|
|
await self.broadcast(message) |
|
|
logger.info(f"Broadcasted rate limit alert for {provider_name} ({percentage}%)") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error broadcasting rate limit alert: {e}", exc_info=True) |
|
|
|
|
|
async def broadcast_provider_status_change( |
|
|
self, |
|
|
provider_name: str, |
|
|
old_status: str, |
|
|
new_status: str, |
|
|
details: Optional[Dict] = None |
|
|
): |
|
|
""" |
|
|
Broadcast provider status change |
|
|
|
|
|
Args: |
|
|
provider_name: Provider name |
|
|
old_status: Previous status |
|
|
new_status: New status |
|
|
details: Optional details about the change |
|
|
""" |
|
|
try: |
|
|
message = { |
|
|
'type': 'provider_status_change', |
|
|
'timestamp': datetime.utcnow().isoformat(), |
|
|
'provider': provider_name, |
|
|
'old_status': old_status, |
|
|
'new_status': new_status, |
|
|
'details': details or {} |
|
|
} |
|
|
|
|
|
await self.broadcast(message) |
|
|
logger.info( |
|
|
f"Broadcasted provider status change: {provider_name} " |
|
|
f"{old_status} -> {new_status}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error broadcasting provider status change: {e}", exc_info=True) |
|
|
|
|
|
async def _periodic_broadcast_loop(self): |
|
|
""" |
|
|
Background task that broadcasts updates every 10 seconds |
|
|
""" |
|
|
logger.info("Starting periodic broadcast loop") |
|
|
|
|
|
while self._is_running: |
|
|
try: |
|
|
|
|
|
await self.broadcast_status_update() |
|
|
|
|
|
|
|
|
rate_limit_statuses = rate_limiter.get_all_statuses() |
|
|
for provider, status_data in rate_limit_statuses.items(): |
|
|
if status_data and status_data.get('percentage', 0) >= 80: |
|
|
await self.broadcast_rate_limit_alert( |
|
|
provider, |
|
|
status_data['percentage'] |
|
|
) |
|
|
|
|
|
|
|
|
await asyncio.sleep(10) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in periodic broadcast loop: {e}", exc_info=True) |
|
|
await asyncio.sleep(10) |
|
|
|
|
|
logger.info("Periodic broadcast loop stopped") |
|
|
|
|
|
async def _heartbeat_loop(self): |
|
|
""" |
|
|
Background task that sends heartbeat pings to all clients |
|
|
""" |
|
|
logger.info("Starting heartbeat loop") |
|
|
|
|
|
while self._is_running: |
|
|
try: |
|
|
|
|
|
ping_message = { |
|
|
'type': 'ping', |
|
|
'timestamp': datetime.utcnow().isoformat() |
|
|
} |
|
|
|
|
|
await self.broadcast(ping_message) |
|
|
|
|
|
|
|
|
await asyncio.sleep(30) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in heartbeat loop: {e}", exc_info=True) |
|
|
await asyncio.sleep(30) |
|
|
|
|
|
logger.info("Heartbeat loop stopped") |
|
|
|
|
|
async def start_background_tasks(self): |
|
|
""" |
|
|
Start background broadcast and heartbeat tasks |
|
|
""" |
|
|
if self._is_running: |
|
|
logger.warning("Background tasks already running") |
|
|
return |
|
|
|
|
|
self._is_running = True |
|
|
|
|
|
|
|
|
self._broadcast_task = asyncio.create_task(self._periodic_broadcast_loop()) |
|
|
logger.info("Started periodic broadcast task") |
|
|
|
|
|
|
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) |
|
|
logger.info("Started heartbeat task") |
|
|
|
|
|
async def stop_background_tasks(self): |
|
|
""" |
|
|
Stop background broadcast and heartbeat tasks |
|
|
""" |
|
|
if not self._is_running: |
|
|
logger.warning("Background tasks not running") |
|
|
return |
|
|
|
|
|
self._is_running = False |
|
|
|
|
|
|
|
|
if self._broadcast_task: |
|
|
self._broadcast_task.cancel() |
|
|
try: |
|
|
await self._broadcast_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
logger.info("Stopped periodic broadcast task") |
|
|
|
|
|
|
|
|
if self._heartbeat_task: |
|
|
self._heartbeat_task.cancel() |
|
|
try: |
|
|
await self._heartbeat_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
logger.info("Stopped heartbeat task") |
|
|
|
|
|
async def close_all_connections(self): |
|
|
""" |
|
|
Close all active WebSocket connections |
|
|
""" |
|
|
logger.info(f"Closing {len(self.active_connections)} active connections") |
|
|
|
|
|
for connection in self.active_connections.copy(): |
|
|
try: |
|
|
if connection.client_state == WebSocketState.CONNECTED: |
|
|
await connection.close(code=1000, reason="Server shutdown") |
|
|
except Exception as e: |
|
|
logger.error(f"Error closing connection: {e}") |
|
|
|
|
|
self.active_connections.clear() |
|
|
self.connection_metadata.clear() |
|
|
logger.info("All WebSocket connections closed") |
|
|
|
|
|
def get_connection_count(self) -> int: |
|
|
""" |
|
|
Get the number of active connections |
|
|
|
|
|
Returns: |
|
|
Number of active connections |
|
|
""" |
|
|
return len(self.active_connections) |
|
|
|
|
|
def get_connection_info(self) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Get information about all active connections |
|
|
|
|
|
Returns: |
|
|
List of connection metadata dictionaries |
|
|
""" |
|
|
return [ |
|
|
{ |
|
|
'client_id': metadata['client_id'], |
|
|
'connected_at': metadata['connected_at'], |
|
|
'last_ping': metadata['last_ping'] |
|
|
} |
|
|
for metadata in self.connection_metadata.values() |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
|
|
|
@router.websocket("/ws/live") |
|
|
async def websocket_live_endpoint(websocket: WebSocket): |
|
|
""" |
|
|
WebSocket endpoint for real-time updates |
|
|
|
|
|
Provides: |
|
|
- System status updates every 10 seconds |
|
|
- Real-time log entries |
|
|
- Rate limit alerts |
|
|
- Provider status changes |
|
|
- Heartbeat pings every 30 seconds |
|
|
|
|
|
Message Types: |
|
|
- connection_established: Sent when client connects |
|
|
- status_update: Periodic system status (every 10s) |
|
|
- new_log_entry: New log entry notification |
|
|
- rate_limit_alert: Rate limit warning |
|
|
- provider_status_change: Provider status change |
|
|
- ping: Heartbeat ping (every 30s) |
|
|
""" |
|
|
client_id = None |
|
|
|
|
|
try: |
|
|
|
|
|
await manager.connect(websocket) |
|
|
client_id = manager.connection_metadata.get(websocket, {}).get('client_id', 'unknown') |
|
|
|
|
|
|
|
|
if not manager._is_running: |
|
|
await manager.start_background_tasks() |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
|
|
|
|
|
|
try: |
|
|
message = json.loads(data) |
|
|
|
|
|
|
|
|
if message.get('type') == 'pong': |
|
|
if websocket in manager.connection_metadata: |
|
|
manager.connection_metadata[websocket]['last_ping'] = datetime.utcnow().isoformat() |
|
|
logger.debug(f"Received pong from {client_id}") |
|
|
|
|
|
|
|
|
elif message.get('type') == 'subscribe': |
|
|
|
|
|
logger.debug(f"Client {client_id} subscription request: {message}") |
|
|
|
|
|
|
|
|
elif message.get('type') == 'unsubscribe': |
|
|
logger.debug(f"Client {client_id} unsubscribe request: {message}") |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
logger.warning(f"Received invalid JSON from {client_id}: {data}") |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
logger.info(f"Client {client_id} disconnected") |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error handling message from {client_id}: {e}", exc_info=True) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"WebSocket error for {client_id}: {e}", exc_info=True) |
|
|
|
|
|
finally: |
|
|
|
|
|
manager.disconnect(websocket) |
|
|
|
|
|
|
|
|
@router.get("/ws/stats") |
|
|
async def websocket_stats(): |
|
|
""" |
|
|
Get WebSocket connection statistics |
|
|
|
|
|
Returns: |
|
|
Dictionary with connection stats |
|
|
""" |
|
|
return { |
|
|
'active_connections': manager.get_connection_count(), |
|
|
'connections': manager.get_connection_info(), |
|
|
'background_tasks_running': manager._is_running, |
|
|
'timestamp': datetime.utcnow().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['router', 'manager', 'ConnectionManager'] |
|
|
|