Spaces:
Sleeping
Sleeping
app.py
Browse files
app.py
CHANGED
|
@@ -54,67 +54,105 @@ def chunk_by_tokens(sentences, max_tokens=900):
|
|
| 54 |
chunks.append(" ".join(cur))
|
| 55 |
return chunks
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# ===== ์์ฝ =====
|
| 58 |
-
def
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
with torch.no_grad():
|
| 61 |
summary_ids = model.generate(
|
| 62 |
input_ids,
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
early_stopping=True
|
| 67 |
)
|
| 68 |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 69 |
|
| 70 |
-
def
|
| 71 |
if mode == "concise":
|
| 72 |
-
|
| 73 |
elif mode == "explanatory":
|
| 74 |
-
|
| 75 |
else:
|
| 76 |
-
|
|
|
|
| 77 |
if final:
|
| 78 |
-
|
| 79 |
-
return f"{
|
| 80 |
|
| 81 |
-
def
|
| 82 |
s = summary.strip()
|
| 83 |
s = re.sub(r"\s+", " ", s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
if mode == "bullets":
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
if len(bullets) > 1:
|
| 88 |
-
s = "\n".join([f"- {b}" for b in bullets])
|
| 89 |
-
else:
|
| 90 |
-
parts = re.split(r"(?<=[\.!?])\s+", s)
|
| 91 |
-
parts = [p.strip() for p in parts if p.strip()]
|
| 92 |
-
s = "\n".join([f"- {p}" for p in parts])
|
| 93 |
return s
|
| 94 |
|
| 95 |
def summarize_long(text: str, target_chars: int, mode: str):
|
| 96 |
text = normalize_text(text)
|
| 97 |
if not text:
|
| 98 |
return "โ ๏ธ ์์ฝํ ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์."
|
|
|
|
| 99 |
approx_tokens = token_length(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
if approx_tokens <= 1000:
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
return
|
|
|
|
| 104 |
sentences = split_into_sentences(text)
|
| 105 |
chunks = chunk_by_tokens(sentences, max_tokens=900)
|
|
|
|
| 106 |
partial_summaries = []
|
| 107 |
-
|
| 108 |
-
per_chunk_chars = max(250, budget_total // max(1, len(chunks)))
|
| 109 |
for c in chunks:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
psum = summarize_raw(apply_style_prompt(c, mode), min_len, max_len)
|
| 113 |
partial_summaries.append(psum)
|
|
|
|
| 114 |
merged = normalize_text(" ".join(partial_summaries))
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# ===== Gradio UI =====
|
| 120 |
def ui_summarize(text, target_len, style):
|
|
@@ -122,7 +160,7 @@ def ui_summarize(text, target_len, style):
|
|
| 122 |
return summarize_long(text, int(target_len), mode)
|
| 123 |
|
| 124 |
with gr.Blocks() as demo:
|
| 125 |
-
gr.Markdown("## ๐ KoT5 ํ๊ตญ์ด ์์ฝ๊ธฐ (
|
| 126 |
with gr.Row():
|
| 127 |
with gr.Column():
|
| 128 |
input_text = gr.Textbox(label="์๋ฌธ ์
๋ ฅ", lines=16)
|
|
|
|
| 54 |
chunks.append(" ".join(cur))
|
| 55 |
return chunks
|
| 56 |
|
| 57 |
+
# ===== ๋ฐ๋ณต ์ ๊ฑฐ =====
|
| 58 |
+
def derpeat(text: str) -> str:
|
| 59 |
+
text = re.sub(r'(.)\1{2,}', r'\1\1', text) # ๋จ์ผ ๋ฌธ์ 3ํ ์ด์ ๋ฐ๋ณต โ 2ํ
|
| 60 |
+
text = re.sub(r'(\b\w+\b)(\s+\1){1,}', r'\1', text) # ๋จ์ด ๋ฐ๋ณต ์ ๊ฑฐ
|
| 61 |
+
text = re.sub(r'([\.!?\-~])\1{2,}', r'\1\1', text) # ๊ตฌ๋์ ๋ฐ๋ณต ์ถ์
|
| 62 |
+
return text.strip()
|
| 63 |
+
|
| 64 |
# ===== ์์ฝ =====
|
| 65 |
+
def approx_tokens_from_chars(n_chars: int) -> int:
|
| 66 |
+
return max(1, n_chars // 2) # ํ๊ธ ๋๋ต 1ํ ํฐ โ 2๋ฌธ์
|
| 67 |
+
|
| 68 |
+
def summarize_raw_t5(input_text: str, target_chars: int, input_tokens: int) -> str:
|
| 69 |
+
safe_target_chars = min(target_chars, max(120, int(len(input_text) * 0.9)))
|
| 70 |
+
max_new = max(40, min(approx_tokens_from_chars(safe_target_chars), 300))
|
| 71 |
+
|
| 72 |
+
if input_tokens <= 200:
|
| 73 |
+
max_new = min(max_new, max(40, int(input_tokens * 0.6)))
|
| 74 |
+
if input_tokens <= 60:
|
| 75 |
+
max_new = min(max_new, 60)
|
| 76 |
+
|
| 77 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=1024)
|
| 78 |
+
|
| 79 |
with torch.no_grad():
|
| 80 |
summary_ids = model.generate(
|
| 81 |
input_ids,
|
| 82 |
+
max_new_tokens=max_new,
|
| 83 |
+
do_sample=True,
|
| 84 |
+
top_p=0.92,
|
| 85 |
+
temperature=0.7,
|
| 86 |
+
num_beams=1,
|
| 87 |
+
no_repeat_ngram_size=4,
|
| 88 |
+
encoder_no_repeat_ngram_size=4,
|
| 89 |
+
repetition_penalty=1.2,
|
| 90 |
+
renormalize_logits=True,
|
| 91 |
early_stopping=True
|
| 92 |
)
|
| 93 |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 94 |
|
| 95 |
+
def apply_style_prompt_t5(text: str, mode: str, final: bool=False) -> str:
|
| 96 |
if mode == "concise":
|
| 97 |
+
tag = "๊ฐ๊ฒฐ ์์ฝ:"
|
| 98 |
elif mode == "explanatory":
|
| 99 |
+
tag = "์ค๋ช
์์ฝ:"
|
| 100 |
else:
|
| 101 |
+
tag = "๋ถ๋ฆฟ ์์ฝ:"
|
| 102 |
+
guide = ""
|
| 103 |
if final:
|
| 104 |
+
guide = " (์๋ ๋ฌธ์์ ์์๋ฅผ ์ ์งํ๊ณ ์ค๋ณต์ ์ ๊ฑฐํ์ธ์.)"
|
| 105 |
+
return f"{tag}{guide}\n{text}"
|
| 106 |
|
| 107 |
+
def postprocess_strict(summary: str, mode: str) -> str:
|
| 108 |
s = summary.strip()
|
| 109 |
s = re.sub(r"\s+", " ", s)
|
| 110 |
+
s = derpeat(s)
|
| 111 |
+
seen, outs = set(), []
|
| 112 |
+
for sent in re.split(r"(?<=[\.!?])\s+", s):
|
| 113 |
+
ss = sent.strip()
|
| 114 |
+
if ss and ss not in seen:
|
| 115 |
+
seen.add(ss)
|
| 116 |
+
outs.append(ss)
|
| 117 |
+
s = " ".join(outs)
|
| 118 |
if mode == "bullets":
|
| 119 |
+
parts = [p for p in outs if p]
|
| 120 |
+
s = "\n".join([f"- {p}" for p in parts[:12]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return s
|
| 122 |
|
| 123 |
def summarize_long(text: str, target_chars: int, mode: str):
|
| 124 |
text = normalize_text(text)
|
| 125 |
if not text:
|
| 126 |
return "โ ๏ธ ์์ฝํ ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์."
|
| 127 |
+
|
| 128 |
approx_tokens = token_length(text)
|
| 129 |
+
|
| 130 |
+
if approx_tokens <= 60:
|
| 131 |
+
prompt = apply_style_prompt_t5(text, mode, final=False)
|
| 132 |
+
out = summarize_raw_t5(prompt, min(target_chars, 300), approx_tokens)
|
| 133 |
+
return postprocess_strict(out, mode)
|
| 134 |
+
|
| 135 |
if approx_tokens <= 1000:
|
| 136 |
+
prompt = apply_style_prompt_t5(text, mode, final=False)
|
| 137 |
+
out = summarize_raw_t5(prompt, target_chars, approx_tokens)
|
| 138 |
+
return postprocess_strict(out, mode)
|
| 139 |
+
|
| 140 |
sentences = split_into_sentences(text)
|
| 141 |
chunks = chunk_by_tokens(sentences, max_tokens=900)
|
| 142 |
+
|
| 143 |
partial_summaries = []
|
| 144 |
+
per_chunk_chars = max(180, int(target_chars * 1.2 / max(1, len(chunks))))
|
|
|
|
| 145 |
for c in chunks:
|
| 146 |
+
prompt = apply_style_prompt_t5(c, mode, final=False)
|
| 147 |
+
psum = summarize_raw_t5(prompt, per_chunk_chars, token_length(c))
|
|
|
|
| 148 |
partial_summaries.append(psum)
|
| 149 |
+
|
| 150 |
merged = normalize_text(" ".join(partial_summaries))
|
| 151 |
+
merged = derpeat(merged)
|
| 152 |
+
|
| 153 |
+
final_prompt = apply_style_prompt_t5(merged, mode, final=True)
|
| 154 |
+
final = summarize_raw_t5(final_prompt, target_chars, token_length(merged))
|
| 155 |
+
return postprocess_strict(final, mode)
|
| 156 |
|
| 157 |
# ===== Gradio UI =====
|
| 158 |
def ui_summarize(text, target_len, style):
|
|
|
|
| 160 |
return summarize_long(text, int(target_len), mode)
|
| 161 |
|
| 162 |
with gr.Blocks() as demo:
|
| 163 |
+
gr.Markdown("## ๐ KoT5 ํ๊ตญ์ด ์์ฝ๊ธฐ (๋ฐ๋ณต ์ต์ + ์์ ๋ณด์กด)")
|
| 164 |
with gr.Row():
|
| 165 |
with gr.Column():
|
| 166 |
input_text = gr.Textbox(label="์๋ฌธ ์
๋ ฅ", lines=16)
|