maxiaolong03 commited on
Commit
5d55daf
·
1 Parent(s): 317a8dd
Files changed (2) hide show
  1. app.py +103 -27
  2. bot_requests.py +39 -14
app.py CHANGED
@@ -47,10 +47,21 @@ def get_args() -> argparse.Namespace:
47
  """
48
  parser = ArgumentParser(description="ERNIE models web chat demo.")
49
 
50
- parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
51
- parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
52
- parser.add_argument("--max_char", type=int, default=8000, help="Maximum character limit for messages.")
53
- parser.add_argument("--max_retry_num", type=int, default=3, help="Maximum retry number for request.")
 
 
 
 
 
 
 
 
 
 
 
54
  parser.add_argument(
55
  "--model_name_map",
56
  type=str,
@@ -97,7 +108,15 @@ def get_args() -> argparse.Namespace:
97
  * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
98
  """,
99
  )
100
- parser.add_argument("--api_key", type=str, default="bce-v3/xxx", help="Model service API key.")
 
 
 
 
 
 
 
 
101
 
102
  args = parser.parse_args()
103
  try:
@@ -202,7 +221,12 @@ class GradioEvents:
202
  if idx in image_history:
203
  content = []
204
  content.append(
205
- {"type": "image_url", "image_url": {"url": GradioEvents.get_image_url(image_history[idx])}}
 
 
 
 
 
206
  )
207
  content.append({"type": "text", "text": query_h})
208
  conversation.append({"role": "user", "content": content})
@@ -211,9 +235,16 @@ class GradioEvents:
211
  conversation.append({"role": "assistant", "content": response_h})
212
 
213
  content = []
214
- if file_url and (len(image_history) == 0 or file_url != list(image_history.values())[-1]):
 
 
215
  image_history[len(task_history)] = file_url
216
- content.append({"type": "image_url", "image_url": {"url": GradioEvents.get_image_url(file_url)}})
 
 
 
 
 
217
  content.append({"type": "text", "text": query})
218
  conversation.append({"role": "user", "content": content})
219
  else:
@@ -222,7 +253,9 @@ class GradioEvents:
222
  try:
223
  req_data = {"messages": conversation}
224
  model_name = model_name_map.get(model_name, model_name)
225
- for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
 
 
226
  if "error" in chunk:
227
  raise Exception(chunk["error"])
228
 
@@ -395,7 +428,9 @@ class GradioEvents:
395
  """
396
  GradioEvents.gc()
397
 
398
- reset_result = namedtuple("reset_result", ["chatbot", "task_history", "image_history", "file_btn"])
 
 
399
  return reset_result(
400
  [], # clear chatbot
401
  [], # clear task_history
@@ -421,7 +456,9 @@ class GradioEvents:
421
  Returns:
422
  gr.update: An update object representing the visibility of the file button.
423
  """
424
- return gr.update(visible=model_name.upper().startswith(MULTI_MODEL_PREFIX)) # file_btn
 
 
425
 
426
 
427
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
@@ -477,11 +514,16 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
477
  (本演示基于文心大模型实现。)</center>"""
478
  )
479
 
480
- chatbot = gr.Chatbot(label="ERNIE", elem_classes="control-height", type="messages")
 
 
481
  model_names = list(args.model_name_map.keys())
482
  with gr.Row():
483
  model_name = gr.Dropdown(
484
- label="Select Model", choices=model_names, value=model_names[0], allow_custom_value=True
 
 
 
485
  )
486
  file_btn = gr.File(
487
  label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
@@ -497,55 +539,89 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
497
  submit_btn = gr.Button("🚀 Submit(发送)", elem_id="submit-button")
498
  regen_btn = gr.Button("🤔️ Regenerate(重试)")
499
 
500
- with gr.Accordion("⚙️ Advanced Config", open=False): # open=False means collapsed by default
 
 
501
  system_message = gr.Textbox(value="", label="System message", visible=True)
502
  additional_inputs = [
503
  system_message,
504
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
505
- gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"),
506
- gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
 
 
 
 
507
  ]
508
 
509
  task_history = gr.State([])
510
  image_history = gr.State({})
511
 
512
- model_name.change(GradioEvents.toggle_components_visibility, inputs=model_name, outputs=file_btn)
513
  model_name.change(
514
- GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
 
 
 
 
 
 
 
515
  )
516
  predict_with_clients = partial(
517
- GradioEvents.predict_stream, model_name_map=args.model_name_map, bot_client=bot_client
 
 
518
  )
519
  regenerate_with_clients = partial(
520
- GradioEvents.regenerate, model_name_map=args.model_name_map, bot_client=bot_client
 
 
521
  )
522
  query.submit(
523
  predict_with_clients,
524
- inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
 
525
  outputs=[chatbot],
526
  show_progress=True,
527
  )
528
  query.submit(GradioEvents.reset_user_input, [], [query])
529
  submit_btn.click(
530
  predict_with_clients,
531
- inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
 
532
  outputs=[chatbot],
533
  show_progress=True,
534
  )
535
  submit_btn.click(GradioEvents.reset_user_input, [], [query])
536
  empty_btn.click(
537
- GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
 
 
538
  )
539
  regen_btn.click(
540
  regenerate_with_clients,
541
- inputs=[chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
 
542
  outputs=[chatbot],
543
  show_progress=True,
544
  )
545
 
546
- demo.load(GradioEvents.toggle_components_visibility, inputs=gr.State(model_names[0]), outputs=file_btn)
 
 
 
 
547
 
548
- demo.queue().launch(server_port=args.server_port, server_name=args.server_name)
 
 
 
549
 
550
 
551
  def main():
 
47
  """
48
  parser = ArgumentParser(description="ERNIE models web chat demo.")
49
 
50
+ parser.add_argument(
51
+ "--server-port", type=int, default=7860, help="Demo server port."
52
+ )
53
+ parser.add_argument(
54
+ "--server-name", type=str, default="0.0.0.0", help="Demo server name."
55
+ )
56
+ parser.add_argument(
57
+ "--max_char",
58
+ type=int,
59
+ default=8000,
60
+ help="Maximum character limit for messages.",
61
+ )
62
+ parser.add_argument(
63
+ "--max_retry_num", type=int, default=3, help="Maximum retry number for request."
64
+ )
65
  parser.add_argument(
66
  "--model_name_map",
67
  type=str,
 
108
  * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
109
  """,
110
  )
111
+ parser.add_argument(
112
+ "--api_key", type=str, default="bce-v3/xxx", help="Model service API key."
113
+ )
114
+ parser.add_argument(
115
+ "--concurrency_limit", type=int, default=10, help="Default concurrency limit."
116
+ )
117
+ parser.add_argument(
118
+ "--max_queue_size", type=int, default=50, help="Maximum queue size for request."
119
+ )
120
 
121
  args = parser.parse_args()
122
  try:
 
221
  if idx in image_history:
222
  content = []
223
  content.append(
224
+ {
225
+ "type": "image_url",
226
+ "image_url": {
227
+ "url": GradioEvents.get_image_url(image_history[idx])
228
+ },
229
+ }
230
  )
231
  content.append({"type": "text", "text": query_h})
232
  conversation.append({"role": "user", "content": content})
 
235
  conversation.append({"role": "assistant", "content": response_h})
236
 
237
  content = []
238
+ if file_url and (
239
+ len(image_history) == 0 or file_url != list(image_history.values())[-1]
240
+ ):
241
  image_history[len(task_history)] = file_url
242
+ content.append(
243
+ {
244
+ "type": "image_url",
245
+ "image_url": {"url": GradioEvents.get_image_url(file_url)},
246
+ }
247
+ )
248
  content.append({"type": "text", "text": query})
249
  conversation.append({"role": "user", "content": content})
250
  else:
 
253
  try:
254
  req_data = {"messages": conversation}
255
  model_name = model_name_map.get(model_name, model_name)
256
+ for chunk in bot_client.process_stream(
257
+ model_name, req_data, max_tokens, temperature, top_p
258
+ ):
259
  if "error" in chunk:
260
  raise Exception(chunk["error"])
261
 
 
428
  """
429
  GradioEvents.gc()
430
 
431
+ reset_result = namedtuple(
432
+ "reset_result", ["chatbot", "task_history", "image_history", "file_btn"]
433
+ )
434
  return reset_result(
435
  [], # clear chatbot
436
  [], # clear task_history
 
456
  Returns:
457
  gr.update: An update object representing the visibility of the file button.
458
  """
459
+ return gr.update(
460
+ visible=model_name.upper().startswith(MULTI_MODEL_PREFIX)
461
+ ) # file_btn
462
 
463
 
464
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
 
514
  (本演示基于文心大模型实现。)</center>"""
515
  )
516
 
517
+ chatbot = gr.Chatbot(
518
+ label="ERNIE", elem_classes="control-height", type="messages"
519
+ )
520
  model_names = list(args.model_name_map.keys())
521
  with gr.Row():
522
  model_name = gr.Dropdown(
523
+ label="Select Model",
524
+ choices=model_names,
525
+ value=model_names[0],
526
+ allow_custom_value=True,
527
  )
528
  file_btn = gr.File(
529
  label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
 
539
  submit_btn = gr.Button("🚀 Submit(发送)", elem_id="submit-button")
540
  regen_btn = gr.Button("🤔️ Regenerate(重试)")
541
 
542
+ with gr.Accordion(
543
+ "⚙️ Advanced Config", open=False
544
+ ): # open=False means collapsed by default
545
  system_message = gr.Textbox(value="", label="System message", visible=True)
546
  additional_inputs = [
547
  system_message,
548
+ gr.Slider(
549
+ minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"
550
+ ),
551
+ gr.Slider(
552
+ minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"
553
+ ),
554
+ gr.Slider(
555
+ minimum=0.1,
556
+ maximum=1.0,
557
+ value=0.7,
558
+ step=0.05,
559
+ label="Top-p (nucleus sampling)",
560
+ ),
561
  ]
562
 
563
  task_history = gr.State([])
564
  image_history = gr.State({})
565
 
 
566
  model_name.change(
567
+ GradioEvents.toggle_components_visibility,
568
+ inputs=model_name,
569
+ outputs=file_btn,
570
+ )
571
+ model_name.change(
572
+ GradioEvents.reset_state,
573
+ outputs=[chatbot, task_history, image_history, file_btn],
574
+ show_progress=True,
575
  )
576
  predict_with_clients = partial(
577
+ GradioEvents.predict_stream,
578
+ model_name_map=args.model_name_map,
579
+ bot_client=bot_client,
580
  )
581
  regenerate_with_clients = partial(
582
+ GradioEvents.regenerate,
583
+ model_name_map=args.model_name_map,
584
+ bot_client=bot_client,
585
  )
586
  query.submit(
587
  predict_with_clients,
588
+ inputs=[query, chatbot, task_history, image_history, model_name, file_btn]
589
+ + additional_inputs,
590
  outputs=[chatbot],
591
  show_progress=True,
592
  )
593
  query.submit(GradioEvents.reset_user_input, [], [query])
594
  submit_btn.click(
595
  predict_with_clients,
596
+ inputs=[query, chatbot, task_history, image_history, model_name, file_btn]
597
+ + additional_inputs,
598
  outputs=[chatbot],
599
  show_progress=True,
600
  )
601
  submit_btn.click(GradioEvents.reset_user_input, [], [query])
602
  empty_btn.click(
603
+ GradioEvents.reset_state,
604
+ outputs=[chatbot, task_history, image_history, file_btn],
605
+ show_progress=True,
606
  )
607
  regen_btn.click(
608
  regenerate_with_clients,
609
+ inputs=[chatbot, task_history, image_history, model_name, file_btn]
610
+ + additional_inputs,
611
  outputs=[chatbot],
612
  show_progress=True,
613
  )
614
 
615
+ demo.load(
616
+ GradioEvents.toggle_components_visibility,
617
+ inputs=gr.State(model_names[0]),
618
+ outputs=file_btn,
619
+ )
620
 
621
+ demo.queue(
622
+ default_concurrency_limit=args.concurrency_limit, max_size=args.max_queue_size
623
+ )
624
+ demo.launch(server_port=args.server_port, server_name=args.server_name)
625
 
626
 
627
  def main():
bot_requests.py CHANGED
@@ -40,17 +40,21 @@ class BotClient:
40
  """
41
  self.logger = logging.getLogger(__name__)
42
 
43
- self.max_retry_num = getattr(args, 'max_retry_num', 3)
44
- self.max_char = getattr(args, 'max_char', 8000)
45
 
46
- self.model_map = getattr(args, 'model_map', {})
47
  self.api_key = os.environ.get("API_KEY")
48
 
49
- self.embedding_service_url = getattr(args, 'embedding_service_url', 'embedding_service_url')
50
- self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
 
 
51
 
52
- self.web_search_service_url = getattr(args, 'web_search_service_url', 'web_search_service_url')
53
- self.max_search_results_num = getattr(args, 'max_search_results_num', 15)
 
 
54
 
55
  self.qianfan_api_key = os.environ.get("API_KEY")
56
 
@@ -109,7 +113,12 @@ class BotClient:
109
  raise
110
 
111
  def process(
112
- self, model_name: str, req_data: dict, max_tokens: int = 2048, temperature: float = 1.0, top_p: float = 0.7
 
 
 
 
 
113
  ) -> dict:
114
  """
115
  Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
@@ -152,7 +161,12 @@ class BotClient:
152
  return res
153
 
154
  def process_stream(
155
- self, model_name: str, req_data: dict, max_tokens: int = 2048, temperature: float = 1.0, top_p: float = 0.7
 
 
 
 
 
156
  ) -> dict:
157
  """
158
  Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
@@ -188,7 +202,9 @@ class BotClient:
188
 
189
  except Exception as e:
190
  last_error = e
191
- self.logger.error(f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}")
 
 
192
 
193
  self.logger.error("All retry attempts failed for stream request")
194
  yield {"error": str(last_error)}
@@ -209,7 +225,9 @@ class BotClient:
209
  en_ch_words = []
210
 
211
  for word in words:
212
- if word.isalpha() and not any("\u4e00" <= char <= "\u9fff" for char in word):
 
 
213
  en_ch_words.append(word)
214
  else:
215
  en_ch_words.extend(list(word))
@@ -341,7 +359,9 @@ class BotClient:
341
  Returns:
342
  list: A list of floats representing the embedding.
343
  """
344
- client = OpenAI(base_url=self.embedding_service_url, api_key=self.qianfan_api_key)
 
 
345
  response = client.embeddings.create(input=[text], model=self.embedding_model)
346
  return response.data[0].embedding
347
 
@@ -355,7 +375,10 @@ class BotClient:
355
  Returns:
356
  list: List of responses from the AI Search service.
357
  """
358
- headers = {"Authorization": "Bearer " + self.qianfan_api_key, "Content-Type": "application/json"}
 
 
 
359
 
360
  results = []
361
  top_k = self.max_search_results_num // len(query_list)
@@ -364,7 +387,9 @@ class BotClient:
364
  "messages": [{"role": "user", "content": query}],
365
  "resource_type_filter": [{"type": "web", "top_k": top_k}],
366
  }
367
- response = requests.post(self.web_search_service_url, headers=headers, json=payload)
 
 
368
 
369
  if response.status_code == 200:
370
  response = response.json()
 
40
  """
41
  self.logger = logging.getLogger(__name__)
42
 
43
+ self.max_retry_num = getattr(args, "max_retry_num", 3)
44
+ self.max_char = getattr(args, "max_char", 8000)
45
 
46
+ self.model_map = getattr(args, "model_map", {})
47
  self.api_key = os.environ.get("API_KEY")
48
 
49
+ self.embedding_service_url = getattr(
50
+ args, "embedding_service_url", "embedding_service_url"
51
+ )
52
+ self.embedding_model = getattr(args, "embedding_model", "embedding_model")
53
 
54
+ self.web_search_service_url = getattr(
55
+ args, "web_search_service_url", "web_search_service_url"
56
+ )
57
+ self.max_search_results_num = getattr(args, "max_search_results_num", 15)
58
 
59
  self.qianfan_api_key = os.environ.get("API_KEY")
60
 
 
113
  raise
114
 
115
  def process(
116
+ self,
117
+ model_name: str,
118
+ req_data: dict,
119
+ max_tokens: int = 2048,
120
+ temperature: float = 1.0,
121
+ top_p: float = 0.7,
122
  ) -> dict:
123
  """
124
  Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
 
161
  return res
162
 
163
  def process_stream(
164
+ self,
165
+ model_name: str,
166
+ req_data: dict,
167
+ max_tokens: int = 2048,
168
+ temperature: float = 1.0,
169
+ top_p: float = 0.7,
170
  ) -> dict:
171
  """
172
  Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
 
202
 
203
  except Exception as e:
204
  last_error = e
205
+ self.logger.error(
206
+ f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}"
207
+ )
208
 
209
  self.logger.error("All retry attempts failed for stream request")
210
  yield {"error": str(last_error)}
 
225
  en_ch_words = []
226
 
227
  for word in words:
228
+ if word.isalpha() and not any(
229
+ "\u4e00" <= char <= "\u9fff" for char in word
230
+ ):
231
  en_ch_words.append(word)
232
  else:
233
  en_ch_words.extend(list(word))
 
359
  Returns:
360
  list: A list of floats representing the embedding.
361
  """
362
+ client = OpenAI(
363
+ base_url=self.embedding_service_url, api_key=self.qianfan_api_key
364
+ )
365
  response = client.embeddings.create(input=[text], model=self.embedding_model)
366
  return response.data[0].embedding
367
 
 
375
  Returns:
376
  list: List of responses from the AI Search service.
377
  """
378
+ headers = {
379
+ "Authorization": "Bearer " + self.qianfan_api_key,
380
+ "Content-Type": "application/json",
381
+ }
382
 
383
  results = []
384
  top_k = self.max_search_results_num // len(query_list)
 
387
  "messages": [{"role": "user", "content": query}],
388
  "resource_type_filter": [{"type": "web", "top_k": top_k}],
389
  }
390
+ response = requests.post(
391
+ self.web_search_service_url, headers=headers, json=payload
392
+ )
393
 
394
  if response.status_code == 200:
395
  response = response.json()