from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from pathlib import Path from transformers import AutoModelForCausalLM, AutoTokenizer import tempfile import traceback import whisper import librosa import numpy as np import torch import outetts import uvicorn import base64 import io import soundfile as sf try: INTERFACE = outetts.Interface( config=outetts.ModelConfig( model_path="models/v10", tokenizer_path="models/v10", audio_codec_path="models/dsp/weights_24khz_1.5kbps_v1.0.pth", device="cuda", dtype=torch.bfloat16, ) ) except Exception as e: raise RuntimeError(f"{e}") asr_model = whisper.load_model("models/wpt/wpt.pt") model_name = "models/lm" tok = AutoTokenizer.from_pretrained(model_name) lm = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda", ).eval() SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav") def gt(audio: np.ndarray, sr: int): ss = audio.squeeze().astype(np.float32) if sr != 16_000: ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) result = asr_model.transcribe(ss, fp16=False, language=None) return result["text"].strip() def sample(rr: str) -> str: if rr.strip() == "": rr = "Hello " inputs = tok(rr, return_tensors="pt").to(lm.device) with torch.inference_mode(): out_ids = lm.generate( **inputs, max_new_tokens=45, do_sample=True, temperature=0.2, repetition_penalty=1.1, top_k=100, top_p=0.95, ) return tok.decode( out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True ) INITIALIZATION_STATUS = {"model_loaded": True, "error": None} class GenerateRequest(BaseModel): audio_data: str = Field( ..., description="", ) sample_rate: int = Field(..., description="") class GenerateResponse(BaseModel): audio_data: str = Field(..., description="") app = FastAPI(title="V1", version="0.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def b64(b64: str) -> np.ndarray: raw = base64.b64decode(b64) return np.load(io.BytesIO(raw), allow_pickle=False) def ab64(arr: np.ndarray, sr: int) -> str: buf = io.BytesIO() resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr) np.save(buf, resampled.astype(np.float32)) return base64.b64encode(buf.getvalue()).decode() def gs( audio: np.ndarray, sr: int, interface: outetts.Interface, ): if audio.ndim == 2: audio = audio.squeeze() audio = audio.astype("float32") max_samples = int(15.0 * sr) if audio.shape[-1] > max_samples: audio = audio[-max_samples:] with tempfile.NamedTemporaryFile(suffix=".wav", dir="/tmp", delete=False) as f: sf.write(f.name, audio, sr) speaker = interface.create_speaker( f.name, whisper_model="models/wpt/wpt.pt", ) return speaker @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy", "model_loaded": INITIALIZATION_STATUS["model_loaded"], "error": INITIALIZATION_STATUS["error"], } return status @app.post("/api/v1/inference", response_model=GenerateResponse) def generate_audio(req: GenerateRequest): audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) try: text = gt(audio_np, req.sample_rate) out = INTERFACE.generate( config=outetts.GenerationConfig( text=sample(text), generation_type=outetts.GenerationType.CHUNKED, speaker=gs(audio_np, req.sample_rate, INTERFACE), sampler_config=outetts.SamplerConfig(), ) ) audio_out = out.audio.squeeze().cpu().numpy() except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) if __name__ == "__main__": uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)