Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from omegaconf import OmegaConf,ListConfig | |
| import spaces | |
| from train import main as train_main | |
| from inference import inference as inference_main | |
| import transformers | |
| transformers.utils.move_cache() | |
| def inference_app( | |
| embedding_dir, | |
| prompt, | |
| video_round, | |
| save_dir, | |
| motion_type, | |
| seed, | |
| inference_steps): | |
| print('inference info:') | |
| print('ref video:',embedding_dir) | |
| print('prompt:',prompt) | |
| print('motion type:',motion_type) | |
| print('infer steps:',inference_steps) | |
| return inference_main( | |
| embedding_dir=embedding_dir, | |
| prompt=prompt, | |
| video_round=video_round, | |
| save_dir=save_dir, | |
| motion_type=motion_type, | |
| seed=seed, | |
| inference_steps=inference_steps | |
| ) | |
| def train_model(video, config): | |
| output_dir = 'results' | |
| os.makedirs(output_dir, exist_ok=True) | |
| cur_save_dir = os.path.join(output_dir, 'custom') | |
| config.dataset.single_video_path = video | |
| config.train.output_dir = cur_save_dir | |
| # copy video to cur_save_dir | |
| video_name = 'source.mp4' | |
| video_path = os.path.join(cur_save_dir, video_name) | |
| os.system(f"cp {video} {video_path}") | |
| train_main(config) | |
| # cur_save_dir = 'results/06' | |
| return cur_save_dir | |
| def inference_model(text, checkpoint, inference_steps, video_type,seed): | |
| checkpoint = os.path.join('results',checkpoint) | |
| embedding_dir = '/'.join(checkpoint.split('/')[:-1]) | |
| video_round = checkpoint.split('/')[-1] | |
| video_path = inference_app( | |
| embedding_dir=embedding_dir, | |
| prompt=text, | |
| video_round=video_round, | |
| save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]), | |
| motion_type=video_type, | |
| seed=seed, | |
| inference_steps=inference_steps | |
| ) | |
| return video_path | |
| def get_checkpoints(checkpoint_dir): | |
| checkpoints = [] | |
| for root, dirs, files in os.walk(checkpoint_dir): | |
| for file in files: | |
| if file == 'motion_embed.pt': | |
| checkpoints.append('/'.join(root.split('/')[-2:])) | |
| return checkpoints | |
| def extract_combinations(motion_embeddings_combinations): | |
| assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required" | |
| combinations = [] | |
| for combination in motion_embeddings_combinations: | |
| name, resolution = combination.split(" ") | |
| combinations.append([name, int(resolution)]) | |
| return combinations | |
| def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
| default_config = OmegaConf.load('configs/config.yaml') | |
| default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
| default_config.model.unet = unet | |
| default_config.train.checkpointing_steps = checkpointing_steps | |
| default_config.train.max_train_steps = max_train_steps | |
| return default_config | |
| def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
| default_config = OmegaConf.load('configs/config.yaml') | |
| default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
| default_config.model.unet = unet | |
| default_config.train.checkpointing_steps = checkpointing_steps | |
| default_config.train.max_train_steps = max_train_steps | |
| return default_config | |
| def update_preview_video(checkpoint_dir): | |
| # get the parent dir of the checkpoint | |
| parent_dir = '/'.join(checkpoint_dir.split('/')[:-1]) | |
| return gr.update(value=f'results/{parent_dir}/source.mp4') | |
| def update_generated_prompt(text): | |
| return gr.update(value=text) | |
| if __name__ == "__main__": | |
| if os.path.exists('results/custom'): | |
| os.system('rm -rf results/custom') | |
| if os.path.exists('outputs'): | |
| os.system('rm -rf outputs') | |
| inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] | |
| default_motion_embeddings_combinations = ['down 1280','up 1280'] | |
| examples_inference = [ | |
| ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'], | |
| ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'], | |
| ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'], | |
| ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'], | |
| ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'], | |
| ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'], | |
| ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], | |
| ] | |
| gradio_theme = gr.themes.Default() | |
| with gr.Blocks( | |
| theme=gradio_theme, | |
| title="Motion Inversion", | |
| css=""" | |
| #download { | |
| height: 118px; | |
| } | |
| .slider .inner { | |
| width: 5px; | |
| background: #FFF; | |
| } | |
| .viewport { | |
| aspect-ratio: 4/3; | |
| } | |
| .tabs button.selected { | |
| font-size: 20px !important; | |
| color: crimson !important; | |
| } | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h2 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h3 { | |
| text-align: center; | |
| display: block; | |
| } | |
| .md_feedback li { | |
| margin-bottom: 0px !important; | |
| } | |
| """, | |
| head=""" | |
| <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
| <script> | |
| window.dataLayer = window.dataLayer || []; | |
| function gtag() {dataLayer.push(arguments);} | |
| gtag('js', new Date()); | |
| gtag('config', 'G-1FWSVCGZTG'); | |
| </script> | |
| """, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # Motion Inversion for Video Customization | |
| <p align="center"> | |
| <a href="https://arxiv.org/abs/2403.20193"><img src='https://img.shields.io/badge/arXiv-2403.20193-b31b1b.svg'></a> | |
| <a href=''><img src='https://img.shields.io/badge/Project_Page-MotionInversion(Coming soon)-blue'></a> | |
| <a href='https://github.com/EnVision-Research/MotionInversion'><img src='https://img.shields.io/github/stars/EnVision-Research/MotionInversion?label=GitHub%20%E2%98%85&logo=github&color=C8C'></a> | |
| <br> | |
| <strong>Please consider starring <span style="color: orange">★</span> the <a href="https://github.com/EnVision-Research/MotionInversion" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this useful!</strong> | |
| </p> | |
| """ | |
| ) | |
| with gr.Tabs(elem_classes=["tabs"]): | |
| with gr.Row(): | |
| with gr.Column(): | |
| preview_video = gr.Video(label="Preview Video") | |
| text_input = gr.Textbox(label="Input Text") | |
| checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results')) | |
| seed = gr.Number(label="Seed", value=0) | |
| inference_button = gr.Button("Generate Video") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Output Video") | |
| generated_prompt = gr.Textbox(label="Generated Prompt") | |
| with gr.Accordion('Encounter Errors', open=False): | |
| gr.Markdown(''' | |
| <strong>Generally, inference time for one video often takes 45~50s on ZeroGPU</strong>. | |
| <br> | |
| <strong>You have exceeded your GPU quota</strong>: A limitation set by HF. Retry in an hour. | |
| <br> | |
| <strong>GPU task aborted</strong>: Possibly caused by ZeroGPU being used by too many people, the inference time excceeds the time limit. You may try again later, or clone the repo and run it locally. | |
| <br> | |
| If any other issues occur, please feel free to contact us through the community or by email ([email protected]). We will try our best to help you :) | |
| ''') | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| inference_steps = gr.Number(label="Inference Steps", value=30) | |
| motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object") | |
| gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown]) | |
| checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video) | |
| inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video) | |
| output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt) | |
| demo.queue( | |
| api_open=False, | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) |