Spaces:
Sleeping
Sleeping
| from typing import Literal | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import colorsys | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import VQModel | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel | |
| from chameleon.image_tokenizer import ImageTokenizer | |
| import torch.backends | |
| import torch.mps | |
| from PIL import Image | |
| import spaces | |
| Model = Literal["vqgan", "paella", "chameleon"] | |
| models = ["vqgan", "paella", "chameleon"] | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| class ImageRoundtripPipeline: | |
| def roundtrip_image(self, image, output_type="pil"): ... | |
| class VQImageRoundtripPipeline(ImageRoundtripPipeline): | |
| vqvae: VQModel | |
| vae_scale_factor: int | |
| vqvae_processor: VaeImageProcessor | |
| def __init__(self): | |
| self.vqvae = VQModel.from_pretrained("amused/amused-512", subfolder="vqvae") | |
| self.vqvae.eval() | |
| self.vqvae.to(device) | |
| self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) | |
| self.vqvae_processor = VaeImageProcessor( | |
| vae_scale_factor=self.vae_scale_factor, do_normalize=False | |
| ) | |
| print("VQ-GAN model loaded", self.vqvae) | |
| def roundtrip_image(self, image, output_type="pil"): | |
| image = self.vqvae_processor.preprocess(image) | |
| device = self.vqvae.device | |
| needs_upcasting = ( | |
| self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast | |
| ) | |
| batch_size, im_channels, height, width = image.shape | |
| if needs_upcasting: | |
| self.vqvae.float() | |
| latents = self.vqvae.encode( | |
| image.to(dtype=self.vqvae.dtype, device=device) | |
| ).latents | |
| latents_batch_size, latent_channels, latents_height, latents_width = ( | |
| latents.shape | |
| ) | |
| latents = self.vqvae.quantize(latents)[2][2].reshape( | |
| batch_size, latents_height, latents_width | |
| ) | |
| # replace 20% of latents with random values | |
| # random_latents = torch.randint( | |
| # 0, self.vqvae.config.num_vq_embeddings, latents.shape, device=device | |
| # ) | |
| # random_mask = torch.rand(latents.shape, device=device) < 0.2 | |
| # latents = torch.where(random_mask, random_latents, latents) | |
| output = self.vqvae.decode( | |
| latents, | |
| force_not_quantize=True, | |
| shape=( | |
| batch_size, | |
| height // self.vae_scale_factor, | |
| width // self.vae_scale_factor, | |
| self.vqvae.config.latent_channels, | |
| ), | |
| ).sample.clip(0, 1) | |
| output = self.vqvae_processor.postprocess(output, output_type) | |
| if needs_upcasting: | |
| self.vqvae.half() | |
| return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings | |
| class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline): | |
| tokenizer: ImageTokenizer | |
| n_embed: int | |
| vae_scale_factor: int | |
| def __init__(self): | |
| vqgan_path = hf_hub_download( | |
| "darknoon/chameleon-tokenizer", "tokenizer/vqgan.ckpt" | |
| ) | |
| vqgan_config_path = hf_hub_download( | |
| "darknoon/chameleon-tokenizer", "tokenizer/vqgan.yaml" | |
| ) | |
| self.tokenizer = ImageTokenizer( | |
| cfg_path=vqgan_config_path, ckpt_path=vqgan_path, device=device | |
| ) | |
| with open(vqgan_config_path) as f: | |
| vq_config = yaml.safe_load(f) | |
| self.n_embed = vq_config["model"]["params"]["n_embed"] | |
| self.vae_scale_factor = 16 | |
| print("Chameleon VQGan model loaded", self.tokenizer._vq_model, self.n_embed) | |
| def preprocess(self, image: Image): | |
| # copied from _vqgan_input_from | |
| np_img = np.array(image) / 255.0 # Normalize to [0, 1] | |
| np_img = np_img * 2 - 1 # Scale to [-1, 1] | |
| tensor_img = ( | |
| torch.from_numpy(np_img).permute(2, 0, 1).float() | |
| ) # (Channels, Height, Width) format. | |
| # Add batch dimension. | |
| return tensor_img.unsqueeze(0) | |
| def roundtrip_image(self, image, output_type="pil"): | |
| # image = self.tokenizer._vqgan_input_from(image).to(device) | |
| image = self.preprocess(image).to(device) | |
| _, _, im_height, im_width = image.shape | |
| _, _, [_, _, latents] = self.tokenizer._vq_model.encode(image) | |
| scale = self.vae_scale_factor | |
| shape = (1, im_height // scale, im_width // scale) | |
| output = self.tokenizer.pil_from_img_toks(latents, shape=shape) | |
| # we actually do want this to be a grid, sorry! | |
| latents = latents.reshape(*shape) | |
| return ( | |
| output, | |
| latents.cpu().numpy(), | |
| self.n_embed, | |
| ) | |
| class PaellaImageRoundtripPipeline(ImageRoundtripPipeline): | |
| vqgan: PaellaVQModel | |
| vae_scale_factor: int | |
| vqvae_processor: VaeImageProcessor | |
| def __init__(self): | |
| self.vqgan = PaellaVQModel.from_pretrained( | |
| "warp-ai/wuerstchen", subfolder="vqgan" | |
| ) | |
| self.vqgan.eval() | |
| self.vqgan.to(device) | |
| self.vae_scale_factor = 4 | |
| self.vqvae_processor = VaeImageProcessor( | |
| vae_scale_factor=self.vae_scale_factor, do_normalize=False | |
| ) | |
| print("Paella VQ-GAN model loaded", self.vqgan) | |
| def roundtrip_image(self, image, output_type="pil"): | |
| image = self.vqvae_processor.preprocess(image) | |
| device = self.vqgan.device | |
| batch_size, im_channels, height, width = image.shape | |
| latents = self.vqgan.encode( | |
| image.to(dtype=self.vqgan.dtype, device=device) | |
| ).latents | |
| latents_batch_size, latent_channels, latents_height, latents_width = ( | |
| latents.shape | |
| ) | |
| # latents = latents * self.vqgan.config.scale_factor | |
| # Manually quantize so we can inspect | |
| latents_q = self.vqgan.vquantizer(latents)[2][2].reshape( | |
| batch_size, latents_height, latents_width | |
| ) | |
| print("latents after quantize", (latents_q.shape, latents_q.dtype)) | |
| images = self.vqgan.decode(latents).sample.clamp(0, 1) | |
| output = self.vqvae_processor.postprocess(images, output_type) | |
| # if needs_upcasting: | |
| # self.vqgan.half() | |
| return output[0], latents_q.cpu().numpy(), self.vqgan.config.num_vq_embeddings | |
| pipeline_paella = PaellaImageRoundtripPipeline() | |
| pipeline_vq = VQImageRoundtripPipeline() | |
| pipeline_vq_chameleon = ChameleonVQImageRoundtripPipeline() | |
| # Function to generate a list of unique colors | |
| def generate_unique_colors_hsl(n): | |
| colors = [] | |
| for i in range(n): | |
| hue = i / (n // 4) # Distribute hues evenly around the color wheel 4 times | |
| lightness = 0.8 - (i / n) * 0.6 # Decrease brightness from 0.8 to 0.2 | |
| saturation = 1.0 | |
| rgb = colorsys.hls_to_rgb(hue, lightness, saturation) | |
| rgb = tuple(int(255 * x) for x in rgb) | |
| colors.append(rgb) | |
| return colors | |
| # Function to create the image from VQGAN tokens | |
| def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor): | |
| # Generate unique colors for each token in the codebook | |
| colors = generate_unique_colors_hsl(codebook_size) | |
| # Create a lookup table | |
| lookup_table = np.array(colors, dtype=np.uint8) | |
| # Extract the token array (remove the batch dimension) | |
| token_array = tokens[0] | |
| # Map tokens to their RGB colors using the lookup table | |
| color_image = lookup_table[token_array] | |
| # Create a PIL image from the numpy array | |
| img = Image.fromarray(color_image, "RGB") | |
| # Upscale the image using nearest neighbor interpolation | |
| img = img.resize( | |
| ( | |
| color_image.shape[1] * downscale_factor, | |
| color_image.shape[0] * downscale_factor, | |
| ), | |
| Image.NEAREST, | |
| ) | |
| return img | |
| def describe_shape(shape): | |
| return f"Shape: {shape} num elements: {np.prod(shape)}" | |
| def calc_psnr(img1: Image, img2: Image): | |
| if img1.size != img2.size: | |
| raise ValueError("Images must have the same dimensions") | |
| img1 = np.array(img1) | |
| img2 = np.array(img2) | |
| mse = np.mean((img1 - img2) ** 2) | |
| if mse == 0: | |
| return float("inf") | |
| return 2 * 10 * np.log10(255.0 / np.sqrt(mse)) | |
| def roundtrip_image( | |
| image, | |
| model: Model, | |
| size: Literal["256x256", "512x512", "1024x1024"], | |
| output_type="pil", | |
| ): | |
| if size == "256x256": | |
| image = image.resize((256, 256)) | |
| elif size == "512x512": | |
| image = image.resize((512, 512)) | |
| elif size == "1024x1024": | |
| image = image.resize((1024, 1024)) | |
| else: | |
| raise ValueError(f"Unknown size {size}") | |
| image_orig = image | |
| if model == "vqgan": | |
| pipeline = pipeline_vq | |
| elif model == "paella": | |
| pipeline = pipeline_paella | |
| elif model == "chameleon": | |
| pipeline = pipeline_vq_chameleon | |
| else: | |
| raise ValueError(f"Unknown model {model}") | |
| image, latents, codebook_size = pipeline.roundtrip_image(image, output_type) | |
| return ( | |
| image, | |
| vqgan_tokens_to_image( | |
| latents, codebook_size, downscale_factor=pipeline.vae_scale_factor | |
| ), | |
| describe_shape(latents.shape), | |
| f"{calc_psnr(image_orig, image):.2f}", | |
| ) | |
| demo = gr.Interface( | |
| fn=roundtrip_image, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Dropdown(models, label="Model", value="vqgan"), | |
| gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"), | |
| ], | |
| outputs=[ | |
| gr.Image(label="Reconstructed", format="png"), | |
| gr.Image(label="Tokens", format="png"), | |
| gr.Text(label="VQ Shape"), | |
| gr.Text(label="PSNR"), | |
| ], | |
| title="Image Tokenizer Playground", | |
| description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.", | |
| ) | |
| demo.launch() | |