|
|
""" |
|
|
Enhanced Rate Limiting System |
|
|
Implements token bucket and sliding window algorithms for API rate limiting |
|
|
""" |
|
|
|
|
|
import time |
|
|
import threading |
|
|
from typing import Dict, Optional, Tuple |
|
|
from collections import deque |
|
|
from dataclasses import dataclass |
|
|
import logging |
|
|
from functools import wraps |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RateLimitConfig: |
|
|
"""Rate limit configuration""" |
|
|
requests_per_minute: int = 30 |
|
|
requests_per_hour: int = 1000 |
|
|
burst_size: int = 10 |
|
|
|
|
|
|
|
|
class TokenBucket: |
|
|
""" |
|
|
Token bucket algorithm for rate limiting |
|
|
Allows burst traffic while maintaining average rate |
|
|
""" |
|
|
|
|
|
def __init__(self, rate: float, capacity: int): |
|
|
""" |
|
|
Initialize token bucket |
|
|
|
|
|
Args: |
|
|
rate: Tokens per second |
|
|
capacity: Maximum bucket capacity (burst size) |
|
|
""" |
|
|
self.rate = rate |
|
|
self.capacity = capacity |
|
|
self.tokens = capacity |
|
|
self.last_update = time.time() |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
def consume(self, tokens: int = 1) -> bool: |
|
|
""" |
|
|
Try to consume tokens from bucket |
|
|
|
|
|
Args: |
|
|
tokens: Number of tokens to consume |
|
|
|
|
|
Returns: |
|
|
True if successful, False if insufficient tokens |
|
|
""" |
|
|
with self.lock: |
|
|
now = time.time() |
|
|
elapsed = now - self.last_update |
|
|
|
|
|
|
|
|
self.tokens = min( |
|
|
self.capacity, |
|
|
self.tokens + elapsed * self.rate |
|
|
) |
|
|
self.last_update = now |
|
|
|
|
|
|
|
|
if self.tokens >= tokens: |
|
|
self.tokens -= tokens |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_wait_time(self, tokens: int = 1) -> float: |
|
|
""" |
|
|
Get time to wait before tokens are available |
|
|
|
|
|
Args: |
|
|
tokens: Number of tokens needed |
|
|
|
|
|
Returns: |
|
|
Wait time in seconds |
|
|
""" |
|
|
with self.lock: |
|
|
if self.tokens >= tokens: |
|
|
return 0.0 |
|
|
|
|
|
tokens_needed = tokens - self.tokens |
|
|
return tokens_needed / self.rate |
|
|
|
|
|
|
|
|
class SlidingWindowCounter: |
|
|
""" |
|
|
Sliding window algorithm for rate limiting |
|
|
Provides accurate rate limiting over time windows |
|
|
""" |
|
|
|
|
|
def __init__(self, window_seconds: int, max_requests: int): |
|
|
""" |
|
|
Initialize sliding window counter |
|
|
|
|
|
Args: |
|
|
window_seconds: Window size in seconds |
|
|
max_requests: Maximum requests in window |
|
|
""" |
|
|
self.window_seconds = window_seconds |
|
|
self.max_requests = max_requests |
|
|
self.requests: deque = deque() |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
def allow_request(self) -> bool: |
|
|
""" |
|
|
Check if request is allowed |
|
|
|
|
|
Returns: |
|
|
True if allowed, False if rate limit exceeded |
|
|
""" |
|
|
with self.lock: |
|
|
now = time.time() |
|
|
cutoff = now - self.window_seconds |
|
|
|
|
|
|
|
|
while self.requests and self.requests[0] < cutoff: |
|
|
self.requests.popleft() |
|
|
|
|
|
|
|
|
if len(self.requests) < self.max_requests: |
|
|
self.requests.append(now) |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_remaining(self) -> int: |
|
|
"""Get remaining requests in current window""" |
|
|
with self.lock: |
|
|
now = time.time() |
|
|
cutoff = now - self.window_seconds |
|
|
|
|
|
|
|
|
while self.requests and self.requests[0] < cutoff: |
|
|
self.requests.popleft() |
|
|
|
|
|
return max(0, self.max_requests - len(self.requests)) |
|
|
|
|
|
|
|
|
class RateLimiter: |
|
|
""" |
|
|
Comprehensive rate limiter combining multiple algorithms |
|
|
Supports per-IP, per-user, and per-API-key limits |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[RateLimitConfig] = None): |
|
|
""" |
|
|
Initialize rate limiter |
|
|
|
|
|
Args: |
|
|
config: Rate limit configuration |
|
|
""" |
|
|
self.config = config or RateLimitConfig() |
|
|
|
|
|
|
|
|
self.minute_limiters: Dict[str, SlidingWindowCounter] = {} |
|
|
self.hour_limiters: Dict[str, SlidingWindowCounter] = {} |
|
|
self.burst_limiters: Dict[str, TokenBucket] = {} |
|
|
|
|
|
self.lock = threading.Lock() |
|
|
|
|
|
logger.info( |
|
|
f"Rate limiter initialized: " |
|
|
f"{self.config.requests_per_minute}/min, " |
|
|
f"{self.config.requests_per_hour}/hour, " |
|
|
f"burst={self.config.burst_size}" |
|
|
) |
|
|
|
|
|
def check_rate_limit(self, client_id: str) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Check if request is within rate limits |
|
|
|
|
|
Args: |
|
|
client_id: Client identifier (IP, user, or API key) |
|
|
|
|
|
Returns: |
|
|
Tuple of (allowed: bool, error_message: Optional[str]) |
|
|
""" |
|
|
with self.lock: |
|
|
|
|
|
if client_id not in self.minute_limiters: |
|
|
self._create_limiters(client_id) |
|
|
|
|
|
|
|
|
if not self.burst_limiters[client_id].consume(): |
|
|
wait_time = self.burst_limiters[client_id].get_wait_time() |
|
|
return False, f"Rate limit exceeded. Retry after {wait_time:.1f}s" |
|
|
|
|
|
|
|
|
if not self.minute_limiters[client_id].allow_request(): |
|
|
return False, f"Rate limit: {self.config.requests_per_minute} requests/minute exceeded" |
|
|
|
|
|
|
|
|
if not self.hour_limiters[client_id].allow_request(): |
|
|
return False, f"Rate limit: {self.config.requests_per_hour} requests/hour exceeded" |
|
|
|
|
|
return True, None |
|
|
|
|
|
def _create_limiters(self, client_id: str): |
|
|
"""Create limiters for new client""" |
|
|
self.minute_limiters[client_id] = SlidingWindowCounter( |
|
|
window_seconds=60, |
|
|
max_requests=self.config.requests_per_minute |
|
|
) |
|
|
self.hour_limiters[client_id] = SlidingWindowCounter( |
|
|
window_seconds=3600, |
|
|
max_requests=self.config.requests_per_hour |
|
|
) |
|
|
self.burst_limiters[client_id] = TokenBucket( |
|
|
rate=self.config.requests_per_minute / 60.0, |
|
|
capacity=self.config.burst_size |
|
|
) |
|
|
|
|
|
def get_limits_info(self, client_id: str) -> Dict[str, any]: |
|
|
""" |
|
|
Get current limits info for client |
|
|
|
|
|
Args: |
|
|
client_id: Client identifier |
|
|
|
|
|
Returns: |
|
|
Dictionary with limit information |
|
|
""" |
|
|
with self.lock: |
|
|
if client_id not in self.minute_limiters: |
|
|
return { |
|
|
'minute_remaining': self.config.requests_per_minute, |
|
|
'hour_remaining': self.config.requests_per_hour, |
|
|
'burst_available': self.config.burst_size |
|
|
} |
|
|
|
|
|
return { |
|
|
'minute_remaining': self.minute_limiters[client_id].get_remaining(), |
|
|
'hour_remaining': self.hour_limiters[client_id].get_remaining(), |
|
|
'minute_limit': self.config.requests_per_minute, |
|
|
'hour_limit': self.config.requests_per_hour |
|
|
} |
|
|
|
|
|
def reset_client(self, client_id: str): |
|
|
"""Reset rate limits for a client""" |
|
|
with self.lock: |
|
|
self.minute_limiters.pop(client_id, None) |
|
|
self.hour_limiters.pop(client_id, None) |
|
|
self.burst_limiters.pop(client_id, None) |
|
|
logger.info(f"Reset rate limits for client: {client_id}") |
|
|
|
|
|
|
|
|
|
|
|
global_rate_limiter = RateLimiter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rate_limit( |
|
|
requests_per_minute: int = 30, |
|
|
requests_per_hour: int = 1000, |
|
|
get_client_id=lambda: "default" |
|
|
): |
|
|
""" |
|
|
Decorator for rate limiting endpoints |
|
|
|
|
|
Args: |
|
|
requests_per_minute: Max requests per minute |
|
|
requests_per_hour: Max requests per hour |
|
|
get_client_id: Function to extract client ID from request |
|
|
|
|
|
Usage: |
|
|
@rate_limit(requests_per_minute=60) |
|
|
async def my_endpoint(): |
|
|
... |
|
|
""" |
|
|
config = RateLimitConfig( |
|
|
requests_per_minute=requests_per_minute, |
|
|
requests_per_hour=requests_per_hour |
|
|
) |
|
|
limiter = RateLimiter(config) |
|
|
|
|
|
def decorator(func): |
|
|
@wraps(func) |
|
|
async def wrapper(*args, **kwargs): |
|
|
client_id = get_client_id() |
|
|
|
|
|
allowed, error_msg = limiter.check_rate_limit(client_id) |
|
|
|
|
|
if not allowed: |
|
|
|
|
|
|
|
|
raise Exception(f"Rate limit exceeded: {error_msg}") |
|
|
|
|
|
return await func(*args, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_rate_limit(client_id: str) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Check rate limit using global limiter |
|
|
|
|
|
Args: |
|
|
client_id: Client identifier |
|
|
|
|
|
Returns: |
|
|
Tuple of (allowed, error_message) |
|
|
""" |
|
|
return global_rate_limiter.check_rate_limit(client_id) |
|
|
|
|
|
|
|
|
def get_rate_limit_info(client_id: str) -> Dict[str, any]: |
|
|
""" |
|
|
Get rate limit info for client |
|
|
|
|
|
Args: |
|
|
client_id: Client identifier |
|
|
|
|
|
Returns: |
|
|
Rate limit information dictionary |
|
|
""" |
|
|
return global_rate_limiter.get_limits_info(client_id) |
|
|
|