AlsuGibadullina commited on
Commit
dd3b111
·
verified ·
1 Parent(s): 4bb5cf0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +943 -0
app.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from typing import List, Tuple, Any
3
+
4
+ import gradio as gr
5
+ import soundfile as sf
6
+ import torch
7
+ import torch.nn.functional as torch_functional
8
+ from gtts import gTTS
9
+ from PIL import Image, ImageDraw
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ CLIPModel,
13
+ CLIPProcessor,
14
+ SamModel,
15
+ SamProcessor,
16
+ VitsModel,
17
+ pipeline,
18
+ BlipForQuestionAnswering,
19
+ BlipProcessor,
20
+ )
21
+
22
+
23
+ MODEL_STORE = {}
24
+
25
+ def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
26
+ if not gallery_value:
27
+ return []
28
+
29
+ normalized_images: List[Image.Image] = []
30
+
31
+ for item in gallery_value:
32
+ if isinstance(item, Image.Image):
33
+ normalized_images.append(item)
34
+ continue
35
+
36
+ if isinstance(item, str):
37
+ try:
38
+ image_object = Image.open(item).convert("RGB")
39
+ normalized_images.append(image_object)
40
+ except Exception:
41
+ continue
42
+ continue
43
+
44
+ if isinstance(item, (list, tuple)) and item:
45
+ candidate = item[0]
46
+ if isinstance(candidate, Image.Image):
47
+ normalized_images.append(candidate)
48
+ continue
49
+
50
+ if isinstance(item, dict):
51
+ candidate = item.get("image") or item.get("value")
52
+ if isinstance(candidate, Image.Image):
53
+ normalized_images.append(candidate)
54
+ continue
55
+
56
+ return normalized_images
57
+
58
+ def get_audio_pipeline(model_key: str):
59
+ if model_key in MODEL_STORE:
60
+ return MODEL_STORE[model_key]
61
+
62
+ if model_key == "whisper":
63
+ audio_pipeline = pipeline(
64
+ task="automatic-speech-recognition",
65
+ model="distil-whisper/distil-small.en",
66
+ )
67
+ elif model_key == "wav2vec2":
68
+ audio_pipeline = pipeline(
69
+ task="automatic-speech-recognition",
70
+ model="openai/whisper-small",
71
+ )
72
+ elif model_key == "audio_classifier":
73
+ audio_pipeline = pipeline(
74
+ task="audio-classification",
75
+ model="MIT/ast-finetuned-audioset-10-10-0.4593",
76
+ )
77
+ elif model_key == "emotion_classifier":
78
+ audio_pipeline = pipeline(
79
+ task="audio-classification",
80
+ model="superb/hubert-large-superb-er",
81
+ )
82
+ else:
83
+ raise ValueError(f"Неизвестный тип аудио модели: {model_key}")
84
+
85
+ MODEL_STORE[model_key] = audio_pipeline
86
+ return audio_pipeline
87
+
88
+
89
+ def get_zero_shot_audio_pipeline():
90
+ if "audio_zero_shot_clap" not in MODEL_STORE:
91
+ zero_shot_pipeline = pipeline(
92
+ task="zero-shot-audio-classification",
93
+ model="laion/clap-htsat-unfused",
94
+ )
95
+ MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline
96
+ return MODEL_STORE["audio_zero_shot_clap"]
97
+
98
+
99
+ def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
100
+ if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
101
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
102
+ blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
103
+ MODEL_STORE["blip_vqa_model"] = blip_model
104
+ MODEL_STORE["blip_vqa_processor"] = blip_processor
105
+
106
+ blip_model = MODEL_STORE["blip_vqa_model"]
107
+ blip_processor = MODEL_STORE["blip_vqa_processor"]
108
+ return blip_model, blip_processor
109
+
110
+ def get_vision_pipeline(model_key: str):
111
+ if model_key in MODEL_STORE:
112
+ return MODEL_STORE[model_key]
113
+
114
+ if model_key == "object_detection_conditional_detr":
115
+ vision_pipeline = pipeline(
116
+ task="object-detection",
117
+ model="microsoft/conditional-detr-resnet-50",
118
+ )
119
+ elif model_key == "object_detection_yolos_small":
120
+ vision_pipeline = pipeline(
121
+ task="object-detection",
122
+ model="hustvl/yolos-small",
123
+ )
124
+
125
+ elif model_key == "segmentation":
126
+ vision_pipeline = pipeline(
127
+ task="image-segmentation",
128
+ model="nvidia/segformer-b0-finetuned-ade-512-512",
129
+ )
130
+
131
+ elif model_key == "depth_estimation":
132
+ vision_pipeline = pipeline(
133
+ task="depth-estimation",
134
+ model="Intel/dpt-hybrid-midas",
135
+ )
136
+
137
+ elif model_key == "captioning_blip_base":
138
+ vision_pipeline = pipeline(
139
+ task="image-to-text",
140
+ model="Salesforce/blip-image-captioning-base",
141
+ )
142
+ elif model_key == "captioning_blip_large":
143
+ vision_pipeline = pipeline(
144
+ task="image-to-text",
145
+ model="Salesforce/blip-image-captioning-large",
146
+ )
147
+
148
+ elif model_key == "vqa_blip_base":
149
+ vision_pipeline = pipeline(
150
+ task="visual-question-answering",
151
+ model="Salesforce/blip-vqa-base",
152
+ )
153
+ elif model_key == "vqa_vilt_b32":
154
+ vision_pipeline = pipeline(
155
+ task="visual-question-answering",
156
+ model="dandelin/vilt-b32-finetuned-vqa",
157
+ )
158
+
159
+ else:
160
+ raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
161
+
162
+ MODEL_STORE[model_key] = vision_pipeline
163
+ return vision_pipeline
164
+
165
+
166
+ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
167
+ model_store_key_model = f"clip_model_{clip_key}"
168
+ model_store_key_processor = f"clip_processor_{clip_key}"
169
+
170
+ if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE:
171
+ if clip_key == "clip_large_patch14":
172
+ clip_name = "openai/clip-vit-large-patch14"
173
+ elif clip_key == "clip_base_patch32":
174
+ clip_name = "openai/clip-vit-base-patch32"
175
+ else:
176
+ raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}")
177
+
178
+ clip_model = CLIPModel.from_pretrained(clip_name)
179
+ clip_processor = CLIPProcessor.from_pretrained(clip_name)
180
+
181
+ MODEL_STORE[model_store_key_model] = clip_model
182
+ MODEL_STORE[model_store_key_processor] = clip_processor
183
+
184
+ clip_model = MODEL_STORE[model_store_key_model]
185
+ clip_processor = MODEL_STORE[model_store_key_processor]
186
+ return clip_model, clip_processor
187
+
188
+
189
+ def get_silero_tts_model():
190
+ if "silero_tts_model" not in MODEL_STORE:
191
+ silero_model, _ = torch.hub.load(
192
+ repo_or_dir="snakers4/silero-models",
193
+ model="silero_tts",
194
+ language="ru",
195
+ speaker="ru_v3",
196
+ )
197
+ MODEL_STORE["silero_tts_model"] = silero_model
198
+ return MODEL_STORE["silero_tts_model"]
199
+
200
+
201
+ def get_mms_tts_components():
202
+ if "mms_tts_pipeline" not in MODEL_STORE:
203
+ tts_pipeline = pipeline(
204
+ task="text-to-speech",
205
+ model="facebook/mms-tts-rus",
206
+ )
207
+ MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
208
+
209
+ return MODEL_STORE["mms_tts_pipeline"]
210
+
211
+
212
+ def get_sam_components() -> Tuple[SamModel, SamProcessor]:
213
+ if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
214
+ sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
215
+ sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
216
+ MODEL_STORE["sam_model"] = sam_model
217
+ MODEL_STORE["sam_processor"] = sam_processor
218
+
219
+ sam_model = MODEL_STORE["sam_model"]
220
+ sam_processor = MODEL_STORE["sam_processor"]
221
+ return sam_model, sam_processor
222
+
223
+
224
+
225
+ def classify_audio_file(audio_path: str, model_key: str) -> str:
226
+ audio_classifier = get_audio_pipeline(model_key)
227
+ prediction_list = audio_classifier(audio_path)
228
+
229
+ result_lines = ["Топ-5 предсказаний:"]
230
+ for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1):
231
+ label_value = prediction_item["label"]
232
+ score_value = prediction_item["score"]
233
+ result_lines.append(
234
+ f"{prediction_index}. {label_value}: {score_value:.4f}"
235
+ )
236
+
237
+ return "\n".join(result_lines)
238
+
239
+
240
+ def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str:
241
+
242
+ clap_pipeline = get_zero_shot_audio_pipeline()
243
+
244
+ label_list = [
245
+ label_item.strip()
246
+ for label_item in label_texts.split(",")
247
+ if label_item.strip()
248
+ ]
249
+ if not label_list:
250
+ return "Не задано ни одной текстовой метки для zero-shot классификации."
251
+
252
+ prediction_list = clap_pipeline(
253
+ audio_path,
254
+ candidate_labels=label_list,
255
+ )
256
+
257
+ result_lines = ["Zero-Shot Audio Classification (CLAP):"]
258
+ for prediction_index, prediction_item in enumerate(prediction_list, start=1):
259
+ label_value = prediction_item["label"]
260
+ score_value = prediction_item["score"]
261
+ result_lines.append(
262
+ f"{prediction_index}. {label_value}: {score_value:.4f}"
263
+ )
264
+
265
+ return "\n".join(result_lines)
266
+
267
+
268
+ def recognize_speech(audio_path: str, model_key: str) -> str:
269
+ speech_pipeline = get_audio_pipeline(model_key)
270
+
271
+ prediction_result = speech_pipeline(audio_path)
272
+
273
+ return prediction_result["text"]
274
+
275
+
276
+ def synthesize_speech(text_value: str, model_key: str):
277
+ if model_key == "Google TTS":
278
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
279
+ text_to_speech_engine = gTTS(text=text_value, lang="ru")
280
+ text_to_speech_engine.save(file_object.name)
281
+ return file_object.name
282
+ elif model_key == "mms":
283
+ model = VitsModel.from_pretrained("facebook/mms-tts-rus")
284
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
285
+
286
+ inputs = tokenizer(text_value, return_tensors="pt")
287
+ with torch.no_grad():
288
+ output = model(**inputs).waveform
289
+
290
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
291
+ sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
292
+ return f.name
293
+
294
+ raise ValueError(f"Неизвестная модель: {model_key}")
295
+
296
+
297
+
298
+ def detect_objects_on_image(image_object, model_key: str):
299
+ detector_pipeline = get_vision_pipeline(model_key)
300
+ detection_results = detector_pipeline(image_object)
301
+
302
+ drawer_object = ImageDraw.Draw(image_object)
303
+ for detection_item in detection_results:
304
+ box_data = detection_item["box"]
305
+ label_value = detection_item["label"]
306
+ score_value = detection_item["score"]
307
+
308
+ drawer_object.rectangle(
309
+ [
310
+ box_data["xmin"],
311
+ box_data["ymin"],
312
+ box_data["xmax"],
313
+ box_data["ymax"],
314
+ ],
315
+ outline="red",
316
+ width=3,
317
+ )
318
+ drawer_object.text(
319
+ (box_data["xmin"], box_data["ymin"]),
320
+ f"{label_value}: {score_value:.2f}",
321
+ fill="red",
322
+ )
323
+
324
+ return image_object
325
+
326
+
327
+ def segment_image(image_object):
328
+ segmentation_pipeline = get_vision_pipeline("segmentation")
329
+ segmentation_results = segmentation_pipeline(image_object)
330
+ return segmentation_results[0]["mask"]
331
+
332
+
333
+ def estimate_image_depth(image_object):
334
+ depth_pipeline = get_vision_pipeline("depth_estimation")
335
+ depth_output = depth_pipeline(image_object)
336
+
337
+ predicted_depth_tensor = depth_output["predicted_depth"]
338
+
339
+ if predicted_depth_tensor.ndim == 3:
340
+ predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1)
341
+ elif predicted_depth_tensor.ndim == 2:
342
+ predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0)
343
+ else:
344
+ raise ValueError(
345
+ f"Неожиданная размерность predicted_depth: {predicted_depth_tensor.shape}"
346
+ )
347
+
348
+ resized_depth_tensor = torch_functional.interpolate(
349
+ predicted_depth_tensor,
350
+ size=image_object.size[::-1],
351
+ mode="bicubic",
352
+ align_corners=False,
353
+ )
354
+
355
+ depth_array = resized_depth_tensor.squeeze().cpu().numpy()
356
+ max_value = float(depth_array.max())
357
+
358
+ if max_value <= 0.0:
359
+ return Image.new("L", image_object.size, color=0)
360
+
361
+ normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8")
362
+ depth_image = Image.fromarray(normalized_depth_array, mode="L")
363
+ return depth_image
364
+
365
+
366
+ def generate_image_caption(image_object, model_key: str) -> str:
367
+ caption_pipeline = get_vision_pipeline(model_key)
368
+ caption_result = caption_pipeline(image_object)
369
+ return caption_result[0]["generated_text"]
370
+
371
+
372
+ def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
373
+ if image_object is None:
374
+ return "Пожалуйста, сначала загрузите изображение."
375
+
376
+ if not question_text.strip():
377
+ return "Пожалуйста, введите вопрос об изображении."
378
+
379
+ if model_key == "vqa_blip_base":
380
+ blip_model, blip_processor = get_blip_vqa_components()
381
+
382
+ inputs = blip_processor(
383
+ images=image_object,
384
+ text=question_text,
385
+ return_tensors="pt",
386
+ )
387
+
388
+ with torch.no_grad():
389
+ output_ids = blip_model.generate(**inputs)
390
+
391
+ decoded_answers = blip_processor.batch_decode(
392
+ output_ids,
393
+ skip_special_tokens=True,
394
+ )
395
+ answer_text = decoded_answers[0] if decoded_answers else ""
396
+
397
+ return answer_text or "Модель не смогла сгенерировать ответ."
398
+
399
+ vqa_pipeline = get_vision_pipeline(model_key)
400
+
401
+ vqa_result = vqa_pipeline(
402
+ image=image_object,
403
+ question=question_text,
404
+ )
405
+
406
+ top_item = vqa_result[0]
407
+ answer_text = top_item["answer"]
408
+ confidence_value = top_item["score"]
409
+
410
+ return f"{answer_text} (confidence: {confidence_value:.3f})"
411
+
412
+ def perform_zero_shot_classification(
413
+ image_object,
414
+ class_texts: str,
415
+ clip_key: str,
416
+ ) -> str:
417
+ clip_model, clip_processor = get_clip_components(clip_key)
418
+
419
+ class_list = [
420
+ class_name.strip()
421
+ for class_name in class_texts.split(",")
422
+ if class_name.strip()
423
+ ]
424
+ if not class_list:
425
+ return "Не задано ни одного класса для классификации."
426
+
427
+ input_batch = clip_processor(
428
+ text=class_list,
429
+ images=image_object,
430
+ return_tensors="pt",
431
+ padding=True,
432
+ )
433
+
434
+ with torch.no_grad():
435
+ clip_outputs = clip_model(**input_batch)
436
+ logits_per_image = clip_outputs.logits_per_image
437
+ probability_tensor = logits_per_image.softmax(dim=1)
438
+
439
+ result_lines = ["Zero-Shot Classification Results:"]
440
+ for class_index, class_name in enumerate(class_list):
441
+ probability_value = probability_tensor[0][class_index].item()
442
+ result_lines.append(f"{class_name}: {probability_value:.4f}")
443
+
444
+ return "\n".join(result_lines)
445
+
446
+
447
+ def retrieve_best_image(
448
+ gallery_value: Any,
449
+ query_text: str,
450
+ clip_key: str,
451
+ ) -> Tuple[str, Image.Image | None]:
452
+ image_list = _normalize_gallery_images(gallery_value)
453
+
454
+ if not image_list or not query_text.strip():
455
+ return "Пожалуйста, загрузите изображения и введите запрос", None
456
+
457
+ clip_model, clip_processor = get_clip_components(clip_key)
458
+
459
+ image_inputs = clip_processor(
460
+ images=image_list,
461
+ return_tensors="pt",
462
+ padding=True,
463
+ )
464
+ with torch.no_grad():
465
+ image_features = clip_model.get_image_features(**image_inputs)
466
+ image_features = image_features / image_features.norm(
467
+ dim=-1,
468
+ keepdim=True,
469
+ )
470
+
471
+ text_inputs = clip_processor(
472
+ text=[query_text],
473
+ return_tensors="pt",
474
+ padding=True,
475
+ )
476
+ with torch.no_grad():
477
+ text_features = clip_model.get_text_features(**text_inputs)
478
+ text_features = text_features / text_features.norm(
479
+ dim=-1,
480
+ keepdim=True,
481
+ )
482
+
483
+ similarity_tensor = image_features @ text_features.T
484
+ best_index_tensor = similarity_tensor.argmax()
485
+ best_index_value = best_index_tensor.item()
486
+ best_score_value = similarity_tensor[best_index_value].item()
487
+
488
+ description_text = (
489
+ f"Лучшее изображение: #{best_index_value + 1} "
490
+ f"(схожесть: {best_score_value:.4f})"
491
+ )
492
+ return description_text, image_list[best_index_value]
493
+
494
+
495
+ def segment_image_with_sam_points(
496
+ image_object,
497
+ point_coordinates_list: List[List[int]],
498
+ ) -> Image.Image:
499
+ if image_object is None:
500
+ raise ValueError("Изображение не передано в segment_image_with_sam_points")
501
+
502
+ if not point_coordinates_list:
503
+ return Image.new("L", image_object.size, color=0)
504
+
505
+ sam_model, sam_processor = get_sam_components()
506
+
507
+ batched_points: List[List[List[int]]] = [point_coordinates_list]
508
+ batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
509
+
510
+ sam_inputs = sam_processor(
511
+ image=image_object,
512
+ input_points=batched_points,
513
+ input_labels=batched_labels,
514
+ return_tensors="pt",
515
+ )
516
+
517
+ with torch.no_grad():
518
+ sam_outputs = sam_model(**sam_inputs, multimask_output=True)
519
+
520
+ processed_masks_list = sam_processor.image_processor.post_process_masks(
521
+ sam_outputs.pred_masks.squeeze(1).cpu(),
522
+ sam_inputs["original_sizes"].cpu(),
523
+ sam_inputs["reshaped_input_sizes"].cpu(),
524
+ )
525
+
526
+ batch_masks_tensor = processed_masks_list[0]
527
+
528
+ if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
529
+ return Image.new("L", image_object.size, color=0)
530
+
531
+ first_mask_tensor = batch_masks_tensor[0]
532
+ mask_array = first_mask_tensor.numpy()
533
+
534
+ binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
535
+
536
+ mask_image = Image.fromarray(binary_mask_array, mode="L")
537
+ return mask_image
538
+
539
+
540
+ def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
541
+
542
+ if image_object is None:
543
+ return None
544
+
545
+ coordinates_text_clean = coordinates_text.strip()
546
+ if not coordinates_text_clean:
547
+ return Image.new("L", image_object.size, color=0)
548
+
549
+ point_coordinates_list: List[List[int]] = []
550
+
551
+ for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
552
+ raw_pair_clean = raw_pair.strip()
553
+ if not raw_pair_clean:
554
+ continue
555
+
556
+ parts = raw_pair_clean.split(",")
557
+ if len(parts) != 2:
558
+ continue
559
+
560
+ try:
561
+ x_value = int(parts[0].strip())
562
+ y_value = int(parts[1].strip())
563
+ except ValueError:
564
+ continue
565
+
566
+ point_coordinates_list.append([x_value, y_value])
567
+
568
+ if not point_coordinates_list:
569
+ return Image.new("L", image_object.size, color=0)
570
+
571
+ return segment_image_with_sam_points(image_object, point_coordinates_list)
572
+
573
+
574
+ def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
575
+ if not coordinates_text.strip():
576
+ return []
577
+
578
+ point_list: List[List[int]] = []
579
+ for raw_pair in coordinates_text.split(";"):
580
+ cleaned_pair = raw_pair.strip()
581
+ if not cleaned_pair:
582
+ continue
583
+ coordinate_parts = cleaned_pair.split(",")
584
+ if len(coordinate_parts) != 2:
585
+ continue
586
+ try:
587
+ x_value = int(coordinate_parts[0].strip())
588
+ y_value = int(coordinate_parts[1].strip())
589
+ except ValueError:
590
+ continue
591
+ point_list.append([x_value, y_value])
592
+
593
+ return point_list
594
+
595
+ def build_interface():
596
+ with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
597
+ gr.Markdown("# AI модели")
598
+
599
+ with gr.Tab("Классификация аудио"):
600
+ gr.Markdown("## Классификация аудио")
601
+ with gr.Row():
602
+ audio_input_component = gr.Audio(
603
+ label="Загрузите аудиофайл",
604
+ type="filepath",
605
+ )
606
+ audio_model_selector = gr.Dropdown(
607
+ choices=["audio_classifier", "emotion_classifier"],
608
+ label="Выберите модель",
609
+ value="audio_classifier",
610
+ info=(
611
+ "audio_classifier - общая классификация (курс)"
612
+ "emotion_classifier - эмоции в речи "
613
+ ),
614
+ )
615
+ audio_classify_button = gr.Button("Применить")
616
+
617
+ audio_output_component = gr.Textbox(
618
+ label="Результаты классификации",
619
+ lines=10,
620
+ )
621
+
622
+ audio_classify_button.click(
623
+ fn=classify_audio_file,
624
+ inputs=[audio_input_component, audio_model_selector],
625
+ outputs=audio_output_component,
626
+ )
627
+
628
+ with gr.Tab("Zero-Shot аудио"):
629
+ gr.Markdown("## Zero-Shot аудио классификатор")
630
+ with gr.Row():
631
+ clap_audio_input_component = gr.Audio(
632
+ label="Загрузите аудиофайл",
633
+ type="filepath",
634
+ )
635
+ clap_label_texts_component = gr.Textbox(
636
+ label="Кандидатные метки (через запятую)",
637
+ placeholder="лай собаки, шум дождя, музыка, разговор",
638
+ lines=2,
639
+ )
640
+ clap_button = gr.Button("Применить")
641
+
642
+ clap_output_component = gr.Textbox(
643
+ label="Результаты zero-shot классификации",
644
+ lines=10,
645
+ )
646
+
647
+ clap_button.click(
648
+ fn=classify_audio_zero_shot_clap,
649
+ inputs=[clap_audio_input_component, clap_label_texts_component],
650
+ outputs=clap_output_component,
651
+ )
652
+
653
+ with gr.Tab("Распознавание речи"):
654
+ gr.Markdown("## Распознавание реч")
655
+ with gr.Row():
656
+ asr_audio_input_component = gr.Audio(
657
+ label="Загрузите аудио с речью",
658
+ type="filepath",
659
+ )
660
+ asr_model_selector = gr.Dropdown(
661
+ choices=["whisper", "wav2vec2"],
662
+ label="Выберите модель",
663
+ value="whisper",
664
+ info=(
665
+ "whisper - distil-whisper/distil-small.en (курс),\n"
666
+ "wav2vec2 - openai/whisper-small"
667
+ ),
668
+ )
669
+ asr_button = gr.Button("Применить")
670
+
671
+ asr_output_component = gr.Textbox(
672
+ label="Транскрипция",
673
+ lines=5,
674
+ )
675
+
676
+ asr_button.click(
677
+ fn=recognize_speech,
678
+ inputs=[asr_audio_input_component, asr_model_selector],
679
+ outputs=asr_output_component,
680
+ )
681
+ with gr.Tab("Синтез речи"):
682
+ gr.Markdown("## Text-to-Speech")
683
+ with gr.Row():
684
+ tts_text_component = gr.Textbox(
685
+ label="Введите текст для синтеза",
686
+ placeholder="Введите текст на русском или английском языке...",
687
+ lines=3,
688
+ )
689
+ tts_model_selector = gr.Dropdown(
690
+ choices=["mms", "Google TTS"],
691
+ label="Выберите модель",
692
+ value="mms",
693
+ info=(
694
+ "facebook/mms-tts-rus\n"
695
+ "Google TTS"
696
+ ),
697
+ )
698
+ tts_button = gr.Button("Применить")
699
+
700
+ tts_audio_output_component = gr.Audio(
701
+ label="Синтезированная речь",
702
+ type="filepath",
703
+ )
704
+
705
+ tts_button.click(
706
+ fn=synthesize_speech,
707
+ inputs=[tts_text_component, tts_model_selector],
708
+ outputs=tts_audio_output_component,
709
+ )
710
+
711
+ with gr.Tab("Детекция объектов"):
712
+ gr.Markdown("## Детекция объектов")
713
+ with gr.Row():
714
+ object_input_image = gr.Image(
715
+ label="Загрузите изображение",
716
+ type="pil",
717
+ )
718
+ object_model_selector = gr.Dropdown(
719
+ choices=[
720
+ "object_detection_conditional_detr",
721
+ "object_detection_yolos_small",
722
+ ],
723
+ label="Модель",
724
+ value="object_detection_conditional_detr",
725
+ info=(
726
+ "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
727
+ "object_detection_yolos_small - hustvl/yolos-small"
728
+ ),
729
+ )
730
+ object_detect_button = gr.Button("Применить")
731
+
732
+ object_output_image = gr.Image(
733
+ label="Результат",
734
+ )
735
+
736
+ object_detect_button.click(
737
+ fn=detect_objects_on_image,
738
+ inputs=[object_input_image, object_model_selector],
739
+ outputs=object_output_image,
740
+ )
741
+
742
+ with gr.Tab("Сегментация"):
743
+ gr.Markdown("## Сегментация")
744
+ with gr.Row():
745
+ segmentation_input_image = gr.Image(
746
+ label="Загрузите изображение",
747
+ type="pil",
748
+ )
749
+ segmentation_button = gr.Button("Применить")
750
+
751
+ segmentation_output_image = gr.Image(
752
+ label="Маска",
753
+ )
754
+
755
+ segmentation_button.click(
756
+ fn=segment_image,
757
+ inputs=segmentation_input_image,
758
+ outputs=segmentation_output_image,
759
+ )
760
+
761
+ with gr.Tab("Глубина"):
762
+ gr.Markdown("## Глубина (Depth Estimation)")
763
+ with gr.Row():
764
+
765
+ depth_input_image = gr.Image(
766
+ label="Загрузите изображение",
767
+ type="pil",
768
+ )
769
+ depth_button = gr.Button("Применить")
770
+
771
+ depth_output_image = gr.Image(
772
+ label="Глубины",
773
+ )
774
+
775
+ depth_button.click(
776
+ fn=estimate_image_depth,
777
+ inputs=depth_input_image,
778
+ outputs=depth_output_image,
779
+ )
780
+
781
+ with gr.Tab("Описание изображений"):
782
+ gr.Markdown("## Описание изображений")
783
+ with gr.Row():
784
+ caption_input_image = gr.Image(
785
+ label="Загрузите изображение",
786
+ type="pil",
787
+ )
788
+ caption_model_selector = gr.Dropdown(
789
+ choices=[
790
+ "captioning_blip_base",
791
+ "captioning_blip_large",
792
+ ],
793
+ label="Модель",
794
+ value="captioning_blip_base",
795
+ info=(
796
+ "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
797
+ "captioning_blip_large - Salesforce/blip-image-captioning-large"
798
+ ),
799
+ )
800
+ caption_button = gr.Button("Применить")
801
+
802
+ caption_output_text = gr.Textbox(
803
+ label="Описание изображения",
804
+ lines=3,
805
+ )
806
+
807
+ caption_button.click(
808
+ fn=generate_image_caption,
809
+ inputs=[caption_input_image, caption_model_selector],
810
+ outputs=caption_output_text,
811
+ )
812
+
813
+ with gr.Tab("Визуальные вопросы"):
814
+ gr.Markdown("## Visual Question Answering")
815
+ with gr.Row():
816
+ vqa_input_image = gr.Image(
817
+ label="Загрузите изображение",
818
+ type="pil",
819
+ )
820
+ vqa_question_text = gr.Textbox(
821
+ label="Вопрос",
822
+ placeholder="Вопрос",
823
+ lines=2,
824
+ )
825
+ vqa_model_selector = gr.Dropdown(
826
+ choices=[
827
+ "vqa_blip_base",
828
+ "vqa_vilt_b32",
829
+ ],
830
+ label="Модель",
831
+ value="vqa_blip_base",
832
+ info=(
833
+ "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
834
+ "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa"
835
+ ),
836
+ )
837
+ vqa_button = gr.Button("Ответить на вопрос")
838
+
839
+ vqa_output_text = gr.Textbox(
840
+ label="Ответ",
841
+ lines=3,
842
+ )
843
+
844
+ vqa_button.click(
845
+ fn=answer_visual_question,
846
+ inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
847
+ outputs=vqa_output_text,
848
+ )
849
+
850
+ with gr.Tab("Zero-Shot классификация"):
851
+ gr.Markdown("## Zero-Shot классификация")
852
+ with gr.Row():
853
+ zero_shot_input_image = gr.Image(
854
+ label="Загрузите изображение",
855
+ type="pil",
856
+ )
857
+ zero_shot_classes_text = gr.Textbox(
858
+ label="Классы для классификации (через запятую)",
859
+ placeholder="человек, машина, дерево, здание, животное",
860
+ lines=2,
861
+ )
862
+ clip_model_selector = gr.Dropdown(
863
+ choices=[
864
+ "clip_large_patch14",
865
+ "clip_base_patch32",
866
+ ],
867
+ label="модель",
868
+ value="clip_large_patch14",
869
+ info=(
870
+ "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
871
+ "clip_base_patch32 - openai/clip-vit-base-patch32"
872
+ ),
873
+ )
874
+ zero_shot_button = gr.Button("Применить")
875
+
876
+ zero_shot_output_text = gr.Textbox(
877
+ label="Результаты",
878
+ lines=10,
879
+ )
880
+
881
+ zero_shot_button.click(
882
+ fn=perform_zero_shot_classification,
883
+ inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
884
+ outputs=zero_shot_output_text,
885
+ )
886
+
887
+ with gr.Tab("Поиск изображений"):
888
+ gr.Markdown("## Поиск изображений")
889
+ with gr.Row():
890
+
891
+ retrieval_dir = gr.File(
892
+ label="Загрузите папку с изображениями",
893
+ file_count="directory",
894
+ file_types=["image"],
895
+ type="filepath",
896
+ )
897
+ retrieval_query_text = gr.Textbox(
898
+ label="Текстовый запрос",
899
+ placeholder="описание того, что вы ищете...",
900
+ lines=2,
901
+ )
902
+ retrieval_clip_selector = gr.Dropdown(
903
+ choices=[
904
+ "clip_large_patch14",
905
+ "clip_base_patch32",
906
+ ],
907
+ label="модель",
908
+ value="clip_large_patch14",
909
+ info=(
910
+ "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
911
+ "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
912
+ ),
913
+ )
914
+ retrieval_button = gr.Button("Поиск")
915
+
916
+ retrieval_output_text = gr.Textbox(
917
+ label="Результат",
918
+ )
919
+ retrieval_output_image = gr.Image(
920
+ label="Наиболее подходящее изображение",
921
+ )
922
+
923
+ retrieval_button.click(
924
+ fn=retrieve_best_image,
925
+ inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
926
+ outputs=[retrieval_output_text, retrieval_output_image],
927
+ )
928
+
929
+ gr.Markdown("---")
930
+ gr.Markdown("### Задачи:")
931
+ gr.Markdown(
932
+ """
933
+ - Аудио: классификация, распознавание речи, синтез речи
934
+ - Компьютерное зрение: детекция объектов, сегментация, оценка глубины, генерация описаний изображений
935
+ - Мультимодальные задачи: вопросы к изображению, zero-shot классификация изображений, поиск по изображениям по текстовому запросу
936
+ """
937
+ )
938
+ return demo_block
939
+
940
+
941
+ if __name__ == "__main__":
942
+ interface_block = build_interface()
943
+ interface_block.launch(share=True)