| from huggingface_hub import HfApi, HfFolder, create_repo, upload_folder | |
| import os | |
| import logging | |
| # 设置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def upload_model_to_hf(model_path, repo_name): | |
| """ | |
| 上传模型到Hugging Face Hub | |
| Args: | |
| model_path: 本地模型路径 | |
| repo_name: Hugging Face仓库名称 (格式: username/repo_name) | |
| """ | |
| # 创建仓库(如果不存在) | |
| try: | |
| create_repo(repo_name, exist_ok=True) | |
| logger.info(f"仓库 {repo_name} 已创建或已存在") | |
| except Exception as e: | |
| logger.error(f"创建仓库时出错: {e}") | |
| return | |
| # 初始化API | |
| api = HfApi() | |
| # 上传整个文件夹 | |
| try: | |
| api.upload_folder( | |
| folder_path=model_path, | |
| repo_id=repo_name, | |
| repo_type="model" | |
| ) | |
| logger.info(f"模型已成功上传到 {repo_name}") | |
| except Exception as e: | |
| logger.error(f"上传模型时出错: {e}") | |
| if __name__ == "__main__": | |
| # 设置参数 | |
| model_path = "/export/disk2/rotation15/projects/patent/correct-model" # 模型文件夹路径 | |
| repo_name = "yushize/patent-classifier" # Hugging Face仓库名称 | |
| # 上传模型 | |
| upload_model_to_hf(model_path, repo_name) |