|
|
""" |
|
|
Unit tests for database module |
|
|
Comprehensive test coverage for database operations |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import sqlite3 |
|
|
import tempfile |
|
|
import os |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from database import db_manager |
|
|
from database.migrations import MigrationManager, auto_migrate |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def temp_db(): |
|
|
"""Create temporary database for testing""" |
|
|
fd, path = tempfile.mkstemp(suffix='.db') |
|
|
os.close(fd) |
|
|
|
|
|
yield path |
|
|
|
|
|
|
|
|
if os.path.exists(path): |
|
|
os.unlink(path) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def db_instance(temp_db): |
|
|
"""Create database instance for testing""" |
|
|
from database import CryptoDatabase |
|
|
db = CryptoDatabase(temp_db) |
|
|
return db |
|
|
|
|
|
|
|
|
class TestDatabaseInitialization: |
|
|
"""Test database initialization and schema creation""" |
|
|
|
|
|
def test_database_creation(self, temp_db): |
|
|
"""Test that database file is created""" |
|
|
from database import CryptoDatabase |
|
|
db = CryptoDatabase(temp_db) |
|
|
|
|
|
assert os.path.exists(temp_db) |
|
|
assert os.path.getsize(temp_db) > 0 |
|
|
|
|
|
def test_tables_created(self, db_instance): |
|
|
"""Test that all required tables are created""" |
|
|
conn = sqlite3.connect(db_instance.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT name FROM sqlite_master |
|
|
WHERE type='table' |
|
|
""") |
|
|
|
|
|
tables = {row[0] for row in cursor.fetchall()} |
|
|
conn.close() |
|
|
|
|
|
required_tables = {'prices', 'news', 'market_analysis', 'user_queries'} |
|
|
assert required_tables.issubset(tables) |
|
|
|
|
|
def test_indices_created(self, db_instance): |
|
|
"""Test that indices are created""" |
|
|
conn = sqlite3.connect(db_instance.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT name FROM sqlite_master |
|
|
WHERE type='index' |
|
|
""") |
|
|
|
|
|
indices = {row[0] for row in cursor.fetchall()} |
|
|
conn.close() |
|
|
|
|
|
|
|
|
assert len(indices) > 0 |
|
|
|
|
|
|
|
|
class TestPriceOperations: |
|
|
"""Test price data operations""" |
|
|
|
|
|
def test_save_price(self, db_instance): |
|
|
"""Test saving price data""" |
|
|
price_data = { |
|
|
'symbol': 'BTC', |
|
|
'name': 'Bitcoin', |
|
|
'price_usd': 50000.0, |
|
|
'volume_24h': 1000000000, |
|
|
'market_cap': 950000000000, |
|
|
'percent_change_1h': 0.5, |
|
|
'percent_change_24h': 2.3, |
|
|
'percent_change_7d': -1.2, |
|
|
'rank': 1 |
|
|
} |
|
|
|
|
|
result = db_instance.save_price(price_data) |
|
|
assert result is True |
|
|
|
|
|
def test_get_latest_prices(self, db_instance): |
|
|
"""Test retrieving latest prices""" |
|
|
|
|
|
for i in range(10): |
|
|
price_data = { |
|
|
'symbol': f'TEST{i}', |
|
|
'name': f'Test Coin {i}', |
|
|
'price_usd': 100.0 * (i + 1), |
|
|
'volume_24h': 1000000, |
|
|
'market_cap': 10000000, |
|
|
'rank': i + 1 |
|
|
} |
|
|
db_instance.save_price(price_data) |
|
|
|
|
|
prices = db_instance.get_latest_prices(limit=5) |
|
|
|
|
|
assert len(prices) == 5 |
|
|
assert prices[0]['rank'] == 1 |
|
|
|
|
|
def test_get_historical_prices(self, db_instance): |
|
|
"""Test retrieving historical prices""" |
|
|
|
|
|
for i in range(5): |
|
|
price_data = { |
|
|
'symbol': 'BTC', |
|
|
'name': 'Bitcoin', |
|
|
'price_usd': 50000.0 + (i * 100), |
|
|
'volume_24h': 1000000000, |
|
|
'market_cap': 950000000000, |
|
|
'rank': 1 |
|
|
} |
|
|
db_instance.save_price(price_data) |
|
|
|
|
|
prices = db_instance.get_historical_prices('BTC', days=7) |
|
|
|
|
|
assert len(prices) > 0 |
|
|
assert all(p['symbol'] == 'BTC' for p in prices) |
|
|
|
|
|
|
|
|
class TestNewsOperations: |
|
|
"""Test news data operations""" |
|
|
|
|
|
def test_save_news(self, db_instance): |
|
|
"""Test saving news article""" |
|
|
news_data = { |
|
|
'title': 'Test Article', |
|
|
'summary': 'This is a test summary', |
|
|
'url': 'https://example.com/test', |
|
|
'source': 'Test Source', |
|
|
'sentiment_score': 0.8, |
|
|
'sentiment_label': 'positive' |
|
|
} |
|
|
|
|
|
result = db_instance.save_news(news_data) |
|
|
assert result is True |
|
|
|
|
|
def test_duplicate_news_url(self, db_instance): |
|
|
"""Test that duplicate URLs are rejected""" |
|
|
news_data = { |
|
|
'title': 'Test Article', |
|
|
'summary': 'Summary', |
|
|
'url': 'https://example.com/unique', |
|
|
'source': 'Test' |
|
|
} |
|
|
|
|
|
|
|
|
assert db_instance.save_news(news_data) is True |
|
|
|
|
|
|
|
|
assert db_instance.save_news(news_data) is False |
|
|
|
|
|
def test_get_latest_news(self, db_instance): |
|
|
"""Test retrieving latest news""" |
|
|
|
|
|
for i in range(10): |
|
|
news_data = { |
|
|
'title': f'Article {i}', |
|
|
'summary': f'Summary {i}', |
|
|
'url': f'https://example.com/article{i}', |
|
|
'source': 'Test Source' |
|
|
} |
|
|
db_instance.save_news(news_data) |
|
|
|
|
|
news = db_instance.get_latest_news(limit=5) |
|
|
|
|
|
assert len(news) == 5 |
|
|
assert all('title' in n for n in news) |
|
|
|
|
|
|
|
|
class TestAnalysisOperations: |
|
|
"""Test market analysis operations""" |
|
|
|
|
|
def test_save_analysis(self, db_instance): |
|
|
"""Test saving market analysis""" |
|
|
analysis_data = { |
|
|
'symbol': 'BTC', |
|
|
'timeframe': '24h', |
|
|
'trend': 'bullish', |
|
|
'support_level': 45000.0, |
|
|
'resistance_level': 55000.0, |
|
|
'prediction': 'Price likely to increase', |
|
|
'confidence': 0.75 |
|
|
} |
|
|
|
|
|
result = db_instance.save_analysis(analysis_data) |
|
|
assert result is True |
|
|
|
|
|
def test_get_latest_analysis(self, db_instance): |
|
|
"""Test retrieving latest analysis""" |
|
|
|
|
|
analysis_data = { |
|
|
'symbol': 'BTC', |
|
|
'timeframe': '24h', |
|
|
'trend': 'bullish', |
|
|
'confidence': 0.8 |
|
|
} |
|
|
db_instance.save_analysis(analysis_data) |
|
|
|
|
|
analysis = db_instance.get_latest_analysis('BTC') |
|
|
|
|
|
assert analysis is not None |
|
|
assert analysis['symbol'] == 'BTC' |
|
|
assert analysis['trend'] == 'bullish' |
|
|
|
|
|
|
|
|
class TestMigrations: |
|
|
"""Test database migration system""" |
|
|
|
|
|
def test_migration_manager_init(self, temp_db): |
|
|
"""Test migration manager initialization""" |
|
|
manager = MigrationManager(temp_db) |
|
|
|
|
|
assert len(manager.migrations) > 0 |
|
|
assert manager.get_current_version() == 0 |
|
|
|
|
|
def test_apply_migration(self, temp_db): |
|
|
"""Test applying a single migration""" |
|
|
manager = MigrationManager(temp_db) |
|
|
pending = manager.get_pending_migrations() |
|
|
|
|
|
assert len(pending) > 0 |
|
|
|
|
|
|
|
|
result = manager.apply_migration(pending[0]) |
|
|
assert result is True |
|
|
|
|
|
|
|
|
assert manager.get_current_version() == pending[0].version |
|
|
|
|
|
def test_migrate_to_latest(self, temp_db): |
|
|
"""Test migrating to latest version""" |
|
|
manager = MigrationManager(temp_db) |
|
|
success, applied = manager.migrate_to_latest() |
|
|
|
|
|
assert success is True |
|
|
assert len(applied) > 0 |
|
|
assert manager.get_current_version() == max(applied) |
|
|
|
|
|
def test_auto_migrate(self, temp_db): |
|
|
"""Test auto-migration function""" |
|
|
result = auto_migrate(temp_db) |
|
|
assert result is True |
|
|
|
|
|
|
|
|
class TestDataValidation: |
|
|
"""Test data validation""" |
|
|
|
|
|
def test_price_validation(self, db_instance): |
|
|
"""Test price data validation""" |
|
|
|
|
|
invalid_price = { |
|
|
'symbol': 'BTC', |
|
|
'name': 'Bitcoin', |
|
|
'price_usd': -100.0, |
|
|
'rank': 1 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_required_fields(self, db_instance): |
|
|
"""Test that required fields are enforced""" |
|
|
|
|
|
incomplete_price = { |
|
|
'symbol': 'BTC' |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConcurrency: |
|
|
"""Test concurrent database access""" |
|
|
|
|
|
def test_concurrent_writes(self, db_instance): |
|
|
"""Test concurrent write operations""" |
|
|
import threading |
|
|
|
|
|
def write_price(i): |
|
|
price_data = { |
|
|
'symbol': f'TEST{i}', |
|
|
'name': f'Test {i}', |
|
|
'price_usd': float(i), |
|
|
'rank': i |
|
|
} |
|
|
db_instance.save_price(price_data) |
|
|
|
|
|
threads = [threading.Thread(target=write_price, args=(i,)) for i in range(10)] |
|
|
|
|
|
for t in threads: |
|
|
t.start() |
|
|
|
|
|
for t in threads: |
|
|
t.join() |
|
|
|
|
|
|
|
|
prices = db_instance.get_latest_prices(limit=10) |
|
|
assert len(prices) == 10 |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
pytest.main([__file__, '-v']) |
|
|
|