Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Optional, Union | |
| from uuid import uuid4 | |
| from huggingface_hub import HfFolder, whoami | |
| from . import __version__ | |
| from .utils import ENV_VARS_TRUE_VALUES, logging | |
| from .utils.import_utils import ( | |
| _flax_version, | |
| _jax_version, | |
| _onnxruntime_version, | |
| _torch_version, | |
| is_flax_available, | |
| is_modelcards_available, | |
| is_onnx_available, | |
| is_torch_available, | |
| ) | |
| if is_modelcards_available(): | |
| from modelcards import CardData, ModelCard | |
| logger = logging.get_logger(__name__) | |
| MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" | |
| SESSION_ID = uuid4().hex | |
| DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES | |
| def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: | |
| """ | |
| Formats a user-agent string with basic info about a request. | |
| """ | |
| ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" | |
| if DISABLE_TELEMETRY: | |
| return ua + "; telemetry/off" | |
| if is_torch_available(): | |
| ua += f"; torch/{_torch_version}" | |
| if is_flax_available(): | |
| ua += f"; jax/{_jax_version}" | |
| ua += f"; flax/{_flax_version}" | |
| if is_onnx_available(): | |
| ua += f"; onnxruntime/{_onnxruntime_version}" | |
| # CI will set this value to True | |
| if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: | |
| ua += "; is_ci/true" | |
| if isinstance(user_agent, dict): | |
| ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) | |
| elif isinstance(user_agent, str): | |
| ua += "; " + user_agent | |
| return ua | |
| def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): | |
| if token is None: | |
| token = HfFolder.get_token() | |
| if organization is None: | |
| username = whoami(token)["name"] | |
| return f"{username}/{model_id}" | |
| else: | |
| return f"{organization}/{model_id}" | |
| def create_model_card(args, model_name): | |
| if not is_modelcards_available: | |
| raise ValueError( | |
| "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" | |
| " install the package with `pip install modelcards`." | |
| ) | |
| if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: | |
| return | |
| hub_token = args.hub_token if hasattr(args, "hub_token") else None | |
| repo_name = get_full_repo_name(model_name, token=hub_token) | |
| model_card = ModelCard.from_template( | |
| card_data=CardData( # Card metadata object that will be converted to YAML block | |
| language="en", | |
| license="apache-2.0", | |
| library_name="diffusers", | |
| tags=[], | |
| datasets=args.dataset_name, | |
| metrics=[], | |
| ), | |
| template_path=MODEL_CARD_TEMPLATE_PATH, | |
| model_name=model_name, | |
| repo_name=repo_name, | |
| dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, | |
| learning_rate=args.learning_rate, | |
| train_batch_size=args.train_batch_size, | |
| eval_batch_size=args.eval_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps | |
| if hasattr(args, "gradient_accumulation_steps") | |
| else None, | |
| adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, | |
| adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, | |
| adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, | |
| adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, | |
| lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, | |
| lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, | |
| ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, | |
| ema_power=args.ema_power if hasattr(args, "ema_power") else None, | |
| ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, | |
| mixed_precision=args.mixed_precision, | |
| ) | |
| card_path = os.path.join(args.output_dir, "README.md") | |
| model_card.save(card_path) | |