from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel import torch import torchaudio from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, pipeline ) from TTS.api import TTS import io import numpy as np import soundfile as sf import tempfile import os app = FastAPI(title="Asistente de Voz API") # ============================================ # CARGAR MODELOS AL INICIAR # ============================================ print("🔄 Cargando modelos...") # 1. WHISPER (Speech-to-Text) print("📝 Cargando Whisper...") whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") whisper_model.eval() # 2. MODELO DE LENGUAJE (Conversacional) print("🤖 Cargando modelo de lenguaje...") # Opción A: Modelo pequeño en español (recomendado para ESP32) llm_tokenizer = AutoTokenizer.from_pretrained("DeepESP/gpt2-spanish") llm_model = AutoModelForCausalLM.from_pretrained("DeepESP/gpt2-spanish") # Opción B: Modelo más potente (requiere más RAM) # llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") # llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") # 3. TTS (Text-to-Speech) print("🔊 Cargando TTS...") # Usar Coqui TTS con modelo en español tts = TTS(model_name="tts_models/es/css10/vits", progress_bar=False, gpu=False) print("✅ Todos los modelos cargados!\n") # ============================================ # MODELOS DE DATOS # ============================================ class ChatRequest(BaseModel): question: str max_length: int = 100 class TTSRequest(BaseModel): text: str # ============================================ # ENDPOINT 1: TRANSCRIPCIÓN (Speech-to-Text) # ============================================ @app.post("/transcribe") async def transcribe_audio(file: UploadFile = File(...)): """ Convierte audio WAV a texto usando Whisper """ try: print(f"📥 Recibiendo audio: {file.filename}") # Leer audio audio_bytes = await file.read() # Guardar temporalmente with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_bytes) tmp_path = tmp.name # Cargar con torchaudio waveform, sample_rate = torchaudio.load(tmp_path) # Remuestrear a 16kHz si es necesario if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) # Convertir a mono si es estéreo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Procesar con Whisper input_features = whisper_processor( waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" ).input_features # Generar transcripción with torch.no_grad(): predicted_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] # Limpiar archivo temporal os.unlink(tmp_path) print(f"✅ Transcrito: {transcription}") return JSONResponse({ "text": transcription, "success": True }) except Exception as e: print(f"❌ Error en transcripción: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================ # ENDPOINT 2: CHAT (IA Conversacional) # ============================================ @app.post("/chat") async def chat(request: ChatRequest): """ Genera respuesta usando modelo de lenguaje """ try: question = request.question.strip() print(f"💬 Pregunta: {question}") if not question: return JSONResponse({ "answer": "No escuché ninguna pregunta", "success": False }) # Preparar prompt prompt = f"Pregunta: {question}\nRespuesta:" # Generar respuesta inputs = llm_tokenizer.encode(prompt, return_tensors="pt") with torch.no_grad(): outputs = llm_model.generate( inputs, max_length=request.max_length, num_return_sequences=1, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id ) # Decodificar respuesta full_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) # Extraer solo la respuesta (después de "Respuesta:") if "Respuesta:" in full_text: answer = full_text.split("Respuesta:")[-1].strip() else: answer = full_text.strip() # Limitar longitud if len(answer) > 200: answer = answer[:200] + "..." print(f"✅ Respuesta: {answer}") return JSONResponse({ "answer": answer, "success": True }) except Exception as e: print(f"❌ Error en chat: {str(e)}") return JSONResponse({ "answer": "Lo siento, tuve un error al procesar tu pregunta", "success": False, "error": str(e) }) # ============================================ # ENDPOINT 3: TEXT-TO-SPEECH # ============================================ @app.post("/tts") async def text_to_speech(request: TTSRequest): """ Convierte texto a audio usando Coqui TTS """ try: text = request.text.strip() print(f"🔊 Generando voz para: {text[:50]}...") if not text: raise HTTPException(status_code=400, detail="Texto vacío") # Limitar longitud para evitar timeouts if len(text) > 300: text = text[:300] + "..." # Generar audio con TTS with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp_path = tmp.name # Generar audio tts.tts_to_file( text=text, file_path=tmp_path ) # Leer audio generado with open(tmp_path, "rb") as f: audio_data = f.read() # Limpiar os.unlink(tmp_path) print(f"✅ Audio generado: {len(audio_data)} bytes") # Retornar como stream return StreamingResponse( io.BytesIO(audio_data), media_type="audio/wav", headers={ "Content-Disposition": "attachment; filename=speech.wav" } ) except Exception as e: print(f"❌ Error en TTS: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================ # ENDPOINT 4: PROCESO COMPLETO (OPCIONAL) # ============================================ @app.post("/complete") async def complete_conversation(file: UploadFile = File(...)): """ Proceso completo: Audio → Texto → IA → Audio (Alternativa más simple para el ESP32) """ try: print("🔄 Iniciando proceso completo...") # 1. Transcribir audio_bytes = await file.read() with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_bytes) tmp_path = tmp.name waveform, sample_rate = torchaudio.load(tmp_path) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) input_features = whisper_processor( waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" ).input_features with torch.no_grad(): predicted_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] os.unlink(tmp_path) print(f"✅ Transcrito: {transcription}") # 2. Generar respuesta prompt = f"Pregunta: {transcription}\nRespuesta:" inputs = llm_tokenizer.encode(prompt, return_tensors="pt") with torch.no_grad(): outputs = llm_model.generate( inputs, max_length=100, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id ) full_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) if "Respuesta:" in full_text: answer = full_text.split("Respuesta:")[-1].strip() else: answer = full_text.strip() if len(answer) > 200: answer = answer[:200] print(f"✅ Respuesta: {answer}") # 3. Generar audio with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: audio_path = tmp.name tts.tts_to_file(text=answer, file_path=audio_path) with open(audio_path, "rb") as f: audio_data = f.read() os.unlink(audio_path) print("✅ Proceso completo!") return StreamingResponse( io.BytesIO(audio_data), media_type="audio/wav", headers={ "X-Transcription": transcription, "X-Answer": answer } ) except Exception as e: print(f"❌ Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================ # ENDPOINTS DE UTILIDAD # ============================================ @app.get("/") async def root(): return { "message": "🤖 API Asistente de Voz", "version": "1.0", "endpoints": { "/transcribe": "POST - Audio WAV → Texto", "/chat": "POST - Pregunta → Respuesta IA", "/tts": "POST - Texto → Audio", "/complete": "POST - Audio → Audio (proceso completo)" } } @app.get("/health") async def health_check(): return { "status": "ok", "models": { "whisper": "loaded", "llm": "loaded", "tts": "loaded" } } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)