Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| # --- ์์ ์ ์ --- | |
| # ์ฌ์ฉํ Hugging Face ๋ชจ๋ธ ID | |
| MODEL_ID = "google/gemma-3n-e4b" | |
| # ํ๋กฌํํธ ํ ํ๋ฆฟ | |
| PROMPT_TEMPLATE = """ | |
| [INST] {system_message} | |
| ํ์ฌ ๋ํ: | |
| {chat_history} | |
| ์ฌ์ฉ์: {user_text} | |
| [/INST] | |
| AI: | |
| """ | |
| # --- LLM ๋ฐ ์ฒด์ธ ์ค์ ํจ์ --- | |
| def get_llm(max_new_tokens=128, temperature=0.1): | |
| """ | |
| Hugging Face ์ถ๋ก ์ ์ํ ์ธ์ด ๋ชจ๋ธ(LLM)์ ์์ฑํ๊ณ ๋ฐํํฉ๋๋ค. | |
| Args: | |
| max_new_tokens (int): ์์ฑํ ์ต๋ ํ ํฐ ์์ ๋๋ค. | |
| temperature (float): ์ํ๋ง ์จ๋๋ก, ๋ฎ์์๋ก ๊ฒฐ์ ์ ์ธ ๋ต๋ณ์ ์์ฑํฉ๋๋ค. | |
| Returns: | |
| HuggingFaceEndpoint: ์ค์ ๋ ์ธ์ด ๋ชจ๋ธ ๊ฐ์ฒด์ ๋๋ค. | |
| """ | |
| return HuggingFaceEndpoint( | |
| repo_id=MODEL_ID, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| def get_chain(llm): | |
| """ | |
| ์ฃผ์ด์ง ์ธ์ด ๋ชจ๋ธ(LLM)์ ์ฌ์ฉํ์ฌ ๋ํ ์ฒด์ธ์ ์์ฑํฉ๋๋ค. | |
| Args: | |
| llm (HuggingFaceEndpoint): ์ฌ์ฉํ ์ธ์ด ๋ชจ๋ธ์ ๋๋ค. | |
| Returns: | |
| RunnableSequence: LangChain ํํ ์ธ์ด(LCEL)๋ก ๊ตฌ์ฑ๋ ์คํ ๊ฐ๋ฅํ ์ฒด์ธ์ ๋๋ค. | |
| """ | |
| prompt = PromptTemplate.from_template(PROMPT_TEMPLATE) | |
| return prompt | llm | StrOutputParser() | |
| def generate_response(chain, system_message, chat_history, user_text): | |
| """ | |
| LLM ์ฒด์ธ์ ํธ์ถํ์ฌ ์ฌ์ฉ์์ ์ ๋ ฅ์ ๋ํ ์๋ต์ ์์ฑํฉ๋๋ค. | |
| Args: | |
| chain (RunnableSequence): ์๋ต ์์ฑ์ ์ํ LLM ์ฒด์ธ์ ๋๋ค. | |
| system_message (str): AI์ ์ญํ ์ ์ ์ํ๋ ์์คํ ๋ฉ์์ง์ ๋๋ค. | |
| chat_history (list[dict]): ์ด์ ๋ํ ๊ธฐ๋ก์ ๋๋ค. | |
| user_text (str): ์ฌ์ฉ์์ ํ์ฌ ์ ๋ ฅ ๋ฉ์์ง์ ๋๋ค. | |
| Returns: | |
| str: ์์ฑ๋ AI์ ์๋ต ๋ฉ์์ง์ ๋๋ค. | |
| """ | |
| history_str = "\n".join( | |
| [f"{msg['role']}: {msg['content']}" for msg in chat_history] | |
| ) | |
| response = chain.invoke({ | |
| "system_message": system_message, | |
| "chat_history": history_str, | |
| "user_text": user_text, | |
| }) | |
| return response.split("AI:")[-1].strip() | |
| # --- UI ๋ ๋๋ง ํจ์ --- | |
| def initialize_session_state(): | |
| """ | |
| Streamlit ์ธ์ ์ํ๋ฅผ ์ด๊ธฐํํฉ๋๋ค. | |
| ์ธ์ ์ด ์ฒ์ ์์๋ ๋ ๊ธฐ๋ณธ๊ฐ์ ์ค์ ํฉ๋๋ค. | |
| """ | |
| defaults = { | |
| "avatars": {"user": "๐ค", "assistant": "๐ค"}, | |
| "chat_history": [], | |
| "max_response_length": 256, | |
| "system_message": "๋น์ ์ ์ธ๊ฐ ์ฌ์ฉ์์ ๋ํํ๋ ์น์ ํ AI์ ๋๋ค.", | |
| "starter_message": "์๋ ํ์ธ์! ์ค๋ ๋ฌด์์ ๋์๋๋ฆด๊น์?", | |
| } | |
| for key, value in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| if not st.session_state.chat_history: | |
| st.session_state.chat_history = [ | |
| {"role": "assistant", "content": st.session_state.starter_message} | |
| ] | |
| def setup_sidebar(): | |
| """ | |
| ์ฌ์ด๋๋ฐ UI ๊ตฌ์ฑ ์์๋ฅผ ์ค์ ํ๊ณ ๋ ๋๋งํฉ๋๋ค. | |
| ์ฌ์ฉ์๋ ์ด ์ฌ์ด๋๋ฐ์์ ์์คํ ์ค์ , AI ๋ฉ์์ง, ๋ชจ๋ธ ์๋ต ๊ธธ์ด ๋ฑ์ ์กฐ์ ํ ์ ์์ต๋๋ค. | |
| """ | |
| with st.sidebar: | |
| st.header("์์คํ ์ค์ ") | |
| st.session_state.system_message = st.text_area( | |
| "์์คํ ๋ฉ์์ง", value=st.session_state.system_message | |
| ) | |
| st.session_state.starter_message = st.text_area( | |
| "์ฒซ ๋ฒ์งธ AI ๋ฉ์์ง", value=st.session_state.starter_message | |
| ) | |
| st.session_state.max_response_length = st.number_input( | |
| "์ต๋ ์๋ต ๊ธธ์ด", value=st.session_state.max_response_length | |
| ) | |
| st.markdown("*์๋ฐํ ์ ํ:*") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.session_state.avatars["assistant"] = st.selectbox( | |
| "AI ์๋ฐํ", options=["๐ค", "๐ฌ", "๐ค"], index=0 | |
| ) | |
| with col2: | |
| st.session_state.avatars["user"] = st.selectbox( | |
| "์ฌ์ฉ์ ์๋ฐํ", options=["๐ค", "๐ฑโโ๏ธ", "๐จ๐พ", "๐ฉ", "๐ง๐พ"], index=0 | |
| ) | |
| if st.button("์ฑํ ๊ธฐ๋ก ์ด๊ธฐํ"): | |
| st.session_state.chat_history = [ | |
| {"role": "assistant", "content": st.session_state.starter_message} | |
| ] | |
| st.rerun() | |
| def display_chat_history(): | |
| """ | |
| ์ธ์ ์ ์ ์ฅ๋ ์ฑํ ๊ธฐ๋ก์ ์ํํ๋ฉฐ ํ๋ฉด์ ๋ฉ์์ง๋ฅผ ํ์ํฉ๋๋ค. | |
| """ | |
| for message in st.session_state.chat_history: | |
| if message["role"] == "system": | |
| continue | |
| avatar = st.session_state.avatars.get(message["role"]) | |
| with st.chat_message(message["role"], avatar=avatar): | |
| st.markdown(message["content"]) | |
| # --- ๋ฉ์ธ ์ ํ๋ฆฌ์ผ์ด์ ์คํ --- | |
| def main(): | |
| """ | |
| ๋ฉ์ธ Streamlit ์ ํ๋ฆฌ์ผ์ด์ ์ ์คํํฉ๋๋ค. | |
| """ | |
| load_dotenv() | |
| st.set_page_config(page_title="HuggingFace ChatBot", page_icon="๐ค") | |
| st.title("๊ฐ์ธ HuggingFace ์ฑ๋ด") | |
| st.markdown( | |
| f"*์ด๊ฒ์ HuggingFace transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ํ ์คํธ ์ ๋ ฅ์ ๋ํ ์๋ต์ ์์ฑํ๋ ๊ฐ๋จํ ์ฑ๋ด์ ๋๋ค. {MODEL_ID} ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค.*" | |
| ) | |
| initialize_session_state() | |
| setup_sidebar() | |
| # ์ฑํ ๊ธฐ๋ก ํ์ | |
| display_chat_history() | |
| # ์ฌ์ฉ์ ์ ๋ ฅ ์ฒ๋ฆฌ | |
| if user_input := st.chat_input("์ฌ๊ธฐ์ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์."): | |
| # ์ฌ์ฉ์ ๋ฉ์์ง๋ฅผ ๊ธฐ๋ก์ ์ถ๊ฐํ๊ณ ํ๋ฉด์ ํ์ | |
| st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user", avatar=st.session_state.avatars["user"]): | |
| st.markdown(user_input) | |
| # AI ์๋ต ์์ฑ ๋ฐ ํ์ | |
| with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]): | |
| with st.spinner("์๊ฐ ์ค..."): | |
| llm = get_llm(max_new_tokens=st.session_state.max_response_length) | |
| chain = get_chain(llm) | |
| response = generate_response( | |
| chain, | |
| st.session_state.system_message, | |
| st.session_state.chat_history, | |
| user_input, | |
| ) | |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| st.markdown(response) | |
| if __name__ == "__main__": | |
| main() |