import gradio as gr import torch import nemo.collections.asr as nemo_asr from pydub import AudioSegment import os import logging from typing import Optional import threading # Konfiguracja logowania logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TimeoutException(Exception): """Wyjątek dla timeoutu transkrypcji.""" pass class TranscriptionService: """Klasa do zarządzania modelami ASR na różnych urządzeniach.""" def __init__(self): # Usunięcie wstępnego ładowania. Modele będą ładowane dynamicznie self.models = { 'mps': None, 'cuda': None, 'cpu': None } self.model_name = "nvidia/parakeet-tdt-0.6b-v3" self.timeout_seconds = 300 # 5 minut timeout self.chunk_length_minutes = 5 # Dziel pliki dłuższe niż 5 minut def _get_optimal_device(self, audio_length_minutes: float) -> str: """ Wybiera optymalne urządzenie na podstawie długości audio i dostępności sprzętu. """ if torch.cuda.is_available(): logger.info("Używam CUDA (GPU) - najlepsza wydajność") return "cuda" if torch.backends.mps.is_available() and audio_length_minutes <= 8: logger.info(f"Plik krótki ({audio_length_minutes:.2f} min) - używam MPS") return "mps" if torch.backends.mps.is_available() and audio_length_minutes > 8: logger.info(f"Plik długi ({audio_length_minutes:.2f} min) - używam CPU zamiast MPS") else: logger.info("Brak GPU/MPS - używam CPU") return "cpu" def _load_model(self, device: str) -> nemo_asr.models.ASRModel: """ Ładuje model na określonym urządzeniu (z cache'owaniem). """ if self.models[device] is None: logger.info(f"Ładowanie modelu na {device.upper()}...") try: model = nemo_asr.models.ASRModel.from_pretrained( model_name=self.model_name ) self.models[device] = model.to(device) logger.info("Model załadowany pomyślnie") except Exception as e: logger.error(f"Błąd ładowania modelu na {device}: {e}") raise return self.models[device] def _split_audio(self, audio_file_path: str, chunk_length_ms: int) -> list: """ Dzieli długi plik audio na mniejsze fragmenty. """ audio = AudioSegment.from_file(audio_file_path) chunks = [] for i, chunk in enumerate(audio[::chunk_length_ms]): chunk_path = f"/tmp/temp_chunk_{i}.wav" chunk.export(chunk_path, format="wav") chunks.append(chunk_path) return chunks def _transcribe_with_timeout(self, audio_file_path: str, device: str) -> str: """ Wykonuje transkrypcję z timeoutem. """ # Ładowanie modelu przeniesione tutaj model = self._load_model(device) result = {"text": None, "error": None} def transcribe_worker(): try: transcriptions = model.transcribe([audio_file_path]) if transcriptions and len(transcriptions) > 0: result["text"] = transcriptions[0].text else: result["error"] = "Model nie zwrócił żadnej transkrypcji." except Exception as e: result["error"] = f"Błąd transkrypcji: {str(e)}" thread = threading.Thread(target=transcribe_worker) thread.start() thread.join(timeout=self.timeout_seconds) if thread.is_alive(): raise TimeoutException(f"Transkrypcja przekroczyła limit {self.timeout_seconds} sekund") if result["error"]: raise Exception(result["error"]) return result["text"] def transcribe(self, audio_file_path: str, progress=None) -> str: """ Główna funkcja transkrypcji. """ if not audio_file_path or not os.path.exists(audio_file_path): return "Błąd: Nie wybrano pliku audio lub plik nie istnieje." temp_files = [] try: logger.info(f"Analizuję plik: {os.path.basename(audio_file_path)}") audio = AudioSegment.from_file(audio_file_path) length_minutes = len(audio) / (1000 * 60) logger.info(f"Długość pliku: {length_minutes:.2f} minut") device = self._get_optimal_device(length_minutes) if length_minutes > self.chunk_length_minutes: if progress: progress(0.1, desc="Dzielę plik na fragmenty...") logger.info(f"Dzielę plik na fragmenty po {self.chunk_length_minutes} minut") chunk_length_ms = self.chunk_length_minutes * 60 * 1000 chunks = self._split_audio(audio_file_path, chunk_length_ms) temp_files.extend(chunks) logger.info(f"Transkrypcja {len(chunks)} fragmentów...") all_transcriptions = [] for i, chunk_path in enumerate(chunks): if progress: progress_value = 0.1 + (0.8 * (i + 1) / len(chunks)) progress(progress_value, desc=f"Transkrypcja fragmentu {i+1}/{len(chunks)}...") logger.info(f"Transkrypcja fragmentu {i+1}/{len(chunks)}...") chunk_text = self._transcribe_with_timeout(chunk_path, device) all_transcriptions.append(chunk_text) logger.info(f"Fragment {i+1} przetworzony") result_text = " ".join(all_transcriptions) else: if progress: progress(0.5, desc="Rozpoczynam transkrypcję...") logger.info("Rozpoczynam transkrypcję...") result_text = self._transcribe_with_timeout(audio_file_path, device) logger.info("Transkrypcja zakończona pomyślnie") return result_text except FileNotFoundError: error_msg = f"Błąd: Plik {audio_file_path} nie został znaleziony." logger.error(error_msg) return error_msg except TimeoutException as e: error_msg = f"Timeout: {str(e)}" logger.error(error_msg) return error_msg except Exception as e: error_msg = f"Wystąpił błąd podczas transkrypcji: {str(e)}" logger.error(error_msg) return error_msg finally: for temp_file in temp_files: try: os.remove(temp_file) except: pass # Globalna instancja serwisu transcription_service = TranscriptionService() def transcribe_audio_wrapper(audio_file_path: str, progress=gr.Progress()) -> str: """Wrapper dla Gradio - izoluje logikę od interfejsu.""" return transcription_service.transcribe(audio_file_path, progress) def create_interface() -> gr.Interface: """Tworzy i konfiguruje interfejs Gradio.""" return gr.Interface( fn=transcribe_audio_wrapper, inputs=gr.Audio( type="filepath", label="Wybierz plik audio", format="wav" # Opcjonalnie: wymuś konkretny format ), outputs=gr.Textbox( lines=10, label="Wynik transkrypcji", placeholder="Tutaj pojawi się transkrypcja..." ), title="🎤 Transkrypcja mowy na tekst", description=""" Wybierz plik audio, a model NVIDIA Parakeet wykona transkrypcję. **Obsługiwane formaty:** WAV, MP3, FLAC, M4A i inne **Optymalizacja urządzenia:** Automatyczny wybór GPU/CPU """, examples=None, cache_examples=False, flagging_options=None, allow_flagging="never" ) if __name__ == "__main__": logger.info("=== Informacje o systemie ===") logger.info(f"CUDA dostępne: {torch.cuda.is_available()}") logger.info(f"MPS dostępne: {torch.backends.mps.is_available()}") if torch.cuda.is_available(): logger.info(f"GPU: {torch.cuda.get_device_name(0)}") interface = create_interface() interface.launch( server_name="0.0.0.0", # Zmieniono z 127.0.0.1 server_port=7860, share=False, debug=False, show_error=True )