|
|
```python |
|
|
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.security import OAuth2PasswordBearer |
|
|
from pydantic import BaseModel, Field |
|
|
import requests |
|
|
import uvicorn |
|
|
from typing import Optional, List, Dict |
|
|
import datetime |
|
|
import json |
|
|
import os |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from sklearn.preprocessing import MinMaxScaler |
|
|
from tensorflow.keras.models import Sequential, load_model |
|
|
from tensorflow.keras.layers import LSTM, Dense, Dropout |
|
|
import yfinance as yf |
|
|
import alpaca_trade_api as tradeapi |
|
|
from transformers import pipeline |
|
|
|
|
|
app = FastAPI( |
|
|
title="AlgoTradeAI Backend", |
|
|
version="2.0.0", |
|
|
description="AI-Powered Trading Platform for Indian Markets" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
class Database: |
|
|
def __init__(self): |
|
|
self.users = {} |
|
|
self.strategies = {} |
|
|
self.trades = {} |
|
|
self.market_data = {} |
|
|
self.portfolio = {} |
|
|
self.news_data = [] |
|
|
self.ohlcv_cache = {} |
|
|
self.ai_models = {} |
|
|
|
|
|
def initialize_models(self): |
|
|
|
|
|
self.ai_models = { |
|
|
'price_prediction': None, |
|
|
'sentiment_analysis': pipeline("text-classification", model="finiteautomata/bertweet-base-sentiment-analysis"), |
|
|
'risk_assessment': None |
|
|
} |
|
|
|
|
|
db = Database() |
|
|
|
|
|
class User(BaseModel): |
|
|
username: str = Field(..., min_length=4, max_length=20) |
|
|
email: str = Field(..., regex=r"^\S+@\S+\.\S+$") |
|
|
full_name: Optional[str] = Field(None, max_length=50) |
|
|
disabled: Optional[bool] = False |
|
|
permissions: List[str] = ["basic"] |
|
|
|
|
|
class UserInDB(User): |
|
|
hashed_password: str |
|
|
api_keys: Dict[str, str] = {} |
|
|
last_login: Optional[datetime.datetime] = None |
|
|
|
|
|
class TradeSignal(BaseModel): |
|
|
symbol: str |
|
|
exchange: str = "NSE" |
|
|
signal_type: str |
|
|
price: float |
|
|
quantity: int |
|
|
timestamp: datetime.datetime = datetime.datetime.now() |
|
|
confidence: float = Field(..., ge=0, le=1) |
|
|
stop_loss: Optional[float] = None |
|
|
target: Optional[float] = None |
|
|
holding_period: str = "intraday" |
|
|
strategy: str |
|
|
risk_reward: Optional[float] = None |
|
|
|
|
|
class Strategy(BaseModel): |
|
|
name: str |
|
|
description: str |
|
|
risk_level: str = "medium" |
|
|
created_at: datetime.datetime = datetime.datetime.now() |
|
|
active: bool = True |
|
|
performance_metrics: Optional[Dict] = None |
|
|
parameters: Dict = {} |
|
|
asset_class: str = "equity" |
|
|
|
|
|
NSE_API_URL = "https://www.nseindia.com/api/" |
|
|
BSE_API_URL = "https://api.bseindia.com/BseIndiaAPI/api/" |
|
|
|
|
|
async def fetch_nse_data(symbol: str): |
|
|
url = f"{NSE_API_URL}quote-equity?symbol={symbol}" |
|
|
headers = { |
|
|
"User-Agent": "Mozilla/5.0", |
|
|
"Accept-Language": "en-US,en;q=0.9", |
|
|
} |
|
|
try: |
|
|
response = requests.get(url, headers=headers) |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"NSE API Error: {str(e)}") |
|
|
|
|
|
async def fetch_bse_data(symbol: str): |
|
|
url = f"{BSE_API_URL}StockTradingwAPIs/getScripHeaderData?scripcode={symbol}&Debtflag=D&series=EQ" |
|
|
try: |
|
|
response = requests.get(url) |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"BSE API Error: {str(e)}") |
|
|
|
|
|
async def get_ohlcv_data(symbol: str, exchange: str = "NSE", period: str = "1d", interval: str = "1m"): |
|
|
cache_key = f"{symbol}_{exchange}_{period}_{interval}" |
|
|
if cache_key in db.ohlcv_cache: |
|
|
return db.ohlcv_cache[cache_key] |
|
|
|
|
|
try: |
|
|
if exchange == "NSE": |
|
|
data = await fetch_nse_data(symbol) |
|
|
else: |
|
|
data = await fetch_bse_data(symbol) |
|
|
|
|
|
db.ohlcv_cache[cache_key] = data |
|
|
return data |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Failed to fetch OHLCV data: {str(e)}") |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
|
def verify_api_key(api_key: str): |
|
|
|
|
|
return True |
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)): |
|
|
try: |
|
|
|
|
|
user = User( |
|
|
username="demo_user", |
|
|
email="[email protected]", |
|
|
full_name="Demo User", |
|
|
permissions=["basic", "market_data", "trading"] |
|
|
) |
|
|
return user |
|
|
except Exception: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Invalid authentication credentials", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
def init_price_prediction_model(): |
|
|
model = Sequential([ |
|
|
LSTM(128, return_sequences=True, input_shape=(60, 5)), |
|
|
Dropout(0.2), |
|
|
LSTM(64, return_sequences=True), |
|
|
Dropout(0.2), |
|
|
LSTM(32), |
|
|
Dense(3) |
|
|
]) |
|
|
model.compile(optimizer='adam', loss='mse') |
|
|
return model |
|
|
|
|
|
async def train_price_model(symbol: str): |
|
|
try: |
|
|
data = await get_ohlcv_data(symbol) |
|
|
df = pd.DataFrame(data) |
|
|
scaler = MinMaxScaler() |
|
|
scaled_data = scaler.fit_transform(df[['open', 'high', 'low', 'close', 'volume']]) |
|
|
|
|
|
X, y = [], [] |
|
|
for i in range(60, len(scaled_data)-3): |
|
|
X.append(scaled_data[i-60:i]) |
|
|
y.append(scaled_data[i:i+3, 3]) |
|
|
|
|
|
X, y = np.array(X), np.array(y) |
|
|
model = init_price_prediction_model() |
|
|
model.fit(X, y, epochs=10, batch_size=32) |
|
|
db.ai_models['price_prediction'] = model |
|
|
return {"status": "success", "symbol": symbol} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Model training failed: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/", tags=["Root"]) |
|
|
async def root(): |
|
|
return { |
|
|
"message": "Welcome to AlgoTradeAI API", |
|
|
"version": "2.0.0", |
|
|
"endpoints": { |
|
|
"market": "/market/{symbol}", |
|
|
"signals": "/signals", |
|
|
"strategies": "/strategies", |
|
|
"portfolio": "/portfolio" |
|
|
} |
|
|
} |
|
|
@app.get("/market/{symbol}", tags=["Market Data"]) |
|
|
async def get_symbol_data( |
|
|
symbol: str, |
|
|
exchange: str = "NSE", |
|
|
current_user: User = Depends(get_current_user) |
|
|
): |
|
|
"""Get real-time market data for a symbol""" |
|
|
if "market_data" not in current_user.permissions: |
|
|
raise HTTPException(status_code=403, detail="Market data access not allowed") |
|
|
|
|
|
try: |
|
|
data = await get_ohlcv_data(symbol, exchange) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(data) |
|
|
df['sma_20'] = df['close'].rolling(20).mean() |
|
|
df['rsi_14'] = 100 - (100 / (1 + (df['close'].diff().clip(lower=0).rolling(14).mean() / |
|
|
-df['close'].diff().clip(upper=0).rolling(14).mean()))) |
|
|
|
|
|
return { |
|
|
"symbol": symbol, |
|
|
"exchange": exchange, |
|
|
"last_price": df['close'].iloc[-1], |
|
|
"volume": df['volume'].iloc[-1], |
|
|
"indicators": { |
|
|
"sma_20": df['sma_20'].iloc[-1], |
|
|
"rsi_14": df['rsi_14'].iloc[-1], |
|
|
"atr_14": None, |
|
|
"macd": None |
|
|
}, |
|
|
"ohlcv": df.tail(10).to_dict('records') |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
@app.post("/signals/", tags=["Signals"]) |
|
|
async def create_signal( |
|
|
signal: TradeSignal, |
|
|
background_tasks: BackgroundTasks, |
|
|
current_user: User = Depends(get_current_user) |
|
|
): |
|
|
"""Create a new trading signal""" |
|
|
if "trading" not in current_user.permissions: |
|
|
raise HTTPException(status_code=403, detail="Trading not allowed") |
|
|
|
|
|
signal_id = f"signal_{len(db.trades) + 1}" |
|
|
signal_dict = signal.dict() |
|
|
|
|
|
|
|
|
if signal.stop_loss and signal.target: |
|
|
risk = abs(signal.price - signal.stop_loss) |
|
|
reward = abs(signal.target - signal.price) |
|
|
signal_dict['risk_reward'] = round(reward/risk, 2) |
|
|
|
|
|
db.trades[signal_id] = signal_dict |
|
|
|
|
|
|
|
|
background_tasks.add_task(process_signal_execution, signal_dict) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"signal_id": signal_id, |
|
|
"signal": signal_dict |
|
|
} |
|
|
|
|
|
async def process_signal_execution(signal: dict): |
|
|
"""Background task to handle signal execution""" |
|
|
|
|
|
|
|
|
print(f"Processing signal: {signal}") |
|
|
@app.get("/signals/", response_model=List[TradeSignal], tags=["Signals"]) |
|
|
async def get_signals( |
|
|
limit: int = 10, |
|
|
strategy: Optional[str] = None, |
|
|
status: Optional[str] = None, |
|
|
current_user: User = Depends(get_current_user) |
|
|
): |
|
|
"""Get trading signals with filters""" |
|
|
signals = list(db.trades.values()) |
|
|
|
|
|
if strategy: |
|
|
signals = [s for s in signals if s.get('strategy') == strategy] |
|
|
if status: |
|
|
signals = [s for s in signals if s.get('status') == status] |
|
|
|
|
|
return signals[:limit] |
|
|
@app.post("/strategies/", tags=["Strategies"]) |
|
|
async def create_strategy( |
|
|
strategy: Strategy, |
|
|
current_user: User = Depends(get_current_user) |
|
|
): |
|
|
"""Create a new trading strategy""" |
|
|
strategy_id = f"strategy_{len(db.strategies) + 1}" |
|
|
strategy_dict = strategy.dict() |
|
|
|
|
|
|
|
|
strategy_dict['performance_metrics'] = { |
|
|
"win_rate": 0, |
|
|
"profit_factor": 0, |
|
|
"sharpe_ratio": 0, |
|
|
"max_drawdown": 0, |
|
|
"total_return": 0 |
|
|
} |
|
|
|
|
|
db.strategies[strategy_id] = strategy_dict |
|
|
return {"strategy_id": strategy_id, **strategy_dict} |
|
|
@app.get("/strategies/", response_model=List[Strategy], tags=["Strategies"]) |
|
|
async def get_strategies( |
|
|
active: Optional[bool] = None, |
|
|
risk_level: Optional[str] = None, |
|
|
asset_class: Optional[str] = None, |
|
|
current_user: User = Depends(get_current_user) |
|
|
): |
|
|
"""Get trading strategies with filters""" |
|
|
strategies = list(db.strategies.values()) |
|
|
|
|
|
if active is not None: |
|
|
strategies = [s for s in strategies if s['active'] == active] |
|
|
if risk_level: |
|
|
strategies = [s for s in strategies if s['risk_level'] == risk_level] |
|
|
if asset_class: |
|
|
strategies = [s for s in strategies if s['asset_class'] == asset_class] |
|
|
|
|
|
return strategies |
|
|
@app.get("/indices/", tags=["Market Data"]) |
|
|
async def get_indices(current_user: User = Depends(get_current_user)): |
|
|
"""Get major Indian indices data""" |
|
|
if "market_data" not in current_user.permissions: |
|
|
raise HTTPException(status_code=403, detail="Market data access not allowed") |
|
|
|
|
|
try: |
|
|
|
|
|
nifty = await fetch_nse_data("NIFTY 50") |
|
|
banknifty = await fetch_nse_data("BANKNIFTY") |
|
|
|
|
|
return { |
|
|
"NIFTY 50": { |
|
|
"price": nifty['lastPrice'], |
|
|
"change": nifty['change'], |
|
|
"change_percent": nifty['pChange'], |
|
|
"direction": "up" if nifty['change'] >= 0 else "down", |
|
|
"high": nifty['intraDayHighLow']['max'], |
|
|
"low": nifty['intraDayHighLow']['min'], |
|
|
"volume": nifty['totalTradedVolume'] |
|
|
}, |
|
|
"BANKNIFTY": { |
|
|
"price": banknifty['lastPrice'], |
|
|
"change": banknifty['change'], |
|
|
"change_percent": banknifty['pChange'], |
|
|
"direction": "up" if banknifty['change'] >= 0 else "down", |
|
|
"high": banknifty['intraDayHighLow']['max'], |
|
|
"low": banknifty['intraDayHighLow']['min'], |
|
|
"volume": banknifty['totalTradedVolume'] |
|
|
}, |
|
|
"INDIA VIX": { |
|
|
"price": None, |
|
|
"change": None, |
|
|
"direction": None |
|
|
} |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Failed to fetch indices: {str(e)}") |
|
|
|
|
|
@app.post("/train-model/", tags=["AI Models"]) |
|
|
async def train_model( |
|
|
symbol: str, |
|
|
background_tasks: BackgroundTasks, |
|
|
current_user: User = Depends(get_current_user) |
|
|
): |
|
|
"""Trigger model training for a symbol""" |
|
|
if "admin" not in current_user.permissions: |
|
|
raise HTTPException(status_code=403, detail="Admin access required") |
|
|
|
|
|
background_tasks.add_task(train_price_model, symbol) |
|
|
return {"status": "training_started", "symbol": symbol} |
|
|
|
|
|
@app.get("/portfolio/", tags=["Portfolio"]) |
|
|
async def get_portfolio(current_user: User = Depends(get_current_user)): |
|
|
"""Get current portfolio holdings""" |
|
|
if "portfolio" not in current_user.permissions: |
|
|
raise HTTPException(status_code=403, detail="Portfolio access not allowed") |
|
|
|
|
|
return { |
|
|
"holdings": db.portfolio.get(current_user.username, []), |
|
|
"total_value": sum(h['current_value'] for h in db.portfolio.get(current_user.username, [])), |
|
|
"performance": { |
|
|
"daily": 0, |
|
|
"weekly": 0, |
|
|
"monthly": 0 |
|
|
} |
|
|
} |
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize the application""" |
|
|
db.initialize_models() |
|
|
print("AI Trading Bot initialized") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run( |
|
|
"backend:app", |
|
|
host="0.0.0.0", |
|
|
port=8000, |
|
|
reload=True, |
|
|
workers=4, |
|
|
access_log=False |
|
|
) |
|
|
``` |