VoiceAPI / src /downloader.py
Harshil748's picture
Initial HF Spaces deployment - downloads models at runtime
ecde958
raw
history blame
5.79 kB
"""
Model Downloader for SYSPIN TTS Models
Downloads models from Hugging Face Hub
"""
import os
import logging
from pathlib import Path
from typing import Optional, List
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm import tqdm
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR
logger = logging.getLogger(__name__)
class ModelDownloader:
"""Downloads and manages SYSPIN TTS models from Hugging Face"""
def __init__(self, models_dir: str = MODELS_DIR):
self.models_dir = Path(models_dir)
self.models_dir.mkdir(parents=True, exist_ok=True)
def download_model(self, voice_key: str, force: bool = False) -> Path:
"""
Download a specific voice model
Args:
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female')
force: Re-download even if exists
Returns:
Path to downloaded model directory
"""
if voice_key not in LANGUAGE_CONFIGS:
raise ValueError(
f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}"
)
config = LANGUAGE_CONFIGS[voice_key]
model_dir = self.models_dir / voice_key
# Check if already downloaded
model_path = model_dir / config.model_filename
chars_path = model_dir / config.chars_filename
extra_path = model_dir / "extra.py"
if not force and model_path.exists() and chars_path.exists():
logger.info(f"Model {voice_key} already downloaded at {model_dir}")
return model_dir
logger.info(f"Downloading {voice_key} from {config.hf_model_id}...")
# Create model directory
model_dir.mkdir(parents=True, exist_ok=True)
try:
# Download all files from the repo
snapshot_download(
repo_id=config.hf_model_id,
local_dir=str(model_dir),
local_dir_use_symlinks=False,
allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"],
)
logger.info(f"Successfully downloaded {voice_key} to {model_dir}")
except Exception as e:
logger.error(f"Failed to download {voice_key}: {e}")
raise
return model_dir
def download_all_models(self, force: bool = False) -> List[Path]:
"""Download all available models"""
downloaded = []
for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"):
try:
path = self.download_model(voice_key, force=force)
downloaded.append(path)
except Exception as e:
logger.warning(f"Failed to download {voice_key}: {e}")
return downloaded
def download_language(self, lang_code: str, force: bool = False) -> List[Path]:
"""Download all voices for a specific language"""
downloaded = []
for voice_key, config in LANGUAGE_CONFIGS.items():
if config.code == lang_code:
try:
path = self.download_model(voice_key, force=force)
downloaded.append(path)
except Exception as e:
logger.warning(f"Failed to download {voice_key}: {e}")
return downloaded
def get_model_path(self, voice_key: str) -> Optional[Path]:
"""Get path to a downloaded model"""
if voice_key not in LANGUAGE_CONFIGS:
return None
config = LANGUAGE_CONFIGS[voice_key]
model_path = self.models_dir / voice_key / config.model_filename
if model_path.exists():
return model_path.parent
return None
def list_downloaded_models(self) -> List[str]:
"""List all downloaded models"""
downloaded = []
for voice_key, config in LANGUAGE_CONFIGS.items():
model_path = self.models_dir / voice_key / config.model_filename
if model_path.exists():
downloaded.append(voice_key)
return downloaded
def get_model_size(self, voice_key: str) -> Optional[int]:
"""Get size of downloaded model in bytes"""
model_path = self.get_model_path(voice_key)
if not model_path:
return None
total_size = 0
for f in model_path.iterdir():
if f.is_file():
total_size += f.stat().st_size
return total_size
def download_models_cli():
"""CLI entry point for downloading models"""
import argparse
parser = argparse.ArgumentParser(description="Download SYSPIN TTS models")
parser.add_argument(
"--voice", type=str, help="Specific voice to download (e.g., hi_male)"
)
parser.add_argument(
"--lang", type=str, help="Download all voices for a language (e.g., hi)"
)
parser.add_argument("--all", action="store_true", help="Download all models")
parser.add_argument("--list", action="store_true", help="List available models")
parser.add_argument("--force", action="store_true", help="Force re-download")
args = parser.parse_args()
downloader = ModelDownloader()
if args.list:
print("Available voices:")
for key, config in LANGUAGE_CONFIGS.items():
downloaded = "✓" if downloader.get_model_path(key) else " "
print(f" [{downloaded}] {key}: {config.name} ({config.code})")
return
if args.voice:
downloader.download_model(args.voice, force=args.force)
elif args.lang:
downloader.download_language(args.lang, force=args.force)
elif args.all:
downloader.download_all_models(force=args.force)
else:
parser.print_help()
if __name__ == "__main__":
download_models_cli()