Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
| from train import export_to_video | |
| from models.unet.motion_embeddings import load_motion_embeddings | |
| from noise_init.blend_init import BlendInit | |
| from noise_init.blend_freq_init import BlendFreqInit | |
| from noise_init.fft_init import FFTInit | |
| from noise_init.freq_init import FreqInit | |
| from attn_ctrl import register_attention_control | |
| import numpy as np | |
| import os | |
| from omegaconf import OmegaConf | |
| def get_pipe(embedding_dir='baseline',config=None,noisy_latent=None, video_round=None): | |
| # load video generation model | |
| pipe = DiffusionPipeline.from_pretrained(config.model.pretrained_model_path,torch_dtype=torch.float16) | |
| # use videocrafterv2 unet | |
| if config.model.unet == 'videoCrafter2': | |
| from models.unet.unet_3d_condition import UNet3DConditionModel | |
| # unet = UNet3DConditionModel.from_pretrained("adamdad/videocrafterv2_diffusers",subfolder='unet',torch_dtype=torch.float16) | |
| unet = UNet3DConditionModel.from_pretrained("adamdad/videocrafterv2_diffusers",torch_dtype=torch.float16) | |
| pipe.unet = unet | |
| # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe.enable_model_cpu_offload() | |
| # memory optimization | |
| pipe.enable_vae_slicing() | |
| # if 'vanilla' not in embedding_dir: | |
| noisy_latent = torch.load(f'{embedding_dir}/cached_latents/cached_0.pt')['inversion_noise'][None,] | |
| if video_round is None: | |
| motion_embed = torch.load(f'{embedding_dir}/motion_embed.pt') | |
| else: | |
| motion_embed = torch.load(f'{embedding_dir}/{video_round}/motion_embed.pt') | |
| load_motion_embeddings( | |
| pipe.unet, | |
| motion_embed, | |
| ) | |
| config.model['embedding_layers'] = list(motion_embed.keys()) | |
| return pipe, config, noisy_latent | |
| def inference(embedding_dir='vanilla', | |
| video_round=None, | |
| prompt=None, | |
| save_dir=None, | |
| seed=None, | |
| motion_type=None, | |
| inference_steps=30 | |
| ): | |
| # check motion type is valid | |
| if motion_type != 'camera' and \ | |
| motion_type != 'object' and \ | |
| motion_type != 'hybrid': | |
| raise ValueError('Invalid motion type') | |
| if seed is None: | |
| seed = 0 | |
| # load motion embedding | |
| noisy_latent = None | |
| config = OmegaConf.load(f'{embedding_dir}/config.yaml') | |
| # different motion type assigns different strategy | |
| if motion_type == 'camera': | |
| config['strategy']['removeMFromV'] = True | |
| elif motion_type == 'object' or motion_type == 'hybrid': | |
| config['strategy']['vSpatial_frameSubtraction'] = True | |
| pipe, config, noisy_latent = get_pipe(embedding_dir=embedding_dir,config=config,noisy_latent=noisy_latent,video_round=video_round) | |
| n_frames = config.val.num_frames | |
| shape = (config.val.height,config.val.width) | |
| os.makedirs(save_dir,exist_ok=True) | |
| cur_save_dir = f'{save_dir}/{"_".join(prompt.split())}.mp4' | |
| register_attention_control(pipe.unet,config=config) | |
| if noisy_latent is not None: | |
| torch.manual_seed(seed) | |
| noise = torch.randn_like(noisy_latent) | |
| init_noise = BlendInit(noisy_latent, noise, noise_prior=0.5) | |
| else: | |
| init_noise = None | |
| input_init_noise = init_noise.clone() if not init_noise is None else None | |
| video_frames = pipe( | |
| prompt=prompt, | |
| num_inference_steps=inference_steps, | |
| guidance_scale=12, | |
| height=shape[0], | |
| width=shape[1], | |
| num_frames=n_frames, | |
| generator=torch.Generator("cuda").manual_seed(seed), | |
| latents=input_init_noise, | |
| ).frames[0] | |
| video_path = export_to_video(video_frames,output_video_path=cur_save_dir,fps=8) | |
| return video_path | |
| if __name__ =="__main__": | |
| prompts = ["A skateboard slides along a city lane", | |
| "A tank is running in the desert.", | |
| "A toy train chugs around a roundabout tree"] | |
| embedding_dir = './results' | |
| video_round = 'checkpoint-250' | |
| save_dir = f'outputs' | |
| inference( | |
| embedding_dir=embedding_dir, | |
| prompt=prompts, | |
| video_round=video_round, | |
| save_dir=save_dir, | |
| motion_type='hybrid', | |
| seed=100 | |
| ) | |