Spaces:
Running
on
A10G
Running
on
A10G
| from collections import OrderedDict | |
| import torch | |
| from .common import MODEL_FOLDER, load_sd_inpainting_model, download_file | |
| model_dict = { | |
| 'sd15_inp': { | |
| 'sd_version': 1, | |
| 'diffusers_ckpt': True, | |
| 'model_path': OrderedDict([ | |
| ('unet', 'sd-1-5-inpainting/unet.fp16.safetensors'), | |
| ('encoder', 'sd-1-5-inpainting/encoder.fp16.safetensors'), | |
| ('vae', 'sd-1-5-inpainting/vae.fp16.safetensors') | |
| ]), | |
| 'download_url': OrderedDict([ | |
| ('unet', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'), | |
| ('encoder', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'), | |
| ('vae', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true') | |
| ]) | |
| }, | |
| 'ds8_inp': { | |
| 'sd_version': 1, | |
| 'diffusers_ckpt': True, | |
| 'model_path': OrderedDict([ | |
| ('unet', 'ds-8-inpainting/unet.fp16.safetensors'), | |
| ('encoder', 'ds-8-inpainting/encoder.fp16.safetensors'), | |
| ('vae', 'ds-8-inpainting/vae.fp16.safetensors') | |
| ]), | |
| 'download_url': OrderedDict([ | |
| ('unet', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'), | |
| ('encoder', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'), | |
| ('vae', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true') | |
| ]) | |
| }, | |
| 'sd2_inp': { | |
| 'sd_version': 2, | |
| 'diffusers_ckpt': False, | |
| 'model_path': 'sd-2-0-inpainting/512-inpainting-ema.safetensors', | |
| 'download_url': 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true' | |
| } | |
| } | |
| model_cache = {} | |
| def pre_download_inpainting_models(): | |
| for model_id, model_details in model_dict.items(): | |
| download_url = model_details['download_url'] | |
| model_path = model_details["model_path"] | |
| if type(download_url) == str and type(model_path) == str: | |
| download_file(download_url, f'{MODEL_FOLDER}/{model_path}') | |
| elif type(download_url) == OrderedDict and type(model_path) == OrderedDict: | |
| for key in download_url.keys(): | |
| download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}') | |
| else: | |
| raise Exception('download_url definition type is not supported') | |
| def load_inpainting_model(model_id, dtype=torch.float16, device='cuda:0', cache=False): | |
| if cache and model_id in model_cache: | |
| return model_cache[model_id] | |
| else: | |
| if model_id not in model_dict: | |
| raise Exception(f'Unsupported model-id. Choose one from {list(model_dict.keys())}.') | |
| model = load_sd_inpainting_model( | |
| **model_dict[model_id], | |
| dtype=dtype, | |
| device=device | |
| ) | |
| if cache: | |
| model_cache[model_id] = model | |
| return model | |