abhiabhash's picture
also for backend amend these things # AI-Powered Trading Bot for Indian Markets
df7da89 verified
```python
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, Field
import requests
import uvicorn
from typing import Optional, List, Dict
import datetime
import json
import os
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense, Dropout
import yfinance as yf
import alpaca_trade_api as tradeapi
from transformers import pipeline
app = FastAPI(
title="AlgoTradeAI Backend",
version="2.0.0",
description="AI-Powered Trading Platform for Indian Markets"
)
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Database models
class Database:
def __init__(self):
self.users = {}
self.strategies = {}
self.trades = {}
self.market_data = {}
self.portfolio = {}
self.news_data = []
self.ohlcv_cache = {}
self.ai_models = {}
def initialize_models(self):
# Initialize with pre-trained models
self.ai_models = {
'price_prediction': None,
'sentiment_analysis': pipeline("text-classification", model="finiteautomata/bertweet-base-sentiment-analysis"),
'risk_assessment': None
}
db = Database()
# Enhanced Models
class User(BaseModel):
username: str = Field(..., min_length=4, max_length=20)
email: str = Field(..., regex=r"^\S+@\S+\.\S+$")
full_name: Optional[str] = Field(None, max_length=50)
disabled: Optional[bool] = False
permissions: List[str] = ["basic"]
class UserInDB(User):
hashed_password: str
api_keys: Dict[str, str] = {}
last_login: Optional[datetime.datetime] = None
class TradeSignal(BaseModel):
symbol: str
exchange: str = "NSE"
signal_type: str # BUY/SELL/HOLD
price: float
quantity: int
timestamp: datetime.datetime = datetime.datetime.now()
confidence: float = Field(..., ge=0, le=1)
stop_loss: Optional[float] = None
target: Optional[float] = None
holding_period: str = "intraday" # intraday/swing/positional
strategy: str
risk_reward: Optional[float] = None
class Strategy(BaseModel):
name: str
description: str
risk_level: str = "medium" # low/medium/high
created_at: datetime.datetime = datetime.datetime.now()
active: bool = True
performance_metrics: Optional[Dict] = None
parameters: Dict = {}
asset_class: str = "equity" # equity/derivatives/currency
# Market Data APIs
NSE_API_URL = "https://www.nseindia.com/api/"
BSE_API_URL = "https://api.bseindia.com/BseIndiaAPI/api/"
async def fetch_nse_data(symbol: str):
url = f"{NSE_API_URL}quote-equity?symbol={symbol}"
headers = {
"User-Agent": "Mozilla/5.0",
"Accept-Language": "en-US,en;q=0.9",
}
try:
response = requests.get(url, headers=headers)
return response.json()
except Exception as e:
raise HTTPException(status_code=400, detail=f"NSE API Error: {str(e)}")
async def fetch_bse_data(symbol: str):
url = f"{BSE_API_URL}StockTradingwAPIs/getScripHeaderData?scripcode={symbol}&Debtflag=D&series=EQ"
try:
response = requests.get(url)
return response.json()
except Exception as e:
raise HTTPException(status_code=400, detail=f"BSE API Error: {str(e)}")
async def get_ohlcv_data(symbol: str, exchange: str = "NSE", period: str = "1d", interval: str = "1m"):
cache_key = f"{symbol}_{exchange}_{period}_{interval}"
if cache_key in db.ohlcv_cache:
return db.ohlcv_cache[cache_key]
try:
if exchange == "NSE":
data = await fetch_nse_data(symbol)
else:
data = await fetch_bse_data(symbol)
db.ohlcv_cache[cache_key] = data
return data
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to fetch OHLCV data: {str(e)}")
# Enhanced Authentication
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def verify_api_key(api_key: str):
# In production, validate against database
return True
async def get_current_user(token: str = Depends(oauth2_scheme)):
try:
# Verify JWT token in production
user = User(
username="demo_user",
email="[email protected]",
full_name="Demo User",
permissions=["basic", "market_data", "trading"]
)
return user
except Exception:
raise HTTPException(
status_code=401,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# AI Models
def init_price_prediction_model():
model = Sequential([
LSTM(128, return_sequences=True, input_shape=(60, 5)),
Dropout(0.2),
LSTM(64, return_sequences=True),
Dropout(0.2),
LSTM(32),
Dense(3) # Predict next 3 periods
])
model.compile(optimizer='adam', loss='mse')
return model
async def train_price_model(symbol: str):
try:
data = await get_ohlcv_data(symbol)
df = pd.DataFrame(data)
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(df[['open', 'high', 'low', 'close', 'volume']])
X, y = [], []
for i in range(60, len(scaled_data)-3):
X.append(scaled_data[i-60:i])
y.append(scaled_data[i:i+3, 3]) # Close price for next 3 periods
X, y = np.array(X), np.array(y)
model = init_price_prediction_model()
model.fit(X, y, epochs=10, batch_size=32)
db.ai_models['price_prediction'] = model
return {"status": "success", "symbol": symbol}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Model training failed: {str(e)}")
# API Endpoints
@app.get("/", tags=["Root"])
async def root():
return {
"message": "Welcome to AlgoTradeAI API",
"version": "2.0.0",
"endpoints": {
"market": "/market/{symbol}",
"signals": "/signals",
"strategies": "/strategies",
"portfolio": "/portfolio"
}
}
@app.get("/market/{symbol}", tags=["Market Data"])
async def get_symbol_data(
symbol: str,
exchange: str = "NSE",
current_user: User = Depends(get_current_user)
):
"""Get real-time market data for a symbol"""
if "market_data" not in current_user.permissions:
raise HTTPException(status_code=403, detail="Market data access not allowed")
try:
data = await get_ohlcv_data(symbol, exchange)
# Technical indicators
df = pd.DataFrame(data)
df['sma_20'] = df['close'].rolling(20).mean()
df['rsi_14'] = 100 - (100 / (1 + (df['close'].diff().clip(lower=0).rolling(14).mean() /
-df['close'].diff().clip(upper=0).rolling(14).mean())))
return {
"symbol": symbol,
"exchange": exchange,
"last_price": df['close'].iloc[-1],
"volume": df['volume'].iloc[-1],
"indicators": {
"sma_20": df['sma_20'].iloc[-1],
"rsi_14": df['rsi_14'].iloc[-1],
"atr_14": None, # Will be calculated
"macd": None # Will be calculated
},
"ohlcv": df.tail(10).to_dict('records')
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/signals/", tags=["Signals"])
async def create_signal(
signal: TradeSignal,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user)
):
"""Create a new trading signal"""
if "trading" not in current_user.permissions:
raise HTTPException(status_code=403, detail="Trading not allowed")
signal_id = f"signal_{len(db.trades) + 1}"
signal_dict = signal.dict()
# Calculate risk/reward if not provided
if signal.stop_loss and signal.target:
risk = abs(signal.price - signal.stop_loss)
reward = abs(signal.target - signal.price)
signal_dict['risk_reward'] = round(reward/risk, 2)
db.trades[signal_id] = signal_dict
# In production, this would trigger order execution
background_tasks.add_task(process_signal_execution, signal_dict)
return {
"status": "success",
"signal_id": signal_id,
"signal": signal_dict
}
async def process_signal_execution(signal: dict):
"""Background task to handle signal execution"""
# This would connect to broker API and place orders
# For demo, we just log it
print(f"Processing signal: {signal}")
@app.get("/signals/", response_model=List[TradeSignal], tags=["Signals"])
async def get_signals(
limit: int = 10,
strategy: Optional[str] = None,
status: Optional[str] = None,
current_user: User = Depends(get_current_user)
):
"""Get trading signals with filters"""
signals = list(db.trades.values())
if strategy:
signals = [s for s in signals if s.get('strategy') == strategy]
if status:
signals = [s for s in signals if s.get('status') == status]
return signals[:limit]
@app.post("/strategies/", tags=["Strategies"])
async def create_strategy(
strategy: Strategy,
current_user: User = Depends(get_current_user)
):
"""Create a new trading strategy"""
strategy_id = f"strategy_{len(db.strategies) + 1}"
strategy_dict = strategy.dict()
# Initialize performance metrics
strategy_dict['performance_metrics'] = {
"win_rate": 0,
"profit_factor": 0,
"sharpe_ratio": 0,
"max_drawdown": 0,
"total_return": 0
}
db.strategies[strategy_id] = strategy_dict
return {"strategy_id": strategy_id, **strategy_dict}
@app.get("/strategies/", response_model=List[Strategy], tags=["Strategies"])
async def get_strategies(
active: Optional[bool] = None,
risk_level: Optional[str] = None,
asset_class: Optional[str] = None,
current_user: User = Depends(get_current_user)
):
"""Get trading strategies with filters"""
strategies = list(db.strategies.values())
if active is not None:
strategies = [s for s in strategies if s['active'] == active]
if risk_level:
strategies = [s for s in strategies if s['risk_level'] == risk_level]
if asset_class:
strategies = [s for s in strategies if s['asset_class'] == asset_class]
return strategies
@app.get("/indices/", tags=["Market Data"])
async def get_indices(current_user: User = Depends(get_current_user)):
"""Get major Indian indices data"""
if "market_data" not in current_user.permissions:
raise HTTPException(status_code=403, detail="Market data access not allowed")
try:
# In production, fetch from live API
nifty = await fetch_nse_data("NIFTY 50")
banknifty = await fetch_nse_data("BANKNIFTY")
return {
"NIFTY 50": {
"price": nifty['lastPrice'],
"change": nifty['change'],
"change_percent": nifty['pChange'],
"direction": "up" if nifty['change'] >= 0 else "down",
"high": nifty['intraDayHighLow']['max'],
"low": nifty['intraDayHighLow']['min'],
"volume": nifty['totalTradedVolume']
},
"BANKNIFTY": {
"price": banknifty['lastPrice'],
"change": banknifty['change'],
"change_percent": banknifty['pChange'],
"direction": "up" if banknifty['change'] >= 0 else "down",
"high": banknifty['intraDayHighLow']['max'],
"low": banknifty['intraDayHighLow']['min'],
"volume": banknifty['totalTradedVolume']
},
"INDIA VIX": {
"price": None, # Will fetch from API
"change": None,
"direction": None
}
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to fetch indices: {str(e)}")
@app.post("/train-model/", tags=["AI Models"])
async def train_model(
symbol: str,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user)
):
"""Trigger model training for a symbol"""
if "admin" not in current_user.permissions:
raise HTTPException(status_code=403, detail="Admin access required")
background_tasks.add_task(train_price_model, symbol)
return {"status": "training_started", "symbol": symbol}
@app.get("/portfolio/", tags=["Portfolio"])
async def get_portfolio(current_user: User = Depends(get_current_user)):
"""Get current portfolio holdings"""
if "portfolio" not in current_user.permissions:
raise HTTPException(status_code=403, detail="Portfolio access not allowed")
return {
"holdings": db.portfolio.get(current_user.username, []),
"total_value": sum(h['current_value'] for h in db.portfolio.get(current_user.username, [])),
"performance": {
"daily": 0,
"weekly": 0,
"monthly": 0
}
}
@app.on_event("startup")
async def startup_event():
"""Initialize the application"""
db.initialize_models()
print("AI Trading Bot initialized")
if __name__ == "__main__":
uvicorn.run(
"backend:app",
host="0.0.0.0",
port=8000,
reload=True,
workers=4,
access_log=False
)
```