Spaces:
Running
Running
| """ | |
| Utilities for working with the local dataset cache. | |
| This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp | |
| Copyright by the AllenNLP authors. | |
| """ | |
| import fnmatch | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| import tarfile | |
| import tempfile | |
| from contextlib import contextmanager | |
| from functools import partial, wraps | |
| from hashlib import sha256 | |
| from pathlib import Path | |
| from typing import Dict, Optional, Union | |
| from urllib.parse import urlparse | |
| from zipfile import ZipFile, is_zipfile | |
| import requests | |
| from filelock import FileLock | |
| from tqdm.auto import tqdm | |
| #from . import __version__ | |
| __version__ = "3.0.2" | |
| logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
| try: | |
| USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
| USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
| if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): | |
| import torch | |
| _torch_available = True # pylint: disable=invalid-name | |
| logger.info("PyTorch version {} available.".format(torch.__version__)) | |
| else: | |
| logger.info("Disabling PyTorch because USE_TF is set") | |
| _torch_available = False | |
| except ImportError: | |
| _torch_available = False # pylint: disable=invalid-name | |
| try: | |
| USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
| USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
| if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): | |
| import tensorflow as tf | |
| assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 | |
| _tf_available = True # pylint: disable=invalid-name | |
| logger.info("TensorFlow version {} available.".format(tf.__version__)) | |
| else: | |
| logger.info("Disabling Tensorflow because USE_TORCH is set") | |
| _tf_available = False | |
| except (ImportError, AssertionError): | |
| _tf_available = False # pylint: disable=invalid-name | |
| try: | |
| from torch.hub import _get_torch_home | |
| torch_cache_home = _get_torch_home() | |
| except ImportError: | |
| torch_cache_home = os.path.expanduser( | |
| os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) | |
| ) | |
| try: | |
| import torch_xla.core.xla_model as xm # noqa: F401 | |
| if _torch_available: | |
| _torch_tpu_available = True # pylint: disable= | |
| else: | |
| _torch_tpu_available = False | |
| except ImportError: | |
| _torch_tpu_available = False | |
| try: | |
| import psutil # noqa: F401 | |
| _psutil_available = True | |
| except ImportError: | |
| _psutil_available = False | |
| try: | |
| import py3nvml # noqa: F401 | |
| _py3nvml_available = True | |
| except ImportError: | |
| _py3nvml_available = False | |
| try: | |
| from apex import amp # noqa: F401 | |
| _has_apex = True | |
| except ImportError: | |
| _has_apex = False | |
| default_cache_path = os.path.join(torch_cache_home, "transformers") | |
| PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) | |
| PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) | |
| TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) | |
| WEIGHTS_NAME = "pytorch_model.bin" | |
| TF2_WEIGHTS_NAME = "tf_model.h5" | |
| TF_WEIGHTS_NAME = "model.ckpt" | |
| CONFIG_NAME = "config.json" | |
| MODEL_CARD_NAME = "modelcard.json" | |
| MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] | |
| DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] | |
| DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] | |
| S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" | |
| CLOUDFRONT_DISTRIB_PREFIX = "/static-proxy?url=https%3A%2F%2Fcdn.huggingface.co%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> | |
| def is_torch_available(): | |
| return _torch_available | |
| def is_tf_available(): | |
| return _tf_available | |
| def is_torch_tpu_available(): | |
| return _torch_tpu_available | |
| def is_psutil_available(): | |
| return _psutil_available | |
| def is_py3nvml_available(): | |
| return _py3nvml_available | |
| def is_apex_available(): | |
| return _has_apex | |
| def add_start_docstrings(*docstr): | |
| def docstring_decorator(fn): | |
| fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") | |
| return fn | |
| return docstring_decorator | |
| def add_start_docstrings_to_callable(*docstr): | |
| def docstring_decorator(fn): | |
| class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) | |
| intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) | |
| note = r""" | |
| .. note:: | |
| Although the recipe for forward pass needs to be defined within | |
| this function, one should call the :class:`Module` instance afterwards | |
| instead of this since the former takes care of running the | |
| pre and post processing steps while the latter silently ignores them. | |
| """ | |
| fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") | |
| return fn | |
| return docstring_decorator | |
| def add_end_docstrings(*docstr): | |
| def docstring_decorator(fn): | |
| fn.__doc__ = fn.__doc__ + "".join(docstr) | |
| return fn | |
| return docstring_decorator | |
| PT_TOKEN_CLASSIFICATION_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import torch | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
| >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1 | |
| >>> outputs = model(**inputs, labels=labels) | |
| >>> loss, scores = outputs[:2] | |
| """ | |
| PT_QUESTION_ANSWERING_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import torch | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
| >>> start_positions = torch.tensor([1]) | |
| >>> end_positions = torch.tensor([3]) | |
| >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions) | |
| >>> loss, start_scores, end_scores = outputs[:3] | |
| """ | |
| PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import torch | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
| >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 | |
| >>> outputs = model(**inputs, labels=labels) | |
| >>> loss, logits = outputs[:2] | |
| """ | |
| PT_MASKED_LM_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import torch | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] | |
| >>> outputs = model(input_ids, labels=input_ids) | |
| >>> loss, prediction_scores = outputs[:2] | |
| """ | |
| PT_BASE_MODEL_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import torch | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
| >>> outputs = model(**inputs) | |
| >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple | |
| """ | |
| PT_MULTIPLE_CHOICE_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import torch | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." | |
| >>> choice0 = "It is eaten with a fork and a knife." | |
| >>> choice1 = "It is eaten while held in the hand." | |
| >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 | |
| >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True) | |
| >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1 | |
| >>> # the linear classifier still needs to be trained | |
| >>> loss, logits = outputs[:2] | |
| """ | |
| PT_CAUSAL_LM_SAMPLE = r""" | |
| Example:: | |
| >>> import torch | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
| >>> outputs = model(**inputs, labels=inputs["input_ids"]) | |
| >>> loss, logits = outputs[:2] | |
| """ | |
| TF_TOKEN_CLASSIFICATION_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") | |
| >>> input_ids = inputs["input_ids"] | |
| >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1 | |
| >>> outputs = model(inputs) | |
| >>> loss, scores = outputs[:2] | |
| """ | |
| TF_QUESTION_ANSWERING_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" | |
| >>> input_dict = tokenizer(question, text, return_tensors='tf') | |
| >>> start_scores, end_scores = model(input_dict) | |
| >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0]) | |
| >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) | |
| """ | |
| TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") | |
| >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1 | |
| >>> outputs = model(inputs) | |
| >>> loss, logits = outputs[:2] | |
| """ | |
| TF_MASKED_LM_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 | |
| >>> outputs = model(input_ids) | |
| >>> prediction_scores = outputs[0] | |
| """ | |
| TF_BASE_MODEL_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") | |
| >>> outputs = model(inputs) | |
| >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple | |
| """ | |
| TF_MULTIPLE_CHOICE_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." | |
| >>> choice0 = "It is eaten with a fork and a knife." | |
| >>> choice1 = "It is eaten while held in the hand." | |
| >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True) | |
| >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} | |
| >>> outputs = model(inputs) # batch size is 1 | |
| >>> # the linear classifier still needs to be trained | |
| >>> logits = outputs[0] | |
| """ | |
| TF_CAUSAL_LM_SAMPLE = r""" | |
| Example:: | |
| >>> from transformers import {tokenizer_class}, {model_class} | |
| >>> import tensorflow as tf | |
| >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') | |
| >>> model = {model_class}.from_pretrained('{checkpoint}') | |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") | |
| >>> outputs = model(inputs) | |
| >>> logits = outputs[0] | |
| """ | |
| def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None): | |
| def docstring_decorator(fn): | |
| model_class = fn.__qualname__.split(".")[0] | |
| is_tf_class = model_class[:2] == "TF" | |
| if "SequenceClassification" in model_class: | |
| code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE | |
| elif "QuestionAnswering" in model_class: | |
| code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE | |
| elif "TokenClassification" in model_class: | |
| code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE | |
| elif "MultipleChoice" in model_class: | |
| code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE | |
| elif "MaskedLM" in model_class: | |
| code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE | |
| elif "LMHead" in model_class: | |
| code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE | |
| elif "Model" in model_class: | |
| code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE | |
| else: | |
| raise ValueError(f"Docstring can't be built for model {model_class}") | |
| built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) | |
| fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc | |
| return fn | |
| return docstring_decorator | |
| def is_remote_url(url_or_filename): | |
| parsed = urlparse(url_or_filename) | |
| return parsed.scheme in ("http", "https") | |
| def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: | |
| """ | |
| Resolve a model identifier, and a file name, to a HF-hosted url | |
| on either S3 or Cloudfront (a Content Delivery Network, or CDN). | |
| Cloudfront is replicated over the globe so downloads are way faster | |
| for the end user (and it also lowers our bandwidth costs). However, it | |
| is more aggressively cached by default, so may not always reflect the | |
| latest changes to the underlying file (default TTL is 24 hours). | |
| In terms of client-side caching from this library, even though | |
| Cloudfront relays the ETags from S3, using one or the other | |
| (or switching from one to the other) will affect caching: cached files | |
| are not shared between the two because the cached file's name contains | |
| a hash of the url. | |
| """ | |
| endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX | |
| legacy_format = "/" not in model_id | |
| if legacy_format: | |
| return f"{endpoint}/{model_id}-{filename}" | |
| else: | |
| return f"{endpoint}/{model_id}/{filename}" | |
| def url_to_filename(url, etag=None): | |
| """ | |
| Convert `url` into a hashed filename in a repeatable way. | |
| If `etag` is specified, append its hash to the url's, delimited | |
| by a period. | |
| If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name | |
| so that TF 2.0 can identify it as a HDF5 file | |
| (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) | |
| """ | |
| url_bytes = url.encode("utf-8") | |
| url_hash = sha256(url_bytes) | |
| filename = url_hash.hexdigest() | |
| if etag: | |
| etag_bytes = etag.encode("utf-8") | |
| etag_hash = sha256(etag_bytes) | |
| filename += "." + etag_hash.hexdigest() | |
| if url.endswith(".h5"): | |
| filename += ".h5" | |
| return filename | |
| def filename_to_url(filename, cache_dir=None): | |
| """ | |
| Return the url and etag (which may be ``None``) stored for `filename`. | |
| Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | |
| """ | |
| if cache_dir is None: | |
| cache_dir = TRANSFORMERS_CACHE | |
| if isinstance(cache_dir, Path): | |
| cache_dir = str(cache_dir) | |
| cache_path = os.path.join(cache_dir, filename) | |
| if not os.path.exists(cache_path): | |
| raise EnvironmentError("file {} not found".format(cache_path)) | |
| meta_path = cache_path + ".json" | |
| if not os.path.exists(meta_path): | |
| raise EnvironmentError("file {} not found".format(meta_path)) | |
| with open(meta_path, encoding="utf-8") as meta_file: | |
| metadata = json.load(meta_file) | |
| url = metadata["url"] | |
| etag = metadata["etag"] | |
| return url, etag | |
| def cached_path( | |
| url_or_filename, | |
| cache_dir=None, | |
| force_download=False, | |
| proxies=None, | |
| resume_download=False, | |
| user_agent: Union[Dict, str, None] = None, | |
| extract_compressed_file=False, | |
| force_extract=False, | |
| local_files_only=False, | |
| ) -> Optional[str]: | |
| """ | |
| Given something that might be a URL (or might be a local path), | |
| determine which. If it's a URL, download the file and cache it, and | |
| return the path to the cached file. If it's already a local path, | |
| make sure the file exists and then return the path. | |
| Args: | |
| cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). | |
| force_download: if True, re-dowload the file even if it's already cached in the cache dir. | |
| resume_download: if True, resume the download if incompletly recieved file is found. | |
| user_agent: Optional string or dict that will be appended to the user-agent on remote requests. | |
| extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed | |
| file in a folder along the archive. | |
| force_extract: if True when extract_compressed_file is True and the archive was already extracted, | |
| re-extract the archive and overide the folder where it was extracted. | |
| Return: | |
| None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). | |
| Local path (string) otherwise | |
| """ | |
| if cache_dir is None: | |
| cache_dir = TRANSFORMERS_CACHE | |
| if isinstance(url_or_filename, Path): | |
| url_or_filename = str(url_or_filename) | |
| if isinstance(cache_dir, Path): | |
| cache_dir = str(cache_dir) | |
| if is_remote_url(url_or_filename): | |
| # URL, so get it from the cache (downloading if necessary) | |
| output_path = get_from_cache( | |
| url_or_filename, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| user_agent=user_agent, | |
| local_files_only=local_files_only, | |
| ) | |
| elif os.path.exists(url_or_filename): | |
| # File, and it exists. | |
| output_path = url_or_filename | |
| elif urlparse(url_or_filename).scheme == "": | |
| # File, but it doesn't exist. | |
| raise EnvironmentError("file {} not found".format(url_or_filename)) | |
| else: | |
| # Something unknown | |
| raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) | |
| if extract_compressed_file: | |
| if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): | |
| return output_path | |
| # Path where we extract compressed archives | |
| # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" | |
| output_dir, output_file = os.path.split(output_path) | |
| output_extract_dir_name = output_file.replace(".", "-") + "-extracted" | |
| output_path_extracted = os.path.join(output_dir, output_extract_dir_name) | |
| if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: | |
| return output_path_extracted | |
| # Prevent parallel extractions | |
| lock_path = output_path + ".lock" | |
| with FileLock(lock_path): | |
| shutil.rmtree(output_path_extracted, ignore_errors=True) | |
| os.makedirs(output_path_extracted) | |
| if is_zipfile(output_path): | |
| with ZipFile(output_path, "r") as zip_file: | |
| zip_file.extractall(output_path_extracted) | |
| zip_file.close() | |
| elif tarfile.is_tarfile(output_path): | |
| tar_file = tarfile.open(output_path) | |
| tar_file.extractall(output_path_extracted) | |
| tar_file.close() | |
| else: | |
| raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) | |
| return output_path_extracted | |
| return output_path | |
| def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None): | |
| ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) | |
| if is_torch_available(): | |
| ua += "; torch/{}".format(torch.__version__) | |
| if is_tf_available(): | |
| ua += "; tensorflow/{}".format(tf.__version__) | |
| if isinstance(user_agent, dict): | |
| ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) | |
| elif isinstance(user_agent, str): | |
| ua += "; " + user_agent | |
| headers = {"user-agent": ua} | |
| if resume_size > 0: | |
| headers["Range"] = "bytes=%d-" % (resume_size,) | |
| response = requests.get(url, stream=True, proxies=proxies, headers=headers) | |
| if response.status_code == 416: # Range not satisfiable | |
| return | |
| content_length = response.headers.get("Content-Length") | |
| total = resume_size + int(content_length) if content_length is not None else None | |
| progress = tqdm( | |
| unit="B", | |
| unit_scale=True, | |
| total=total, | |
| initial=resume_size, | |
| desc="Downloading", | |
| disable=bool(logger.getEffectiveLevel() == logging.NOTSET), | |
| ) | |
| for chunk in response.iter_content(chunk_size=1024): | |
| if chunk: # filter out keep-alive new chunks | |
| progress.update(len(chunk)) | |
| temp_file.write(chunk) | |
| progress.close() | |
| def get_from_cache( | |
| url, | |
| cache_dir=None, | |
| force_download=False, | |
| proxies=None, | |
| etag_timeout=10, | |
| resume_download=False, | |
| user_agent: Union[Dict, str, None] = None, | |
| local_files_only=False, | |
| ) -> Optional[str]: | |
| """ | |
| Given a URL, look for the corresponding file in the local cache. | |
| If it's not there, download it. Then return the path to the cached file. | |
| Return: | |
| None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). | |
| Local path (string) otherwise | |
| """ | |
| if cache_dir is None: | |
| cache_dir = TRANSFORMERS_CACHE | |
| if isinstance(cache_dir, Path): | |
| cache_dir = str(cache_dir) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| etag = None | |
| if not local_files_only: | |
| try: | |
| response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) | |
| if response.status_code == 200: | |
| etag = response.headers.get("ETag") | |
| except (EnvironmentError, requests.exceptions.Timeout): | |
| # etag is already None | |
| pass | |
| filename = url_to_filename(url, etag) | |
| # get cache path to put the file | |
| cache_path = os.path.join(cache_dir, filename) | |
| # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. | |
| # try to get the last downloaded one | |
| if etag is None: | |
| if os.path.exists(cache_path): | |
| return cache_path | |
| else: | |
| matching_files = [ | |
| file | |
| for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") | |
| if not file.endswith(".json") and not file.endswith(".lock") | |
| ] | |
| if len(matching_files) > 0: | |
| return os.path.join(cache_dir, matching_files[-1]) | |
| else: | |
| # If files cannot be found and local_files_only=True, | |
| # the models might've been found if local_files_only=False | |
| # Notify the user about that | |
| if local_files_only: | |
| raise ValueError( | |
| "Cannot find the requested files in the cached path and outgoing traffic has been" | |
| " disabled. To enable model look-ups and downloads online, set 'local_files_only'" | |
| " to False." | |
| ) | |
| return None | |
| # From now on, etag is not None. | |
| if os.path.exists(cache_path) and not force_download: | |
| return cache_path | |
| # Prevent parallel downloads of the same file with a lock. | |
| lock_path = cache_path + ".lock" | |
| with FileLock(lock_path): | |
| # If the download just completed while the lock was activated. | |
| if os.path.exists(cache_path) and not force_download: | |
| # Even if returning early like here, the lock will be released. | |
| return cache_path | |
| if resume_download: | |
| incomplete_path = cache_path + ".incomplete" | |
| def _resumable_file_manager(): | |
| with open(incomplete_path, "a+b") as f: | |
| yield f | |
| temp_file_manager = _resumable_file_manager | |
| if os.path.exists(incomplete_path): | |
| resume_size = os.stat(incomplete_path).st_size | |
| else: | |
| resume_size = 0 | |
| else: | |
| temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) | |
| resume_size = 0 | |
| # Download to temporary file, then copy to cache dir once finished. | |
| # Otherwise you get corrupt cache entries if the download gets interrupted. | |
| with temp_file_manager() as temp_file: | |
| logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) | |
| http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) | |
| logger.info("storing %s in cache at %s", url, cache_path) | |
| os.replace(temp_file.name, cache_path) | |
| logger.info("creating metadata file for %s", cache_path) | |
| meta = {"url": url, "etag": etag} | |
| meta_path = cache_path + ".json" | |
| with open(meta_path, "w") as meta_file: | |
| json.dump(meta, meta_file) | |
| return cache_path | |
| class cached_property(property): | |
| """ | |
| Descriptor that mimics @property but caches output in member variable. | |
| From tensorflow_datasets | |
| Built-in in functools from Python 3.8. | |
| """ | |
| def __get__(self, obj, objtype=None): | |
| # See docs.python.org/3/howto/descriptor.html#properties | |
| if obj is None: | |
| return self | |
| if self.fget is None: | |
| raise AttributeError("unreadable attribute") | |
| attr = "__cached_" + self.fget.__name__ | |
| cached = getattr(obj, attr, None) | |
| if cached is None: | |
| cached = self.fget(obj) | |
| setattr(obj, attr, cached) | |
| return cached | |
| def torch_required(func): | |
| # Chose a different decorator name than in tests so it's clear they are not the same. | |
| def wrapper(*args, **kwargs): | |
| if is_torch_available(): | |
| return func(*args, **kwargs) | |
| else: | |
| raise ImportError(f"Method `{func.__name__}` requires PyTorch.") | |
| return wrapper | |
| def tf_required(func): | |
| # Chose a different decorator name than in tests so it's clear they are not the same. | |
| def wrapper(*args, **kwargs): | |
| if is_tf_available(): | |
| return func(*args, **kwargs) | |
| else: | |
| raise ImportError(f"Method `{func.__name__}` requires TF.") | |
| return wrapper | |