AlsuGibadullina commited on
Commit
7dc4be3
·
verified ·
1 Parent(s): c426f40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -49
app.py CHANGED
@@ -2,8 +2,10 @@ import tempfile
2
  from typing import List, Tuple, Any
3
 
4
  import gradio as gr
 
5
  import torch
6
  import torch.nn.functional as torch_functional
 
7
  from PIL import Image, ImageDraw
8
  from transformers import (
9
  AutoTokenizer,
@@ -11,22 +13,26 @@ from transformers import (
11
  CLIPProcessor,
12
  SamModel,
13
  SamProcessor,
 
14
  pipeline,
15
  BlipForQuestionAnswering,
16
  BlipProcessor,
17
  )
18
 
19
- MODEL_STORE = {}
20
 
 
21
 
22
  def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
23
  if not gallery_value:
24
  return []
 
25
  normalized_images: List[Image.Image] = []
 
26
  for item in gallery_value:
27
  if isinstance(item, Image.Image):
28
  normalized_images.append(item)
29
  continue
 
30
  if isinstance(item, str):
31
  try:
32
  image_object = Image.open(item).convert("RGB")
@@ -34,18 +40,61 @@ def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
34
  except Exception:
35
  continue
36
  continue
 
37
  if isinstance(item, (list, tuple)) and item:
38
  candidate = item[0]
39
  if isinstance(candidate, Image.Image):
40
  normalized_images.append(candidate)
41
  continue
 
42
  if isinstance(item, dict):
43
  candidate = item.get("image") or item.get("value")
44
  if isinstance(candidate, Image.Image):
45
  normalized_images.append(candidate)
46
  continue
 
47
  return normalized_images
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
51
  if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
@@ -53,11 +102,11 @@ def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
53
  blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
54
  MODEL_STORE["blip_vqa_model"] = blip_model
55
  MODEL_STORE["blip_vqa_processor"] = blip_processor
 
56
  blip_model = MODEL_STORE["blip_vqa_model"]
57
  blip_processor = MODEL_STORE["blip_vqa_processor"]
58
  return blip_model, blip_processor
59
 
60
-
61
  def get_vision_pipeline(model_key: str):
62
  if model_key in MODEL_STORE:
63
  return MODEL_STORE[model_key]
@@ -72,16 +121,19 @@ def get_vision_pipeline(model_key: str):
72
  task="object-detection",
73
  model="hustvl/yolos-small",
74
  )
 
75
  elif model_key == "segmentation":
76
  vision_pipeline = pipeline(
77
  task="image-segmentation",
78
  model="nvidia/segformer-b0-finetuned-ade-512-512",
79
  )
 
80
  elif model_key == "depth_estimation":
81
  vision_pipeline = pipeline(
82
  task="depth-estimation",
83
  model="Intel/dpt-hybrid-midas",
84
  )
 
85
  elif model_key == "captioning_blip_base":
86
  vision_pipeline = pipeline(
87
  task="image-to-text",
@@ -92,6 +144,7 @@ def get_vision_pipeline(model_key: str):
92
  task="image-to-text",
93
  model="Salesforce/blip-image-captioning-large",
94
  )
 
95
  elif model_key == "vqa_blip_base":
96
  vision_pipeline = pipeline(
97
  task="visual-question-answering",
@@ -102,6 +155,7 @@ def get_vision_pipeline(model_key: str):
102
  task="visual-question-answering",
103
  model="dandelin/vilt-b32-finetuned-vqa",
104
  )
 
105
  else:
106
  raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
107
 
@@ -123,6 +177,7 @@ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
123
 
124
  clip_model = CLIPModel.from_pretrained(clip_name)
125
  clip_processor = CLIPProcessor.from_pretrained(clip_name)
 
126
  MODEL_STORE[model_store_key_model] = clip_model
127
  MODEL_STORE[model_store_key_processor] = clip_processor
128
 
@@ -131,26 +186,125 @@ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
131
  return clip_model, clip_processor
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def get_sam_components() -> Tuple[SamModel, SamProcessor]:
135
  if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
136
  sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
137
  sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
138
  MODEL_STORE["sam_model"] = sam_model
139
  MODEL_STORE["sam_processor"] = sam_processor
 
140
  sam_model = MODEL_STORE["sam_model"]
141
  sam_processor = MODEL_STORE["sam_processor"]
142
  return sam_model, sam_processor
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def detect_objects_on_image(image_object, model_key: str):
146
  detector_pipeline = get_vision_pipeline(model_key)
147
  detection_results = detector_pipeline(image_object)
148
- drawer_object = ImageDraw.Draw(image_object)
149
 
 
150
  for detection_item in detection_results:
151
  box_data = detection_item["box"]
152
  label_value = detection_item["label"]
153
  score_value = detection_item["score"]
 
154
  drawer_object.rectangle(
155
  [
156
  box_data["xmin"],
@@ -166,6 +320,7 @@ def detect_objects_on_image(image_object, model_key: str):
166
  f"{label_value}: {score_value:.2f}",
167
  fill="red",
168
  )
 
169
  return image_object
170
 
171
 
@@ -178,6 +333,7 @@ def segment_image(image_object):
178
  def estimate_image_depth(image_object):
179
  depth_pipeline = get_vision_pipeline("depth_estimation")
180
  depth_output = depth_pipeline(image_object)
 
181
  predicted_depth_tensor = depth_output["predicted_depth"]
182
 
183
  if predicted_depth_tensor.ndim == 3:
@@ -195,8 +351,10 @@ def estimate_image_depth(image_object):
195
  mode="bicubic",
196
  align_corners=False,
197
  )
 
198
  depth_array = resized_depth_tensor.squeeze().cpu().numpy()
199
  max_value = float(depth_array.max())
 
200
  if max_value <= 0.0:
201
  return Image.new("L", image_object.size, color=0)
202
 
@@ -214,35 +372,42 @@ def generate_image_caption(image_object, model_key: str) -> str:
214
  def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
215
  if image_object is None:
216
  return "Пожалуйста, сначала загрузите изображение."
 
217
  if not question_text.strip():
218
  return "Пожалуйста, введите вопрос об изображении."
219
 
220
  if model_key == "vqa_blip_base":
221
  blip_model, blip_processor = get_blip_vqa_components()
 
222
  inputs = blip_processor(
223
  images=image_object,
224
  text=question_text,
225
  return_tensors="pt",
226
  )
 
227
  with torch.no_grad():
228
  output_ids = blip_model.generate(**inputs)
 
229
  decoded_answers = blip_processor.batch_decode(
230
  output_ids,
231
  skip_special_tokens=True,
232
  )
233
  answer_text = decoded_answers[0] if decoded_answers else ""
 
234
  return answer_text or "Модель не смогла сгенерировать ответ."
235
 
236
  vqa_pipeline = get_vision_pipeline(model_key)
 
237
  vqa_result = vqa_pipeline(
238
  image=image_object,
239
  question=question_text,
240
  )
 
241
  top_item = vqa_result[0]
242
  answer_text = top_item["answer"]
243
  confidence_value = top_item["score"]
244
- return f"{answer_text} (confidence: {confidence_value:.3f})"
245
 
 
246
 
247
  def perform_zero_shot_classification(
248
  image_object,
@@ -250,6 +415,7 @@ def perform_zero_shot_classification(
250
  clip_key: str,
251
  ) -> str:
252
  clip_model, clip_processor = get_clip_components(clip_key)
 
253
  class_list = [
254
  class_name.strip()
255
  for class_name in class_texts.split(",")
@@ -264,6 +430,7 @@ def perform_zero_shot_classification(
264
  return_tensors="pt",
265
  padding=True,
266
  )
 
267
  with torch.no_grad():
268
  clip_outputs = clip_model(**input_batch)
269
  logits_per_image = clip_outputs.logits_per_image
@@ -273,6 +440,7 @@ def perform_zero_shot_classification(
273
  for class_index, class_name in enumerate(class_list):
274
  probability_value = probability_tensor[0][class_index].item()
275
  result_lines.append(f"{class_name}: {probability_value:.4f}")
 
276
  return "\n".join(result_lines)
277
 
278
 
@@ -282,10 +450,12 @@ def retrieve_best_image(
282
  clip_key: str,
283
  ) -> Tuple[str, Image.Image | None]:
284
  image_list = _normalize_gallery_images(gallery_value)
 
285
  if not image_list or not query_text.strip():
286
  return "Пожалуйста, загрузите изображения и введите запрос", None
287
 
288
  clip_model, clip_processor = get_clip_components(clip_key)
 
289
  image_inputs = clip_processor(
290
  images=image_list,
291
  return_tensors="pt",
@@ -293,10 +463,10 @@ def retrieve_best_image(
293
  )
294
  with torch.no_grad():
295
  image_features = clip_model.get_image_features(**image_inputs)
296
- image_features = image_features / image_features.norm(
297
- dim=-1,
298
- keepdim=True,
299
- )
300
 
301
  text_inputs = clip_processor(
302
  text=[query_text],
@@ -305,10 +475,10 @@ def retrieve_best_image(
305
  )
306
  with torch.no_grad():
307
  text_features = clip_model.get_text_features(**text_inputs)
308
- text_features = text_features / text_features.norm(
309
- dim=-1,
310
- keepdim=True,
311
- )
312
 
313
  similarity_tensor = image_features @ text_features.T
314
  best_index_tensor = similarity_tensor.argmax()
@@ -328,10 +498,12 @@ def segment_image_with_sam_points(
328
  ) -> Image.Image:
329
  if image_object is None:
330
  raise ValueError("Изображение не передано в segment_image_with_sam_points")
 
331
  if not point_coordinates_list:
332
  return Image.new("L", image_object.size, color=0)
333
 
334
  sam_model, sam_processor = get_sam_components()
 
335
  batched_points: List[List[List[int]]] = [point_coordinates_list]
336
  batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
337
 
@@ -341,6 +513,7 @@ def segment_image_with_sam_points(
341
  input_labels=batched_labels,
342
  return_tensors="pt",
343
  )
 
344
  with torch.no_grad():
345
  sam_outputs = sam_model(**sam_inputs, multimask_output=True)
346
 
@@ -349,37 +522,47 @@ def segment_image_with_sam_points(
349
  sam_inputs["original_sizes"].cpu(),
350
  sam_inputs["reshaped_input_sizes"].cpu(),
351
  )
 
352
  batch_masks_tensor = processed_masks_list[0]
 
353
  if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
354
  return Image.new("L", image_object.size, color=0)
355
 
356
  first_mask_tensor = batch_masks_tensor[0]
357
  mask_array = first_mask_tensor.numpy()
 
358
  binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
 
359
  mask_image = Image.fromarray(binary_mask_array, mode="L")
360
  return mask_image
361
 
362
 
363
  def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
 
364
  if image_object is None:
365
  return None
 
366
  coordinates_text_clean = coordinates_text.strip()
367
  if not coordinates_text_clean:
368
  return Image.new("L", image_object.size, color=0)
369
 
370
  point_coordinates_list: List[List[int]] = []
 
371
  for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
372
  raw_pair_clean = raw_pair.strip()
373
  if not raw_pair_clean:
374
  continue
 
375
  parts = raw_pair_clean.split(",")
376
  if len(parts) != 2:
377
  continue
 
378
  try:
379
  x_value = int(parts[0].strip())
380
  y_value = int(parts[1].strip())
381
  except ValueError:
382
  continue
 
383
  point_coordinates_list.append([x_value, y_value])
384
 
385
  if not point_coordinates_list:
@@ -391,6 +574,7 @@ def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Ima
391
  def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
392
  if not coordinates_text.strip():
393
  return []
 
394
  point_list: List[List[int]] = []
395
  for raw_pair in coordinates_text.split(";"):
396
  cleaned_pair = raw_pair.strip()
@@ -405,13 +589,13 @@ def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
405
  except ValueError:
406
  continue
407
  point_list.append([x_value, y_value])
408
- return point_list
409
 
 
410
 
411
  def build_interface():
412
  with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
413
  gr.Markdown("# AI модели")
414
-
415
  with gr.Tab("Детекция объектов"):
416
  gr.Markdown("## Детекция объектов")
417
  with gr.Row():
@@ -428,13 +612,15 @@ def build_interface():
428
  value="object_detection_conditional_detr",
429
  info=(
430
  "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
431
- "object_detection_yolos_small - hustvl/yolos-small"
432
  ),
433
  )
434
  object_detect_button = gr.Button("Применить")
435
- object_output_image = gr.Image(
436
- label="Результат",
437
- )
 
 
438
  object_detect_button.click(
439
  fn=detect_objects_on_image,
440
  inputs=[object_input_image, object_model_selector],
@@ -449,9 +635,11 @@ def build_interface():
449
  type="pil",
450
  )
451
  segmentation_button = gr.Button("Применить")
452
- segmentation_output_image = gr.Image(
453
- label="Маска",
454
- )
 
 
455
  segmentation_button.click(
456
  fn=segment_image,
457
  inputs=segmentation_input_image,
@@ -461,14 +649,17 @@ def build_interface():
461
  with gr.Tab("Глубина"):
462
  gr.Markdown("## Глубина (Depth Estimation)")
463
  with gr.Row():
 
464
  depth_input_image = gr.Image(
465
  label="Загрузите изображение",
466
  type="pil",
467
  )
468
  depth_button = gr.Button("Применить")
469
- depth_output_image = gr.Image(
470
- label="Глубины",
471
- )
 
 
472
  depth_button.click(
473
  fn=estimate_image_depth,
474
  inputs=depth_input_image,
@@ -490,15 +681,17 @@ def build_interface():
490
  label="Модель",
491
  value="captioning_blip_base",
492
  info=(
493
- "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
494
  "captioning_blip_large - Salesforce/blip-image-captioning-large"
495
  ),
496
  )
497
  caption_button = gr.Button("Применить")
498
- caption_output_text = gr.Textbox(
499
- label="Описание изображения",
500
- lines=3,
501
- )
 
 
502
  caption_button.click(
503
  fn=generate_image_caption,
504
  inputs=[caption_input_image, caption_model_selector],
@@ -526,14 +719,16 @@ def build_interface():
526
  value="vqa_blip_base",
527
  info=(
528
  "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
529
- "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa"
530
  ),
531
  )
532
  vqa_button = gr.Button("Ответить на вопрос")
533
- vqa_output_text = gr.Textbox(
534
- label="Ответ",
535
- lines=3,
536
- )
 
 
537
  vqa_button.click(
538
  fn=answer_visual_question,
539
  inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
@@ -561,14 +756,16 @@ def build_interface():
561
  value="clip_large_patch14",
562
  info=(
563
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
564
- "clip_base_patch32 - openai/clip-vit-base-patch32"
565
  ),
566
  )
567
  zero_shot_button = gr.Button("Применить")
568
- zero_shot_output_text = gr.Textbox(
569
- label="Результаты",
570
- lines=10,
571
- )
 
 
572
  zero_shot_button.click(
573
  fn=perform_zero_shot_classification,
574
  inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
@@ -578,6 +775,7 @@ def build_interface():
578
  with gr.Tab("Поиск изображений"):
579
  gr.Markdown("## Поиск изображений")
580
  with gr.Row():
 
581
  retrieval_dir = gr.File(
582
  label="Загрузите папку с изображениями",
583
  file_count="directory",
@@ -598,16 +796,18 @@ def build_interface():
598
  value="clip_large_patch14",
599
  info=(
600
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
601
- "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
602
  ),
603
  )
604
  retrieval_button = gr.Button("Поиск")
605
- retrieval_output_text = gr.Textbox(
606
- label="Результат",
607
- )
608
- retrieval_output_image = gr.Image(
609
- label="Наиболее подходящее изображение",
610
- )
 
 
611
  retrieval_button.click(
612
  fn=retrieve_best_image,
613
  inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
@@ -618,14 +818,15 @@ def build_interface():
618
  gr.Markdown("### Задачи:")
619
  gr.Markdown(
620
  """
621
- - Компьютерное зрение: детекция объектов, сегментация, оценка глубины, генерация описаний изображений
 
622
  - Мультимодальные задачи: вопросы к изображению, zero-shot классификация изображений, поиск по изображениям по текстовому запросу
623
- """
624
  )
625
-
626
  return demo_block
627
 
628
 
629
  if __name__ == "__main__":
630
  interface_block = build_interface()
631
  interface_block.launch(share=True)
 
 
2
  from typing import List, Tuple, Any
3
 
4
  import gradio as gr
5
+ import soundfile as sf
6
  import torch
7
  import torch.nn.functional as torch_functional
8
+ from gtts import gTTS
9
  from PIL import Image, ImageDraw
10
  from transformers import (
11
  AutoTokenizer,
 
13
  CLIPProcessor,
14
  SamModel,
15
  SamProcessor,
16
+ VitsModel,
17
  pipeline,
18
  BlipForQuestionAnswering,
19
  BlipProcessor,
20
  )
21
 
 
22
 
23
+ MODEL_STORE = {}
24
 
25
  def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
26
  if not gallery_value:
27
  return []
28
+
29
  normalized_images: List[Image.Image] = []
30
+
31
  for item in gallery_value:
32
  if isinstance(item, Image.Image):
33
  normalized_images.append(item)
34
  continue
35
+
36
  if isinstance(item, str):
37
  try:
38
  image_object = Image.open(item).convert("RGB")
 
40
  except Exception:
41
  continue
42
  continue
43
+
44
  if isinstance(item, (list, tuple)) and item:
45
  candidate = item[0]
46
  if isinstance(candidate, Image.Image):
47
  normalized_images.append(candidate)
48
  continue
49
+
50
  if isinstance(item, dict):
51
  candidate = item.get("image") or item.get("value")
52
  if isinstance(candidate, Image.Image):
53
  normalized_images.append(candidate)
54
  continue
55
+
56
  return normalized_images
57
 
58
+ def get_audio_pipeline(model_key: str):
59
+ if model_key in MODEL_STORE:
60
+ return MODEL_STORE[model_key]
61
+
62
+ if model_key == "whisper":
63
+ audio_pipeline = pipeline(
64
+ task="automatic-speech-recognition",
65
+ model="distil-whisper/distil-small.en",
66
+ )
67
+ elif model_key == "wav2vec2":
68
+ audio_pipeline = pipeline(
69
+ task="automatic-speech-recognition",
70
+ model="openai/whisper-small",
71
+ )
72
+ elif model_key == "audio_classifier":
73
+ audio_pipeline = pipeline(
74
+ task="audio-classification",
75
+ model="MIT/ast-finetuned-audioset-10-10-0.4593",
76
+ )
77
+ elif model_key == "emotion_classifier":
78
+ audio_pipeline = pipeline(
79
+ task="audio-classification",
80
+ model="superb/hubert-large-superb-er",
81
+ )
82
+ else:
83
+ raise ValueError(f"Неизвестный тип аудио модели: {model_key}")
84
+
85
+ MODEL_STORE[model_key] = audio_pipeline
86
+ return audio_pipeline
87
+
88
+
89
+ def get_zero_shot_audio_pipeline():
90
+ if "audio_zero_shot_clap" not in MODEL_STORE:
91
+ zero_shot_pipeline = pipeline(
92
+ task="zero-shot-audio-classification",
93
+ model="laion/clap-htsat-unfused",
94
+ )
95
+ MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline
96
+ return MODEL_STORE["audio_zero_shot_clap"]
97
+
98
 
99
  def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
100
  if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
 
102
  blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
103
  MODEL_STORE["blip_vqa_model"] = blip_model
104
  MODEL_STORE["blip_vqa_processor"] = blip_processor
105
+
106
  blip_model = MODEL_STORE["blip_vqa_model"]
107
  blip_processor = MODEL_STORE["blip_vqa_processor"]
108
  return blip_model, blip_processor
109
 
 
110
  def get_vision_pipeline(model_key: str):
111
  if model_key in MODEL_STORE:
112
  return MODEL_STORE[model_key]
 
121
  task="object-detection",
122
  model="hustvl/yolos-small",
123
  )
124
+
125
  elif model_key == "segmentation":
126
  vision_pipeline = pipeline(
127
  task="image-segmentation",
128
  model="nvidia/segformer-b0-finetuned-ade-512-512",
129
  )
130
+
131
  elif model_key == "depth_estimation":
132
  vision_pipeline = pipeline(
133
  task="depth-estimation",
134
  model="Intel/dpt-hybrid-midas",
135
  )
136
+
137
  elif model_key == "captioning_blip_base":
138
  vision_pipeline = pipeline(
139
  task="image-to-text",
 
144
  task="image-to-text",
145
  model="Salesforce/blip-image-captioning-large",
146
  )
147
+
148
  elif model_key == "vqa_blip_base":
149
  vision_pipeline = pipeline(
150
  task="visual-question-answering",
 
155
  task="visual-question-answering",
156
  model="dandelin/vilt-b32-finetuned-vqa",
157
  )
158
+
159
  else:
160
  raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
161
 
 
177
 
178
  clip_model = CLIPModel.from_pretrained(clip_name)
179
  clip_processor = CLIPProcessor.from_pretrained(clip_name)
180
+
181
  MODEL_STORE[model_store_key_model] = clip_model
182
  MODEL_STORE[model_store_key_processor] = clip_processor
183
 
 
186
  return clip_model, clip_processor
187
 
188
 
189
+ def get_silero_tts_model():
190
+ if "silero_tts_model" not in MODEL_STORE:
191
+ silero_model, _ = torch.hub.load(
192
+ repo_or_dir="snakers4/silero-models",
193
+ model="silero_tts",
194
+ language="ru",
195
+ speaker="ru_v3",
196
+ )
197
+ MODEL_STORE["silero_tts_model"] = silero_model
198
+ return MODEL_STORE["silero_tts_model"]
199
+
200
+
201
+ def get_mms_tts_components():
202
+ if "mms_tts_pipeline" not in MODEL_STORE:
203
+ tts_pipeline = pipeline(
204
+ task="text-to-speech",
205
+ model="facebook/mms-tts-rus",
206
+ )
207
+ MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
208
+
209
+ return MODEL_STORE["mms_tts_pipeline"]
210
+
211
+
212
  def get_sam_components() -> Tuple[SamModel, SamProcessor]:
213
  if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
214
  sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
215
  sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
216
  MODEL_STORE["sam_model"] = sam_model
217
  MODEL_STORE["sam_processor"] = sam_processor
218
+
219
  sam_model = MODEL_STORE["sam_model"]
220
  sam_processor = MODEL_STORE["sam_processor"]
221
  return sam_model, sam_processor
222
 
223
 
224
+
225
+ def classify_audio_file(audio_path: str, model_key: str) -> str:
226
+ audio_classifier = get_audio_pipeline(model_key)
227
+ prediction_list = audio_classifier(audio_path)
228
+
229
+ result_lines = ["Топ-5 предсказаний:"]
230
+ for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1):
231
+ label_value = prediction_item["label"]
232
+ score_value = prediction_item["score"]
233
+ result_lines.append(
234
+ f"{prediction_index}. {label_value}: {score_value:.4f}"
235
+ )
236
+
237
+ return "\n".join(result_lines)
238
+
239
+
240
+ def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str:
241
+
242
+ clap_pipeline = get_zero_shot_audio_pipeline()
243
+
244
+ label_list = [
245
+ label_item.strip()
246
+ for label_item in label_texts.split(",")
247
+ if label_item.strip()
248
+ ]
249
+ if not label_list:
250
+ return "Не задано ни одной текстовой метки для zero-shot классификации."
251
+
252
+ prediction_list = clap_pipeline(
253
+ audio_path,
254
+ candidate_labels=label_list,
255
+ )
256
+
257
+ result_lines = ["Zero-Shot Audio Classification (CLAP):"]
258
+ for prediction_index, prediction_item in enumerate(prediction_list, start=1):
259
+ label_value = prediction_item["label"]
260
+ score_value = prediction_item["score"]
261
+ result_lines.append(
262
+ f"{prediction_index}. {label_value}: {score_value:.4f}"
263
+ )
264
+
265
+ return "\n".join(result_lines)
266
+
267
+
268
+ def recognize_speech(audio_path: str, model_key: str) -> str:
269
+ speech_pipeline = get_audio_pipeline(model_key)
270
+
271
+ prediction_result = speech_pipeline(audio_path)
272
+
273
+ return prediction_result["text"]
274
+
275
+
276
+ def synthesize_speech(text_value: str, model_key: str):
277
+ if model_key == "Google TTS":
278
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
279
+ text_to_speech_engine = gTTS(text=text_value, lang="ru")
280
+ text_to_speech_engine.save(file_object.name)
281
+ return file_object.name
282
+ elif model_key == "mms":
283
+ model = VitsModel.from_pretrained("facebook/mms-tts-rus")
284
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
285
+
286
+ inputs = tokenizer(text_value, return_tensors="pt")
287
+ with torch.no_grad():
288
+ output = model(**inputs).waveform
289
+
290
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
291
+ sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
292
+ return f.name
293
+
294
+ raise ValueError(f"Неизвестная модель: {model_key}")
295
+
296
+
297
+
298
  def detect_objects_on_image(image_object, model_key: str):
299
  detector_pipeline = get_vision_pipeline(model_key)
300
  detection_results = detector_pipeline(image_object)
 
301
 
302
+ drawer_object = ImageDraw.Draw(image_object)
303
  for detection_item in detection_results:
304
  box_data = detection_item["box"]
305
  label_value = detection_item["label"]
306
  score_value = detection_item["score"]
307
+
308
  drawer_object.rectangle(
309
  [
310
  box_data["xmin"],
 
320
  f"{label_value}: {score_value:.2f}",
321
  fill="red",
322
  )
323
+
324
  return image_object
325
 
326
 
 
333
  def estimate_image_depth(image_object):
334
  depth_pipeline = get_vision_pipeline("depth_estimation")
335
  depth_output = depth_pipeline(image_object)
336
+
337
  predicted_depth_tensor = depth_output["predicted_depth"]
338
 
339
  if predicted_depth_tensor.ndim == 3:
 
351
  mode="bicubic",
352
  align_corners=False,
353
  )
354
+
355
  depth_array = resized_depth_tensor.squeeze().cpu().numpy()
356
  max_value = float(depth_array.max())
357
+
358
  if max_value <= 0.0:
359
  return Image.new("L", image_object.size, color=0)
360
 
 
372
  def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
373
  if image_object is None:
374
  return "Пожалуйста, сначала загрузите изображение."
375
+
376
  if not question_text.strip():
377
  return "Пожалуйста, введите вопрос об изображении."
378
 
379
  if model_key == "vqa_blip_base":
380
  blip_model, blip_processor = get_blip_vqa_components()
381
+
382
  inputs = blip_processor(
383
  images=image_object,
384
  text=question_text,
385
  return_tensors="pt",
386
  )
387
+
388
  with torch.no_grad():
389
  output_ids = blip_model.generate(**inputs)
390
+
391
  decoded_answers = blip_processor.batch_decode(
392
  output_ids,
393
  skip_special_tokens=True,
394
  )
395
  answer_text = decoded_answers[0] if decoded_answers else ""
396
+
397
  return answer_text or "Модель не смогла сгенерировать ответ."
398
 
399
  vqa_pipeline = get_vision_pipeline(model_key)
400
+
401
  vqa_result = vqa_pipeline(
402
  image=image_object,
403
  question=question_text,
404
  )
405
+
406
  top_item = vqa_result[0]
407
  answer_text = top_item["answer"]
408
  confidence_value = top_item["score"]
 
409
 
410
+ return f"{answer_text} (confidence: {confidence_value:.3f})"
411
 
412
  def perform_zero_shot_classification(
413
  image_object,
 
415
  clip_key: str,
416
  ) -> str:
417
  clip_model, clip_processor = get_clip_components(clip_key)
418
+
419
  class_list = [
420
  class_name.strip()
421
  for class_name in class_texts.split(",")
 
430
  return_tensors="pt",
431
  padding=True,
432
  )
433
+
434
  with torch.no_grad():
435
  clip_outputs = clip_model(**input_batch)
436
  logits_per_image = clip_outputs.logits_per_image
 
440
  for class_index, class_name in enumerate(class_list):
441
  probability_value = probability_tensor[0][class_index].item()
442
  result_lines.append(f"{class_name}: {probability_value:.4f}")
443
+
444
  return "\n".join(result_lines)
445
 
446
 
 
450
  clip_key: str,
451
  ) -> Tuple[str, Image.Image | None]:
452
  image_list = _normalize_gallery_images(gallery_value)
453
+
454
  if not image_list or not query_text.strip():
455
  return "Пожалуйста, загрузите изображения и введите запрос", None
456
 
457
  clip_model, clip_processor = get_clip_components(clip_key)
458
+
459
  image_inputs = clip_processor(
460
  images=image_list,
461
  return_tensors="pt",
 
463
  )
464
  with torch.no_grad():
465
  image_features = clip_model.get_image_features(**image_inputs)
466
+ image_features = image_features / image_features.norm(
467
+ dim=-1,
468
+ keepdim=True,
469
+ )
470
 
471
  text_inputs = clip_processor(
472
  text=[query_text],
 
475
  )
476
  with torch.no_grad():
477
  text_features = clip_model.get_text_features(**text_inputs)
478
+ text_features = text_features / text_features.norm(
479
+ dim=-1,
480
+ keepdim=True,
481
+ )
482
 
483
  similarity_tensor = image_features @ text_features.T
484
  best_index_tensor = similarity_tensor.argmax()
 
498
  ) -> Image.Image:
499
  if image_object is None:
500
  raise ValueError("Изображение не передано в segment_image_with_sam_points")
501
+
502
  if not point_coordinates_list:
503
  return Image.new("L", image_object.size, color=0)
504
 
505
  sam_model, sam_processor = get_sam_components()
506
+
507
  batched_points: List[List[List[int]]] = [point_coordinates_list]
508
  batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
509
 
 
513
  input_labels=batched_labels,
514
  return_tensors="pt",
515
  )
516
+
517
  with torch.no_grad():
518
  sam_outputs = sam_model(**sam_inputs, multimask_output=True)
519
 
 
522
  sam_inputs["original_sizes"].cpu(),
523
  sam_inputs["reshaped_input_sizes"].cpu(),
524
  )
525
+
526
  batch_masks_tensor = processed_masks_list[0]
527
+
528
  if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
529
  return Image.new("L", image_object.size, color=0)
530
 
531
  first_mask_tensor = batch_masks_tensor[0]
532
  mask_array = first_mask_tensor.numpy()
533
+
534
  binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
535
+
536
  mask_image = Image.fromarray(binary_mask_array, mode="L")
537
  return mask_image
538
 
539
 
540
  def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
541
+
542
  if image_object is None:
543
  return None
544
+
545
  coordinates_text_clean = coordinates_text.strip()
546
  if not coordinates_text_clean:
547
  return Image.new("L", image_object.size, color=0)
548
 
549
  point_coordinates_list: List[List[int]] = []
550
+
551
  for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
552
  raw_pair_clean = raw_pair.strip()
553
  if not raw_pair_clean:
554
  continue
555
+
556
  parts = raw_pair_clean.split(",")
557
  if len(parts) != 2:
558
  continue
559
+
560
  try:
561
  x_value = int(parts[0].strip())
562
  y_value = int(parts[1].strip())
563
  except ValueError:
564
  continue
565
+
566
  point_coordinates_list.append([x_value, y_value])
567
 
568
  if not point_coordinates_list:
 
574
  def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
575
  if not coordinates_text.strip():
576
  return []
577
+
578
  point_list: List[List[int]] = []
579
  for raw_pair in coordinates_text.split(";"):
580
  cleaned_pair = raw_pair.strip()
 
589
  except ValueError:
590
  continue
591
  point_list.append([x_value, y_value])
 
592
 
593
+ return point_list
594
 
595
  def build_interface():
596
  with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
597
  gr.Markdown("# AI модели")
598
+
599
  with gr.Tab("Детекция объектов"):
600
  gr.Markdown("## Детекция объектов")
601
  with gr.Row():
 
612
  value="object_detection_conditional_detr",
613
  info=(
614
  "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
615
+ "object_detection_yolos_small - hustvl/yolos-small"
616
  ),
617
  )
618
  object_detect_button = gr.Button("Применить")
619
+
620
+ object_output_image = gr.Image(
621
+ label="Результат",
622
+ )
623
+
624
  object_detect_button.click(
625
  fn=detect_objects_on_image,
626
  inputs=[object_input_image, object_model_selector],
 
635
  type="pil",
636
  )
637
  segmentation_button = gr.Button("Применить")
638
+
639
+ segmentation_output_image = gr.Image(
640
+ label="Маска",
641
+ )
642
+
643
  segmentation_button.click(
644
  fn=segment_image,
645
  inputs=segmentation_input_image,
 
649
  with gr.Tab("Глубина"):
650
  gr.Markdown("## Глубина (Depth Estimation)")
651
  with gr.Row():
652
+
653
  depth_input_image = gr.Image(
654
  label="Загрузите изображение",
655
  type="pil",
656
  )
657
  depth_button = gr.Button("Применить")
658
+
659
+ depth_output_image = gr.Image(
660
+ label="Глубины",
661
+ )
662
+
663
  depth_button.click(
664
  fn=estimate_image_depth,
665
  inputs=depth_input_image,
 
681
  label="Модель",
682
  value="captioning_blip_base",
683
  info=(
684
+ "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
685
  "captioning_blip_large - Salesforce/blip-image-captioning-large"
686
  ),
687
  )
688
  caption_button = gr.Button("Применить")
689
+
690
+ caption_output_text = gr.Textbox(
691
+ label="Описание изображения",
692
+ lines=3,
693
+ )
694
+
695
  caption_button.click(
696
  fn=generate_image_caption,
697
  inputs=[caption_input_image, caption_model_selector],
 
719
  value="vqa_blip_base",
720
  info=(
721
  "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
722
+ "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa"
723
  ),
724
  )
725
  vqa_button = gr.Button("Ответить на вопрос")
726
+
727
+ vqa_output_text = gr.Textbox(
728
+ label="Ответ",
729
+ lines=3,
730
+ )
731
+
732
  vqa_button.click(
733
  fn=answer_visual_question,
734
  inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
 
756
  value="clip_large_patch14",
757
  info=(
758
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
759
+ "clip_base_patch32 - openai/clip-vit-base-patch32"
760
  ),
761
  )
762
  zero_shot_button = gr.Button("Применить")
763
+
764
+ zero_shot_output_text = gr.Textbox(
765
+ label="Результаты",
766
+ lines=10,
767
+ )
768
+
769
  zero_shot_button.click(
770
  fn=perform_zero_shot_classification,
771
  inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
 
775
  with gr.Tab("Поиск изображений"):
776
  gr.Markdown("## Поиск изображений")
777
  with gr.Row():
778
+
779
  retrieval_dir = gr.File(
780
  label="Загрузите папку с изображениями",
781
  file_count="directory",
 
796
  value="clip_large_patch14",
797
  info=(
798
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
799
+ "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
800
  ),
801
  )
802
  retrieval_button = gr.Button("Поиск")
803
+
804
+ retrieval_output_text = gr.Textbox(
805
+ label="Результат",
806
+ )
807
+ retrieval_output_image = gr.Image(
808
+ label="Наиболее подходящее изображение",
809
+ )
810
+
811
  retrieval_button.click(
812
  fn=retrieve_best_image,
813
  inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
 
818
  gr.Markdown("### Задачи:")
819
  gr.Markdown(
820
  """
821
+ - Аудио: классификация, распознавание речи, синтез речи
822
+ - Компьютерное зрение: детекция объектов, сегментация, оценка глубины, генерация описаний изображений
823
  - Мультимодальные задачи: вопросы к изображению, zero-shot классификация изображений, поиск по изображениям по текстовому запросу
824
+ """
825
  )
 
826
  return demo_block
827
 
828
 
829
  if __name__ == "__main__":
830
  interface_block = build_interface()
831
  interface_block.launch(share=True)
832
+