Spaces:
Sleeping
Sleeping
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import PreTrainedTokenizerFast, T5ForConditionalGeneration | |
| # โ KoT5 ์์ฝ ๋ชจ๋ธ | |
| MODEL_NAME = "psyche/KoT5-summarization" | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME) | |
| model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) | |
| # CPU ๋์ ์์ํ ์ ์ฉ | |
| try: | |
| model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| except: | |
| pass | |
| model.eval() | |
| # ===== ์ ํธ ===== | |
| def normalize_text(text: str) -> str: | |
| return re.sub(r"\s+", " ", text).strip() | |
| def split_into_sentences(text: str): | |
| text = text.replace("\n", " ") | |
| parts = re.split(r"(?<=[\.!?])\s+", text) | |
| return [p.strip() for p in parts if p.strip()] | |
| def token_length(s: str) -> int: | |
| return len(tokenizer.encode(s, add_special_tokens=False)) | |
| def chunk_by_tokens(sentences, max_tokens=900): | |
| chunks, cur, cur_tokens = [], [], 0 | |
| for s in sentences: | |
| tl = token_length(s) | |
| if tl > max_tokens: | |
| piece_size = max(200, int(len(s) * (max_tokens / tl))) | |
| for i in range(0, len(s), piece_size): | |
| sub = s[i:i+piece_size] | |
| if sub.strip(): | |
| chunks.append(sub.strip()) | |
| cur, cur_tokens = [], 0 | |
| continue | |
| if cur_tokens + tl <= max_tokens: | |
| cur.append(s) | |
| cur_tokens += tl | |
| else: | |
| if cur: | |
| chunks.append(" ".join(cur)) | |
| cur, cur_tokens = [s], tl | |
| if cur: | |
| chunks.append(" ".join(cur)) | |
| return chunks | |
| # ===== ๋ฐ๋ณต ์ ๊ฑฐ ===== | |
| def derpeat(text: str) -> str: | |
| text = re.sub(r'(.)\1{2,}', r'\1\1', text) # ๋จ์ผ ๋ฌธ์ 3ํ ์ด์ ๋ฐ๋ณต โ 2ํ | |
| text = re.sub(r'(\b\w+\b)(\s+\1){1,}', r'\1', text) # ๋จ์ด ๋ฐ๋ณต ์ ๊ฑฐ | |
| text = re.sub(r'([\.!?\-~])\1{2,}', r'\1\1', text) # ๊ตฌ๋์ ๋ฐ๋ณต ์ถ์ | |
| return text.strip() | |
| # ===== ์์ฝ ===== | |
| def approx_tokens_from_chars(n_chars: int) -> int: | |
| return max(1, n_chars // 2) # ํ๊ธ ๋๋ต 1ํ ํฐ โ 2๋ฌธ์ | |
| def summarize_raw_t5(input_text: str, target_chars: int, input_tokens: int) -> str: | |
| safe_target_chars = min(target_chars, max(120, int(len(input_text) * 0.9))) | |
| max_new = max(40, min(approx_tokens_from_chars(safe_target_chars), 300)) | |
| if input_tokens <= 200: | |
| max_new = min(max_new, max(40, int(input_tokens * 0.6))) | |
| if input_tokens <= 60: | |
| max_new = min(max_new, 60) | |
| input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=1024) | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new, | |
| do_sample=True, | |
| top_p=0.92, | |
| temperature=0.7, | |
| num_beams=1, | |
| no_repeat_ngram_size=4, | |
| encoder_no_repeat_ngram_size=4, | |
| repetition_penalty=1.2, | |
| renormalize_logits=True, | |
| early_stopping=True | |
| ) | |
| return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| def apply_style_prompt_t5(text: str, mode: str, final: bool=False) -> str: | |
| if mode == "concise": | |
| tag = "๊ฐ๊ฒฐ ์์ฝ:" | |
| elif mode == "explanatory": | |
| tag = "์ค๋ช ์์ฝ:" | |
| else: | |
| tag = "๋ถ๋ฆฟ ์์ฝ:" | |
| guide = "" | |
| if final: | |
| guide = " (์๋ ๋ฌธ์์ ์์๋ฅผ ์ ์งํ๊ณ ์ค๋ณต์ ์ ๊ฑฐํ์ธ์.)" | |
| return f"{tag}{guide}\n{text}" | |
| def postprocess_strict(summary: str, mode: str) -> str: | |
| s = summary.strip() | |
| s = re.sub(r"\s+", " ", s) | |
| s = derpeat(s) | |
| seen, outs = set(), [] | |
| for sent in re.split(r"(?<=[\.!?])\s+", s): | |
| ss = sent.strip() | |
| if ss and ss not in seen: | |
| seen.add(ss) | |
| outs.append(ss) | |
| s = " ".join(outs) | |
| if mode == "bullets": | |
| parts = [p for p in outs if p] | |
| s = "\n".join([f"- {p}" for p in parts[:12]]) | |
| return s | |
| def summarize_long(text: str, target_chars: int, mode: str): | |
| text = normalize_text(text) | |
| if not text: | |
| return "โ ๏ธ ์์ฝํ ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์." | |
| approx_tokens = token_length(text) | |
| if approx_tokens <= 60: | |
| prompt = apply_style_prompt_t5(text, mode, final=False) | |
| out = summarize_raw_t5(prompt, min(target_chars, 300), approx_tokens) | |
| return postprocess_strict(out, mode) | |
| if approx_tokens <= 1000: | |
| prompt = apply_style_prompt_t5(text, mode, final=False) | |
| out = summarize_raw_t5(prompt, target_chars, approx_tokens) | |
| return postprocess_strict(out, mode) | |
| sentences = split_into_sentences(text) | |
| chunks = chunk_by_tokens(sentences, max_tokens=900) | |
| partial_summaries = [] | |
| per_chunk_chars = max(180, int(target_chars * 1.2 / max(1, len(chunks)))) | |
| for c in chunks: | |
| prompt = apply_style_prompt_t5(c, mode, final=False) | |
| psum = summarize_raw_t5(prompt, per_chunk_chars, token_length(c)) | |
| partial_summaries.append(psum) | |
| merged = normalize_text(" ".join(partial_summaries)) | |
| merged = derpeat(merged) | |
| final_prompt = apply_style_prompt_t5(merged, mode, final=True) | |
| final = summarize_raw_t5(final_prompt, target_chars, token_length(merged)) | |
| return postprocess_strict(final, mode) | |
| # ===== Gradio UI ===== | |
| def ui_summarize(text, target_len, style): | |
| mode = {"๊ฐ๊ฒฐํ":"concise", "์ค๋ช ํ":"explanatory", "ํต์ฌ bullet":"bullets"}[style] | |
| return summarize_long(text, int(target_len), mode) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## ๐ KoT5 ํ๊ตญ์ด ์์ฝ๊ธฐ (๋ฐ๋ณต ์ต์ + ์์ ๋ณด์กด)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="์๋ฌธ ์ ๋ ฅ", lines=16) | |
| style = gr.Radio(["๊ฐ๊ฒฐํ", "์ค๋ช ํ", "ํต์ฌ bullet"], value="๊ฐ๊ฒฐํ", label="์์ฝ ์คํ์ผ") | |
| target_len = gr.Slider(300, 1500, value=1000, step=50, label="๋ชฉํ ์์ฝ ๊ธธ์ด(๋ฌธ์)") | |
| btn = gr.Button("์์ฝ ์คํ") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="์์ฝ ๊ฒฐ๊ณผ", lines=16) | |
| btn.click(ui_summarize, inputs=[input_text, target_len, style], outputs=output_text) | |
| if __name__ == "__main__": | |
| demo.launch() |