Yossilevii100 commited on
Commit
237f560
·
verified ·
1 Parent(s): c96e79c

add app.py

Browse files
Files changed (3) hide show
  1. app.py +86 -0
  2. requirements.txt +8 -0
  3. vslerp.py +557 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+
7
+ from vslerp import UnCLIPImageInterpolationPipeline # your pipeline + vSLERP
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load pipeline once
12
+ pipe = UnCLIPImageInterpolationPipeline.from_pretrained(
13
+ "kakaobrain/karlo-v1-alpha-image-variations",
14
+ torch_dtype=torch.float16
15
+ ).to(device)
16
+
17
+ # Put your own images in a local "bank" folder
18
+ IMAGE_BANK = {
19
+ "Example 1": "lj.png",
20
+ "Example 2": "kd.png",
21
+ "Example 3": "vase.png",
22
+ "Example 4": "lamp.jpeg"
23
+ }
24
+
25
+ def run_vslerp(img0, img1, bank0, bank1, slerp_num_steps, vslerp_start_idx, vslerp_end_idx, vslerp_num_steps):
26
+ # Decide input images: uploaded takes precedence, else from bank
27
+ if img0 is None and bank0 != "None":
28
+ img0 = Image.open(IMAGE_BANK[bank0])
29
+ if img1 is None and bank1 != "None":
30
+ img1 = Image.open(IMAGE_BANK[bank1])
31
+
32
+ if img0 is None or img1 is None:
33
+ raise ValueError("Please provide two images (either upload or select from bank).")
34
+
35
+ images = [img0, img1]
36
+ generator = torch.Generator(device=device).manual_seed(42)
37
+
38
+ # Prepare a 2D list for the gallery
39
+ gallery_matrix = []
40
+
41
+ vslerp_values = np.linspace(vslerp_start_idx, vslerp_end_idx, vslerp_num_steps)
42
+ for m_val in vslerp_values:
43
+ row = []
44
+ for step in range(slerp_num_steps):
45
+ out = pipe(
46
+ image=images,
47
+ generator=generator,
48
+ steps=slerp_num_steps,
49
+ decoder_guidance_scale=1,
50
+ mean_val=m_val
51
+ )
52
+ row.append(out.images[0]) # assuming pipe returns a list with one image per call
53
+ gallery_matrix.append(row)
54
+
55
+ return gallery_matrix
56
+
57
+
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown("## vSLERP Demo")
60
+ gr.Markdown("Note: The run may take a while, please be patient 🙏")
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ img0 = gr.Image(label="Upload Image 0", type="pil")
65
+ bank0 = gr.Dropdown(choices=["None"] + list(IMAGE_BANK.keys()), value="None", label="Or choose from bank")
66
+ with gr.Column():
67
+ img1 = gr.Image(label="Upload Image 1", type="pil")
68
+ bank1 = gr.Dropdown(choices=["None"] + list(IMAGE_BANK.keys()), value="None", label="Or choose from bank")
69
+
70
+ with gr.Row():
71
+ slerp_num_steps = gr.Slider(3, 6, value=6, step=1, label="slerp_num_steps")
72
+ vslerp_start_idx = gr.Slider(-2, 0, value=-1, step=1, label="vslerp_start_idx")
73
+ vslerp_end_idx = gr.Slider(1, 3, value=3, step=1, label="vslerp_end_idx")
74
+ vslerp_num_steps = gr.Slider(3, 6, value=6, step=1, label="vslerp_num_steps")
75
+
76
+ run_btn = gr.Button("Run vSLERP")
77
+ gallery = gr.Gallery(label="Generated Interpolations").style(grid=[4], height="auto")
78
+
79
+ run_btn.click(
80
+ run_vslerp,
81
+ inputs=[img0, img1, bank0, bank1, slerp_num_steps, vslerp_start_idx, vslerp_end_idx, vslerp_num_steps],
82
+ outputs=[gallery]
83
+ )
84
+
85
+ if __name__ == "__main__":
86
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt for CLIPLatent Space
2
+
3
+ torch
4
+ transformers
5
+ gradio
6
+ Pillow
7
+ numpy
8
+ matplotlib
vslerp.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Union
3
+
4
+ import argparse
5
+ import PIL
6
+ from PIL import Image
7
+ import torch
8
+ from torch.nn import functional as F
9
+ from transformers import (
10
+ CLIPFeatureExtractor,
11
+ CLIPTextModelWithProjection,
12
+ CLIPTokenizer,
13
+ CLIPVisionModelWithProjection,
14
+ )
15
+
16
+ from diffusers import (
17
+ DiffusionPipeline,
18
+ ImagePipelineOutput,
19
+ UnCLIPScheduler,
20
+ UNet2DConditionModel,
21
+ UNet2DModel,
22
+ )
23
+ from diffusers.pipelines.unclip import UnCLIPTextProjModel
24
+ from diffusers.utils import is_accelerate_available, logging
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+ import os
31
+ import scipy.io as sio
32
+ import numpy as np
33
+ from tqdm import tqdm
34
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
35
+
36
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
+
38
+ def vSLERP(val, low, high, mean_val = 1):
39
+ """
40
+ Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
41
+ """
42
+ # fetch and fit the mean magnitude
43
+ data = torch.load('mean_feat.pt').to(device).half()
44
+ mean_feats = data[0]
45
+
46
+ mean_feats = mean_feats*mean_val
47
+
48
+ # shift both features
49
+ low = low-mean_feats
50
+ high = high-mean_feats
51
+
52
+ # apply slerp
53
+ low_norm = low / torch.norm(low)
54
+ high_norm = high / torch.norm(high)
55
+ omega = torch.acos((low_norm * high_norm))
56
+ so = torch.sin(omega)
57
+ res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
58
+
59
+ # reshift both features back
60
+ res = res+mean_feats
61
+ return res
62
+
63
+
64
+ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
65
+ """
66
+ Pipeline to generate variations from an input image using unCLIP
67
+
68
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
69
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
70
+
71
+ Args:
72
+ text_encoder ([`CLIPTextModelWithProjection`]):
73
+ Frozen text-encoder.
74
+ tokenizer (`CLIPTokenizer`):
75
+ Tokenizer of class
76
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
77
+ feature_extractor ([`CLIPFeatureExtractor`]):
78
+ Model that extracts features from generated images to be used as inputs for the `image_encoder`.
79
+ image_encoder ([`CLIPVisionModelWithProjection`]):
80
+ Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
81
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
82
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
83
+ text_proj ([`UnCLIPTextProjModel`]):
84
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
85
+ decoder ([`UNet2DConditionModel`]):
86
+ The decoder to invert the image embedding into an image.
87
+ super_res_first ([`UNet2DModel`]):
88
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
89
+ super_res_last ([`UNet2DModel`]):
90
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
91
+ decoder_scheduler ([`UnCLIPScheduler`]):
92
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
93
+ super_res_scheduler ([`UnCLIPScheduler`]):
94
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
95
+
96
+ """
97
+
98
+ decoder: UNet2DConditionModel
99
+ text_proj: UnCLIPTextProjModel
100
+ text_encoder: CLIPTextModelWithProjection
101
+ tokenizer: CLIPTokenizer
102
+ feature_extractor: CLIPFeatureExtractor
103
+ image_encoder: CLIPVisionModelWithProjection
104
+ super_res_first: UNet2DModel
105
+ super_res_last: UNet2DModel
106
+
107
+ decoder_scheduler: UnCLIPScheduler
108
+ super_res_scheduler: UnCLIPScheduler
109
+
110
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__
111
+ def __init__(
112
+ self,
113
+ decoder: UNet2DConditionModel,
114
+ text_encoder: CLIPTextModelWithProjection,
115
+ tokenizer: CLIPTokenizer,
116
+ text_proj: UnCLIPTextProjModel,
117
+ feature_extractor: CLIPFeatureExtractor,
118
+ image_encoder: CLIPVisionModelWithProjection,
119
+ super_res_first: UNet2DModel,
120
+ super_res_last: UNet2DModel,
121
+ decoder_scheduler: UnCLIPScheduler,
122
+ super_res_scheduler: UnCLIPScheduler,
123
+ ):
124
+ super().__init__()
125
+
126
+ self.register_modules(
127
+ decoder=decoder,
128
+ text_encoder=text_encoder,
129
+ tokenizer=tokenizer,
130
+ text_proj=text_proj,
131
+ feature_extractor=feature_extractor,
132
+ image_encoder=image_encoder,
133
+ super_res_first=super_res_first,
134
+ super_res_last=super_res_last,
135
+ decoder_scheduler=decoder_scheduler,
136
+ super_res_scheduler=super_res_scheduler,
137
+ )
138
+
139
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
140
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
141
+ if latents is None:
142
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
143
+ else:
144
+ if latents.shape != shape:
145
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
146
+ latents = latents.to(device)
147
+
148
+ latents = latents * scheduler.init_noise_sigma
149
+ return latents
150
+
151
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt
152
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
153
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
154
+
155
+ # get prompt text embeddings
156
+ text_inputs = self.tokenizer(
157
+ prompt,
158
+ padding="max_length",
159
+ max_length=self.tokenizer.model_max_length,
160
+ return_tensors="pt",
161
+ )
162
+ text_input_ids = text_inputs.input_ids
163
+ text_mask = text_inputs.attention_mask.bool().to(device)
164
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
165
+
166
+ prompt_embeds = text_encoder_output.text_embeds
167
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
168
+
169
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
170
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
171
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
172
+
173
+ if do_classifier_free_guidance:
174
+ uncond_tokens = [""] * batch_size
175
+
176
+ max_length = text_input_ids.shape[-1]
177
+ uncond_input = self.tokenizer(
178
+ uncond_tokens,
179
+ padding="max_length",
180
+ max_length=max_length,
181
+ truncation=True,
182
+ return_tensors="pt",
183
+ )
184
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
185
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
186
+
187
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
188
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
189
+
190
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
191
+
192
+ seq_len = negative_prompt_embeds.shape[1]
193
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
194
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
195
+
196
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
197
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
198
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
199
+ batch_size * num_images_per_prompt, seq_len, -1
200
+ )
201
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
202
+
203
+ # done duplicates
204
+
205
+ # For classifier free guidance, we need to do two forward passes.
206
+ # Here we concatenate the unconditional and text embeddings into a single batch
207
+ # to avoid doing two forward passes
208
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
209
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
210
+
211
+ text_mask = torch.cat([uncond_text_mask, text_mask])
212
+
213
+ return prompt_embeds, text_encoder_hidden_states, text_mask
214
+
215
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image
216
+ def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
217
+ dtype = next(self.image_encoder.parameters()).dtype
218
+
219
+ if image_embeddings is None:
220
+ if not isinstance(image, torch.Tensor):
221
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
222
+
223
+ image = image.to(device=device, dtype=dtype)
224
+ image_embeddings = self.image_encoder(image).image_embeds
225
+
226
+ image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
227
+
228
+ return image_embeddings
229
+
230
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.enable_sequential_cpu_offload
231
+ def enable_sequential_cpu_offload(self, gpu_id=0):
232
+ r"""
233
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
234
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
235
+ when their specific submodule has its `forward` method called.
236
+ """
237
+ if is_accelerate_available():
238
+ from accelerate import cpu_offload
239
+ else:
240
+ raise ImportError("Please install accelerate via `pip install accelerate`")
241
+
242
+ device = torch.device(f"cuda:{gpu_id}")
243
+
244
+ models = [
245
+ self.decoder,
246
+ self.text_proj,
247
+ self.text_encoder,
248
+ self.super_res_first,
249
+ self.super_res_last,
250
+ ]
251
+ for cpu_offloaded_model in models:
252
+ if cpu_offloaded_model is not None:
253
+ cpu_offload(cpu_offloaded_model, device)
254
+
255
+ @property
256
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
257
+ def _execution_device(self):
258
+ r"""
259
+ Returns the device on which the pipeline's models will be executed. After calling
260
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
261
+ hooks.
262
+ """
263
+ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
264
+ return self.device
265
+ for module in self.decoder.modules():
266
+ if (
267
+ hasattr(module, "_hf_hook")
268
+ and hasattr(module._hf_hook, "execution_device")
269
+ and module._hf_hook.execution_device is not None
270
+ ):
271
+ return torch.device(module._hf_hook.execution_device)
272
+ return self.device
273
+
274
+ @torch.no_grad()
275
+ def __call__(
276
+ self,
277
+ image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
278
+ steps: int = 5,
279
+ decoder_num_inference_steps: int = 25,
280
+ super_res_num_inference_steps: int = 7,
281
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
282
+ image_embeddings: Optional[torch.Tensor] = None,
283
+ decoder_latents: Optional[torch.FloatTensor] = None,
284
+ super_res_latents: Optional[torch.FloatTensor] = None,
285
+ decoder_guidance_scale: float = 8.0,
286
+ output_type: Optional[str] = "pil",
287
+ return_dict: bool = True,
288
+ mean_val: float = 1.0
289
+ ):
290
+ """
291
+ Function invoked when calling the pipeline for generation.
292
+
293
+ Args:
294
+ image (`List[PIL.Image.Image]` or `torch.FloatTensor`):
295
+ The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
296
+ configuration of
297
+ [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
298
+ `CLIPFeatureExtractor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
299
+ steps (`int`, *optional*, defaults to 5):
300
+ The number of interpolation images to generate.
301
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
302
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
303
+ image at the expense of slower inference.
304
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
305
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
306
+ quality image at the expense of slower inference.
307
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
308
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
309
+ to make generation deterministic.
310
+ image_embeddings (`torch.Tensor`, *optional*):
311
+ Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
312
+ can be passed for tasks like image interpolations. `image` can the be left to `None`.
313
+ decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
314
+ Pre-generated noisy latents to be used as inputs for the decoder.
315
+ super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
316
+ Pre-generated noisy latents to be used as inputs for the decoder.
317
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
318
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
319
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
320
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
321
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
322
+ usually at the expense of lower image quality.
323
+ output_type (`str`, *optional*, defaults to `"pil"`):
324
+ The output format of the generated image. Choose between
325
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
326
+ return_dict (`bool`, *optional*, defaults to `True`):
327
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
328
+ """
329
+
330
+ batch_size = steps
331
+
332
+ device = self._execution_device
333
+
334
+ if isinstance(image, List):
335
+ if len(image) != 2:
336
+ raise AssertionError(
337
+ f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"
338
+ )
339
+ elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):
340
+ raise AssertionError(
341
+ f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
342
+ )
343
+ elif isinstance(image, torch.FloatTensor):
344
+ if image.shape[0] != 2:
345
+ raise AssertionError(
346
+ f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
347
+ )
348
+ elif isinstance(image_embeddings, torch.Tensor):
349
+ if image_embeddings.shape[0] != 2:
350
+ raise AssertionError(
351
+ f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
352
+ )
353
+ else:
354
+ raise AssertionError(
355
+ f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
356
+ )
357
+
358
+ original_image_embeddings = self._encode_image(
359
+ image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings
360
+ )
361
+
362
+ image_embeddings = []
363
+
364
+ for interp_step in torch.linspace(0, 1, steps):
365
+ temp_image_embeddings = vSLERP(
366
+ interp_step, original_image_embeddings[0], original_image_embeddings[1], mean_val = mean_val
367
+ ).unsqueeze(0)
368
+ image_embeddings.append(temp_image_embeddings)
369
+
370
+ image_embeddings = torch.cat(image_embeddings).to(device)
371
+
372
+
373
+ do_classifier_free_guidance = decoder_guidance_scale > 1.0
374
+
375
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
376
+ prompt=["" for i in range(steps)],
377
+ device=device,
378
+ num_images_per_prompt=1,
379
+ do_classifier_free_guidance=do_classifier_free_guidance,
380
+ )
381
+
382
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
383
+ image_embeddings=image_embeddings,
384
+ prompt_embeds=prompt_embeds,
385
+ text_encoder_hidden_states=text_encoder_hidden_states,
386
+ do_classifier_free_guidance=do_classifier_free_guidance,
387
+ )
388
+
389
+ if device.type == "mps":
390
+ # HACK: MPS: There is a panic when padding bool tensors,
391
+ # so cast to int tensor for the pad and back to bool afterwards
392
+ text_mask = text_mask.type(torch.int)
393
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
394
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
395
+ else:
396
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
397
+
398
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
399
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
400
+
401
+ num_channels_latents = self.decoder.in_channels
402
+ height = self.decoder.sample_size
403
+ width = self.decoder.sample_size
404
+
405
+ #decoder_latents = self.prepare_latents(
406
+ # (batch_size, num_channels_latents, height, width),
407
+ # text_encoder_hidden_states.dtype,
408
+ # device,
409
+ # generator,
410
+ # decoder_latents,
411
+ # self.decoder_scheduler,
412
+ #)
413
+
414
+ decoder_latents = self.prepare_latents(
415
+ (1, num_channels_latents, height, height),
416
+ text_encoder_hidden_states.dtype,
417
+ device,
418
+ generator,
419
+ None,
420
+ self.decoder_scheduler,
421
+ )
422
+
423
+ decoder_latents = decoder_latents.repeat(steps,1,1,1)
424
+
425
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
426
+ # expand the latents if we are doing classifier free guidance
427
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
428
+
429
+ noise_pred = self.decoder(
430
+ sample=latent_model_input,
431
+ timestep=t,
432
+ encoder_hidden_states=text_encoder_hidden_states,
433
+ class_labels=additive_clip_time_embeddings,
434
+ attention_mask=decoder_text_mask,
435
+ ).sample
436
+
437
+ if do_classifier_free_guidance:
438
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
439
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
440
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
441
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
442
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
443
+
444
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
445
+ prev_timestep = None
446
+ else:
447
+ prev_timestep = decoder_timesteps_tensor[i + 1]
448
+
449
+ # compute the previous noisy sample x_t -> x_t-1
450
+ decoder_latents = self.decoder_scheduler.step(
451
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
452
+ ).prev_sample
453
+
454
+ decoder_latents = decoder_latents.clamp(-1, 1)
455
+
456
+ image_small = decoder_latents
457
+ # done decoder
458
+
459
+ # super res
460
+
461
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
462
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
463
+
464
+ channels = self.super_res_first.in_channels // 2
465
+ height = self.super_res_first.sample_size
466
+ width = self.super_res_first.sample_size
467
+
468
+ super_res_latents = self.prepare_latents(
469
+ (batch_size, channels, height, width),
470
+ image_small.dtype,
471
+ device,
472
+ generator,
473
+ super_res_latents,
474
+ self.super_res_scheduler,
475
+ )
476
+
477
+ if device.type == "mps":
478
+ # MPS does not support many interpolations
479
+ image_upscaled = F.interpolate(image_small, size=[height, width])
480
+ else:
481
+ interpolate_antialias = {}
482
+ if "antialias" in inspect.signature(F.interpolate).parameters:
483
+ interpolate_antialias["antialias"] = True
484
+
485
+ image_upscaled = F.interpolate(
486
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
487
+ )
488
+
489
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
490
+ # no classifier free guidance
491
+
492
+ if i == super_res_timesteps_tensor.shape[0] - 1:
493
+ unet = self.super_res_last
494
+ else:
495
+ unet = self.super_res_first
496
+
497
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
498
+
499
+ noise_pred = unet(
500
+ sample=latent_model_input,
501
+ timestep=t,
502
+ ).sample
503
+
504
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
505
+ prev_timestep = None
506
+ else:
507
+ prev_timestep = super_res_timesteps_tensor[i + 1]
508
+
509
+ # compute the previous noisy sample x_t -> x_t-1
510
+ super_res_latents = self.super_res_scheduler.step(
511
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
512
+ ).prev_sample
513
+
514
+ image = super_res_latents
515
+ # done super res
516
+
517
+ # post processing
518
+
519
+ image = image * 0.5 + 0.5
520
+ image = image.clamp(0, 1)
521
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
522
+
523
+ if output_type == "pil":
524
+ image = self.numpy_to_pil(image)
525
+
526
+ if not return_dict:
527
+ return (image,)
528
+
529
+ return ImagePipelineOutput(images=image)
530
+
531
+
532
+ def main(args):
533
+ pipe = UnCLIPImageInterpolationPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype = torch.float16)
534
+ pipe.to(device)
535
+
536
+ images = [Image.open(args.image_path0), Image.open(args.image_path1)]
537
+ for m_iter, m_val in enumerate(np.linspace(args.vslerp_start_idx,args.vslerp_end_idx, args.vslerp_num_steps)):
538
+ generator = torch.Generator(device=device)
539
+ generator.manual_seed(42)
540
+ out = pipe(image = images, generator = generator, steps=args.slerp_num_steps, decoder_guidance_scale=1, mean_val = m_val)
541
+ for ii, image in enumerate(out.images):
542
+ img = Image.fromarray(np.array(image))
543
+ if not os.path.exists(f'{ii}'):
544
+ os.makedirs(f'{ii}')
545
+ img.save(os.path.join(f'{ii}', f'{m_iter}.png'))
546
+
547
+ if __name__ == "__main__":
548
+ args = argparse.ArgumentParser(description="Example script")
549
+ args.add_argument("--vslerp_start_idx", type=float, default=-1)
550
+ args.add_argument("--vslerp_end_idx", type=float, default=3)
551
+ args.add_argument("--vslerp_num_steps", type=int, default=16)
552
+ args.add_argument("--slerp_num_steps", type=int, default=6)
553
+ args.add_argument("--image_path0", type=str, default='path.to.image0')
554
+ args.add_argument("--image_path1", type=str, default='path.to.image1')
555
+ args = args.parse_args()
556
+ main(args)
557
+