ooooomeganFox2 / server.py
Primeness's picture
Upload folder using huggingface_hub
e0520d0 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import tempfile
import traceback
import whisper
import librosa
import numpy as np
import torch
import outetts
import uvicorn
import base64
import io
import soundfile as sf
try:
INTERFACE = outetts.Interface(
config=outetts.ModelConfig(
model_path="models/v10",
tokenizer_path="models/v10",
audio_codec_path="models/dsp/weights_24khz_1.5kbps_v1.0.pth",
device="cuda",
dtype=torch.bfloat16,
)
)
except Exception as e:
raise RuntimeError(f"{e}")
asr_model = whisper.load_model("models/wpt/wpt.pt")
model_name = "models/lm"
tok = AutoTokenizer.from_pretrained(model_name)
lm = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav")
def gt(audio: np.ndarray, sr: int):
ss = audio.squeeze().astype(np.float32)
if sr != 16_000:
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000)
result = asr_model.transcribe(ss, fp16=False, language=None)
return result["text"].strip()
def sample(rr: str) -> str:
if rr.strip() == "":
rr = "Hello "
inputs = tok(rr, return_tensors="pt").to(lm.device)
with torch.inference_mode():
out_ids = lm.generate(
**inputs,
max_new_tokens=45,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
return tok.decode(
out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True
)
INITIALIZATION_STATUS = {"model_loaded": True, "error": None}
class GenerateRequest(BaseModel):
audio_data: str = Field(
...,
description="",
)
sample_rate: int = Field(..., description="")
class GenerateResponse(BaseModel):
audio_data: str = Field(..., description="")
app = FastAPI(title="V1", version="0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def b64(b64: str) -> np.ndarray:
raw = base64.b64decode(b64)
return np.load(io.BytesIO(raw), allow_pickle=False)
def ab64(arr: np.ndarray, sr: int) -> str:
buf = io.BytesIO()
resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr)
np.save(buf, resampled.astype(np.float32))
return base64.b64encode(buf.getvalue()).decode()
def gs(
audio: np.ndarray,
sr: int,
interface: outetts.Interface,
):
if audio.ndim == 2:
audio = audio.squeeze()
audio = audio.astype("float32")
max_samples = int(15.0 * sr)
if audio.shape[-1] > max_samples:
audio = audio[-max_samples:]
with tempfile.NamedTemporaryFile(suffix=".wav", dir="/tmp", delete=False) as f:
sf.write(f.name, audio, sr)
speaker = interface.create_speaker(
f.name,
whisper_model="models/wpt/wpt.pt",
)
return speaker
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy",
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
"error": INITIALIZATION_STATUS["error"],
}
return status
@app.post("/api/v1/inference", response_model=GenerateResponse)
def generate_audio(req: GenerateRequest):
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
try:
text = gt(audio_np, req.sample_rate)
out = INTERFACE.generate(
config=outetts.GenerationConfig(
text=sample(text),
generation_type=outetts.GenerationType.CHUNKED,
speaker=gs(audio_np, req.sample_rate, INTERFACE),
sampler_config=outetts.SamplerConfig(),
)
)
audio_out = out.audio.squeeze().cpu().numpy()
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate))
if __name__ == "__main__":
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)