orgoflu commited on
Commit
0b37765
ยท
verified ยท
1 Parent(s): ab4dfe6
Files changed (1) hide show
  1. app.py +70 -32
app.py CHANGED
@@ -54,67 +54,105 @@ def chunk_by_tokens(sentences, max_tokens=900):
54
  chunks.append(" ".join(cur))
55
  return chunks
56
 
 
 
 
 
 
 
 
57
  # ===== ์š”์•ฝ =====
58
- def summarize_raw(text: str, min_len: int, max_len: int) -> str:
59
- input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024)
 
 
 
 
 
 
 
 
 
 
 
 
60
  with torch.no_grad():
61
  summary_ids = model.generate(
62
  input_ids,
63
- num_beams=4,
64
- min_length=min_len,
65
- max_length=max_len,
 
 
 
 
 
 
66
  early_stopping=True
67
  )
68
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
69
 
70
- def apply_style_prompt(text: str, mode: str, final: bool=False) -> str:
71
  if mode == "concise":
72
- inst = "๋‹ค์Œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ํ•ต์‹ฌ๋งŒ ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์š”์•ฝํ•˜์„ธ์š”."
73
  elif mode == "explanatory":
74
- inst = "๋‹ค์Œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ๋งฅ๋ฝ์„ ๋ณด์กดํ•˜๋ฉฐ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ฒŒ ์š”์•ฝํ•˜์„ธ์š”."
75
  else:
76
- inst = "๋‹ค์Œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ bullet ํ˜•ํƒœ๋กœ ํ•ต์‹ฌ๋งŒ ์š”์•ฝํ•˜์„ธ์š”."
 
77
  if final:
78
- inst += " ์›๋ž˜ ์ˆœ์„œ๋ฅผ ์œ ์ง€ํ•˜๋ฉฐ ๋ฌธ์žฅ ์—ฐ๊ฒฐ์„ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ•˜์„ธ์š”."
79
- return f"{inst}\n\n{text}"
80
 
81
- def postprocess(summary: str, mode: str) -> str:
82
  s = summary.strip()
83
  s = re.sub(r"\s+", " ", s)
 
 
 
 
 
 
 
 
84
  if mode == "bullets":
85
- bullets = re.split(r"\s*[-โ€ข]\s*", s)
86
- bullets = [b.strip() for b in bullets if b.strip()]
87
- if len(bullets) > 1:
88
- s = "\n".join([f"- {b}" for b in bullets])
89
- else:
90
- parts = re.split(r"(?<=[\.!?])\s+", s)
91
- parts = [p.strip() for p in parts if p.strip()]
92
- s = "\n".join([f"- {p}" for p in parts])
93
  return s
94
 
95
  def summarize_long(text: str, target_chars: int, mode: str):
96
  text = normalize_text(text)
97
  if not text:
98
  return "โš ๏ธ ์š”์•ฝํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."
 
99
  approx_tokens = token_length(text)
 
 
 
 
 
 
100
  if approx_tokens <= 1000:
101
- min_len = max(60, int(target_chars * 0.4 / 2))
102
- max_len = max(120, int(target_chars * 0.8 / 2))
103
- return postprocess(summarize_raw(apply_style_prompt(text, mode), min_len, max_len), mode)
 
104
  sentences = split_into_sentences(text)
105
  chunks = chunk_by_tokens(sentences, max_tokens=900)
 
106
  partial_summaries = []
107
- budget_total = int(target_chars * 1.5)
108
- per_chunk_chars = max(250, budget_total // max(1, len(chunks)))
109
  for c in chunks:
110
- min_len = max(50, int(per_chunk_chars * 0.4 / 2))
111
- max_len = max(100, int(per_chunk_chars * 0.9 / 2))
112
- psum = summarize_raw(apply_style_prompt(c, mode), min_len, max_len)
113
  partial_summaries.append(psum)
 
114
  merged = normalize_text(" ".join(partial_summaries))
115
- final_min = max(80, int(target_chars * 0.45 / 2))
116
- final_max = max(160, int(target_chars * 1.05 / 2))
117
- return postprocess(summarize_raw(apply_style_prompt(merged, mode, final=True), final_min, final_max), mode)
 
 
118
 
119
  # ===== Gradio UI =====
120
  def ui_summarize(text, target_len, style):
@@ -122,7 +160,7 @@ def ui_summarize(text, target_len, style):
122
  return summarize_long(text, int(target_len), mode)
123
 
124
  with gr.Blocks() as demo:
125
- gr.Markdown("## ๐Ÿ“ KoT5 ํ•œ๊ตญ์–ด ์š”์•ฝ๊ธฐ (๊ธด ๋ฌธ์„œ ์ž๋™ ๋ถ„ํ•  + ์ˆœ์„œ ๋ณด์กด)")
126
  with gr.Row():
127
  with gr.Column():
128
  input_text = gr.Textbox(label="์›๋ฌธ ์ž…๋ ฅ", lines=16)
 
54
  chunks.append(" ".join(cur))
55
  return chunks
56
 
57
+ # ===== ๋ฐ˜๋ณต ์ œ๊ฑฐ =====
58
+ def derpeat(text: str) -> str:
59
+ text = re.sub(r'(.)\1{2,}', r'\1\1', text) # ๋‹จ์ผ ๋ฌธ์ž 3ํšŒ ์ด์ƒ ๋ฐ˜๋ณต โ†’ 2ํšŒ
60
+ text = re.sub(r'(\b\w+\b)(\s+\1){1,}', r'\1', text) # ๋‹จ์–ด ๋ฐ˜๋ณต ์ œ๊ฑฐ
61
+ text = re.sub(r'([\.!?\-~])\1{2,}', r'\1\1', text) # ๊ตฌ๋‘์  ๋ฐ˜๋ณต ์ถ•์†Œ
62
+ return text.strip()
63
+
64
  # ===== ์š”์•ฝ =====
65
+ def approx_tokens_from_chars(n_chars: int) -> int:
66
+ return max(1, n_chars // 2) # ํ•œ๊ธ€ ๋Œ€๋žต 1ํ† ํฐ โ‰ˆ 2๋ฌธ์ž
67
+
68
+ def summarize_raw_t5(input_text: str, target_chars: int, input_tokens: int) -> str:
69
+ safe_target_chars = min(target_chars, max(120, int(len(input_text) * 0.9)))
70
+ max_new = max(40, min(approx_tokens_from_chars(safe_target_chars), 300))
71
+
72
+ if input_tokens <= 200:
73
+ max_new = min(max_new, max(40, int(input_tokens * 0.6)))
74
+ if input_tokens <= 60:
75
+ max_new = min(max_new, 60)
76
+
77
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=1024)
78
+
79
  with torch.no_grad():
80
  summary_ids = model.generate(
81
  input_ids,
82
+ max_new_tokens=max_new,
83
+ do_sample=True,
84
+ top_p=0.92,
85
+ temperature=0.7,
86
+ num_beams=1,
87
+ no_repeat_ngram_size=4,
88
+ encoder_no_repeat_ngram_size=4,
89
+ repetition_penalty=1.2,
90
+ renormalize_logits=True,
91
  early_stopping=True
92
  )
93
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
94
 
95
+ def apply_style_prompt_t5(text: str, mode: str, final: bool=False) -> str:
96
  if mode == "concise":
97
+ tag = "๊ฐ„๊ฒฐ ์š”์•ฝ:"
98
  elif mode == "explanatory":
99
+ tag = "์„ค๋ช… ์š”์•ฝ:"
100
  else:
101
+ tag = "๋ถˆ๋ฆฟ ์š”์•ฝ:"
102
+ guide = ""
103
  if final:
104
+ guide = " (์›๋ž˜ ๋ฌธ์„œ์˜ ์ˆœ์„œ๋ฅผ ์œ ์ง€ํ•˜๊ณ  ์ค‘๋ณต์„ ์ œ๊ฑฐํ•˜์„ธ์š”.)"
105
+ return f"{tag}{guide}\n{text}"
106
 
107
+ def postprocess_strict(summary: str, mode: str) -> str:
108
  s = summary.strip()
109
  s = re.sub(r"\s+", " ", s)
110
+ s = derpeat(s)
111
+ seen, outs = set(), []
112
+ for sent in re.split(r"(?<=[\.!?])\s+", s):
113
+ ss = sent.strip()
114
+ if ss and ss not in seen:
115
+ seen.add(ss)
116
+ outs.append(ss)
117
+ s = " ".join(outs)
118
  if mode == "bullets":
119
+ parts = [p for p in outs if p]
120
+ s = "\n".join([f"- {p}" for p in parts[:12]])
 
 
 
 
 
 
121
  return s
122
 
123
  def summarize_long(text: str, target_chars: int, mode: str):
124
  text = normalize_text(text)
125
  if not text:
126
  return "โš ๏ธ ์š”์•ฝํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."
127
+
128
  approx_tokens = token_length(text)
129
+
130
+ if approx_tokens <= 60:
131
+ prompt = apply_style_prompt_t5(text, mode, final=False)
132
+ out = summarize_raw_t5(prompt, min(target_chars, 300), approx_tokens)
133
+ return postprocess_strict(out, mode)
134
+
135
  if approx_tokens <= 1000:
136
+ prompt = apply_style_prompt_t5(text, mode, final=False)
137
+ out = summarize_raw_t5(prompt, target_chars, approx_tokens)
138
+ return postprocess_strict(out, mode)
139
+
140
  sentences = split_into_sentences(text)
141
  chunks = chunk_by_tokens(sentences, max_tokens=900)
142
+
143
  partial_summaries = []
144
+ per_chunk_chars = max(180, int(target_chars * 1.2 / max(1, len(chunks))))
 
145
  for c in chunks:
146
+ prompt = apply_style_prompt_t5(c, mode, final=False)
147
+ psum = summarize_raw_t5(prompt, per_chunk_chars, token_length(c))
 
148
  partial_summaries.append(psum)
149
+
150
  merged = normalize_text(" ".join(partial_summaries))
151
+ merged = derpeat(merged)
152
+
153
+ final_prompt = apply_style_prompt_t5(merged, mode, final=True)
154
+ final = summarize_raw_t5(final_prompt, target_chars, token_length(merged))
155
+ return postprocess_strict(final, mode)
156
 
157
  # ===== Gradio UI =====
158
  def ui_summarize(text, target_len, style):
 
160
  return summarize_long(text, int(target_len), mode)
161
 
162
  with gr.Blocks() as demo:
163
+ gr.Markdown("## ๐Ÿ“ KoT5 ํ•œ๊ตญ์–ด ์š”์•ฝ๊ธฐ (๋ฐ˜๋ณต ์–ต์ œ + ์ˆœ์„œ ๋ณด์กด)")
164
  with gr.Row():
165
  with gr.Column():
166
  input_text = gr.Textbox(label="์›๋ฌธ ์ž…๋ ฅ", lines=16)