Spaces:
Running
Running
| """ | |
| Model Downloader for SYSPIN TTS Models | |
| Downloads models from Hugging Face Hub | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, List | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from tqdm import tqdm | |
| from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR | |
| logger = logging.getLogger(__name__) | |
| class ModelDownloader: | |
| """Downloads and manages SYSPIN TTS models from Hugging Face""" | |
| def __init__(self, models_dir: str = MODELS_DIR): | |
| self.models_dir = Path(models_dir) | |
| self.models_dir.mkdir(parents=True, exist_ok=True) | |
| def download_model(self, voice_key: str, force: bool = False) -> Path: | |
| """ | |
| Download a specific voice model | |
| Args: | |
| voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female') | |
| force: Re-download even if exists | |
| Returns: | |
| Path to downloaded model directory | |
| """ | |
| if voice_key not in LANGUAGE_CONFIGS: | |
| raise ValueError( | |
| f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}" | |
| ) | |
| config = LANGUAGE_CONFIGS[voice_key] | |
| model_dir = self.models_dir / voice_key | |
| # Check if already downloaded | |
| model_path = model_dir / config.model_filename | |
| chars_path = model_dir / config.chars_filename | |
| extra_path = model_dir / "extra.py" | |
| if not force and model_path.exists() and chars_path.exists(): | |
| logger.info(f"Model {voice_key} already downloaded at {model_dir}") | |
| return model_dir | |
| logger.info(f"Downloading {voice_key} from {config.hf_model_id}...") | |
| # Create model directory | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| # Download all files from the repo | |
| snapshot_download( | |
| repo_id=config.hf_model_id, | |
| local_dir=str(model_dir), | |
| local_dir_use_symlinks=False, | |
| allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"], | |
| ) | |
| logger.info(f"Successfully downloaded {voice_key} to {model_dir}") | |
| except Exception as e: | |
| logger.error(f"Failed to download {voice_key}: {e}") | |
| raise | |
| return model_dir | |
| def download_all_models(self, force: bool = False) -> List[Path]: | |
| """Download all available models""" | |
| downloaded = [] | |
| for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"): | |
| try: | |
| path = self.download_model(voice_key, force=force) | |
| downloaded.append(path) | |
| except Exception as e: | |
| logger.warning(f"Failed to download {voice_key}: {e}") | |
| return downloaded | |
| def download_language(self, lang_code: str, force: bool = False) -> List[Path]: | |
| """Download all voices for a specific language""" | |
| downloaded = [] | |
| for voice_key, config in LANGUAGE_CONFIGS.items(): | |
| if config.code == lang_code: | |
| try: | |
| path = self.download_model(voice_key, force=force) | |
| downloaded.append(path) | |
| except Exception as e: | |
| logger.warning(f"Failed to download {voice_key}: {e}") | |
| return downloaded | |
| def get_model_path(self, voice_key: str) -> Optional[Path]: | |
| """Get path to a downloaded model""" | |
| if voice_key not in LANGUAGE_CONFIGS: | |
| return None | |
| config = LANGUAGE_CONFIGS[voice_key] | |
| model_path = self.models_dir / voice_key / config.model_filename | |
| if model_path.exists(): | |
| return model_path.parent | |
| return None | |
| def list_downloaded_models(self) -> List[str]: | |
| """List all downloaded models""" | |
| downloaded = [] | |
| for voice_key, config in LANGUAGE_CONFIGS.items(): | |
| model_path = self.models_dir / voice_key / config.model_filename | |
| if model_path.exists(): | |
| downloaded.append(voice_key) | |
| return downloaded | |
| def get_model_size(self, voice_key: str) -> Optional[int]: | |
| """Get size of downloaded model in bytes""" | |
| model_path = self.get_model_path(voice_key) | |
| if not model_path: | |
| return None | |
| total_size = 0 | |
| for f in model_path.iterdir(): | |
| if f.is_file(): | |
| total_size += f.stat().st_size | |
| return total_size | |
| def download_models_cli(): | |
| """CLI entry point for downloading models""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Download SYSPIN TTS models") | |
| parser.add_argument( | |
| "--voice", type=str, help="Specific voice to download (e.g., hi_male)" | |
| ) | |
| parser.add_argument( | |
| "--lang", type=str, help="Download all voices for a language (e.g., hi)" | |
| ) | |
| parser.add_argument("--all", action="store_true", help="Download all models") | |
| parser.add_argument("--list", action="store_true", help="List available models") | |
| parser.add_argument("--force", action="store_true", help="Force re-download") | |
| args = parser.parse_args() | |
| downloader = ModelDownloader() | |
| if args.list: | |
| print("Available voices:") | |
| for key, config in LANGUAGE_CONFIGS.items(): | |
| downloaded = "✓" if downloader.get_model_path(key) else " " | |
| print(f" [{downloaded}] {key}: {config.name} ({config.code})") | |
| return | |
| if args.voice: | |
| downloader.download_model(args.voice, force=args.force) | |
| elif args.lang: | |
| downloader.download_language(args.lang, force=args.force) | |
| elif args.all: | |
| downloader.download_all_models(force=args.force) | |
| else: | |
| parser.print_help() | |
| if __name__ == "__main__": | |
| download_models_cli() | |