Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| from omegaconf import OmegaConf | |
| import math | |
| from copy import deepcopy | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import clip | |
| from transformers import AutoTokenizer | |
| from kandinsky2.model.text_encoders import TextEncoder | |
| from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ | |
| from kandinsky2.model.samplers import DDIMSampler, PLMSSampler | |
| from kandinsky2.model.model_creation import create_model, create_gaussian_diffusion | |
| from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer | |
| from kandinsky2.utils import prepare_image, q_sample, process_images, prepare_mask | |
| class Kandinsky2_1: | |
| def __init__( | |
| self, | |
| config, | |
| model_path, | |
| prior_path, | |
| device, | |
| task_type="text2img" | |
| ): | |
| self.config = config | |
| self.device = device | |
| self.use_fp16 = self.config["model_config"]["use_fp16"] | |
| self.task_type = task_type | |
| self.clip_image_size = config["clip_image_size"] | |
| if task_type == "text2img": | |
| self.config["model_config"]["up"] = False | |
| self.config["model_config"]["inpainting"] = False | |
| elif task_type == "inpainting": | |
| self.config["model_config"]["up"] = False | |
| self.config["model_config"]["inpainting"] = True | |
| else: | |
| raise ValueError("Only text2img and inpainting is available") | |
| self.tokenizer1 = AutoTokenizer.from_pretrained(self.config["tokenizer_name"]) | |
| self.tokenizer2 = CustomizedTokenizer() | |
| clip_mean, clip_std = torch.load( | |
| config["prior"]["clip_mean_std_path"], map_location="cpu" | |
| ) | |
| self.prior = PriorDiffusionModel( | |
| config["prior"]["params"], | |
| self.tokenizer2, | |
| clip_mean, | |
| clip_std, | |
| ) | |
| self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False) | |
| if self.use_fp16: | |
| self.prior = self.prior.half() | |
| self.text_encoder = TextEncoder(**self.config["text_enc_params"]) | |
| if self.use_fp16: | |
| self.text_encoder = self.text_encoder.half() | |
| self.clip_model, self.preprocess = clip.load( | |
| config["clip_name"], device=self.device, jit=False | |
| ) | |
| self.clip_model.eval() | |
| if self.config["image_enc_params"] is not None: | |
| self.use_image_enc = True | |
| self.scale = self.config["image_enc_params"]["scale"] | |
| if self.config["image_enc_params"]["name"] == "AutoencoderKL": | |
| self.image_encoder = AutoencoderKL( | |
| **self.config["image_enc_params"]["params"] | |
| ) | |
| elif self.config["image_enc_params"]["name"] == "VQModelInterface": | |
| self.image_encoder = VQModelInterface( | |
| **self.config["image_enc_params"]["params"] | |
| ) | |
| elif self.config["image_enc_params"]["name"] == "MOVQ": | |
| self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"]) | |
| self.image_encoder.load_state_dict( | |
| torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu') | |
| ) | |
| self.image_encoder.eval() | |
| else: | |
| self.use_image_enc = False | |
| self.config["model_config"]["cache_text_emb"] = True | |
| self.model = create_model(**self.config["model_config"]) | |
| self.model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| if self.use_fp16: | |
| self.model.convert_to_fp16() | |
| self.image_encoder = self.image_encoder.half() | |
| self.model_dtype = torch.float16 | |
| else: | |
| self.model_dtype = torch.float32 | |
| self.image_encoder = self.image_encoder.to(self.device).eval() | |
| self.text_encoder = self.text_encoder.to(self.device).eval() | |
| self.prior = self.prior.to(self.device).eval() | |
| self.model.eval() | |
| self.model.to(self.device) | |
| def get_new_h_w(self, h, w): | |
| new_h = h // 64 | |
| if h % 64 != 0: | |
| new_h += 1 | |
| new_w = w // 64 | |
| if w % 64 != 0: | |
| new_w += 1 | |
| return new_h * 8, new_w * 8 | |
| def encode_text(self, text_encoder, tokenizer, prompt, batch_size): | |
| text_encoding = tokenizer( | |
| [prompt] * batch_size + [""] * batch_size, | |
| max_length=77, | |
| padding="max_length", | |
| truncation=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| tokens = text_encoding["input_ids"].to(self.device) | |
| mask = text_encoding["attention_mask"].to(self.device) | |
| full_emb, pooled_emb = text_encoder(tokens=tokens, mask=mask) | |
| return full_emb, pooled_emb | |
| def generate_clip_emb( | |
| self, | |
| prompt, | |
| batch_size=1, | |
| prior_cf_scale=4, | |
| prior_steps="25", | |
| negative_prior_prompt="", | |
| ): | |
| prompts_batch = [prompt for _ in range(batch_size)] | |
| prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch) | |
| prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device) | |
| max_txt_length = self.prior.model.text_ctx | |
| tok, mask = self.tokenizer2.padded_tokens_and_mask( | |
| prompts_batch, max_txt_length | |
| ) | |
| cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask( | |
| [negative_prior_prompt], max_txt_length | |
| ) | |
| if not (cf_token.shape == tok.shape): | |
| cf_token = cf_token.expand(tok.shape[0], -1) | |
| cf_mask = cf_mask.expand(tok.shape[0], -1) | |
| tok = torch.cat([tok, cf_token], dim=0) | |
| mask = torch.cat([mask, cf_mask], dim=0) | |
| tok, mask = tok.to(device=self.device), mask.to(device=self.device) | |
| x = self.clip_model.token_embedding(tok).type(self.clip_model.dtype) | |
| x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype) | |
| x = x.permute(1, 0, 2) # NLD -> LND| | |
| x = self.clip_model.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.clip_model.ln_final(x).type(self.clip_model.dtype) | |
| txt_feat_seq = x | |
| txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection) | |
| txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device) | |
| img_feat = self.prior( | |
| txt_feat, | |
| txt_feat_seq, | |
| mask, | |
| prior_cf_scales_batch, | |
| timestep_respacing=prior_steps, | |
| ) | |
| return img_feat.to(self.model_dtype) | |
| def encode_images(self, image, is_pil=False): | |
| if is_pil: | |
| image = self.preprocess(image).unsqueeze(0).to(self.device) | |
| return self.clip_model.encode_image(image).to(self.model_dtype) | |
| def generate_img( | |
| self, | |
| prompt, | |
| img_prompt, | |
| batch_size=1, | |
| diffusion=None, | |
| guidance_scale=7, | |
| init_step=None, | |
| noise=None, | |
| init_img=None, | |
| img_mask=None, | |
| h=512, | |
| w=512, | |
| sampler="ddim_sampler", | |
| num_steps=50, | |
| ): | |
| new_h, new_w = self.get_new_h_w(h, w) | |
| full_batch_size = batch_size * 2 | |
| model_kwargs = {} | |
| if init_img is not None and self.use_fp16: | |
| init_img = init_img.half() | |
| if img_mask is not None and self.use_fp16: | |
| img_mask = img_mask.half() | |
| model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text( | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer1, | |
| prompt=prompt, | |
| batch_size=batch_size, | |
| ) | |
| model_kwargs["image_emb"] = img_prompt | |
| if self.task_type == "inpainting": | |
| init_img = init_img.to(self.device) | |
| img_mask = img_mask.to(self.device) | |
| model_kwargs["inpaint_image"] = init_img * img_mask | |
| model_kwargs["inpaint_mask"] = img_mask | |
| def model_fn(x_t, ts, **kwargs): | |
| half = x_t[: len(x_t) // 2] | |
| combined = torch.cat([half, half], dim=0) | |
| model_out = self.model(combined, ts, **kwargs) | |
| eps, rest = model_out[:, :4], model_out[:, 4:] | |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
| half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) | |
| eps = torch.cat([half_eps, half_eps], dim=0) | |
| if sampler == "p_sampler": | |
| return torch.cat([eps, rest], dim=1) | |
| else: | |
| return eps | |
| if noise is not None: | |
| noise = noise.float() | |
| if self.task_type == "inpainting": | |
| def denoised_fun(x_start): | |
| x_start = x_start.clamp(-2, 2) | |
| return x_start * (1 - img_mask) + init_img * img_mask | |
| else: | |
| def denoised_fun(x): | |
| return x.clamp(-2, 2) | |
| if sampler == "p_sampler": | |
| self.model.del_cache() | |
| samples = diffusion.p_sample_loop( | |
| model_fn, | |
| (full_batch_size, 4, new_h, new_w), | |
| device=self.device, | |
| noise=noise, | |
| progress=True, | |
| model_kwargs=model_kwargs, | |
| init_step=init_step, | |
| denoised_fn=denoised_fun, | |
| )[:batch_size] | |
| self.model.del_cache() | |
| else: | |
| if sampler == "ddim_sampler": | |
| sampler = DDIMSampler( | |
| model=model_fn, | |
| old_diffusion=diffusion, | |
| schedule="linear", | |
| ) | |
| elif sampler == "plms_sampler": | |
| sampler = PLMSSampler( | |
| model=model_fn, | |
| old_diffusion=diffusion, | |
| schedule="linear", | |
| ) | |
| else: | |
| raise ValueError("Only ddim_sampler and plms_sampler is available") | |
| self.model.del_cache() | |
| samples, _ = sampler.sample( | |
| num_steps, | |
| batch_size * 2, | |
| (4, new_h, new_w), | |
| conditioning=model_kwargs, | |
| x_T=noise, | |
| init_step=init_step, | |
| ) | |
| self.model.del_cache() | |
| samples = samples[:batch_size] | |
| if self.use_image_enc: | |
| if self.use_fp16: | |
| samples = samples.half() | |
| samples = self.image_encoder.decode(samples / self.scale) | |
| samples = samples[:, :, :h, :w] | |
| return process_images(samples) | |
| def create_zero_img_emb(self, batch_size): | |
| img = torch.zeros(1, 3, self.clip_image_size, self.clip_image_size).to(self.device) | |
| return self.encode_images(img, is_pil=False).repeat(batch_size, 1) | |
| def generate_text2img( | |
| self, | |
| prompt, | |
| num_steps=100, | |
| batch_size=1, | |
| guidance_scale=7, | |
| h=512, | |
| w=512, | |
| sampler="ddim_sampler", | |
| prior_cf_scale=4, | |
| prior_steps="25", | |
| negative_prior_prompt="", | |
| negative_decoder_prompt="", | |
| ): | |
| # generate clip embeddings | |
| image_emb = self.generate_clip_emb( | |
| prompt, | |
| batch_size=batch_size, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| negative_prior_prompt=negative_prior_prompt, | |
| ) | |
| if negative_decoder_prompt == "": | |
| zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
| else: | |
| zero_image_emb = self.generate_clip_emb( | |
| negative_decoder_prompt, | |
| batch_size=batch_size, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| negative_prior_prompt=negative_prior_prompt, | |
| ) | |
| image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
| # load diffusion | |
| config = deepcopy(self.config) | |
| if sampler == "p_sampler": | |
| config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
| diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
| return self.generate_img( | |
| prompt=prompt, | |
| img_prompt=image_emb, | |
| batch_size=batch_size, | |
| guidance_scale=guidance_scale, | |
| h=h, | |
| w=w, | |
| sampler=sampler, | |
| num_steps=num_steps, | |
| diffusion=diffusion, | |
| ) | |
| def mix_images( | |
| self, | |
| images_texts, | |
| weights, | |
| num_steps=100, | |
| batch_size=1, | |
| guidance_scale=7, | |
| h=512, | |
| w=512, | |
| sampler="ddim_sampler", | |
| prior_cf_scale=4, | |
| prior_steps="25", | |
| negative_prior_prompt="", | |
| negative_decoder_prompt="", | |
| ): | |
| assert len(images_texts) == len(weights) and len(images_texts) > 0 | |
| # generate clip embeddings | |
| image_emb = None | |
| for i in range(len(images_texts)): | |
| if image_emb is None: | |
| if type(images_texts[i]) == str: | |
| image_emb = weights[i] * self.generate_clip_emb( | |
| images_texts[i], | |
| batch_size=1, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| negative_prior_prompt=negative_prior_prompt, | |
| ) | |
| else: | |
| image_emb = self.encode_images(images_texts[i], is_pil=True) * weights[i] | |
| else: | |
| if type(images_texts[i]) == str: | |
| image_emb = image_emb + weights[i] * self.generate_clip_emb( | |
| images_texts[i], | |
| batch_size=1, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| negative_prior_prompt=negative_prior_prompt, | |
| ) | |
| else: | |
| image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i] | |
| image_emb = image_emb.repeat(batch_size, 1) | |
| if negative_decoder_prompt == "": | |
| zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
| else: | |
| zero_image_emb = self.generate_clip_emb( | |
| negative_decoder_prompt, | |
| batch_size=batch_size, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| negative_prior_prompt=negative_prior_prompt, | |
| ) | |
| image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
| # load diffusion | |
| config = deepcopy(self.config) | |
| if sampler == "p_sampler": | |
| config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
| diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
| return self.generate_img( | |
| prompt="", | |
| img_prompt=image_emb, | |
| batch_size=batch_size, | |
| guidance_scale=guidance_scale, | |
| h=h, | |
| w=w, | |
| sampler=sampler, | |
| num_steps=num_steps, | |
| diffusion=diffusion, | |
| ) | |
| def generate_img2img( | |
| self, | |
| prompt, | |
| pil_img, | |
| strength=0.7, | |
| num_steps=100, | |
| batch_size=1, | |
| guidance_scale=7, | |
| h=512, | |
| w=512, | |
| sampler="ddim_sampler", | |
| prior_cf_scale=4, | |
| prior_steps="25", | |
| ): | |
| # generate clip embeddings | |
| image_emb = self.generate_clip_emb( | |
| prompt, | |
| batch_size=batch_size, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| ) | |
| zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
| image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
| # load diffusion | |
| config = deepcopy(self.config) | |
| if sampler == "p_sampler": | |
| config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
| diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
| image = prepare_image(pil_img, h=h, w=w).to(self.device) | |
| if self.use_fp16: | |
| image = image.half() | |
| image = self.image_encoder.encode(image) * self.scale | |
| start_step = int(diffusion.num_timesteps * (1 - strength)) | |
| image = q_sample( | |
| image, | |
| torch.tensor(diffusion.timestep_map[start_step - 1]).to(self.device), | |
| schedule_name=config["diffusion_config"]["noise_schedule"], | |
| num_steps=config["diffusion_config"]["steps"], | |
| ) | |
| image = image.repeat(2, 1, 1, 1) | |
| return self.generate_img( | |
| prompt=prompt, | |
| img_prompt=image_emb, | |
| batch_size=batch_size, | |
| guidance_scale=guidance_scale, | |
| h=h, | |
| w=w, | |
| sampler=sampler, | |
| num_steps=num_steps, | |
| diffusion=diffusion, | |
| noise=image, | |
| init_step=start_step, | |
| ) | |
| def generate_inpainting( | |
| self, | |
| prompt, | |
| pil_img, | |
| img_mask, | |
| num_steps=100, | |
| batch_size=1, | |
| guidance_scale=7, | |
| h=512, | |
| w=512, | |
| sampler="ddim_sampler", | |
| prior_cf_scale=4, | |
| prior_steps="25", | |
| negative_prior_prompt="", | |
| negative_decoder_prompt="", | |
| ): | |
| # generate clip embeddings | |
| image_emb = self.generate_clip_emb( | |
| prompt, | |
| batch_size=batch_size, | |
| prior_cf_scale=prior_cf_scale, | |
| prior_steps=prior_steps, | |
| negative_prior_prompt=negative_prior_prompt, | |
| ) | |
| zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) | |
| image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) | |
| # load diffusion | |
| config = deepcopy(self.config) | |
| if sampler == "p_sampler": | |
| config["diffusion_config"]["timestep_respacing"] = str(num_steps) | |
| diffusion = create_gaussian_diffusion(**config["diffusion_config"]) | |
| image = prepare_image(pil_img, w, h).to(self.device) | |
| if self.use_fp16: | |
| image = image.half() | |
| image = self.image_encoder.encode(image) * self.scale | |
| image_shape = tuple(image.shape[-2:]) | |
| img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0) | |
| img_mask = F.interpolate( | |
| img_mask, | |
| image_shape, | |
| mode="nearest", | |
| ) | |
| img_mask = prepare_mask(img_mask).to(self.device) | |
| if self.use_fp16: | |
| img_mask = img_mask.half() | |
| image = image.repeat(2, 1, 1, 1) | |
| img_mask = img_mask.repeat(2, 1, 1, 1) | |
| return self.generate_img( | |
| prompt=prompt, | |
| img_prompt=image_emb, | |
| batch_size=batch_size, | |
| guidance_scale=guidance_scale, | |
| h=h, | |
| w=w, | |
| sampler=sampler, | |
| num_steps=num_steps, | |
| diffusion=diffusion, | |
| init_img=image, | |
| img_mask=img_mask, | |
| ) | |
| import os | |
| from huggingface_hub import hf_hub_url, cached_download | |
| from copy import deepcopy | |
| from omegaconf.dictconfig import DictConfig | |
| def get_kandinsky2_1( | |
| device, | |
| task_type="text2img", | |
| cache_dir="/tmp/kandinsky2", | |
| use_auth_token=None, | |
| use_flash_attention=False, | |
| ): | |
| cache_dir = os.path.join(cache_dir, "2_1") | |
| config = DictConfig(deepcopy(CONFIG_2_1)) | |
| config["model_config"]["use_flash_attention"] = use_flash_attention | |
| if task_type == "text2img": | |
| model_name = "decoder_fp16.ckpt" | |
| config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name) | |
| elif task_type == "inpainting": | |
| model_name = "inpainting_fp16.ckpt" | |
| config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name) | |
| cached_download( | |
| config_file_url, | |
| cache_dir=cache_dir, | |
| force_filename=model_name, | |
| use_auth_token=use_auth_token, | |
| ) | |
| prior_name = "prior_fp16.ckpt" | |
| config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name) | |
| cached_download( | |
| config_file_url, | |
| cache_dir=cache_dir, | |
| force_filename=prior_name, | |
| use_auth_token=use_auth_token, | |
| ) | |
| cache_dir_text_en = os.path.join(cache_dir, "text_encoder") | |
| for name in [ | |
| "config.json", | |
| "pytorch_model.bin", | |
| "sentencepiece.bpe.model", | |
| "special_tokens_map.json", | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| ]: | |
| config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}") | |
| cached_download( | |
| config_file_url, | |
| cache_dir=cache_dir_text_en, | |
| force_filename=name, | |
| use_auth_token=use_auth_token, | |
| ) | |
| config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt") | |
| cached_download( | |
| config_file_url, | |
| cache_dir=cache_dir, | |
| force_filename="movq_final.ckpt", | |
| use_auth_token=use_auth_token, | |
| ) | |
| config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th") | |
| cached_download( | |
| config_file_url, | |
| cache_dir=cache_dir, | |
| force_filename="ViT-L-14_stats.th", | |
| use_auth_token=use_auth_token, | |
| ) | |
| config["tokenizer_name"] = cache_dir_text_en | |
| config["text_enc_params"]["model_path"] = cache_dir_text_en | |
| config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th") | |
| config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt") | |
| cache_model_name = os.path.join(cache_dir, model_name) | |
| cache_prior_name = os.path.join(cache_dir, prior_name) | |
| model = Kandinsky2_1(config, cache_model_name, cache_prior_name, device, task_type=task_type) | |
| return model | |
| def get_kandinsky2( | |
| device, | |
| task_type="text2img", | |
| cache_dir="/tmp/kandinsky2", | |
| use_auth_token=None, | |
| model_version="2.1", | |
| use_flash_attention=False, | |
| ): | |
| if model_version == "2.0": | |
| model = get_kandinsky2_0( | |
| device, | |
| task_type=task_type, | |
| cache_dir=cache_dir, | |
| use_auth_token=use_auth_token, | |
| ) | |
| elif model_version == "2.1": | |
| model = get_kandinsky2_1( | |
| device, | |
| task_type=task_type, | |
| cache_dir=cache_dir, | |
| use_auth_token=use_auth_token, | |
| use_flash_attention=use_flash_attention, | |
| ) | |
| elif model_version == "2.2": | |
| model = Kandinsky2_2(device=device, task_type=task_type) | |
| else: | |
| raise ValueError("Only 2.0 and 2.1 is available") | |
| return model |