seawolf2357 commited on
Commit
4bf30b7
·
verified ·
1 Parent(s): 0889c6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -30,17 +30,13 @@ SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
30
  ##############################################################################
31
  def extract_keywords(text: str, top_k: int = 5) -> str:
32
  """
33
- 1) 한글, 영어, 숫자, 공백만 남기도록 정규식 변경
34
  2) 공백 기준 토큰 분리
35
  3) 최대 top_k개만
36
  """
37
- # 한글(가-힣)+영어대소문자+숫자+공백만 보존
38
  text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text)
39
- # 토큰 분리
40
  tokens = text.split()
41
- # 최대 top_k개 추출
42
  key_tokens = tokens[:top_k]
43
- # 다시 합침
44
  return " ".join(key_tokens)
45
 
46
  ##############################################################################
@@ -74,7 +70,6 @@ def do_web_search(query: str) -> str:
74
 
75
  summary_lines = []
76
  for idx, item in enumerate(organic[:20], start=1):
77
- # item 전체를 JSON 문자열로
78
  item_json = json.dumps(item, ensure_ascii=False, indent=2)
79
  summary_lines.append(f"Result {idx}:\n{item_json}\n")
80
 
@@ -89,6 +84,7 @@ def do_web_search(query: str) -> str:
89
  ##############################################################################
90
  MAX_CONTENT_CHARS = 4000
91
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
 
92
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
93
  model = Gemma3ForConditionalGeneration.from_pretrained(
94
  model_id,
@@ -390,47 +386,36 @@ def run(
390
  return
391
 
392
  try:
393
- # (1) system 메시지를 하나로 합치기 위해, 미리 buffer
394
  combined_system_msg = ""
395
 
396
- # 사용자가 system_prompt를 입력했다면
397
  if system_prompt.strip():
398
  combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
399
 
400
- # (2) 웹 검색 체크 시, 키워드 추출
401
  if use_web_search:
402
  user_text = message["text"]
403
  ws_query = extract_keywords(user_text, top_k=5)
404
- # 만약 추출 키워드가 비어있으면 검색을 건너뜀
405
  if ws_query.strip():
406
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
407
  ws_result = do_web_search(ws_query)
408
- # 검색 결과를 시스템 메시지 끝에 합침
409
  combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
410
  else:
411
- # 추출된 키워드가 없으면 굳이 검색 시도 안 함
412
  combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
413
 
414
- # (3) system 메시지가 최종적으로 비어 있지 않다면
415
  messages = []
416
  if combined_system_msg.strip():
417
- # system 역할 메시지 하나 생성
418
  messages.append({
419
  "role": "system",
420
  "content": [{"type": "text", "text": combined_system_msg.strip()}],
421
  })
422
 
423
- # (4) 이전 대화이력
424
  messages.extend(process_history(history))
425
 
426
- # (5) 새 유저 메시지
427
  user_content = process_new_user_message(message)
428
  for item in user_content:
429
  if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
430
  item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
431
  messages.append({"role": "user", "content": user_content})
432
 
433
- # (6) LLM 입력 생성
434
  inputs = processor.apply_chat_template(
435
  messages,
436
  add_generation_prompt=True,
@@ -446,7 +431,7 @@ def run(
446
  max_new_tokens=max_new_tokens,
447
  )
448
 
449
- t = Thread(target=model.generate, kwargs=gen_kwargs)
450
  t.start()
451
 
452
  output = ""
@@ -459,6 +444,22 @@ def run(
459
  yield f"죄송합니다. 오류가 발생했습니다: {str(e)}"
460
 
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  ##############################################################################
463
  # 예시들 (한글화)
464
  ##############################################################################
@@ -658,7 +659,7 @@ with gr.Blocks(css=css, title="Vidraft-Gemma-3-27B") as demo:
658
  minimum=100,
659
  maximum=8000,
660
  step=50,
661
- value=2000,
662
  )
663
 
664
  gr.Markdown("<br><br>")
@@ -698,12 +699,12 @@ with gr.Blocks(css=css, title="Vidraft-Gemma-3-27B") as demo:
698
  gr.Markdown("### Example Inputs (click to load)")
699
  gr.Examples(
700
  examples=examples,
701
- inputs=[], # 연결할 inputs가 없으므로 빈 리스트
702
  cache_examples=False
703
  )
704
 
705
  if __name__ == "__main__":
706
- # 615줄 + filler로 715줄 맞추려면 아래 주석 추가
707
- demo.launch(share=True)
708
-
709
 
 
30
  ##############################################################################
31
  def extract_keywords(text: str, top_k: int = 5) -> str:
32
  """
33
+ 1) 한글(가-힣), 영어(a-zA-Z), 숫자(0-9), 공백만 남김
34
  2) 공백 기준 토큰 분리
35
  3) 최대 top_k개만
36
  """
 
37
  text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text)
 
38
  tokens = text.split()
 
39
  key_tokens = tokens[:top_k]
 
40
  return " ".join(key_tokens)
41
 
42
  ##############################################################################
 
70
 
71
  summary_lines = []
72
  for idx, item in enumerate(organic[:20], start=1):
 
73
  item_json = json.dumps(item, ensure_ascii=False, indent=2)
74
  summary_lines.append(f"Result {idx}:\n{item_json}\n")
75
 
 
84
  ##############################################################################
85
  MAX_CONTENT_CHARS = 4000
86
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
87
+
88
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
89
  model = Gemma3ForConditionalGeneration.from_pretrained(
90
  model_id,
 
386
  return
387
 
388
  try:
 
389
  combined_system_msg = ""
390
 
 
391
  if system_prompt.strip():
392
  combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
393
 
 
394
  if use_web_search:
395
  user_text = message["text"]
396
  ws_query = extract_keywords(user_text, top_k=5)
 
397
  if ws_query.strip():
398
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
399
  ws_result = do_web_search(ws_query)
 
400
  combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
401
  else:
 
402
  combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
403
 
 
404
  messages = []
405
  if combined_system_msg.strip():
 
406
  messages.append({
407
  "role": "system",
408
  "content": [{"type": "text", "text": combined_system_msg.strip()}],
409
  })
410
 
 
411
  messages.extend(process_history(history))
412
 
 
413
  user_content = process_new_user_message(message)
414
  for item in user_content:
415
  if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
416
  item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
417
  messages.append({"role": "user", "content": user_content})
418
 
 
419
  inputs = processor.apply_chat_template(
420
  messages,
421
  add_generation_prompt=True,
 
431
  max_new_tokens=max_new_tokens,
432
  )
433
 
434
+ t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
435
  t.start()
436
 
437
  output = ""
 
444
  yield f"죄송합니다. 오류가 발생했습니다: {str(e)}"
445
 
446
 
447
+ ##############################################################################
448
+ # [추가] 별도 함수에서 model.generate(...)를 호출, OOM 캐치
449
+ ##############################################################################
450
+ def _model_gen_with_oom_catch(**kwargs):
451
+ """
452
+ 별도 스레드에서 OutOfMemoryError를 잡아주기 위해
453
+ """
454
+ try:
455
+ model.generate(**kwargs)
456
+ except torch.cuda.OutOfMemoryError:
457
+ raise RuntimeError(
458
+ "[OutOfMemoryError] GPU 메모리가 부족합니다. "
459
+ "Max New Tokens을 줄이거나, 프롬프트 길이를 줄여주세요."
460
+ )
461
+
462
+
463
  ##############################################################################
464
  # 예시들 (한글화)
465
  ##############################################################################
 
659
  minimum=100,
660
  maximum=8000,
661
  step=50,
662
+ value=512, # GPU 메모리 절약 위해 기본값 약간 축소
663
  )
664
 
665
  gr.Markdown("<br><br>")
 
699
  gr.Markdown("### Example Inputs (click to load)")
700
  gr.Examples(
701
  examples=examples,
702
+ inputs=[],
703
  cache_examples=False
704
  )
705
 
706
  if __name__ == "__main__":
707
+ # share=True HF Spaces에서 경고 발생 - 로컬에서만 동작
708
+ # demo.launch(share=True)
709
+ demo.launch()
710