import re import gradio as gr import torch from transformers import PreTrainedTokenizerFast, T5ForConditionalGeneration # ✅ KoT5 요약 모델 MODEL_NAME = "psyche/KoT5-summarization" tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME) model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) # CPU 동적 양자화 적용 try: model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) except: pass model.eval() # ===== 유틸 ===== def normalize_text(text: str) -> str: return re.sub(r"\s+", " ", text).strip() def split_into_sentences(text: str): text = text.replace("\n", " ") parts = re.split(r"(?<=[\.!?])\s+", text) return [p.strip() for p in parts if p.strip()] def token_length(s: str) -> int: return len(tokenizer.encode(s, add_special_tokens=False)) def chunk_by_tokens(sentences, max_tokens=900): chunks, cur, cur_tokens = [], [], 0 for s in sentences: tl = token_length(s) if tl > max_tokens: piece_size = max(200, int(len(s) * (max_tokens / tl))) for i in range(0, len(s), piece_size): sub = s[i:i+piece_size] if sub.strip(): chunks.append(sub.strip()) cur, cur_tokens = [], 0 continue if cur_tokens + tl <= max_tokens: cur.append(s) cur_tokens += tl else: if cur: chunks.append(" ".join(cur)) cur, cur_tokens = [s], tl if cur: chunks.append(" ".join(cur)) return chunks # ===== 반복 제거 ===== def derpeat(text: str) -> str: text = re.sub(r'(.)\1{2,}', r'\1\1', text) # 단일 문자 3회 이상 반복 → 2회 text = re.sub(r'(\b\w+\b)(\s+\1){1,}', r'\1', text) # 단어 반복 제거 text = re.sub(r'([\.!?\-~])\1{2,}', r'\1\1', text) # 구두점 반복 축소 return text.strip() # ===== 요약 ===== def approx_tokens_from_chars(n_chars: int) -> int: return max(1, n_chars // 2) # 한글 대략 1토큰 ≈ 2문자 def summarize_raw_t5(input_text: str, target_chars: int, input_tokens: int) -> str: safe_target_chars = min(target_chars, max(120, int(len(input_text) * 0.9))) max_new = max(40, min(approx_tokens_from_chars(safe_target_chars), 300)) if input_tokens <= 200: max_new = min(max_new, max(40, int(input_tokens * 0.6))) if input_tokens <= 60: max_new = min(max_new, 60) input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=1024) with torch.no_grad(): summary_ids = model.generate( input_ids, max_new_tokens=max_new, do_sample=True, top_p=0.92, temperature=0.7, num_beams=1, no_repeat_ngram_size=4, encoder_no_repeat_ngram_size=4, repetition_penalty=1.2, renormalize_logits=True, early_stopping=True ) return tokenizer.decode(summary_ids[0], skip_special_tokens=True) def apply_style_prompt_t5(text: str, mode: str, final: bool=False) -> str: if mode == "concise": tag = "간결 요약:" elif mode == "explanatory": tag = "설명 요약:" else: tag = "불릿 요약:" guide = "" if final: guide = " (원래 문서의 순서를 유지하고 중복을 제거하세요.)" return f"{tag}{guide}\n{text}" def postprocess_strict(summary: str, mode: str) -> str: s = summary.strip() s = re.sub(r"\s+", " ", s) s = derpeat(s) seen, outs = set(), [] for sent in re.split(r"(?<=[\.!?])\s+", s): ss = sent.strip() if ss and ss not in seen: seen.add(ss) outs.append(ss) s = " ".join(outs) if mode == "bullets": parts = [p for p in outs if p] s = "\n".join([f"- {p}" for p in parts[:12]]) return s def summarize_long(text: str, target_chars: int, mode: str): text = normalize_text(text) if not text: return "⚠️ 요약할 텍스트를 입력하세요." approx_tokens = token_length(text) if approx_tokens <= 60: prompt = apply_style_prompt_t5(text, mode, final=False) out = summarize_raw_t5(prompt, min(target_chars, 300), approx_tokens) return postprocess_strict(out, mode) if approx_tokens <= 1000: prompt = apply_style_prompt_t5(text, mode, final=False) out = summarize_raw_t5(prompt, target_chars, approx_tokens) return postprocess_strict(out, mode) sentences = split_into_sentences(text) chunks = chunk_by_tokens(sentences, max_tokens=900) partial_summaries = [] per_chunk_chars = max(180, int(target_chars * 1.2 / max(1, len(chunks)))) for c in chunks: prompt = apply_style_prompt_t5(c, mode, final=False) psum = summarize_raw_t5(prompt, per_chunk_chars, token_length(c)) partial_summaries.append(psum) merged = normalize_text(" ".join(partial_summaries)) merged = derpeat(merged) final_prompt = apply_style_prompt_t5(merged, mode, final=True) final = summarize_raw_t5(final_prompt, target_chars, token_length(merged)) return postprocess_strict(final, mode) # ===== Gradio UI ===== def ui_summarize(text, target_len, style): mode = {"간결형":"concise", "설명형":"explanatory", "핵심 bullet":"bullets"}[style] return summarize_long(text, int(target_len), mode) with gr.Blocks() as demo: gr.Markdown("## 📝 KoT5 한국어 요약기 (반복 억제 + 순서 보존)") with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="원문 입력", lines=16) style = gr.Radio(["간결형", "설명형", "핵심 bullet"], value="간결형", label="요약 스타일") target_len = gr.Slider(300, 1500, value=1000, step=50, label="목표 요약 길이(문자)") btn = gr.Button("요약 실행") with gr.Column(): output_text = gr.Textbox(label="요약 결과", lines=16) btn.click(ui_summarize, inputs=[input_text, target_len, style], outputs=output_text) if __name__ == "__main__": demo.launch()