Spaces:
Sleeping
Sleeping
saved_models
#1
by
kkandull
- opened
- app.py +0 -193
- e_text_best_model.bin +0 -3
- p_text_best_model.bin +0 -3
app.py
DELETED
|
@@ -1,193 +0,0 @@
|
|
| 1 |
-
# ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
|
| 2 |
-
import os
|
| 3 |
-
import pandas as pd # pandas๋ ํ์ฌ ์ฝ๋์์๋ ์ง์ ์ฌ์ฉ๋์ง ์์ง๋ง, ๋ฐ์ดํฐ ์ฒ๋ฆฌ ๊ด๋ จ ์ ํธ๋ฆฌํฐ๋ก ๋จ๊ฒจ๋ ์ ์์ต๋๋ค.
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torch.utils.data import Dataset, DataLoader # DataLoader์ Dataset์ ์ถ๋ก ์ ์ง์ ์ฌ์ฉ๋์ง ์์ง๋ง, ๋ชจ๋ธ ์ ์์ ํ์ํ ์ ์์ด ๋จ๊ฒจ๋
|
| 9 |
-
from transformers import LongformerForSequenceClassification, AutoTokenizer
|
| 10 |
-
import gradio as gr
|
| 11 |
-
|
| 12 |
-
# =======================================================
|
| 13 |
-
# 1. ์ ์ญ ์ค์ ๋ฐ ์์ ์ ์
|
| 14 |
-
# =======================================================
|
| 15 |
-
MODEL_NAME = 'kiddothe2b/longformer-mini-1024' # HuggingFace ๋ชจ๋ธ ์ด๋ฆ
|
| 16 |
-
MAX_LEN = 1024 # ๋ชจ๋ธ ์
๋ ฅ ์ต๋ ๊ธธ์ด
|
| 17 |
-
|
| 18 |
-
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ ๋๋ฐ์ด์ค ์ค์
|
| 19 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 20 |
-
print(f"Using device: {device}")
|
| 21 |
-
|
| 22 |
-
# ํ ํฌ๋์ด์ ๋ก๋ (์ถ๋ก ์ ํ์)
|
| 23 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 24 |
-
|
| 25 |
-
# =======================================================
|
| 26 |
-
# 2. PyTorch ๋ฐ์ดํฐ์
์ ์ (ํ์ต ์ ์ฌ์ฉ๋์๋ ํด๋์ค. ์ถ๋ก ์ ์ง์ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋ง๋ค์ง๋ ์์)
|
| 27 |
-
# =======================================================
|
| 28 |
-
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ด ํ์ต๋ ๋ ์ฌ์ฉ๋์๋ ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ ์ํฉ๋๋ค.
|
| 29 |
-
# ์ถ๋ก ์์๋ ๋จ์ผ ํ
์คํธ ์
๋ ฅ์ด ๋ค์ด์ค๋ฏ๋ก ์ง์ DataLoader๋ฅผ ๋ง๋ค ํ์๋ ์์ต๋๋ค.
|
| 30 |
-
# ํ์ง๋ง ๋ชจ๋ธ์ด ๊ธฐ๋ํ๋ ์
๋ ฅ ํํ๋ฅผ ๋ง์ถ๊ธฐ ์ํด encoding ๊ณผ์ ์ด ์ฌ์ฉ๋ฉ๋๋ค.
|
| 31 |
-
class DepressionDataset(Dataset):
|
| 32 |
-
def __init__(self, texts, labels, tokenizer, max_len):
|
| 33 |
-
self.texts = texts
|
| 34 |
-
self.labels = labels
|
| 35 |
-
self.tokenizer = tokenizer
|
| 36 |
-
self.max_len = max_len
|
| 37 |
-
|
| 38 |
-
def __len__(self):
|
| 39 |
-
return len(self.texts)
|
| 40 |
-
|
| 41 |
-
def __getitem__(self, item):
|
| 42 |
-
text = str(self.texts[item])
|
| 43 |
-
label = self.labels[item]
|
| 44 |
-
encoding = self.tokenizer.encode_plus(
|
| 45 |
-
text,
|
| 46 |
-
add_special_tokens=True,
|
| 47 |
-
max_length=self.max_len,
|
| 48 |
-
return_token_type_ids=False,
|
| 49 |
-
padding='max_length',
|
| 50 |
-
truncation=True,
|
| 51 |
-
return_attention_mask=True,
|
| 52 |
-
return_tensors='pt',
|
| 53 |
-
)
|
| 54 |
-
return {
|
| 55 |
-
'input_ids': encoding['input_ids'].flatten(),
|
| 56 |
-
'attention_mask': encoding['attention_mask'].flatten(),
|
| 57 |
-
'labels': torch.tensor(label, dtype=torch.long)
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
# =======================================================
|
| 61 |
-
# 3. ๋ชจ๋ธ ๋ก๋ฉ (ํ์ต๋ ๊ฐ์ค์น๋ฅผ ๋ก๋)
|
| 62 |
-
# =======================================================
|
| 63 |
-
print("\n--- Loading models for inference ---")
|
| 64 |
-
|
| 65 |
-
# ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก (saved_models ํด๋๊ฐ ์์ผ๋ฏ๋ก ๋ฃจํธ ๋๋ ํ ๋ฆฌ์ ์๋ค๊ณ ๊ฐ์ )
|
| 66 |
-
# ์ด์ ์ ์๋ save_dir ๋ณ์๋ ์ด์ ํ์ ์์ต๋๋ค.
|
| 67 |
-
p_model_path = 'p_text_best_model.bin' # ํ์ผ๋ช
์ด ๋ฃจํธ์ ๋ฐ๋ก ์๋ค๊ณ ๊ฐ์
|
| 68 |
-
e_model_path = 'e_text_best_model.bin' # ํ์ผ๋ช
์ด ๋ฃจํธ์ ๋ฐ๋ก ์๋ค๊ณ ๊ฐ์
|
| 69 |
-
|
| 70 |
-
# ๋ชจ๋ธ ๋ก๋ฉ ๋ฐ ํ๊ฐ ๋ชจ๋ ์ค์
|
| 71 |
-
p_model_for_inference = None
|
| 72 |
-
e_model_for_inference = None
|
| 73 |
-
|
| 74 |
-
try:
|
| 75 |
-
# ์ฐธ๊ฐ์ ๋ฐํ ๋ชจ๋ธ (P-model) ๋ก๋
|
| 76 |
-
if os.path.exists(p_model_path):
|
| 77 |
-
p_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
|
| 78 |
-
p_model_for_inference.load_state_dict(torch.load(p_model_path, map_location=device))
|
| 79 |
-
p_model_for_inference.to(device)
|
| 80 |
-
p_model_for_inference.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
|
| 81 |
-
print(f"P-model loaded successfully from {p_model_path}")
|
| 82 |
-
else:
|
| 83 |
-
print(f"Warning: P-model file not found at {p_model_path}. Please ensure it's uploaded to the root directory.")
|
| 84 |
-
|
| 85 |
-
# ์๋ฆฌ ๋ฐํ ๋ชจ๋ธ (E-model) ๋ก๋
|
| 86 |
-
if os.path.exists(e_model_path):
|
| 87 |
-
e_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
|
| 88 |
-
e_model_for_inference.load_state_dict(torch.load(e_model_path, map_location=device))
|
| 89 |
-
e_model_for_inference.to(device)
|
| 90 |
-
e_model_for_inference.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
|
| 91 |
-
print(f"E-model loaded successfully from {e_model_path}")
|
| 92 |
-
else:
|
| 93 |
-
print(f"Warning: E-model file not found at {e_model_path}. Please ensure it's uploaded to the root directory.")
|
| 94 |
-
|
| 95 |
-
except Exception as e:
|
| 96 |
-
print(f"Error loading models: {e}")
|
| 97 |
-
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์, UI๊ฐ ์คํ๋์ง ์๋๋ก ์ค์
|
| 98 |
-
p_model_for_inference = None
|
| 99 |
-
e_model_for_inference = None
|
| 100 |
-
|
| 101 |
-
# =======================================================
|
| 102 |
-
# 4. Gradio ์์ธก ํจ์ ์ ์
|
| 103 |
-
# =======================================================
|
| 104 |
-
def predict_depression(participant_text, ellie_text):
|
| 105 |
-
# ๋ชจ๋ธ์ด ์ ๋๋ก ๋ก๋๋์๋์ง ํ์ธ
|
| 106 |
-
if p_model_for_inference is None or e_model_for_inference is None:
|
| 107 |
-
return "**์ค๋ฅ:** ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ๊ด๏ฟฝ๏ฟฝ์์๊ฒ ๋ฌธ์ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ ์
๋ก๋ ์ฌ๋ถ๋ฅผ ํ์ธํด์ฃผ์ธ์."
|
| 108 |
-
|
| 109 |
-
# ์๋ฆฌ ๋ฐํ ์ ์ฒ๋ฆฌ (ํ์ต ์์ ๋์ผํ ๋ก์ง ์ ์ฉ)
|
| 110 |
-
e_text_words = ellie_text.split()
|
| 111 |
-
if len(e_text_words) > 0:
|
| 112 |
-
ellie_text_processed = " ".join(e_text_words[len(e_text_words) // 2:])
|
| 113 |
-
else:
|
| 114 |
-
ellie_text_processed = ""
|
| 115 |
-
|
| 116 |
-
# P-model ์์ธก
|
| 117 |
-
p_encoding = tokenizer.encode_plus(
|
| 118 |
-
participant_text,
|
| 119 |
-
add_special_tokens=True,
|
| 120 |
-
max_length=MAX_LEN,
|
| 121 |
-
return_token_type_ids=False,
|
| 122 |
-
padding='max_length',
|
| 123 |
-
truncation=True,
|
| 124 |
-
return_attention_mask=True,
|
| 125 |
-
return_tensors='pt',
|
| 126 |
-
)
|
| 127 |
-
p_input_ids = p_encoding['input_ids'].to(device)
|
| 128 |
-
p_attention_mask = p_encoding['attention_mask'].to(device)
|
| 129 |
-
|
| 130 |
-
with torch.no_grad(): # ์ถ๋ก ์์๋ ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋ถํ์
|
| 131 |
-
p_outputs = p_model_for_inference(input_ids=p_input_ids, attention_mask=p_attention_mask)
|
| 132 |
-
p_probs = F.softmax(p_outputs.logits, dim=1).cpu().numpy().flatten()
|
| 133 |
-
p_pred_label = np.argmax(p_probs)
|
| 134 |
-
|
| 135 |
-
# E-model ์์ธก
|
| 136 |
-
e_encoding = tokenizer.encode_plus(
|
| 137 |
-
ellie_text_processed,
|
| 138 |
-
add_special_tokens=True,
|
| 139 |
-
max_length=MAX_LEN,
|
| 140 |
-
return_token_type_ids=False,
|
| 141 |
-
padding='max_length',
|
| 142 |
-
truncation=True,
|
| 143 |
-
return_attention_mask=True,
|
| 144 |
-
return_tensors='pt',
|
| 145 |
-
)
|
| 146 |
-
e_input_ids = e_encoding['input_ids'].to(device)
|
| 147 |
-
e_attention_mask = e_encoding['attention_mask'].to(device)
|
| 148 |
-
|
| 149 |
-
with torch.no_grad(): # ์ถ๋ก ์์๋ ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋ถํ์
|
| 150 |
-
e_outputs = e_model_for_inference(input_ids=e_input_ids, attention_mask=e_attention_mask)
|
| 151 |
-
e_probs = F.softmax(e_outputs.logits, dim=1).cpu().numpy().flatten()
|
| 152 |
-
e_pred_label = np.argmax(e_probs)
|
| 153 |
-
|
| 154 |
-
# ์์๋ธ (OR ์ ๋ต): ๋ ์ค ํ๋๋ผ๋ ์ฐ์ธ์ฆ(1)์ผ๋ก ์์ธกํ๋ฉด ์ฐ์ธ์ฆ์ผ๋ก ๊ฐ์ฃผ
|
| 155 |
-
ensemble_pred_label = 1 if p_pred_label == 1 or e_pred_label == 1 else 0
|
| 156 |
-
|
| 157 |
-
labels = ['Control (๋น์ฐ์ธ)', 'Depressed (์ฐ์ธ)']
|
| 158 |
-
ensemble_result = labels[ensemble_pred_label]
|
| 159 |
-
p_model_result = labels[p_pred_label]
|
| 160 |
-
e_model_result = labels[e_pred_label]
|
| 161 |
-
|
| 162 |
-
return (f"**์ต์ข
์์๋ธ ์์ธก (OR ์ ๋ต): {ensemble_result}**\n\n"
|
| 163 |
-
f" - ์ฐธ๊ฐ์ ๋ชจ๋ธ (P-longBERT) ์์ธก: {p_model_result} (ํ๋ฅ : Control={p_probs[0]:.2f}, Depressed={p_probs[1]:.2f})\n"
|
| 164 |
-
f" - ์๋ฆฌ ๋ชจ๋ธ (E-longBERT) ์์ธก: {e_model_result} (ํ๋ฅ : Control={e_probs[0]:.2f}, Depressed={e_probs[1]:.2f})\n\n"
|
| 165 |
-
f"**์ฐธ๊ณ :**\n"
|
| 166 |
-
f"- ์์ธก์ ๊ฐ ๋ํ ๋ด์ฉ์๋ง ๊ธฐ๋ฐํ๋ฉฐ, ์ค์ ์ง๋จ์ ์ ๋ฌธ๊ฐ์ ์๋ดํด์ผ ํฉ๋๋ค.\n"
|
| 167 |
-
f"- GPU ํ๊ฒฝ์์๋ ์์ธก์ด ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค."
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
# =======================================================
|
| 171 |
-
# 5. Gradio UI ์ธํฐํ์ด์ค ์์ฑ ๋ฐ ์คํ
|
| 172 |
-
# =======================================================
|
| 173 |
-
print("\n--- Setting up Gradio UI ---")
|
| 174 |
-
|
| 175 |
-
# ๋ชจ๋ธ์ด ์ฑ๊ณต์ ์ผ๋ก ๋ก๋๋์์ ๊ฒฝ์ฐ์๋ง Gradio UI๋ฅผ ์คํ
|
| 176 |
-
if p_model_for_inference is not None and e_model_for_inference is not None:
|
| 177 |
-
gr.Interface(
|
| 178 |
-
fn=predict_depression,
|
| 179 |
-
inputs=[
|
| 180 |
-
gr.Textbox(lines=10, label="์ฐธ๊ฐ์ ๋ฐํ ๋ด์ฉ (Participant's speech)", placeholder="์ฌ๊ธฐ์ ์ฐธ๊ฐ์์ ๋ฐํ ๋ด์ฉ์ ์
๋ ฅํ์ธ์..."),
|
| 181 |
-
gr.Textbox(lines=10, label="์๋ฆฌ ๋ฐํ ๋ด์ฉ (Ellie's speech)", placeholder="์ฌ๊ธฐ์ ์๋ฆฌ(๊ฐ์ ์์ด์ ํธ)์ ๋ฐํ ๋ด์ฉ์ ์
๋ ฅํ์ธ์... (์ ์ฒด ๋ด์ฉ ์ค ํ๋ฐ๋ถ๋ง ์ฌ์ฉ๋จ)")
|
| 182 |
-
],
|
| 183 |
-
outputs="markdown",
|
| 184 |
-
title="DAIC-WOZ ์ฐ์ธ์ฆ ๊ฐ์ง ์์๋ธ ๋ชจ๋ธ (GPU ๊ฐ์)",
|
| 185 |
-
description=f"""์ด ์ฑ์ DAIC-WOZ ๋ฐ์ดํฐ์
์ ๊ธฐ๋ฐ์ผ๋ก ์ฐธ๊ฐ์์ ๊ฐ์ ์์ด์ ํธ(์๋ฆฌ)์ ๋ํ ๋ด์ฉ์ ๋ถ์ํ์ฌ ์ฐ์ธ์ฆ ์ฌ๋ถ๋ฅผ ์์ธกํฉ๋๋ค.
|
| 186 |
-
P-longBERT (์ฐธ๊ฐ์ ๋ฐํ)์ E-longBERT (์๋ฆฌ ๋ฐํ) ๋ชจ๋ธ์ ์์๋ธ (OR ์ ๋ต) ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํฉ๋๋ค.
|
| 187 |
-
**GPU ํ๊ฒฝ์์๋ ์์ธก์ด ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค.**
|
| 188 |
-
**์ฐธ๊ณ :** ์ด๋ AI ๋ชจ๋ธ์ ์์ธก์ผ ๋ฟ์ด๋ฉฐ, **์ค์ ์ํ์ ์ง๋จ์ ๋ฐ๋์ ์ ๋ฌธ๊ฐ์ ์๋ดํด์ผ ํฉ๋๋ค.**
|
| 189 |
-
์ฌ์ฉ ์ค์ธ ๋๋ฐ์ด์ค: {device}
|
| 190 |
-
"""
|
| 191 |
-
).launch() # Hugging Face Spaces์์๋ share=True๊ฐ ํ์ ์์
|
| 192 |
-
else:
|
| 193 |
-
print("\nGradio UI could not be launched because models failed to load. Please check model files.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
e_text_best_model.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:accfd3aab9348aed05f32753969b48f9eae028f8f2139e1ec26d0c746cb0f0e4
|
| 3 |
-
size 56324610
|
|
|
|
|
|
|
|
|
|
|
|
p_text_best_model.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a29208f799be73233fecbde5a3103c8df2de952d53dd071e4edc2a7eb73cdefe
|
| 3 |
-
size 56324610
|
|
|
|
|
|
|
|
|
|
|
|