moro_mini_llm / app.py
orgoflu's picture
app.py
0b37765 verified
import re
import gradio as gr
import torch
from transformers import PreTrainedTokenizerFast, T5ForConditionalGeneration
# โœ… KoT5 ์š”์•ฝ ๋ชจ๋ธ
MODEL_NAME = "psyche/KoT5-summarization"
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
# CPU ๋™์  ์–‘์žํ™” ์ ์šฉ
try:
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
except:
pass
model.eval()
# ===== ์œ ํ‹ธ =====
def normalize_text(text: str) -> str:
return re.sub(r"\s+", " ", text).strip()
def split_into_sentences(text: str):
text = text.replace("\n", " ")
parts = re.split(r"(?<=[\.!?])\s+", text)
return [p.strip() for p in parts if p.strip()]
def token_length(s: str) -> int:
return len(tokenizer.encode(s, add_special_tokens=False))
def chunk_by_tokens(sentences, max_tokens=900):
chunks, cur, cur_tokens = [], [], 0
for s in sentences:
tl = token_length(s)
if tl > max_tokens:
piece_size = max(200, int(len(s) * (max_tokens / tl)))
for i in range(0, len(s), piece_size):
sub = s[i:i+piece_size]
if sub.strip():
chunks.append(sub.strip())
cur, cur_tokens = [], 0
continue
if cur_tokens + tl <= max_tokens:
cur.append(s)
cur_tokens += tl
else:
if cur:
chunks.append(" ".join(cur))
cur, cur_tokens = [s], tl
if cur:
chunks.append(" ".join(cur))
return chunks
# ===== ๋ฐ˜๋ณต ์ œ๊ฑฐ =====
def derpeat(text: str) -> str:
text = re.sub(r'(.)\1{2,}', r'\1\1', text) # ๋‹จ์ผ ๋ฌธ์ž 3ํšŒ ์ด์ƒ ๋ฐ˜๋ณต โ†’ 2ํšŒ
text = re.sub(r'(\b\w+\b)(\s+\1){1,}', r'\1', text) # ๋‹จ์–ด ๋ฐ˜๋ณต ์ œ๊ฑฐ
text = re.sub(r'([\.!?\-~])\1{2,}', r'\1\1', text) # ๊ตฌ๋‘์  ๋ฐ˜๋ณต ์ถ•์†Œ
return text.strip()
# ===== ์š”์•ฝ =====
def approx_tokens_from_chars(n_chars: int) -> int:
return max(1, n_chars // 2) # ํ•œ๊ธ€ ๋Œ€๋žต 1ํ† ํฐ โ‰ˆ 2๋ฌธ์ž
def summarize_raw_t5(input_text: str, target_chars: int, input_tokens: int) -> str:
safe_target_chars = min(target_chars, max(120, int(len(input_text) * 0.9)))
max_new = max(40, min(approx_tokens_from_chars(safe_target_chars), 300))
if input_tokens <= 200:
max_new = min(max_new, max(40, int(input_tokens * 0.6)))
if input_tokens <= 60:
max_new = min(max_new, 60)
input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=1024)
with torch.no_grad():
summary_ids = model.generate(
input_ids,
max_new_tokens=max_new,
do_sample=True,
top_p=0.92,
temperature=0.7,
num_beams=1,
no_repeat_ngram_size=4,
encoder_no_repeat_ngram_size=4,
repetition_penalty=1.2,
renormalize_logits=True,
early_stopping=True
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def apply_style_prompt_t5(text: str, mode: str, final: bool=False) -> str:
if mode == "concise":
tag = "๊ฐ„๊ฒฐ ์š”์•ฝ:"
elif mode == "explanatory":
tag = "์„ค๋ช… ์š”์•ฝ:"
else:
tag = "๋ถˆ๋ฆฟ ์š”์•ฝ:"
guide = ""
if final:
guide = " (์›๋ž˜ ๋ฌธ์„œ์˜ ์ˆœ์„œ๋ฅผ ์œ ์ง€ํ•˜๊ณ  ์ค‘๋ณต์„ ์ œ๊ฑฐํ•˜์„ธ์š”.)"
return f"{tag}{guide}\n{text}"
def postprocess_strict(summary: str, mode: str) -> str:
s = summary.strip()
s = re.sub(r"\s+", " ", s)
s = derpeat(s)
seen, outs = set(), []
for sent in re.split(r"(?<=[\.!?])\s+", s):
ss = sent.strip()
if ss and ss not in seen:
seen.add(ss)
outs.append(ss)
s = " ".join(outs)
if mode == "bullets":
parts = [p for p in outs if p]
s = "\n".join([f"- {p}" for p in parts[:12]])
return s
def summarize_long(text: str, target_chars: int, mode: str):
text = normalize_text(text)
if not text:
return "โš ๏ธ ์š”์•ฝํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."
approx_tokens = token_length(text)
if approx_tokens <= 60:
prompt = apply_style_prompt_t5(text, mode, final=False)
out = summarize_raw_t5(prompt, min(target_chars, 300), approx_tokens)
return postprocess_strict(out, mode)
if approx_tokens <= 1000:
prompt = apply_style_prompt_t5(text, mode, final=False)
out = summarize_raw_t5(prompt, target_chars, approx_tokens)
return postprocess_strict(out, mode)
sentences = split_into_sentences(text)
chunks = chunk_by_tokens(sentences, max_tokens=900)
partial_summaries = []
per_chunk_chars = max(180, int(target_chars * 1.2 / max(1, len(chunks))))
for c in chunks:
prompt = apply_style_prompt_t5(c, mode, final=False)
psum = summarize_raw_t5(prompt, per_chunk_chars, token_length(c))
partial_summaries.append(psum)
merged = normalize_text(" ".join(partial_summaries))
merged = derpeat(merged)
final_prompt = apply_style_prompt_t5(merged, mode, final=True)
final = summarize_raw_t5(final_prompt, target_chars, token_length(merged))
return postprocess_strict(final, mode)
# ===== Gradio UI =====
def ui_summarize(text, target_len, style):
mode = {"๊ฐ„๊ฒฐํ˜•":"concise", "์„ค๋ช…ํ˜•":"explanatory", "ํ•ต์‹ฌ bullet":"bullets"}[style]
return summarize_long(text, int(target_len), mode)
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿ“ KoT5 ํ•œ๊ตญ์–ด ์š”์•ฝ๊ธฐ (๋ฐ˜๋ณต ์–ต์ œ + ์ˆœ์„œ ๋ณด์กด)")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="์›๋ฌธ ์ž…๋ ฅ", lines=16)
style = gr.Radio(["๊ฐ„๊ฒฐํ˜•", "์„ค๋ช…ํ˜•", "ํ•ต์‹ฌ bullet"], value="๊ฐ„๊ฒฐํ˜•", label="์š”์•ฝ ์Šคํƒ€์ผ")
target_len = gr.Slider(300, 1500, value=1000, step=50, label="๋ชฉํ‘œ ์š”์•ฝ ๊ธธ์ด(๋ฌธ์ž)")
btn = gr.Button("์š”์•ฝ ์‹คํ–‰")
with gr.Column():
output_text = gr.Textbox(label="์š”์•ฝ ๊ฒฐ๊ณผ", lines=16)
btn.click(ui_summarize, inputs=[input_text, target_len, style], outputs=output_text)
if __name__ == "__main__":
demo.launch()