|
|
""" |
|
|
Unit tests for async API client |
|
|
Test async HTTP operations, retry logic, and error handling |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import aiohttp |
|
|
from unittest.mock import AsyncMock, patch, MagicMock |
|
|
import asyncio |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from utils.async_api_client import AsyncAPIClient, safe_api_call, parallel_api_calls |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
class TestAsyncAPIClient: |
|
|
"""Test AsyncAPIClient class""" |
|
|
|
|
|
async def test_client_initialization(self): |
|
|
"""Test client initialization with context manager""" |
|
|
async with AsyncAPIClient() as client: |
|
|
assert client._session is not None |
|
|
assert isinstance(client._session, aiohttp.ClientSession) |
|
|
|
|
|
async def test_successful_get_request(self): |
|
|
"""Test successful GET request""" |
|
|
mock_response_data = {"status": "success", "data": "test"} |
|
|
|
|
|
async with AsyncAPIClient() as client: |
|
|
with patch.object( |
|
|
client._session, |
|
|
'get', |
|
|
return_value=AsyncMock( |
|
|
json=AsyncMock(return_value=mock_response_data), |
|
|
raise_for_status=MagicMock(), |
|
|
__aenter__=AsyncMock(), |
|
|
__aexit__=AsyncMock() |
|
|
) |
|
|
): |
|
|
result = await client.get("https://api.example.com/data") |
|
|
|
|
|
|
|
|
async def test_retry_on_timeout(self): |
|
|
"""Test retry logic on timeout""" |
|
|
async with AsyncAPIClient(max_retries=3, retry_delay=0.1) as client: |
|
|
|
|
|
client._session.get = AsyncMock(side_effect=asyncio.TimeoutError()) |
|
|
|
|
|
result = await client.get("https://api.example.com/data") |
|
|
|
|
|
|
|
|
assert result is None |
|
|
|
|
|
assert client._session.get.call_count == 3 |
|
|
|
|
|
async def test_retry_on_server_error(self): |
|
|
"""Test retry on 5xx server errors""" |
|
|
|
|
|
pass |
|
|
|
|
|
async def test_no_retry_on_client_error(self): |
|
|
"""Test that client errors (4xx) don't trigger retries""" |
|
|
|
|
|
pass |
|
|
|
|
|
async def test_parallel_requests(self): |
|
|
"""Test parallel request execution""" |
|
|
urls = [ |
|
|
"https://api.example.com/endpoint1", |
|
|
"https://api.example.com/endpoint2", |
|
|
"https://api.example.com/endpoint3" |
|
|
] |
|
|
|
|
|
async with AsyncAPIClient() as client: |
|
|
|
|
|
mock_data = [{"id": i} for i in range(len(urls))] |
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
class TestConvenienceFunctions: |
|
|
"""Test convenience functions""" |
|
|
|
|
|
async def test_safe_api_call(self): |
|
|
"""Test safe_api_call convenience function""" |
|
|
|
|
|
with patch('utils.async_api_client.AsyncAPIClient') as MockClient: |
|
|
mock_instance = MockClient.return_value.__aenter__.return_value |
|
|
mock_instance.get = AsyncMock(return_value={"success": True}) |
|
|
|
|
|
result = await safe_api_call("https://api.example.com/test") |
|
|
|
|
|
|
|
|
async def test_parallel_api_calls(self): |
|
|
"""Test parallel_api_calls convenience function""" |
|
|
urls = ["https://api.example.com/1", "https://api.example.com/2"] |
|
|
|
|
|
with patch('utils.async_api_client.AsyncAPIClient') as MockClient: |
|
|
mock_instance = MockClient.return_value.__aenter__.return_value |
|
|
mock_instance.gather_requests = AsyncMock( |
|
|
return_value=[{"id": 1}, {"id": 2}] |
|
|
) |
|
|
|
|
|
results = await parallel_api_calls(urls) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
pytest.main([__file__, '-v']) |
|
|
|