Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import torch | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from PIL import Image | |
| # LangChain & LangGraph | |
| from langgraph.graph import StateGraph | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain.tools import tool | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # Web Search | |
| from duckduckgo_search import DDGS | |
| # Llama GGUF Model Loader | |
| from llama_cpp import Llama | |
| # ------------------------------ | |
| # ๐น Setup Logging | |
| # ------------------------------ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------ | |
| # ๐น Load GGUF Model with llama-cpp-python | |
| # ------------------------------ | |
| model_path = "./Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q6_K.gguf" # Update with actual GGUF model path | |
| llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=35) # Optimized for Hugging Face T4 GPU | |
| logger.info("Llama GGUF Model Loaded Successfully.") | |
| # ------------------------------ | |
| # ๐น Define Expert System Prompts | |
| # ------------------------------ | |
| GP_PROMPT = "You are a General Practitioner AI Assistant. Answer medical questions with scientifically accurate information." | |
| RADIOLOGY_PROMPT = "You are a Radiology AI expert. Analyze medical images and provide diagnostic insights." | |
| WEBSEARCH_PROMPT = "You are a Web Search AI. Retrieve up-to-date medical information." | |
| # ------------------------------ | |
| # ๐น FAISS Vector Store for RAG | |
| # ------------------------------ | |
| _vector_store_cache = None | |
| def load_vectorstore(pdf_path="medical_docs.pdf"): | |
| """Loads PDF files into a FAISS vector store for RAG.""" | |
| try: | |
| loader = PyPDFLoader(pdf_path) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) | |
| docs = text_splitter.split_documents(documents) | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vector_store = FAISS.from_documents(docs, embeddings) | |
| logger.info(f"Vector store loaded with {len(docs)} documents.") | |
| return vector_store | |
| except Exception as e: | |
| logger.error(f"Error loading vector store: {str(e)}") | |
| return None | |
| def update_vector_store(pdf_file): | |
| """Updates FAISS vector store when a new PDF is uploaded.""" | |
| pdf_path = "uploaded_medical_docs.pdf" | |
| try: | |
| with open(pdf_path, "wb") as f: | |
| f.write(pdf_file.read()) | |
| vector_store = load_vectorstore(pdf_path) | |
| os.remove(pdf_path) # Clean up | |
| return vector_store | |
| except Exception as e: | |
| logger.error(f"Error updating vector store: {str(e)}") | |
| return _vector_store_cache # Fallback to cached version | |
| if os.path.exists("medical_docs.pdf"): | |
| _vector_store_cache = load_vectorstore("medical_docs.pdf") | |
| else: | |
| _vector_store_cache = None | |
| vector_store = _vector_store_cache | |
| # ------------------------------ | |
| # ๐น Define AI Tools | |
| # ------------------------------ | |
| def analyze_medical_image(image_path: str): | |
| """Analyzes a medical image and returns a diagnostic explanation.""" | |
| try: | |
| image = Image.open(image_path) | |
| except Exception as e: | |
| logger.error(f"Error opening image: {str(e)}") | |
| return "Error processing image." | |
| # Process image using Llama GGUF model | |
| output = llm(f"Analyze this medical image and provide a diagnosis:\n{image}") | |
| return output["choices"][0]["text"] | |
| def retrieve_medical_knowledge(query: str): | |
| """Retrieves medical knowledge from FAISS vector store.""" | |
| if vector_store is None: | |
| return "No external medical knowledge available." | |
| retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| docs = retriever.get_relevant_documents(query) | |
| citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)] | |
| citations_text = "\n".join(citations) | |
| content = "\n".join([doc.page_content for doc in docs]) | |
| return content + f"\n\n**Citations:**\n{citations_text}" | |
| def web_search(query: str): | |
| """Performs a real-time web search using DuckDuckGo.""" | |
| try: | |
| results = ddg(query, max_results=3) | |
| summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found." | |
| return summary | |
| except Exception as e: | |
| logger.error(f"Web search error: {str(e)}") | |
| return "Error retrieving web search results." | |
| # ------------------------------ | |
| # ๐น Define Multi-Agent Workflow (LangGraph) | |
| # ------------------------------ | |
| class AgentState: | |
| def __init__(self, query="", response="", image_path=None, expert=""): | |
| self.query = query | |
| self.response = response | |
| self.image_path = image_path | |
| self.expert = expert # "GP", "Radiology", "Web Search" | |
| # Memory checkpointing | |
| checkpointer = MemorySaver() | |
| # Create LangGraph state graph | |
| agent_graph = StateGraph(AgentState) | |
| def route_query(state: AgentState): | |
| """Determines which AI expert should handle the query.""" | |
| if state.image_path: | |
| return "radiology_specialist" | |
| elif any(word in state.query.lower() for word in ["latest", "update", "breaking news"]): | |
| return "web_search_expert" | |
| else: | |
| return "general_practitioner" | |
| def general_practitioner(state: AgentState): | |
| """GP Expert: Handles medical text queries and retrieves knowledge.""" | |
| query = state.query | |
| retrieved_info = retrieve_medical_knowledge.run(query) | |
| output = llm(f"{GP_PROMPT}\nQ: {query}\nA:") | |
| return AgentState(query=query, response=output["choices"][0]["text"] + "\n\n" + retrieved_info, expert="GP") | |
| def radiology_specialist(state: AgentState): | |
| """Radiology Expert: Analyzes medical images.""" | |
| image_analysis = analyze_medical_image.run(state.image_path) | |
| return AgentState(query=state.query, response=image_analysis, expert="Radiology") | |
| def web_search_expert(state: AgentState): | |
| """Web Search Expert: Retrieves the latest information.""" | |
| search_result = web_search.run(state.query) | |
| return AgentState(query=state.query, response=search_result, expert="Web Search") | |
| # Add nodes | |
| agent_graph.add_node("general_practitioner", general_practitioner) | |
| agent_graph.add_node("radiology_specialist", radiology_specialist) | |
| agent_graph.add_node("web_search_expert", web_search_expert) | |
| agent_graph.add_conditional_edges("route_query", route_query, {"general_practitioner", "radiology_specialist", "web_search_expert"}) | |
| agent_graph.set_entry_point("route_query") | |
| # Compile graph | |
| app = agent_graph.compile(checkpointer=checkpointer) | |
| # ------------------------------ | |
| # ๐น Gradio Interface | |
| # ------------------------------ | |
| with gr.Blocks(title="Llama3-Med Multi-Agent AI") as demo: | |
| gr.Markdown("# ๐ฅ AI Medical Assistant") | |
| with gr.Row(): | |
| user_input = gr.Textbox(label="Your Question") | |
| image_file = gr.Image(label="Upload Medical Image (Optional)", type="file") | |
| pdf_file = gr.File(label="Upload PDF (Optional)", file_types=[".pdf"]) | |
| submit_btn = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Assistant's Response", interactive=False) | |
| submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text) | |
| if __name__ == "__main__": | |
| demo.launch() | |