File size: 9,585 Bytes
e4e4574 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
"""
Enhanced Rate Limiting System
Implements token bucket and sliding window algorithms for API rate limiting
"""
import time
import threading
from typing import Dict, Optional, Tuple
from collections import deque
from dataclasses import dataclass
import logging
from functools import wraps
logger = logging.getLogger(__name__)
@dataclass
class RateLimitConfig:
"""Rate limit configuration"""
requests_per_minute: int = 30
requests_per_hour: int = 1000
burst_size: int = 10 # Allow burst requests
class TokenBucket:
"""
Token bucket algorithm for rate limiting
Allows burst traffic while maintaining average rate
"""
def __init__(self, rate: float, capacity: int):
"""
Initialize token bucket
Args:
rate: Tokens per second
capacity: Maximum bucket capacity (burst size)
"""
self.rate = rate
self.capacity = capacity
self.tokens = capacity
self.last_update = time.time()
self.lock = threading.Lock()
def consume(self, tokens: int = 1) -> bool:
"""
Try to consume tokens from bucket
Args:
tokens: Number of tokens to consume
Returns:
True if successful, False if insufficient tokens
"""
with self.lock:
now = time.time()
elapsed = now - self.last_update
# Add tokens based on elapsed time
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.rate
)
self.last_update = now
# Try to consume
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def get_wait_time(self, tokens: int = 1) -> float:
"""
Get time to wait before tokens are available
Args:
tokens: Number of tokens needed
Returns:
Wait time in seconds
"""
with self.lock:
if self.tokens >= tokens:
return 0.0
tokens_needed = tokens - self.tokens
return tokens_needed / self.rate
class SlidingWindowCounter:
"""
Sliding window algorithm for rate limiting
Provides accurate rate limiting over time windows
"""
def __init__(self, window_seconds: int, max_requests: int):
"""
Initialize sliding window counter
Args:
window_seconds: Window size in seconds
max_requests: Maximum requests in window
"""
self.window_seconds = window_seconds
self.max_requests = max_requests
self.requests: deque = deque()
self.lock = threading.Lock()
def allow_request(self) -> bool:
"""
Check if request is allowed
Returns:
True if allowed, False if rate limit exceeded
"""
with self.lock:
now = time.time()
cutoff = now - self.window_seconds
# Remove old requests outside window
while self.requests and self.requests[0] < cutoff:
self.requests.popleft()
# Check limit
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
def get_remaining(self) -> int:
"""Get remaining requests in current window"""
with self.lock:
now = time.time()
cutoff = now - self.window_seconds
# Remove old requests
while self.requests and self.requests[0] < cutoff:
self.requests.popleft()
return max(0, self.max_requests - len(self.requests))
class RateLimiter:
"""
Comprehensive rate limiter combining multiple algorithms
Supports per-IP, per-user, and per-API-key limits
"""
def __init__(self, config: Optional[RateLimitConfig] = None):
"""
Initialize rate limiter
Args:
config: Rate limit configuration
"""
self.config = config or RateLimitConfig()
# Per-client limiters (keyed by IP/user/API key)
self.minute_limiters: Dict[str, SlidingWindowCounter] = {}
self.hour_limiters: Dict[str, SlidingWindowCounter] = {}
self.burst_limiters: Dict[str, TokenBucket] = {}
self.lock = threading.Lock()
logger.info(
f"Rate limiter initialized: "
f"{self.config.requests_per_minute}/min, "
f"{self.config.requests_per_hour}/hour, "
f"burst={self.config.burst_size}"
)
def check_rate_limit(self, client_id: str) -> Tuple[bool, Optional[str]]:
"""
Check if request is within rate limits
Args:
client_id: Client identifier (IP, user, or API key)
Returns:
Tuple of (allowed: bool, error_message: Optional[str])
"""
with self.lock:
# Get or create limiters for this client
if client_id not in self.minute_limiters:
self._create_limiters(client_id)
# Check burst limit (token bucket)
if not self.burst_limiters[client_id].consume():
wait_time = self.burst_limiters[client_id].get_wait_time()
return False, f"Rate limit exceeded. Retry after {wait_time:.1f}s"
# Check minute limit
if not self.minute_limiters[client_id].allow_request():
return False, f"Rate limit: {self.config.requests_per_minute} requests/minute exceeded"
# Check hour limit
if not self.hour_limiters[client_id].allow_request():
return False, f"Rate limit: {self.config.requests_per_hour} requests/hour exceeded"
return True, None
def _create_limiters(self, client_id: str):
"""Create limiters for new client"""
self.minute_limiters[client_id] = SlidingWindowCounter(
window_seconds=60,
max_requests=self.config.requests_per_minute
)
self.hour_limiters[client_id] = SlidingWindowCounter(
window_seconds=3600,
max_requests=self.config.requests_per_hour
)
self.burst_limiters[client_id] = TokenBucket(
rate=self.config.requests_per_minute / 60.0, # per second
capacity=self.config.burst_size
)
def get_limits_info(self, client_id: str) -> Dict[str, any]:
"""
Get current limits info for client
Args:
client_id: Client identifier
Returns:
Dictionary with limit information
"""
with self.lock:
if client_id not in self.minute_limiters:
return {
'minute_remaining': self.config.requests_per_minute,
'hour_remaining': self.config.requests_per_hour,
'burst_available': self.config.burst_size
}
return {
'minute_remaining': self.minute_limiters[client_id].get_remaining(),
'hour_remaining': self.hour_limiters[client_id].get_remaining(),
'minute_limit': self.config.requests_per_minute,
'hour_limit': self.config.requests_per_hour
}
def reset_client(self, client_id: str):
"""Reset rate limits for a client"""
with self.lock:
self.minute_limiters.pop(client_id, None)
self.hour_limiters.pop(client_id, None)
self.burst_limiters.pop(client_id, None)
logger.info(f"Reset rate limits for client: {client_id}")
# Global rate limiter instance
global_rate_limiter = RateLimiter()
# ==================== DECORATORS ====================
def rate_limit(
requests_per_minute: int = 30,
requests_per_hour: int = 1000,
get_client_id=lambda: "default"
):
"""
Decorator for rate limiting endpoints
Args:
requests_per_minute: Max requests per minute
requests_per_hour: Max requests per hour
get_client_id: Function to extract client ID from request
Usage:
@rate_limit(requests_per_minute=60)
async def my_endpoint():
...
"""
config = RateLimitConfig(
requests_per_minute=requests_per_minute,
requests_per_hour=requests_per_hour
)
limiter = RateLimiter(config)
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
client_id = get_client_id()
allowed, error_msg = limiter.check_rate_limit(client_id)
if not allowed:
# Return HTTP 429 Too Many Requests
# Actual implementation depends on framework
raise Exception(f"Rate limit exceeded: {error_msg}")
return await func(*args, **kwargs)
return wrapper
return decorator
# ==================== HELPER FUNCTIONS ====================
def check_rate_limit(client_id: str) -> Tuple[bool, Optional[str]]:
"""
Check rate limit using global limiter
Args:
client_id: Client identifier
Returns:
Tuple of (allowed, error_message)
"""
return global_rate_limiter.check_rate_limit(client_id)
def get_rate_limit_info(client_id: str) -> Dict[str, any]:
"""
Get rate limit info for client
Args:
client_id: Client identifier
Returns:
Rate limit information dictionary
"""
return global_rate_limiter.get_limits_info(client_id)
|