|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
os.environ['NUMPY_IMPORT'] = 'done' |
|
|
|
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.document_loaders import TextLoader |
|
|
from langchain.document_loaders import PyPDFLoader |
|
|
from langchain.text_splitter import CharacterTextSplitter |
|
|
from app.config import CHROMA_DB_DIR |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.chains import RetrievalQA |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
OPENAI_ROUTER_TOKEN=os.getenv("OPENROUTER") |
|
|
|
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) |
|
|
|
|
|
|
|
|
db = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings) |
|
|
|
|
|
from langchain.docstore.document import Document |
|
|
|
|
|
def add_document(file_path: str, user_id: str): |
|
|
|
|
|
if file_path.lower().endswith(".pdf"): |
|
|
loader = PyPDFLoader(file_path) |
|
|
elif file_path.lower().endswith(".txt"): |
|
|
loader = TextLoader(file_path, encoding="utf-8") |
|
|
else: |
|
|
raise RuntimeError(f"Unsupported file type: {file_path}") |
|
|
|
|
|
documents = loader.load() |
|
|
|
|
|
|
|
|
splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
docs = splitter.split_documents(documents) |
|
|
|
|
|
|
|
|
docs_with_metadata = [ |
|
|
Document(page_content=d.page_content, metadata={"user_id": user_id, "filename": os.path.basename(file_path)}) |
|
|
for d in docs |
|
|
] |
|
|
|
|
|
|
|
|
db.add_documents(docs_with_metadata) |
|
|
|
|
|
|
|
|
def get_qa_chain(user_id: str): |
|
|
""" |
|
|
Return a RetrievalQA pipeline for a specific user using OpenRouter's Phi-3 Medium Instruct model. |
|
|
|
|
|
Args: |
|
|
user_id (str): Unique identifier for the user. |
|
|
""" |
|
|
|
|
|
llm = ChatOpenAI( |
|
|
openai_api_key=OPENAI_ROUTER_TOKEN, |
|
|
model="meta-llama/llama-4-scout:free", |
|
|
temperature=0, |
|
|
max_tokens=512, |
|
|
openai_api_base="https://openrouter.ai/api/v1" |
|
|
) |
|
|
|
|
|
retriever = db.as_retriever(search_kwargs={"filter": {"user_id": user_id}}) |
|
|
|
|
|
|
|
|
qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff") |
|
|
return qa |
|
|
|
|
|
|
|
|
|
|
|
|