|
|
from app import rag |
|
|
import shutil |
|
|
import os |
|
|
from fastapi import FastAPI, Depends, HTTPException, status, UploadFile, File |
|
|
from fastapi.security import OAuth2PasswordBearer |
|
|
from sqlmodel import select |
|
|
from app.models import User, UserCreate, UserLogin, Token |
|
|
from app.auth import hash_password, verify_password, create_access_token, create_refresh_token, decode_token |
|
|
from app.database import init_db, get_session |
|
|
from sqlmodel import Session |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
|
|
|
|
|
|
init_db() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
origins = [ |
|
|
"http://localhost:5173", |
|
|
] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
|
|
|
|
UPLOAD_DIR = "./uploads" |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_user(token: str = Depends(oauth2_scheme), session: Session = Depends(get_session)): |
|
|
payload = decode_token(token) |
|
|
if not payload: |
|
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
username = payload.get("sub") |
|
|
user = session.exec(select(User).where(User.username == username)).first() |
|
|
if not user: |
|
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
return user |
|
|
|
|
|
@app.get("/protected") |
|
|
def protected_route(current_user: User = Depends(get_current_user)): |
|
|
return {"message": f"Hello {current_user.username}, you are authenticated!"} |
|
|
|
|
|
@app.post("/upload") |
|
|
def upload_file(file: UploadFile = File(...), current_user: User = Depends(get_current_user)): |
|
|
user_id = current_user.username |
|
|
file_path = f"./uploads/{file.filename}" |
|
|
with open(file_path, "wb") as f: |
|
|
shutil.copyfileobj(file.file, f) |
|
|
rag.add_document(file_path, user_id=user_id) |
|
|
return {"message": "Document uploaded successfully."} |
|
|
|
|
|
@app.get("/ask") |
|
|
def ask(q: str, current_user: User = Depends(get_current_user)): |
|
|
user_id = current_user.username |
|
|
qa = rag.get_qa_chain(user_id=user_id) |
|
|
answer = qa.run(q) |
|
|
return {"question": q, "answer": answer} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/signup", response_model=Token) |
|
|
def signup(user: UserCreate, session: Session = Depends(get_session)): |
|
|
existing_user = session.exec(select(User).where(User.username == user.username)).first() |
|
|
if existing_user: |
|
|
raise HTTPException(status_code=400, detail="Username already exists") |
|
|
db_user = User(username=user.username, hashed_password=hash_password(user.password)) |
|
|
session.add(db_user) |
|
|
session.commit() |
|
|
session.refresh(db_user) |
|
|
access_token = create_access_token({"sub": db_user.username}) |
|
|
refresh_token = create_refresh_token({"sub": db_user.username}) |
|
|
return {"access_token": access_token, "refresh_token": refresh_token} |
|
|
|
|
|
@app.post("/signin", response_model=Token) |
|
|
def signin(user: UserLogin, session: Session = Depends(get_session)): |
|
|
db_user = session.exec(select(User).where(User.username == user.username)).first() |
|
|
if not db_user or not verify_password(user.password, db_user.hashed_password): |
|
|
raise HTTPException(status_code=401, detail="Invalid username or password") |
|
|
access_token = create_access_token({"sub": db_user.username}) |
|
|
refresh_token = create_refresh_token({"sub": db_user.username}) |
|
|
return {"access_token": access_token, "refresh_token": refresh_token} |
|
|
|
|
|
from fastapi import Body |
|
|
|
|
|
@app.post("/refresh", response_model=Token) |
|
|
def refresh_token(refresh_token: str = Body(..., embed=True)): |
|
|
payload = decode_token(refresh_token) |
|
|
if not payload: |
|
|
raise HTTPException(status_code=401, detail="Invalid refresh token") |
|
|
username = payload.get("sub") |
|
|
|
|
|
new_access_token = create_access_token({"sub": username}) |
|
|
|
|
|
return {"access_token": new_access_token, "refresh_token": refresh_token} |
|
|
|
|
|
|
|
|
|
|
|
|