Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast | |
| import logging | |
| def load_text_encoders(args, class_one, class_two): | |
| text_encoder_one = class_one.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant | |
| ) | |
| text_encoder_two = class_two.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant | |
| ) | |
| return text_encoder_one, text_encoder_two | |
| def import_model_class_from_model_name_or_path( | |
| pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" | |
| ): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, subfolder=subfolder, revision=revision | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| elif model_class == "T5EncoderModel": | |
| from transformers import T5EncoderModel | |
| return T5EncoderModel | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| def create_logger(logging_dir,accelerator): | |
| """ | |
| Create a logger that writes to a log file and stdout. | |
| """ | |
| if accelerator.is_main_process: # real logger | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="[\033[34m%(asctime)s\033[0m] %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler(f"{logging_dir}/log.txt"), | |
| ], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| else: # dummy logger (does nothing) | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| return logger |