Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| import argparse | |
| import random | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torchvision | |
| from pytorch_lightning import seed_everything | |
| from huggingface_hub import hf_hub_download | |
| sys.path.insert(0, "scripts/evaluation") | |
| from funcs import ( | |
| batch_ddim_sampling_freenoise, | |
| load_model_checkpoint, | |
| ) | |
| from utils.utils import instantiate_from_config | |
| def infer(prompt, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps): | |
| window_size = 16 | |
| window_stride = 4 | |
| if output_size == "320x512": | |
| width = 512 | |
| height = 320 | |
| ckpt_dir_512 = "checkpoints/base_512_v2" | |
| ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt" | |
| config_512 = "configs/inference_t2v_tconv512_v2.0_freenoise.yaml" | |
| config_512 = OmegaConf.load(config_512) | |
| model_config_512 = config_512.pop("model", OmegaConf.create()) | |
| model_512 = instantiate_from_config(model_config_512) | |
| model_512 = model_512.cuda() | |
| if not os.path.exists(ckpt_path_512): | |
| os.makedirs(ckpt_dir_512, exist_ok=True) | |
| hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512) | |
| try: | |
| model_512 = load_model_checkpoint(model_512, ckpt_path_512) | |
| except: | |
| hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True) | |
| model_512 = load_model_checkpoint(model_512, ckpt_path_512) | |
| model_512.eval() | |
| model = model_512 | |
| fps = 16 | |
| if output_size == "576x1024": | |
| width = 1024 | |
| height = 576 | |
| ckpt_dir_1024 = "checkpoints/base_1024_v1" | |
| ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt" | |
| config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml" | |
| config_1024 = OmegaConf.load(config_1024) | |
| model_config_1024 = config_1024.pop("model", OmegaConf.create()) | |
| model_1024 = instantiate_from_config(model_config_1024) | |
| model_1024 = model_1024.cuda() | |
| if not os.path.exists(ckpt_path_1024): | |
| os.makedirs(ckpt_dir_1024, exist_ok=True) | |
| hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024) | |
| try: | |
| model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024) | |
| except: | |
| hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024, force_download=True) | |
| model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024) | |
| model_1024.eval() | |
| model = model_1024 | |
| fps = 28 | |
| num_frames = min(num_frames, 36) | |
| elif output_size == "256x256": | |
| width = 256 | |
| height = 256 | |
| ckpt_dir_256 = "checkpoints/base_256_v1" | |
| ckpt_path_256 = "checkpoints/base_256_v1/model.ckpt" | |
| config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml" | |
| config_256 = OmegaConf.load(config_256) | |
| model_config_256 = config_256.pop("model", OmegaConf.create()) | |
| model_256 = instantiate_from_config(model_config_256) | |
| model_256 = model_256.cuda() | |
| if not os.path.exists(ckpt_path_256): | |
| os.makedirs(ckpt_dir_256, exist_ok=True) | |
| hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256) | |
| try: | |
| model_256 = load_model_checkpoint(model_256, ckpt_path_256) | |
| except: | |
| hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256, force_download=True) | |
| model_256 = load_model_checkpoint(model_256, ckpt_path_256) | |
| model_256.eval() | |
| model = model_256 | |
| fps = 8 | |
| print('Model Loaded.') | |
| if seed is None: | |
| seed = int.from_bytes(os.urandom(2), "big") | |
| print(f"Using seed: {seed}") | |
| seed_everything(seed) | |
| args = argparse.Namespace( | |
| mode="base", | |
| savefps=save_fps, | |
| n_samples=1, | |
| ddim_steps=ddim_steps, | |
| ddim_eta=0.0, | |
| bs=1, | |
| height=height, | |
| width=width, | |
| frames=num_frames, | |
| fps=fps, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_guidance_scale_temporal=None, | |
| cond_input=None, | |
| window_size=window_size, | |
| window_stride=window_stride, | |
| ) | |
| ## latent noise shape | |
| h, w = args.height // 8, args.width // 8 | |
| frames = model.temporal_length if args.frames < 0 else args.frames | |
| channels = model.channels | |
| x_T_total = torch.randn( | |
| [args.n_samples, 1, channels, frames, h, w], device=model.device | |
| ).repeat(1, args.bs, 1, 1, 1, 1) | |
| for frame_index in range(args.window_size, args.frames, args.window_stride): | |
| list_index = list( | |
| range( | |
| frame_index - args.window_size, | |
| frame_index + args.window_stride - args.window_size, | |
| ) | |
| ) | |
| random.shuffle(list_index) | |
| x_T_total[ | |
| :, :, :, frame_index : frame_index + args.window_stride | |
| ] = x_T_total[:, :, :, list_index] | |
| batch_size = 1 | |
| noise_shape = [batch_size, channels, frames, h, w] | |
| fps = torch.tensor([args.fps] * batch_size).to(model.device).long() | |
| prompts = [prompt] | |
| text_emb = model.get_learned_conditioning(prompts) | |
| cond = {"c_crossattn": [text_emb], "fps": fps} | |
| ## inference | |
| batch_samples = batch_ddim_sampling_freenoise( | |
| model, | |
| cond, | |
| noise_shape, | |
| args.n_samples, | |
| args.ddim_steps, | |
| args.ddim_eta, | |
| args.unconditional_guidance_scale, | |
| args=args, | |
| x_T_total=x_T_total, | |
| ) | |
| video_path = "output.mp4" | |
| vid_tensor = batch_samples[0] | |
| video = vid_tensor.detach().cpu() | |
| video = torch.clamp(video.float(), -1.0, 1.0) | |
| video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w | |
| frame_grids = [ | |
| torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) | |
| for framesheet in video | |
| ] # [3, 1*h, n*w] | |
| grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] | |
| grid = (grid + 1.0) / 2.0 | |
| grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) | |
| torchvision.io.write_video( | |
| video_path, | |
| grid, | |
| fps=args.savefps, | |
| video_codec="h264", | |
| options={"crf": "10"}, | |
| ) | |
| print(video_path) | |
| return video_path | |
| examples = [ | |
| ["A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",], | |
| ["A corgi is swimming quickly",], | |
| ["A bigfoot walking in the snowstorm",], | |
| ["Campfire at night in a snowy forest with starry sky in the background",], | |
| ["A panda is surfing in the universe",], | |
| ] | |
| css = """ | |
| #col-container {max-width: 640px; margin-left: auto; margin-right: auto;} | |
| a {text-decoration-line: underline; font-weight: 600;} | |
| .animate-spin { | |
| animation: spin 1s linear infinite; | |
| } | |
| @keyframes spin { | |
| from { | |
| transform: rotate(0deg); | |
| } | |
| to { | |
| transform: rotate(360deg); | |
| } | |
| } | |
| #share-btn-container { | |
| display: flex; | |
| padding-left: 0.5rem !important; | |
| padding-right: 0.5rem !important; | |
| background-color: #000000; | |
| justify-content: center; | |
| align-items: center; | |
| border-radius: 9999px !important; | |
| max-width: 15rem; | |
| height: 36px; | |
| } | |
| div#share-btn-container > div { | |
| flex-direction: row; | |
| background: black; | |
| align-items: center; | |
| } | |
| #share-btn-container:hover { | |
| background-color: #060606; | |
| } | |
| #share-btn { | |
| all: initial; | |
| color: #ffffff; | |
| font-weight: 600; | |
| cursor:pointer; | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| margin-left: 0.5rem !important; | |
| padding-top: 0.5rem !important; | |
| padding-bottom: 0.5rem !important; | |
| right:0; | |
| } | |
| #share-btn * { | |
| all: unset; | |
| } | |
| #share-btn-container div:nth-child(-n+2){ | |
| width: auto !important; | |
| min-height: 0px !important; | |
| } | |
| #share-btn-container .wrap { | |
| display: none !important; | |
| } | |
| #share-btn-container.hidden { | |
| display: none!important; | |
| } | |
| img[src*='#center'] { | |
| display: inline-block; | |
| margin: unset; | |
| } | |
| .footer { | |
| margin-bottom: 45px; | |
| margin-top: 10px; | |
| text-align: center; | |
| border-bottom: 1px solid #e5e5e5; | |
| } | |
| .footer>p { | |
| font-size: .8rem; | |
| display: inline-block; | |
| padding: 0 10px; | |
| transform: translateY(10px); | |
| background: white; | |
| } | |
| .dark .footer { | |
| border-color: #303030; | |
| } | |
| .dark .footer>p { | |
| background: #0b0f19; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown( | |
| """ | |
| <h1 style="text-align: center;">FreeNoise (Longer Text-to-Video)</h1> | |
| <p style="text-align: center;"> | |
| FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling (ICLR 2024) | |
| </p> | |
| <p style="text-align: center;"> | |
| <a href="https://arxiv.org/abs/2310.15169" target="_blank"><b>[arXiv]</b></a> | |
| <a href="http://haonanqiu.com/projects/FreeNoise.html" target="_blank"><b>[Project Page]</b></a> | |
| <a href="https://github.com/AILab-CVC/FreeNoise" target="_blank"><b>[Code]</b></a> | |
| </p> | |
| """ | |
| ) | |
| prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect") | |
| with gr.Row(): | |
| with gr.Accordion('FreeNoise Parameters (feel free to adjust these parameters based on your prompt): ', open=False): | |
| with gr.Row(): | |
| output_size = gr.Dropdown(["320x512", "576x1024", "256x256"], value="320x512", label="Output Size", info="250s for 512 model, 900s for 1024 model (32 frames). Recovering from sleeping will take more time to download ckpt") | |
| with gr.Row(): | |
| num_frames = gr.Slider(label='Frames (a multiple of 4), max 36 for 1024 model', | |
| minimum=16, | |
| maximum=64, | |
| step=4, | |
| value=32) | |
| ddim_steps = gr.Slider(label='DDIM Steps', | |
| minimum=5, | |
| maximum=200, | |
| step=1, | |
| value=50) | |
| with gr.Row(): | |
| unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale', | |
| minimum=1.0, | |
| maximum=20.0, | |
| step=0.1, | |
| value=12.0) | |
| save_fps = gr.Slider(label='Save FPS', | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=10) | |
| with gr.Row(): | |
| seed = gr.Slider(label='Random Seed', | |
| minimum=0, | |
| maximum=10000, | |
| step=1, | |
| value=123) | |
| submit_btn = gr.Button("Generate", variant='primary') | |
| video_result = gr.Video(label="Video Output") | |
| gr.Examples(examples=examples, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps]) | |
| submit_btn.click(fn=infer, | |
| inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps], | |
| outputs=[video_result], | |
| api_name="zrscp") | |
| demo.queue(max_size=12).launch(show_api=True) | |