Russellml commited on
Commit
1e0d0c8
·
verified ·
1 Parent(s): bf1bdee

upload files

Browse files
app/__init__.py ADDED
File without changes
app/crisis_toolchain.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import logging
3
+ import os
4
+ import json
5
+ import re
6
+ import streamlit as st
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain.prompts import PromptTemplate
9
+ from app.models.emotion_detector.detector import detect_emotions
10
+
11
+ logging.basicConfig(
12
+ filename="crisis_log.txt",
13
+ level=logging.INFO,
14
+ format="%(asctime)s - %(levelname)s - Session: %(session_id)s - %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S"
16
+ )
17
+
18
+ def get_session_id():
19
+ """Generate a unique session ID."""
20
+ return str(uuid.uuid4())
21
+
22
+ def detect_crisis(transcript, emotions):
23
+ """Detect crisis risk based on transcript and emotions using LLM."""
24
+ session_id = st.session_state.get("session_id", get_session_id())
25
+ try:
26
+ api_key = st.session_state.get("openai_api_key") or os.getenv("OPENAI_API_KEY")
27
+ if not api_key:
28
+ raise ValueError("No OpenAI API key found in st.session_state or environment")
29
+
30
+ llm = ChatOpenAI(
31
+ model="gpt-4o",
32
+ temperature=0.3,
33
+ api_key=api_key
34
+ )
35
+
36
+ prompt_template = PromptTemplate(
37
+ input_variables=["transcript", "emotions"],
38
+ template="""
39
+ You are a mental health crisis detection system. Analyze the following transcript and emotions to determine the crisis risk level (Low, Medium, High) and assign a score between 0.0 and 1.0.
40
+
41
+ Transcript: {transcript}
42
+ Emotions: {emotions}
43
+
44
+ Consider linguistic nuances, cultural context (Omani Arabic or English), and emotional intensity. Examples:
45
+ - High Risk: Explicit suicidal intent (e.g., "أفكر في الانتحار", "I want to end my life").
46
+ - Medium Risk: Expressions of sadness, hopelessness, or distress (e.g., "أنا حزين جدًا", "I feel hopeless").
47
+ - Low Risk: Neutral or positive statements (e.g., "أنا بخير", "I'm okay").
48
+
49
+ Return a JSON object with "crisis_risk" (Low, Medium, High) and "score" (0.0 to 1.0).
50
+ """
51
+ )
52
+
53
+ emotions_str = ", ".join([f"{e['label']}: {e['score']:.2f}" for e in emotions])
54
+
55
+ prompt = prompt_template.format(transcript=transcript, emotions=emotions_str)
56
+ logging.info(f"Crisis detection prompt: {prompt}", extra={"session_id": session_id})
57
+
58
+ response = llm.invoke(prompt)
59
+ logging.info(f"LLM response: {response.content}", extra={"session_id": session_id})
60
+
61
+ response_text = re.sub(r'^```json\s*|\s*```$', '', response.content, flags=re.MULTILINE).strip()
62
+ logging.info(f"Cleaned LLM response: {response_text}", extra={"session_id": session_id})
63
+
64
+ try:
65
+ result = json.loads(response_text)
66
+ except json.JSONDecodeError as e:
67
+ logging.error(f"Failed to parse LLM response: {response_text}, error: {str(e)}", extra={"session_id": session_id})
68
+ return "Unknown Risk", 0.0
69
+
70
+ # Map abbreviated risk levels
71
+ risk_mapping = {
72
+ "Low": "Low Risk",
73
+ "Medium": "Medium Risk",
74
+ "High": "High Risk"
75
+ }
76
+ crisis_risk = risk_mapping.get(result.get("crisis_risk", "Unknown Risk"), "Unknown Risk")
77
+ score = float(result.get("score", 0.0))
78
+
79
+ logging.info(f"Parsed crisis_risk: {crisis_risk}, score: {score:.2f}", extra={"session_id": session_id})
80
+
81
+ # Adjust score based on emotions
82
+ for emotion in emotions:
83
+ if emotion['label'] in ['sadness', 'fear'] and emotion['score'] > 0.7:
84
+ score = min(score + 0.15, 1.0)
85
+ elif emotion['label'] == 'anger' and emotion['score'] > 0.7:
86
+ score = min(score + 0.1, 1.0)
87
+
88
+ # Validate score
89
+ score = max(0.0, min(score, 1.0))
90
+
91
+ logging.info(f"Crisis detection: {crisis_risk}, Score: {score:.2f}", extra={"session_id": session_id})
92
+
93
+ return crisis_risk, score
94
+
95
+ except Exception as e:
96
+ logging.error(f"Crisis detection failed: {str(e)}", extra={"session_id": session_id})
97
+ return "Unknown Risk", 0.0
app/intent_analysis.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain.prompts import PromptTemplate
5
+ from app.models.emotion_detector.detector import load_emotion_classifier, detect_emotions
6
+
7
+ logging.basicConfig(
8
+ filename="crisis_log.txt",
9
+ level=logging.INFO,
10
+ format="%(asctime)s - %(levelname)s - Session: %(session_id)s - %(message)s",
11
+ datefmt="%Y-%m-%d %H:%M:%S"
12
+ )
13
+
14
+ emotion_classifier = load_emotion_classifier()
15
+
16
+ INTENT_MAPPING = {
17
+ "seeking_resources": "Looking for support or resources",
18
+ "venting_emotions": "Expressing feelings or stress",
19
+ "unknown": "Talking about daily life / feelings"
20
+ }
21
+
22
+ def analyze_intent(transcript, session_id="unknown"):
23
+ """Analyze emotions and intent in transcript using LLM."""
24
+ emotions = detect_emotions(transcript, emotion_classifier)
25
+ logging.info(f"Emotions detected: {emotions}", extra={"session_id": session_id})
26
+
27
+ # LLM-based Intent Classification
28
+ try:
29
+ api_key = os.getenv("OPENAI_API_KEY")
30
+ if not api_key:
31
+ raise ValueError("No OpenAI API key found in environment")
32
+
33
+ llm = ChatOpenAI(
34
+ model="gpt-4o",
35
+ temperature=0.3,
36
+ api_key=api_key
37
+ )
38
+
39
+ prompt_template = PromptTemplate(
40
+ input_variables=["transcript"],
41
+ template="""
42
+ You are an intent classification system for a mental health voice bot. Analyze the following transcript to determine the user's intent. Choose one of the following intents:
43
+ - seeking_resources: User is looking for support, resources, or help (e.g., "I need a therapist in Muscat", "أحتاج معالج في مسقط").
44
+ - venting_emotions: User is expressing emotions like sadness, stress, or anxiety (e.g., "I'm so sad", "أنا حزين جدًا").
45
+ - unknown: User is talking about daily life or neutral topics (e.g., "Today was a good day", "اليوم كان جيدًا").
46
+
47
+ Consider linguistic nuances and cultural context (Omani Arabic or English). Return only the intent name (seeking_resources, venting_emotions, or unknown).
48
+
49
+ Transcript: {transcript}
50
+ """
51
+ )
52
+
53
+ prompt = prompt_template.format(transcript=transcript)
54
+ logging.info(f"Intent classification prompt: {prompt}", extra={"session_id": session_id})
55
+
56
+ response = llm.invoke(prompt)
57
+ intent = response.content.strip()
58
+ logging.info(f"LLM intent response: {intent}", extra={"session_id": session_id})
59
+
60
+ # Validate intent
61
+ if intent not in INTENT_MAPPING:
62
+ logging.warning(f"Invalid intent '{intent}' detected, defaulting to 'unknown'", extra={"session_id": session_id})
63
+ intent = "unknown"
64
+
65
+ except Exception as e:
66
+ logging.error(f"Intent classification failed: {str(e)}", extra={"session_id": session_id})
67
+ intent = "unknown"
68
+
69
+ # Map to human-readable description
70
+ intent_description = INTENT_MAPPING.get(intent, INTENT_MAPPING["unknown"])
71
+ logging.info(f"Intent classified: {intent_description}", extra={"session_id": session_id})
72
+
73
+ return emotions, intent_description
app/rag_layer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import streamlit as st
4
+ from langchain_openai import OpenAIEmbeddings
5
+ from langchain_chroma import Chroma
6
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from langchain.schema import Document
9
+
10
+ logging.basicConfig(
11
+ filename="crisis_log.txt",
12
+ level=logging.INFO,
13
+ format="%(asctime)s - %(levelname)s - %(message)s"
14
+ )
15
+
16
+ # Suppress ChromaDB telemetry + non-critical logs
17
+ logging.getLogger("chromadb").setLevel(logging.ERROR)
18
+
19
+ KB_PATH = "data/kb/"
20
+ CHROMA_PATH = "data/chroma_db" # Persistent Chroma DB
21
+ embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
22
+
23
+ def load_documents():
24
+ """Load and split KB files, with error handling."""
25
+ docs = []
26
+ if not os.path.exists(KB_PATH):
27
+ logging.error(f"Knowledge base directory {KB_PATH} does not exist")
28
+ st.warning(f"Knowledge base directory {KB_PATH} not found. Using fallback document.")
29
+ return [Document(page_content="No knowledge base documents available.", metadata={"source": "fallback"})]
30
+
31
+ for file in os.listdir(KB_PATH):
32
+ file_path = os.path.join(KB_PATH, file)
33
+ try:
34
+ if not os.path.isfile(file_path):
35
+ logging.warning(f"Skipping {file_path}: Not a file")
36
+ continue
37
+ if file.endswith(".pdf"):
38
+ loader = PyPDFLoader(file_path)
39
+ file_docs = loader.load()
40
+ docs.extend(file_docs)
41
+ logging.info(f"Loaded PDF: {file_path} with {len(file_docs)} pages")
42
+ elif file.endswith(".txt"):
43
+ loader = TextLoader(file_path)
44
+ file_docs = loader.load()
45
+ docs.extend(file_docs)
46
+ logging.info(f"Loaded text file: {file_path} with {len(file_docs)} chunks")
47
+ except Exception as e:
48
+ logging.error(f"Error loading {file_path}: {str(e)}")
49
+ st.warning(f"Failed to load {file_path}. Skipping.")
50
+
51
+ if not docs:
52
+ logging.warning("No documents loaded from knowledge base")
53
+ st.warning("No valid documents found in knowledge base. Using fallback document.")
54
+ docs = [Document(page_content="No knowledge base documents available.", metadata={"source": "fallback"})]
55
+
56
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
57
+ split_docs = splitter.split_documents(docs)
58
+ logging.info(f"Split {len(docs)} documents into {len(split_docs)} chunks")
59
+ return split_docs
60
+
61
+ def setup_vectorstore(force_rebuild=False):
62
+ """Setup or load Chroma vectorstore."""
63
+ try:
64
+ if force_rebuild or not os.path.exists(CHROMA_PATH):
65
+ docs = load_documents()
66
+ if not docs:
67
+ raise ValueError("No documents available for vectorstore creation")
68
+ vectorstore = Chroma.from_documents(docs, embeddings, persist_directory=CHROMA_PATH)
69
+ vectorstore.persist()
70
+ logging.info(f"Created new vectorstore at {CHROMA_PATH} with {len(docs)} documents")
71
+ else:
72
+ vectorstore = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
73
+ logging.info(f"Loaded existing vectorstore from {CHROMA_PATH}")
74
+ return vectorstore
75
+ except Exception as e:
76
+ logging.error(f"Error setting up vectorstore: {str(e)}")
77
+ st.error(f"Failed to initialize vectorstore: {str(e)}. App may have limited functionality.")
78
+ # Return a dummy vectorstore to prevent app crash
79
+ return Chroma.from_texts(
80
+ texts=["No knowledge base available"],
81
+ embedding=embeddings,
82
+ persist_directory=CHROMA_PATH
83
+ )
84
+
85
+ def retrieve_context(query, k=3):
86
+ """Retrieve relevant chunks from KB."""
87
+ try:
88
+ retriever = st.session_state.vectorstore.as_retriever(search_kwargs={"k": k})
89
+ docs = retriever.get_relevant_documents(query)
90
+ return docs
91
+ except Exception as e:
92
+ logging.error(f"Error retrieving context: {str(e)}")
93
+ st.warning(f"Failed to retrieve context: {str(e)}. Using fallback response.")
94
+ return [Document(page_content="Unable to retrieve context.", metadata={"source": "error"})]
app/response_gen.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import TypedDict
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_anthropic import ChatAnthropic
6
+ from langchain.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langgraph.graph import StateGraph, END
9
+ from langchain.schema import AIMessage, HumanMessage
10
+ from app.rag_layer import retrieve_context
11
+ from app.validation import validate_response
12
+ from app import config
13
+ import streamlit as st
14
+ from langsmith import Client
15
+
16
+ langsmith_client = Client() if config.LANGSMITH_API_KEY else None
17
+
18
+ with open("data/kb/few_shot_prompts.json", "r") as f:
19
+ few_shot_data = json.load(f)
20
+
21
+ few_shot_examples = [
22
+ {"user": ex["user"], "context": ex["context"], "response": ex["response"]}
23
+ for ex in few_shot_data["examples"]
24
+ ]
25
+
26
+ example_prompt = ChatPromptTemplate.from_messages([
27
+ ("human", "{user}\nContext: {context}"),
28
+ ("ai", "{response}")
29
+ ])
30
+
31
+ few_shot_prompt = FewShotChatMessagePromptTemplate(
32
+ example_prompt=example_prompt,
33
+ examples=few_shot_examples,
34
+ )
35
+
36
+ system_prompt = """
37
+ You are an empathetic mental health chatbot specializing in Omani Arabic. Respond in Omani dialect Arabic.
38
+ Use retrieved context and best practices: {best_practices}.
39
+ Be supportive, suggest resources, but never diagnose or give medical advice.
40
+ If crisis_risk is 'High Risk', respond with: "ياخي، أشوفك محتاج دعم فوري. تواصل مع مستشفى المسارة على 2487 3268 أو اتصل 9999 للطوارئ."
41
+ If crisis_risk is 'Medium Risk', include a CBT de-escalation technique from context and suggest professional help.
42
+ User query: {query}
43
+ Retrieved context: {context}
44
+ Emotions: {emotions}
45
+ Intent: {intent}
46
+ Crisis Risk: {crisis_risk}
47
+ """
48
+
49
+ final_prompt = ChatPromptTemplate.from_messages([
50
+ ("system", system_prompt),
51
+ few_shot_prompt,
52
+ ("human", "{query}")
53
+ ])
54
+
55
+ # LLMs
56
+ gpt4 = ChatOpenAI(model="gpt-4o", api_key=config.OPENAI_API_KEY)
57
+ claude = ChatAnthropic(model="claude-3-opus-20240229", api_key=config.ANTHROPIC_API_KEY)
58
+
59
+ def generate_with_failover(query, context, emotions, intent, crisis_risk):
60
+ chain = final_prompt | gpt4 | StrOutputParser()
61
+ try:
62
+ response = chain.invoke({
63
+ "query": query,
64
+ "context": context,
65
+ "emotions": emotions,
66
+ "intent": intent,
67
+ "crisis_risk": crisis_risk,
68
+ "best_practices": json.dumps(few_shot_data["best_practices"])
69
+ })
70
+ # Log to LangSmith
71
+ if langsmith_client:
72
+ langsmith_client.create_run(
73
+ name="response_generation",
74
+ inputs={"query": query, "crisis_risk": crisis_risk},
75
+ outputs={"response": response},
76
+ run_type="chain"
77
+ )
78
+ return response
79
+ except Exception:
80
+ # Failover to Claude
81
+ chain = final_prompt | claude | StrOutputParser()
82
+ response = chain.invoke({
83
+ "query": query,
84
+ "context": context,
85
+ "emotions": emotions,
86
+ "intent": intent,
87
+ "crisis_risk": crisis_risk,
88
+ "best_practices": json.dumps(few_shot_data["best_practices"])
89
+ })
90
+ if langsmith_client:
91
+ langsmith_client.create_run(
92
+ name="response_generation_failover",
93
+ inputs={"query": query, "crisis_risk": crisis_risk},
94
+ outputs={"response": response},
95
+ run_type="chain"
96
+ )
97
+ return response
98
+
99
+ # LangGraph workflow: retrieve -> generate -> validate
100
+ class AgentState(TypedDict):
101
+ query: str
102
+ emotions: str
103
+ intent: str
104
+ crisis_risk: str
105
+ context: str
106
+ response: str
107
+ validation_score: float
108
+
109
+ def retrieve(state: AgentState) -> AgentState:
110
+ context_docs = retrieve_context(state["query"])
111
+ state["context"] = "\n".join([doc.page_content for doc in context_docs])
112
+ return state
113
+
114
+ def generate(state: AgentState) -> AgentState:
115
+ state["response"] = generate_with_failover(
116
+ state["query"], state["context"], state["emotions"], state["intent"], state["crisis_risk"]
117
+ )
118
+ return state
119
+
120
+ def validate(state: AgentState) -> AgentState:
121
+ score, feedback = validate_response(state["response"], state["query"])
122
+ state["validation_score"] = score
123
+ if score < 0.7: # Threshold for re-generation
124
+ state["response"] = "Response invalidated. Regenerating..." # Or re-run generate
125
+ return state
126
+
127
+ workflow = StateGraph(AgentState)
128
+ workflow.add_node("retrieve", retrieve)
129
+ workflow.add_node("generate", generate)
130
+ workflow.add_node("validate", validate)
131
+ workflow.add_edge("retrieve", "generate")
132
+ workflow.add_edge("generate", "validate")
133
+ workflow.add_edge("validate", END)
134
+ workflow.set_entry_point("retrieve")
135
+
136
+ app = workflow.compile()
137
+
138
+ def generate_response(query, emotions, intent, crisis_risk):
139
+ inputs = {
140
+ "query": query,
141
+ "emotions": emotions,
142
+ "intent": intent,
143
+ "crisis_risk": crisis_risk
144
+ }
145
+ result = app.invoke(inputs)
146
+ return result["response"]
app/stt_pipeline.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import openai
3
+ from io import BytesIO
4
+ from streamlit_mic_recorder import mic_recorder
5
+ from app import config
6
+ import logging
7
+
8
+ logging.getLogger("openai").setLevel(logging.ERROR)
9
+
10
+ if config.OPENAI_API_KEY:
11
+ client = openai.OpenAI(api_key=config.OPENAI_API_KEY)
12
+ else:
13
+ client = None
14
+ st.error("OpenAI API key not found in .env file. Please set OPENAI_API_KEY.")
15
+
16
+ # session state
17
+ if "chat_history" not in st.session_state:
18
+ st.session_state.chat_history = []
19
+ if "audio_history" not in st.session_state:
20
+ st.session_state.audio_history = []
21
+ if "transcript" not in st.session_state:
22
+ st.session_state.transcript = None
23
+
24
+ def transcribe_callback():
25
+ """Callback to auto-transcribe and display on recording stop."""
26
+ if 'recorder_output' in st.session_state and st.session_state.recorder_output:
27
+ audio_data = st.session_state.recorder_output
28
+ if client:
29
+ try:
30
+ audio_file = BytesIO(audio_data['bytes'])
31
+ audio_file.name = "audio.webm"
32
+
33
+ st.session_state.is_processing = True
34
+
35
+ transcript = client.audio.transcriptions.create(
36
+ model="whisper-1",
37
+ file=audio_file,
38
+ response_format="text"
39
+ )
40
+
41
+ st.session_state.transcript = transcript
42
+ st.session_state.chat_history.append(f"User: {transcript}")
43
+ st.session_state.audio_history.append(("user", audio_data['bytes'], transcript))
44
+ st.success(f"Transcription: {transcript}")
45
+ logging.info(f"Session: {st.session_state.get('session_id', 'unknown')} - Transcription completed: {transcript}")
46
+
47
+ st.session_state.is_processing = False
48
+ st.rerun()
49
+
50
+ except Exception as e:
51
+ st.session_state.is_processing = False
52
+ logging.error(f"Session: {st.session_state.get('session_id', 'unknown')} - Transcription error: {str(e)}")
53
+ st.error(f"Transcription error: {str(e)}")
54
+ else:
55
+ st.session_state.is_processing = False
56
+ st.error("No OpenAI API key provided.")
57
+
58
+ def render_mic_recorder():
59
+ """Render the mic recorder component."""
60
+ if st.session_state.get("is_processing", False):
61
+ with st.spinner("Processing audio..."):
62
+ pass
63
+ else:
64
+ audio = mic_recorder(
65
+ key="recorder",
66
+ start_prompt="🎤 Start Recording",
67
+ stop_prompt="⏹️ Stop & Transcribe",
68
+ just_once=False,
69
+ use_container_width=True,
70
+ format="webm",
71
+ callback=transcribe_callback
72
+ )
73
+ return audio
app/tts_pipeline.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import openai
3
+ from io import BytesIO
4
+ from app import config
5
+ import logging
6
+
7
+ logging.getLogger("openai").setLevel(logging.ERROR)
8
+
9
+ client = openai.OpenAI(api_key=config.OPENAI_API_KEY)
10
+
11
+ def text_to_speech(text):
12
+ """Convert text to speech using OpenAI TTS."""
13
+ try:
14
+ st.session_state.is_processing = True
15
+
16
+ response = client.audio.speech.create(
17
+ model="tts-1",
18
+ voice="nova",
19
+ input=text
20
+ )
21
+ audio_bytes = BytesIO(response.content)
22
+ st.audio(audio_bytes, format="audio/mp3")
23
+ if "chat_history" in st.session_state:
24
+ st.session_state.chat_history.append(f"Bot: {text}")
25
+ if "audio_history" in st.session_state:
26
+ st.session_state.audio_history.append(("bot", response.content, text))
27
+ logging.info(f"Session: {st.session_state.get('session_id', 'unknown')} - TTS completed for response: {text[:50]}...")
28
+
29
+ st.session_state.is_processing = False
30
+ except Exception as e:
31
+ st.session_state.is_processing = False
32
+ logging.error(f"Session: {st.session_state.get('session_id', 'unknown')} - TTS error: {str(e)}")
33
+ st.error(f"TTS error: {str(e)}")
app/ui.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from app.stt_pipeline import render_mic_recorder
3
+ from app.crisis_toolchain import get_session_id
4
+
5
+ def consent_banner():
6
+ """Render consent banner and manage consent state."""
7
+ st.markdown("## 🛡️ Consent & Disclosure")
8
+ st.info(
9
+ "This AI chatbot is **not a substitute for professional therapy**.\n\n"
10
+ "Your voice may be processed for analysis. "
11
+ "Data will only be stored if you consent."
12
+ )
13
+
14
+ if "consent" not in st.session_state:
15
+ st.session_state.consent = False
16
+
17
+ st.session_state.consent = st.checkbox(
18
+ "I consent to continue and agree that this is not a substitute for professional therapy.",
19
+ value=st.session_state.consent
20
+ )
21
+
22
+ if st.session_state.consent:
23
+ if "recorder_output" in st.session_state:
24
+ st.session_state.recorder_output = None
25
+
26
+ return st.session_state.consent
27
+
28
+ def emergency_resources():
29
+ """Display emergency resources."""
30
+ st.markdown("### Ministry of Health (MOH)")
31
+ st.markdown("- **General Adult Psychiatry:** +968 2487 3127")
32
+ st.markdown("- **Child & Adolescence Psychiatry:** +968 2487 3127")
33
+ st.markdown("- **Psychology Services:** +968 2487 3983")
34
+ st.markdown("### National Hospitals")
35
+ st.markdown("- **Al Masarra Hospital (Psychiatry & Addiction):** +968 2487 3268")
36
+ st.markdown("- **Royal Hospital – Mental Health Support (Mon–Fri, 8 AM–8 PM):** +968 24 607 555")
37
+ st.markdown("### Private Clinics")
38
+ st.markdown("- **KIMSHEALTH Oman Hospital (Darsait):** +968 2476 0100")
39
+ st.markdown("- **Oman International Hospital (Al Ghubrah):** +968 2490 3500 | WhatsApp: +968 9938 9376")
40
+ st.markdown("- **Muscat Private Hospital:** +968 2458 3600")
41
+ st.markdown("- **Aster Al Raffah Hospital:** +968 2249 6000")
42
+ st.markdown("- **Badr Al Samaa Hospital:** +968 2459 1000")
43
+ st.markdown("- **Burjeel Medical Center:** +968 24 399 777")
44
+ st.markdown("- **Hatat House Polyclinic:** +968 2456 3641 / 9943 1173")
45
+ st.markdown("### 🌐 Global Resource")
46
+ st.markdown("- **WHO – Mental Health:** https://www.who.int/health-topics/mental-health")
47
+ st.info("🚨 Immediate emergency? Call Royal Oman Police / Ambulance: 9999.")
48
+
49
+ def audio_input():
50
+ """Render audio input UI and return chat history."""
51
+ if "chat_history" not in st.session_state:
52
+ st.session_state.chat_history = []
53
+ if "audio_history" not in st.session_state:
54
+ st.session_state.audio_history = [] # Store (type, audio_bytes, transcript) tuples
55
+ if "is_processing" not in st.session_state:
56
+ st.session_state.is_processing = False
57
+
58
+ # session ID
59
+ if "session_id" not in st.session_state:
60
+ st.session_state.session_id = get_session_id()
61
+ session_id = st.session_state.session_id
62
+ st.markdown(f"**Session ID**: {session_id}")
63
+
64
+ if st.session_state.consent:
65
+ st.success("✅ Consent given. Ready to start.")
66
+ st.markdown("Click on **START** to begin speaking.")
67
+
68
+ with st.container():
69
+ render_mic_recorder()
70
+
71
+ # Display chat history + audio playback
72
+ st.markdown("### Conversation History")
73
+ for i, entry in enumerate(st.session_state.chat_history):
74
+ if entry.startswith("User:"):
75
+ transcript = entry.replace("User: ", "")
76
+ st.markdown(f"**You**: {transcript}")
77
+
78
+ for audio_type, audio_bytes, audio_transcript in st.session_state.audio_history:
79
+ if audio_type == "user" and audio_transcript == transcript:
80
+ st.audio(audio_bytes, format="audio/webm")
81
+ elif entry.startswith("Bot:"):
82
+ transcript = entry.replace("Bot: ", "")
83
+ st.markdown(f"**Bot**: {transcript}")
84
+
85
+ for audio_type, audio_bytes, audio_transcript in st.session_state.audio_history:
86
+ if audio_type == "bot" and audio_transcript == transcript:
87
+ st.audio(audio_bytes, format="audio/mp3")
88
+
89
+ return st.session_state.chat_history
90
+
91
+ def display_latest_audio():
92
+ """Placeholder function, no longer needed as audio is shown in chat history."""
93
+ pass
app/validation.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain.prompts import ChatPromptTemplate
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from app import config
5
+
6
+ validator = ChatOpenAI(model="gpt-4o", api_key=config.OPENAI_API_KEY)
7
+
8
+ validation_prompt = ChatPromptTemplate.from_template("""
9
+ Validate this mental health response for safety and quality.
10
+ Criteria:
11
+ - Empathy: 0-1 (shows understanding?)
12
+ - Accuracy: 0-1 (facts correct, no diagnosis?)
13
+ - Non-harmful: 0-1 (no encouragement of harm?)
14
+ - Cultural fit: 0-1 (Omani/Arabic sensitive?)
15
+ Overall score: average of above.
16
+ Query: {query}
17
+ Response: {response}
18
+ Output JSON: {{"score": float, "feedback": str}}
19
+ """)
20
+
21
+ chain = validation_prompt | validator | JsonOutputParser()
22
+
23
+ def validate_response(response, query):
24
+ result = chain.invoke({"query": query, "response": response})
25
+ return result["score"], result["feedback"]