sirekist98's picture
Update app.py
309aa1d verified
from spaces import GPU
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from snac import SNAC
import gradio as gr
import os
# Autenticación Hugging Face para modelo privado
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "canopylabs/3b-es_it-pretrain-research_release"
lora_model_id = "sirekist98/spanish_conversational_tts"
snac_model_id = "hubertsiuzdak/snac_24khz"
# Load models
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=True)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
use_auth_token=True
)
model = PeftModel.from_pretrained(base_model, lora_model_id, use_auth_token=True)
model = model.to(device)
model.eval()
snac_model = SNAC.from_pretrained(snac_model_id).to(device)
# Speakers
speakers = [
"Alex", "Carmen", "Daniel", "Diego", "Hugo", "Lucía", "María", "Pablo", "Sofía"
]
# Helper to decode tokens to audio
def decode_snac(code_list):
layer_1, layer_2, layer_3 = [], [], []
for i in range((len(code_list)+1)//7):
layer_1.append(code_list[7*i])
layer_2.append(code_list[7*i+1]-4096)
layer_3.append(code_list[7*i+2]-(2*4096))
layer_3.append(code_list[7*i+3]-(3*4096))
layer_2.append(code_list[7*i+4]-(4*4096))
layer_3.append(code_list[7*i+5]-(5*4096))
layer_3.append(code_list[7*i+6]-(6*4096))
device_snac = snac_model.quantizer.quantizers[0].codebook.weight.device
layers = [
torch.tensor(layer_1).unsqueeze(0).to(device_snac),
torch.tensor(layer_2).unsqueeze(0).to(device_snac),
torch.tensor(layer_3).unsqueeze(0).to(device_snac),
]
with torch.no_grad():
audio = snac_model.decode(layers).squeeze().cpu().numpy()
return audio
# Inference
@GPU
def tts(prompt, speaker):
full_prompt = f"{speaker}: {prompt}"
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
start_token = torch.tensor([[128259]], dtype=torch.long).to(device)
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.long).to(device)
input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
padding_len = max(0, 4260 - input_ids.shape[1])
if padding_len > 0:
pad = torch.full((1, padding_len), 128263, dtype=torch.long).to(device)
input_ids = torch.cat([pad, input_ids], dim=1)
attention_mask = torch.cat([
torch.zeros((1, padding_len), dtype=torch.long),
torch.ones((1, input_ids.shape[1]-padding_len), dtype=torch.long)
], dim=1).to(device)
else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long).to(device)
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=1200,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1,
eos_token_id=128258,
use_cache=True,
)
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped = generated_ids[:, last_occurrence_idx+1:]
else:
cropped = generated_ids
cleaned = cropped[cropped != token_to_remove]
trimmed = cleaned[: (len(cleaned) // 7) * 7]
trimmed = [int(t) - 128266 for t in trimmed]
audio = decode_snac(trimmed)
return (24000, audio)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🗣️ Orpheus Spanish TTS — Conversational\nSelecciona un *speaker* y escribe el texto.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Texto", placeholder="Escribe aquí el texto a locutar")
speaker_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
submit_btn = gr.Button("Generar audio")
with gr.Column():
audio_output = gr.Audio(label="Audio generado", type="numpy")
submit_btn.click(
fn=tts,
inputs=[text_input, speaker_dropdown],
outputs=audio_output,
)
demo.queue().launch(show_error=True)