Spaces:
Running
Running
| """ | |
| REST API Server for Multi-lingual TTS | |
| FastAPI-based server with OpenAPI documentation | |
| Hackathon API Specification: | |
| - GET /Get_Inference with text, lang, speaker_wav parameters | |
| """ | |
| import os | |
| import io | |
| import time | |
| import logging | |
| import tempfile | |
| from typing import Optional, List | |
| from pathlib import Path | |
| import numpy as np | |
| from fastapi import ( | |
| FastAPI, | |
| HTTPException, | |
| Query, | |
| Response, | |
| BackgroundTasks, | |
| UploadFile, | |
| File, | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, FileResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| import soundfile as sf | |
| from .engine import TTSEngine, TTSOutput | |
| from .config import ( | |
| LANGUAGE_CONFIGS, | |
| get_available_languages, | |
| get_available_voices, | |
| STYLE_PRESETS, | |
| ) | |
| # Language name to voice key mapping (for hackathon API) | |
| LANG_TO_VOICE = { | |
| "hindi": "hi_female", | |
| "bengali": "bn_female", | |
| "marathi": "mr_female", | |
| "telugu": "te_female", | |
| "kannada": "kn_female", | |
| "bhojpuri": "bho_female", | |
| "chhattisgarhi": "hne_female", | |
| "maithili": "mai_female", | |
| "magahi": "mag_female", | |
| "english": "en_female", | |
| "gujarati": "gu_mms", | |
| } | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Voice Tech for All - Multi-lingual TTS API", | |
| description=""" | |
| A multi-lingual Text-to-Speech API supporting 10+ Indian languages. | |
| ## Features | |
| - 10 Indian languages with male/female voices | |
| - Real-time speech synthesis | |
| - Text normalization for Indian languages | |
| - Speed control | |
| - Multiple audio formats (WAV, MP3) | |
| ## Supported Languages | |
| Hindi, Bengali, Marathi, Telugu, Kannada, Bhojpuri, | |
| Chhattisgarhi, Maithili, Magahi, English | |
| ## Use Case | |
| Built for an LLM-based healthcare assistant for pregnant mothers | |
| in low-income communities. | |
| """, | |
| version="1.0.0", | |
| contact={ | |
| "name": "Voice Tech for All Hackathon", | |
| "url": "https://huggingface.co/SYSPIN", | |
| }, | |
| license_info={ | |
| "name": "CC BY 4.0", | |
| "url": "https://creativecommons.org/licenses/by/4.0/", | |
| }, | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize TTS Engine (lazy loading) | |
| _engine: Optional[TTSEngine] = None | |
| def get_engine() -> TTSEngine: | |
| """Get or create TTS engine instance""" | |
| global _engine | |
| if _engine is None: | |
| _engine = TTSEngine(device="auto") | |
| return _engine | |
| # Request/Response Models | |
| class SynthesizeRequest(BaseModel): | |
| """Request body for text synthesis""" | |
| text: str = Field( | |
| ..., description="Text to synthesize", min_length=1, max_length=5000 | |
| ) | |
| voice: str = Field( | |
| "hi_male", description="Voice key (e.g., hi_male, bn_female, gu_mms)" | |
| ) | |
| speed: float = Field(1.0, description="Speech speed (0.5-2.0)", ge=0.5, le=2.0) | |
| pitch: float = Field(1.0, description="Pitch multiplier (0.5-2.0)", ge=0.5, le=2.0) | |
| energy: float = Field(1.0, description="Energy/volume (0.5-2.0)", ge=0.5, le=2.0) | |
| style: Optional[str] = Field( | |
| None, description="Style preset (happy, sad, calm, excited, etc.)" | |
| ) | |
| normalize: bool = Field(True, description="Apply text normalization") | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "text": "નમસ્તે, હું તમારી કેવી રીતે મદદ કરી શકું?", | |
| "voice": "gu_mms", | |
| "speed": 1.0, | |
| "pitch": 1.0, | |
| "energy": 1.0, | |
| "style": "calm", | |
| "normalize": True, | |
| } | |
| } | |
| class SynthesizeResponse(BaseModel): | |
| """Response metadata for synthesis""" | |
| success: bool | |
| duration: float | |
| sample_rate: int | |
| voice: str | |
| text: str | |
| inference_time: float | |
| class VoiceInfo(BaseModel): | |
| """Information about a voice""" | |
| key: str | |
| name: str | |
| language_code: str | |
| gender: str | |
| loaded: bool | |
| downloaded: bool | |
| model_type: str = "vits" | |
| class HealthResponse(BaseModel): | |
| """Health check response""" | |
| status: str | |
| device: str | |
| loaded_voices: List[str] | |
| available_voices: int | |
| style_presets: List[str] | |
| # API Endpoints | |
| async def root(): | |
| """API root - welcome message""" | |
| return { | |
| "message": "Voice Tech for All - Multi-lingual TTS API", | |
| "docs": "/docs", | |
| "health": "/health", | |
| "synthesize": "/synthesize", | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| engine = get_engine() | |
| return HealthResponse( | |
| status="healthy", | |
| device=str(engine.device), | |
| loaded_voices=engine.get_loaded_voices(), | |
| available_voices=len(LANGUAGE_CONFIGS), | |
| style_presets=list(STYLE_PRESETS.keys()), | |
| ) | |
| async def list_voices(): | |
| """List all available voices""" | |
| engine = get_engine() | |
| voices = engine.get_available_voices() | |
| return [ | |
| VoiceInfo( | |
| key=key, | |
| name=info["name"], | |
| language_code=info["code"], | |
| gender=info["gender"], | |
| loaded=info["loaded"], | |
| downloaded=info["downloaded"], | |
| model_type=info.get("type", "vits"), | |
| ) | |
| for key, info in voices.items() | |
| ] | |
| async def list_styles(): | |
| """List available style presets for prosody control""" | |
| return { | |
| "presets": STYLE_PRESETS, | |
| "description": { | |
| "speed": "Speech rate multiplier (0.5-2.0)", | |
| "pitch": "Pitch multiplier (0.5-2.0), >1 = higher", | |
| "energy": "Volume/energy multiplier (0.5-2.0)", | |
| }, | |
| } | |
| async def list_languages(): | |
| """List supported languages""" | |
| return get_available_languages() | |
| async def synthesize_audio(request: SynthesizeRequest): | |
| """ | |
| Synthesize speech from text | |
| Returns WAV audio file directly | |
| """ | |
| engine = get_engine() | |
| # Validate voice | |
| if request.voice not in LANGUAGE_CONFIGS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown voice: {request.voice}. Use /voices to see available options.", | |
| ) | |
| try: | |
| start_time = time.time() | |
| # Synthesize | |
| output = engine.synthesize( | |
| text=request.text, | |
| voice=request.voice, | |
| speed=request.speed, | |
| pitch=request.pitch, | |
| energy=request.energy, | |
| style=request.style, | |
| normalize_text=request.normalize, | |
| ) | |
| inference_time = time.time() - start_time | |
| # Convert to WAV bytes | |
| buffer = io.BytesIO() | |
| sf.write(buffer, output.audio, output.sample_rate, format="WAV") | |
| buffer.seek(0) | |
| # Return audio with metadata headers | |
| return Response( | |
| content=buffer.read(), | |
| media_type="audio/wav", | |
| headers={ | |
| "X-Duration": str(output.duration), | |
| "X-Sample-Rate": str(output.sample_rate), | |
| "X-Voice": output.voice, | |
| "X-Style": output.style or "default", | |
| "X-Inference-Time": str(inference_time), | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def synthesize_stream(request: SynthesizeRequest): | |
| """ | |
| Synthesize speech and stream the audio | |
| Returns streaming WAV audio | |
| """ | |
| engine = get_engine() | |
| if request.voice not in LANGUAGE_CONFIGS: | |
| raise HTTPException(status_code=400, detail=f"Unknown voice: {request.voice}") | |
| try: | |
| output = engine.synthesize( | |
| text=request.text, | |
| voice=request.voice, | |
| speed=request.speed, | |
| pitch=request.pitch, | |
| energy=request.energy, | |
| style=request.style, | |
| normalize_text=request.normalize, | |
| ) | |
| # Create streaming response | |
| buffer = io.BytesIO() | |
| sf.write(buffer, output.audio, output.sample_rate, format="WAV") | |
| buffer.seek(0) | |
| return StreamingResponse( | |
| buffer, | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "attachment; filename=speech.wav"}, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Streaming error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def synthesize_get( | |
| text: str = Query( | |
| ..., description="Text to synthesize", min_length=1, max_length=1000 | |
| ), | |
| voice: str = Query("hi_male", description="Voice key"), | |
| speed: float = Query(1.0, description="Speech speed", ge=0.5, le=2.0), | |
| pitch: float = Query(1.0, description="Pitch", ge=0.5, le=2.0), | |
| energy: float = Query(1.0, description="Energy", ge=0.5, le=2.0), | |
| style: Optional[str] = Query(None, description="Style preset"), | |
| ): | |
| """ | |
| GET endpoint for simple synthesis | |
| Useful for testing and simple integrations | |
| """ | |
| request = SynthesizeRequest( | |
| text=text, voice=voice, speed=speed, pitch=pitch, energy=energy, style=style | |
| ) | |
| return await synthesize_audio(request) | |
| async def get_inference( | |
| text: str = Query( | |
| ..., | |
| description="The input text to be converted into speech. For English, text must be lowercase.", | |
| ), | |
| lang: str = Query( | |
| ..., | |
| description="Language of input text. Supported: bhojpuri, bengali, english, gujarati, hindi, chhattisgarhi, kannada, magahi, maithili, marathi, telugu", | |
| ), | |
| speaker_wav: UploadFile = File( | |
| ..., | |
| description="A reference WAV file representing the speaker's voice (mandatory per hackathon spec).", | |
| ), | |
| ): | |
| """ | |
| Hackathon API - Generate speech audio from text | |
| This endpoint follows the Voice Tech for All hackathon specification. | |
| Supports both GET and POST methods with multipart form data. | |
| Parameters: | |
| - text: Input text to synthesize (query param) | |
| - lang: Language (query param) - bhojpuri, bengali, english, gujarati, hindi, chhattisgarhi, kannada, magahi, maithili, marathi, telugu | |
| - speaker_wav: Reference WAV file (multipart file upload, mandatory) | |
| Returns: | |
| - 200 OK: WAV audio file as streaming response | |
| """ | |
| engine = get_engine() | |
| # Normalize language name | |
| lang_lower = lang.lower().strip() | |
| # Enforce lowercase for English text (per spec) | |
| if lang_lower == "english": | |
| text = text.lower() | |
| # Map language to voice | |
| if lang_lower not in LANG_TO_VOICE: | |
| supported = list(LANG_TO_VOICE.keys()) | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported language: {lang}. Supported languages: {', '.join(supported)}", | |
| ) | |
| voice = LANG_TO_VOICE[lang_lower] | |
| # Read speaker_wav (mandatory per spec) | |
| # Note: Current VITS models don't support voice cloning, but we accept the file | |
| # for API compatibility and validation. In future, this could be used for voice adaptation. | |
| try: | |
| speaker_audio_bytes = await speaker_wav.read() | |
| logger.info( | |
| f"Received speaker reference WAV: {len(speaker_audio_bytes)} bytes, filename: {speaker_wav.filename}" | |
| ) | |
| # Validate it's a valid audio file (basic check) | |
| if len(speaker_audio_bytes) < 44: # Minimum WAV header size | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid speaker_wav: file too small to be a valid WAV", | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Could not read speaker_wav: {e}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Failed to read speaker_wav file: {str(e)}" | |
| ) | |
| try: | |
| # Synthesize audio | |
| output = engine.synthesize( | |
| text=text, | |
| voice=voice, | |
| speed=1.0, | |
| normalize_text=True, | |
| ) | |
| # Convert to WAV bytes | |
| buffer = io.BytesIO() | |
| sf.write(buffer, output.audio, output.sample_rate, format="WAV") | |
| buffer.seek(0) | |
| # Return as streaming response (per spec) | |
| return StreamingResponse( | |
| buffer, | |
| media_type="audio/wav", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=output.wav", | |
| "X-Duration": str(output.duration), | |
| "X-Sample-Rate": str(output.sample_rate), | |
| "X-Language": lang, | |
| "X-Voice": voice, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def preload_voice(voice: str): | |
| """Preload a voice model into memory""" | |
| engine = get_engine() | |
| if voice not in LANGUAGE_CONFIGS: | |
| raise HTTPException(status_code=400, detail=f"Unknown voice: {voice}") | |
| try: | |
| engine.load_voice(voice) | |
| return {"message": f"Voice {voice} loaded successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def unload_voice(voice: str): | |
| """Unload a voice model from memory""" | |
| engine = get_engine() | |
| engine.unload_voice(voice) | |
| return {"message": f"Voice {voice} unloaded"} | |
| async def batch_synthesize( | |
| texts: List[str], voice: str = "hi_male", speed: float = 1.0 | |
| ): | |
| """ | |
| Synthesize multiple texts | |
| Returns a list of base64-encoded audio | |
| """ | |
| import base64 | |
| engine = get_engine() | |
| if voice not in LANGUAGE_CONFIGS: | |
| raise HTTPException(status_code=400, detail=f"Unknown voice: {voice}") | |
| results = [] | |
| for text in texts: | |
| output = engine.synthesize(text, voice, speed) | |
| buffer = io.BytesIO() | |
| sf.write(buffer, output.audio, output.sample_rate, format="WAV") | |
| buffer.seek(0) | |
| results.append( | |
| { | |
| "text": text, | |
| "audio_base64": base64.b64encode(buffer.read()).decode(), | |
| "duration": output.duration, | |
| } | |
| ) | |
| return results | |
| # Startup/Shutdown events | |
| async def startup_event(): | |
| """Initialize on startup""" | |
| logger.info("Starting TTS API server...") | |
| # Optionally preload default voice | |
| # get_engine().load_voice("hi_male") | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| logger.info("Shutting down TTS API server...") | |
| def start_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False): | |
| """Start the API server""" | |
| import uvicorn | |
| uvicorn.run("src.api:app", host=host, port=port, reload=reload, log_level="info") | |
| if __name__ == "__main__": | |
| start_server() | |