QinOwen
commited on
Commit
·
5098655
1
Parent(s):
2846498
load-base-model-first
Browse files- VADER-VideoCrafter/scripts/main/train_t2v_lora.py +29 -32
- app.py +21 -49
- app_bk.py +273 -0
- gradio_cached_examples/32/indices.csv +1 -0
- gradio_cached_examples/32/log.csv +2 -0
- gradio_cached_examples/34/indices.csv +1 -0
- gradio_cached_examples/34/log.csv +2 -0
VADER-VideoCrafter/scripts/main/train_t2v_lora.py
CHANGED
|
@@ -567,7 +567,7 @@ def should_sample(global_step, validation_steps, is_sample_preview):
|
|
| 567 |
and is_sample_preview
|
| 568 |
|
| 569 |
|
| 570 |
-
def run_training(args,
|
| 571 |
## ---------------------step 1: accelerator setup---------------------------
|
| 572 |
accelerator = Accelerator( # Initialize Accelerator
|
| 573 |
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
@@ -576,6 +576,29 @@ def run_training(args, peft_model, **kwargs):
|
|
| 576 |
)
|
| 577 |
output_dir = args.project_dir
|
| 578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
# Make one log on every process with the configuration for debugging.
|
| 580 |
create_logging(logging, logger, accelerator)
|
| 581 |
|
|
@@ -698,7 +721,7 @@ def run_training(args, peft_model, **kwargs):
|
|
| 698 |
# ==================================================================
|
| 699 |
|
| 700 |
|
| 701 |
-
def setup_model(
|
| 702 |
parser = get_parser()
|
| 703 |
args = parser.parse_args()
|
| 704 |
|
|
@@ -721,41 +744,13 @@ def setup_model(lora_ckpt_path="huggingface-pickscore", lora_rank=16):
|
|
| 721 |
model.first_stage_model = model.first_stage_model.half()
|
| 722 |
model.cond_stage_model = model.cond_stage_model.half()
|
| 723 |
|
| 724 |
-
# step 2.1: add LoRA using peft
|
| 725 |
-
config = peft.LoraConfig(
|
| 726 |
-
r=lora_rank,
|
| 727 |
-
target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
|
| 728 |
-
lora_dropout=0.01,
|
| 729 |
-
)
|
| 730 |
|
| 731 |
-
peft_model = peft.get_peft_model(model, config)
|
| 732 |
-
|
| 733 |
-
peft_model.print_trainable_parameters()
|
| 734 |
-
|
| 735 |
-
# load the pretrained LoRA model
|
| 736 |
-
if lora_ckpt_path != "Base Model":
|
| 737 |
-
if lora_ckpt_path == "huggingface-hps-aesthetic": # download the pretrained LoRA model from huggingface
|
| 738 |
-
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
| 739 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_hps_aesthetic.pt'
|
| 740 |
-
elif lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
|
| 741 |
-
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
| 742 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
|
| 743 |
-
elif lora_ckpt_path == "peft_model_532":
|
| 744 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_532.pt'
|
| 745 |
-
elif lora_ckpt_path == "peft_model_548":
|
| 746 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_548.pt'
|
| 747 |
-
elif lora_ckpt_path == "peft_model_536":
|
| 748 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_536.pt'
|
| 749 |
-
elif lora_ckpt_path == "peft_model_400":
|
| 750 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_400.pt'
|
| 751 |
-
# load the pretrained LoRA model
|
| 752 |
-
peft.set_peft_model_state_dict(peft_model, torch.load(lora_ckpt_path))
|
| 753 |
|
| 754 |
print("Model setup complete!")
|
| 755 |
-
return
|
| 756 |
|
| 757 |
|
| 758 |
-
def main_fn(prompt, seed=200, height=320, width=512, unconditional_guidance_scale=12, ddim_steps=25, ddim_eta=1.0,
|
| 759 |
frames=24, savefps=10, model=None):
|
| 760 |
|
| 761 |
parser = get_parser()
|
|
@@ -765,6 +760,8 @@ def main_fn(prompt, seed=200, height=320, width=512, unconditional_guidance_scal
|
|
| 765 |
|
| 766 |
# overwrite the default arguments
|
| 767 |
args.prompt_str = prompt
|
|
|
|
|
|
|
| 768 |
args.seed = seed
|
| 769 |
args.height = height
|
| 770 |
args.width = width
|
|
|
|
| 567 |
and is_sample_preview
|
| 568 |
|
| 569 |
|
| 570 |
+
def run_training(args, model, **kwargs):
|
| 571 |
## ---------------------step 1: accelerator setup---------------------------
|
| 572 |
accelerator = Accelerator( # Initialize Accelerator
|
| 573 |
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
|
|
| 576 |
)
|
| 577 |
output_dir = args.project_dir
|
| 578 |
|
| 579 |
+
|
| 580 |
+
# step 2.1: add LoRA using peft
|
| 581 |
+
config = peft.LoraConfig(
|
| 582 |
+
r=args.lora_rank,
|
| 583 |
+
target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
|
| 584 |
+
lora_dropout=0.01,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
peft_model = peft.get_peft_model(model, config)
|
| 588 |
+
|
| 589 |
+
peft_model.print_trainable_parameters()
|
| 590 |
+
|
| 591 |
+
# load the pretrained LoRA model
|
| 592 |
+
if args.lora_ckpt_path != "Base Model":
|
| 593 |
+
if args.lora_ckpt_path == "huggingface-hps-aesthetic": # download the pretrained LoRA model from huggingface
|
| 594 |
+
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
| 595 |
+
args.lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_hps_aesthetic.pt'
|
| 596 |
+
elif args.lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
|
| 597 |
+
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
| 598 |
+
args.lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
|
| 599 |
+
# load the pretrained LoRA model
|
| 600 |
+
peft.set_peft_model_state_dict(peft_model, torch.load(args.lora_ckpt_path))
|
| 601 |
+
|
| 602 |
# Make one log on every process with the configuration for debugging.
|
| 603 |
create_logging(logging, logger, accelerator)
|
| 604 |
|
|
|
|
| 721 |
# ==================================================================
|
| 722 |
|
| 723 |
|
| 724 |
+
def setup_model():
|
| 725 |
parser = get_parser()
|
| 726 |
args = parser.parse_args()
|
| 727 |
|
|
|
|
| 744 |
model.first_stage_model = model.first_stage_model.half()
|
| 745 |
model.cond_stage_model = model.cond_stage_model.half()
|
| 746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
|
| 749 |
print("Model setup complete!")
|
| 750 |
+
return model
|
| 751 |
|
| 752 |
|
| 753 |
+
def main_fn(prompt, lora_model, lora_rank, seed=200, height=320, width=512, unconditional_guidance_scale=12, ddim_steps=25, ddim_eta=1.0,
|
| 754 |
frames=24, savefps=10, model=None):
|
| 755 |
|
| 756 |
parser = get_parser()
|
|
|
|
| 760 |
|
| 761 |
# overwrite the default arguments
|
| 762 |
args.prompt_str = prompt
|
| 763 |
+
args.lora_ckpt_path = lora_model
|
| 764 |
+
args.lora_rank = lora_rank
|
| 765 |
args.seed = seed
|
| 766 |
args.height = height
|
| 767 |
args.width = width
|
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import gradio as gr
|
|
| 2 |
import os
|
| 3 |
import spaces
|
| 4 |
import sys
|
|
|
|
| 5 |
sys.path.append('./VADER-VideoCrafter/scripts/main')
|
| 6 |
sys.path.append('./VADER-VideoCrafter/scripts')
|
| 7 |
sys.path.append('./VADER-VideoCrafter')
|
|
@@ -19,24 +20,26 @@ examples = [
|
|
| 19 |
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
|
| 20 |
]
|
| 21 |
|
| 22 |
-
model =
|
| 23 |
|
| 24 |
@spaces.GPU(duration=70)
|
| 25 |
-
def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
| 26 |
frames, savefps):
|
| 27 |
global model
|
| 28 |
if model is None:
|
| 29 |
return "Model is not loaded. Please load the model first."
|
| 30 |
video_path = main_fn(prompt=prompt,
|
|
|
|
|
|
|
| 31 |
seed=int(seed),
|
| 32 |
height=int(height),
|
| 33 |
-
width=int(width),
|
| 34 |
-
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
| 35 |
-
ddim_steps=int(ddim_steps),
|
| 36 |
ddim_eta=float(ddim_eta),
|
| 37 |
-
frames=int(frames),
|
| 38 |
savefps=int(savefps),
|
| 39 |
-
model=model)
|
| 40 |
|
| 41 |
return video_path
|
| 42 |
|
|
@@ -60,35 +63,6 @@ def update_dropdown(lora_rank):
|
|
| 60 |
else: # 0
|
| 61 |
return gr.update(value="Base Model")
|
| 62 |
|
| 63 |
-
@spaces.GPU(duration=180)
|
| 64 |
-
def setup_model_progress(lora_model, lora_rank):
|
| 65 |
-
global model
|
| 66 |
-
|
| 67 |
-
# Disable buttons and show loading indicator
|
| 68 |
-
yield (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "Loading model...")
|
| 69 |
-
|
| 70 |
-
model = setup_model(lora_model, lora_rank) # Ensure you pass the necessary parameters to the setup_model function
|
| 71 |
-
|
| 72 |
-
# Enable buttons after loading and update indicator
|
| 73 |
-
yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully")
|
| 74 |
-
|
| 75 |
-
@spaces.GPU(duration=300)
|
| 76 |
-
def generate_example(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
| 77 |
-
frames, savefps):
|
| 78 |
-
global model
|
| 79 |
-
model = setup_model(lora_model, lora_rank)
|
| 80 |
-
video_path = main_fn(prompt=prompt,
|
| 81 |
-
seed=int(seed),
|
| 82 |
-
height=int(height),
|
| 83 |
-
width=int(width),
|
| 84 |
-
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
| 85 |
-
ddim_steps=int(ddim_steps),
|
| 86 |
-
ddim_eta=float(ddim_eta),
|
| 87 |
-
frames=int(frames),
|
| 88 |
-
savefps=int(savefps),
|
| 89 |
-
model=model)
|
| 90 |
-
return video_path
|
| 91 |
-
|
| 92 |
custom_css = """
|
| 93 |
#centered {
|
| 94 |
display: flex;
|
|
@@ -215,23 +189,19 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 215 |
value="huggingface-pickscore"
|
| 216 |
)
|
| 217 |
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
|
| 222 |
with gr.Column(scale=0.3):
|
| 223 |
output_video = gr.Video(elem_id="image-upload")
|
| 224 |
|
| 225 |
with gr.Row(elem_id="centered"):
|
| 226 |
with gr.Column(scale=0.6):
|
| 227 |
-
|
| 228 |
-
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
|
| 229 |
|
| 230 |
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
|
| 231 |
|
| 232 |
-
run_btn = gr.Button("Run Inference")
|
| 233 |
-
|
| 234 |
-
|
| 235 |
with gr.Row():
|
| 236 |
height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
|
| 237 |
width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
|
|
@@ -252,10 +222,10 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 252 |
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
|
| 253 |
|
| 254 |
|
| 255 |
-
|
| 256 |
-
load_btn.click(fn=setup_model_progress, inputs=[lora_model, lora_rank], outputs=[load_btn, run_btn, reset_btn, loading_indicator])
|
| 257 |
run_btn.click(fn=gradio_main_fn,
|
| 258 |
-
inputs=[prompt,
|
|
|
|
|
|
|
| 259 |
outputs=output_video
|
| 260 |
)
|
| 261 |
|
|
@@ -263,9 +233,11 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 263 |
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
|
| 264 |
|
| 265 |
gr.Examples(examples=examples,
|
| 266 |
-
inputs=[prompt, lora_model, lora_rank, seed,
|
|
|
|
|
|
|
| 267 |
outputs=output_video,
|
| 268 |
-
fn=
|
| 269 |
run_on_click=False,
|
| 270 |
cache_examples="lazy",
|
| 271 |
)
|
|
|
|
| 2 |
import os
|
| 3 |
import spaces
|
| 4 |
import sys
|
| 5 |
+
from copy import deepcopy
|
| 6 |
sys.path.append('./VADER-VideoCrafter/scripts/main')
|
| 7 |
sys.path.append('./VADER-VideoCrafter/scripts')
|
| 8 |
sys.path.append('./VADER-VideoCrafter')
|
|
|
|
| 20 |
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
|
| 21 |
]
|
| 22 |
|
| 23 |
+
model = setup_model()
|
| 24 |
|
| 25 |
@spaces.GPU(duration=70)
|
| 26 |
+
def gradio_main_fn(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
| 27 |
frames, savefps):
|
| 28 |
global model
|
| 29 |
if model is None:
|
| 30 |
return "Model is not loaded. Please load the model first."
|
| 31 |
video_path = main_fn(prompt=prompt,
|
| 32 |
+
lora_model=lora_model,
|
| 33 |
+
lora_rank=int(lora_rank),
|
| 34 |
seed=int(seed),
|
| 35 |
height=int(height),
|
| 36 |
+
width=int(width),
|
| 37 |
+
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
| 38 |
+
ddim_steps=int(ddim_steps),
|
| 39 |
ddim_eta=float(ddim_eta),
|
| 40 |
+
frames=int(frames),
|
| 41 |
savefps=int(savefps),
|
| 42 |
+
model=deepcopy(model))
|
| 43 |
|
| 44 |
return video_path
|
| 45 |
|
|
|
|
| 63 |
else: # 0
|
| 64 |
return gr.update(value="Base Model")
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
custom_css = """
|
| 67 |
#centered {
|
| 68 |
display: flex;
|
|
|
|
| 189 |
value="huggingface-pickscore"
|
| 190 |
)
|
| 191 |
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
|
| 192 |
+
prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
|
| 193 |
+
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
|
| 194 |
+
run_btn = gr.Button("Run Inference")
|
| 195 |
|
| 196 |
with gr.Column(scale=0.3):
|
| 197 |
output_video = gr.Video(elem_id="image-upload")
|
| 198 |
|
| 199 |
with gr.Row(elem_id="centered"):
|
| 200 |
with gr.Column(scale=0.6):
|
| 201 |
+
|
|
|
|
| 202 |
|
| 203 |
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
| 205 |
with gr.Row():
|
| 206 |
height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
|
| 207 |
width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
|
|
|
|
| 222 |
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
|
| 223 |
|
| 224 |
|
|
|
|
|
|
|
| 225 |
run_btn.click(fn=gradio_main_fn,
|
| 226 |
+
inputs=[prompt, lora_model, lora_rank,
|
| 227 |
+
seed, height, width, unconditional_guidance_scale,
|
| 228 |
+
DDIM_Steps, DDIM_Eta, frames, savefps],
|
| 229 |
outputs=output_video
|
| 230 |
)
|
| 231 |
|
|
|
|
| 233 |
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
|
| 234 |
|
| 235 |
gr.Examples(examples=examples,
|
| 236 |
+
inputs=[prompt, lora_model, lora_rank, seed,
|
| 237 |
+
height, width, unconditional_guidance_scale,
|
| 238 |
+
DDIM_Steps, DDIM_Eta, frames, savefps],
|
| 239 |
outputs=output_video,
|
| 240 |
+
fn=gradio_main_fn,
|
| 241 |
run_on_click=False,
|
| 242 |
cache_examples="lazy",
|
| 243 |
)
|
app_bk.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import spaces
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('./VADER-VideoCrafter/scripts/main')
|
| 6 |
+
sys.path.append('./VADER-VideoCrafter/scripts')
|
| 7 |
+
sys.path.append('./VADER-VideoCrafter')
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from train_t2v_lora import main_fn, setup_model
|
| 11 |
+
|
| 12 |
+
examples = [
|
| 13 |
+
["A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
|
| 14 |
+
["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.", 'huggingface-hps-aesthetic', 8, 206, 384, 512, 12.0, 25, 1.0, 24, 10],
|
| 15 |
+
["A raccoon playing a guitar under a blossoming cherry tree.", 'huggingface-hps-aesthetic', 8, 204, 384, 512, 12.0, 25, 1.0, 24, 10],
|
| 16 |
+
["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
|
| 17 |
+
"huggingface-pickscore", 16, 205, 384, 512, 12.0, 25, 1.0, 24, 10],
|
| 18 |
+
["A talking bird with shimmering feathers and a melodious voice leads an adventure to find a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.",
|
| 19 |
+
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
model = None # Placeholder for model
|
| 23 |
+
|
| 24 |
+
@spaces.GPU(duration=70)
|
| 25 |
+
def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
| 26 |
+
frames, savefps):
|
| 27 |
+
global model
|
| 28 |
+
if model is None:
|
| 29 |
+
return "Model is not loaded. Please load the model first."
|
| 30 |
+
video_path = main_fn(prompt=prompt,
|
| 31 |
+
seed=int(seed),
|
| 32 |
+
height=int(height),
|
| 33 |
+
width=int(width),
|
| 34 |
+
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
| 35 |
+
ddim_steps=int(ddim_steps),
|
| 36 |
+
ddim_eta=float(ddim_eta),
|
| 37 |
+
frames=int(frames),
|
| 38 |
+
savefps=int(savefps),
|
| 39 |
+
model=model)
|
| 40 |
+
|
| 41 |
+
return video_path
|
| 42 |
+
|
| 43 |
+
def reset_fn():
|
| 44 |
+
return ("A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
|
| 45 |
+
200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore")
|
| 46 |
+
|
| 47 |
+
def update_lora_rank(lora_model):
|
| 48 |
+
if lora_model == "huggingface-pickscore":
|
| 49 |
+
return gr.update(value=16)
|
| 50 |
+
elif lora_model == "huggingface-hps-aesthetic":
|
| 51 |
+
return gr.update(value=8)
|
| 52 |
+
else: # "Base Model"
|
| 53 |
+
return gr.update(value=8)
|
| 54 |
+
|
| 55 |
+
def update_dropdown(lora_rank):
|
| 56 |
+
if lora_rank == 16:
|
| 57 |
+
return gr.update(value="huggingface-pickscore")
|
| 58 |
+
elif lora_rank == 8:
|
| 59 |
+
return gr.update(value="huggingface-hps-aesthetic")
|
| 60 |
+
else: # 0
|
| 61 |
+
return gr.update(value="Base Model")
|
| 62 |
+
|
| 63 |
+
@spaces.GPU(duration=120)
|
| 64 |
+
def setup_model_progress(lora_model, lora_rank):
|
| 65 |
+
global model
|
| 66 |
+
|
| 67 |
+
# Disable buttons and show loading indicator
|
| 68 |
+
yield (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "Loading model...")
|
| 69 |
+
|
| 70 |
+
model = setup_model(lora_model, lora_rank) # Ensure you pass the necessary parameters to the setup_model function
|
| 71 |
+
|
| 72 |
+
# Enable buttons after loading and update indicator
|
| 73 |
+
yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully")
|
| 74 |
+
|
| 75 |
+
@spaces.GPU(duration=180)
|
| 76 |
+
def generate_example(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
| 77 |
+
frames, savefps):
|
| 78 |
+
global model
|
| 79 |
+
model = setup_model(lora_model, lora_rank)
|
| 80 |
+
video_path = main_fn(prompt=prompt,
|
| 81 |
+
seed=int(seed),
|
| 82 |
+
height=int(height),
|
| 83 |
+
width=int(width),
|
| 84 |
+
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
| 85 |
+
ddim_steps=int(ddim_steps),
|
| 86 |
+
ddim_eta=float(ddim_eta),
|
| 87 |
+
frames=int(frames),
|
| 88 |
+
savefps=int(savefps),
|
| 89 |
+
model=model)
|
| 90 |
+
return video_path
|
| 91 |
+
|
| 92 |
+
custom_css = """
|
| 93 |
+
#centered {
|
| 94 |
+
display: flex;
|
| 95 |
+
justify-content: center;
|
| 96 |
+
}
|
| 97 |
+
.column-centered {
|
| 98 |
+
display: flex;
|
| 99 |
+
flex-direction: column;
|
| 100 |
+
align-items: center;
|
| 101 |
+
width: 60%;
|
| 102 |
+
}
|
| 103 |
+
#image-upload {
|
| 104 |
+
flex-grow: 1;
|
| 105 |
+
}
|
| 106 |
+
#params .tabs {
|
| 107 |
+
display: flex;
|
| 108 |
+
flex-direction: column;
|
| 109 |
+
flex-grow: 1;
|
| 110 |
+
}
|
| 111 |
+
#params .tabitem[style="display: block;"] {
|
| 112 |
+
flex-grow: 1;
|
| 113 |
+
display: flex !important;
|
| 114 |
+
}
|
| 115 |
+
#params .gap {
|
| 116 |
+
flex-grow: 1;
|
| 117 |
+
}
|
| 118 |
+
#params .form {
|
| 119 |
+
flex-grow: 1 !important;
|
| 120 |
+
}
|
| 121 |
+
#params .form > :last-child{
|
| 122 |
+
flex-grow: 1;
|
| 123 |
+
}
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
with gr.Blocks(css=custom_css) as demo:
|
| 127 |
+
with gr.Row():
|
| 128 |
+
with gr.Column():
|
| 129 |
+
gr.HTML(
|
| 130 |
+
"""
|
| 131 |
+
<h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'>
|
| 132 |
+
Video Diffusion Alignment via Reward Gradient
|
| 133 |
+
</h1>
|
| 134 |
+
"""
|
| 135 |
+
)
|
| 136 |
+
gr.HTML(
|
| 137 |
+
"""
|
| 138 |
+
<style>
|
| 139 |
+
body {
|
| 140 |
+
font-family: Arial, sans-serif;
|
| 141 |
+
text-align: center;
|
| 142 |
+
margin: 50px;
|
| 143 |
+
}
|
| 144 |
+
a {
|
| 145 |
+
text-decoration: none !important;
|
| 146 |
+
color: black !important;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
</style>
|
| 150 |
+
<body>
|
| 151 |
+
<div style="font-size: 1.4em; margin-bottom: 0.5em; ">
|
| 152 |
+
<a href="https://mihirp1998.github.io">Mihir Prabhudesai</a><sup>*</sup>
|
| 153 |
+
<a href="https://russellmendonca.github.io/">Russell Mendonca</a><sup>*</sup>
|
| 154 |
+
<a href="mailto: [email protected]">Zheyang Qin</a><sup>*</sup>
|
| 155 |
+
<a href="https://www.cs.cmu.edu/~katef/">Katerina Fragkiadaki</a><sup></sup>
|
| 156 |
+
<a href="https://www.cs.cmu.edu/~dpathak/">Deepak Pathak</a><sup></sup>
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
</div>
|
| 160 |
+
<div style="font-size: 1.3em; font-style: italic;">
|
| 161 |
+
Carnegie Mellon University
|
| 162 |
+
</div>
|
| 163 |
+
</body>
|
| 164 |
+
"""
|
| 165 |
+
)
|
| 166 |
+
gr.HTML(
|
| 167 |
+
"""
|
| 168 |
+
<head>
|
| 169 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
|
| 170 |
+
|
| 171 |
+
<style>
|
| 172 |
+
.button-container {
|
| 173 |
+
display: flex;
|
| 174 |
+
justify-content: center;
|
| 175 |
+
gap: 10px;
|
| 176 |
+
margin-top: 10px;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
.button-container a {
|
| 180 |
+
display: inline-flex;
|
| 181 |
+
align-items: center;
|
| 182 |
+
padding: 10px 20px;
|
| 183 |
+
border-radius: 30px;
|
| 184 |
+
border: 1px solid #ccc;
|
| 185 |
+
text-decoration: none;
|
| 186 |
+
color: #333 !important;
|
| 187 |
+
font-size: 16px;
|
| 188 |
+
text-decoration: none !important;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.button-container a i {
|
| 192 |
+
margin-right: 8px;
|
| 193 |
+
}
|
| 194 |
+
</style>
|
| 195 |
+
</head>
|
| 196 |
+
|
| 197 |
+
<div class="button-container">
|
| 198 |
+
<a href="https://arxiv.org/abs/2407.08737" class="btn btn-outline-primary">
|
| 199 |
+
<i class="fa-solid fa-file-pdf"></i> Paper
|
| 200 |
+
</a>
|
| 201 |
+
<a href="https://vader-vid.github.io/" class="btn btn-outline-danger">
|
| 202 |
+
<i class="fa-solid fa-video"></i> Website
|
| 203 |
+
<a href="https://github.com/mihirp1998/VADER" class="btn btn-outline-secondary">
|
| 204 |
+
<i class="fa-brands fa-github"></i> Code
|
| 205 |
+
</a>
|
| 206 |
+
</div>
|
| 207 |
+
"""
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
with gr.Row(elem_id="centered"):
|
| 211 |
+
with gr.Column(scale=0.3, elem_id="params"):
|
| 212 |
+
lora_model = gr.Dropdown(
|
| 213 |
+
label="VADER Model",
|
| 214 |
+
choices=["huggingface-pickscore", "huggingface-hps-aesthetic", "Base Model"],
|
| 215 |
+
value="huggingface-pickscore"
|
| 216 |
+
)
|
| 217 |
+
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
|
| 218 |
+
load_btn = gr.Button("Load Model")
|
| 219 |
+
# Add a label to show the loading indicator
|
| 220 |
+
loading_indicator = gr.Label(value="", label="Loading Indicator")
|
| 221 |
+
|
| 222 |
+
with gr.Column(scale=0.3):
|
| 223 |
+
output_video = gr.Video(elem_id="image-upload")
|
| 224 |
+
|
| 225 |
+
with gr.Row(elem_id="centered"):
|
| 226 |
+
with gr.Column(scale=0.6):
|
| 227 |
+
prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
|
| 228 |
+
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
|
| 229 |
+
|
| 230 |
+
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
|
| 231 |
+
|
| 232 |
+
run_btn = gr.Button("Run Inference")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
with gr.Row():
|
| 236 |
+
height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
|
| 237 |
+
width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
|
| 238 |
+
|
| 239 |
+
with gr.Row():
|
| 240 |
+
frames = gr.Slider(minimum=0, maximum=50, label="Frames", step = 1, value=24)
|
| 241 |
+
savefps = gr.Slider(minimum=0, maximum=60, label="Save FPS", step = 1, value=10)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
with gr.Row():
|
| 245 |
+
DDIM_Steps = gr.Slider(minimum=0, maximum=100, label="DDIM Steps", step = 1, value=25)
|
| 246 |
+
unconditional_guidance_scale = gr.Slider(minimum=0, maximum=50, label="Guidance Scale", step = 0.1, value=12.0)
|
| 247 |
+
DDIM_Eta = gr.Slider(minimum=0, maximum=1, label="DDIM Eta", step = 0.01, value=1.0)
|
| 248 |
+
|
| 249 |
+
# reset button
|
| 250 |
+
reset_btn = gr.Button("Reset")
|
| 251 |
+
|
| 252 |
+
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
load_btn.click(fn=setup_model_progress, inputs=[lora_model, lora_rank], outputs=[load_btn, run_btn, reset_btn, loading_indicator])
|
| 257 |
+
run_btn.click(fn=gradio_main_fn,
|
| 258 |
+
inputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps],
|
| 259 |
+
outputs=output_video
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank)
|
| 263 |
+
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
|
| 264 |
+
|
| 265 |
+
gr.Examples(examples=examples,
|
| 266 |
+
inputs=[prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps],
|
| 267 |
+
outputs=output_video,
|
| 268 |
+
fn=generate_example,
|
| 269 |
+
run_on_click=False,
|
| 270 |
+
cache_examples="lazy",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
demo.launch(share=True)
|
gradio_cached_examples/32/indices.csv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0
|
gradio_cached_examples/32/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
component 0,flag,username,timestamp
|
| 2 |
+
"{""video"": {""path"": ""gradio_cached_examples/32/component 0/fd156c6a458fa048724e/temporal.mp4"", ""url"": ""/file=/tmp/gradio/4bc133becbc469de8da700250f7f7df1103c6f56/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-19 00:00:10.509808
|
gradio_cached_examples/34/indices.csv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0
|
gradio_cached_examples/34/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
component 0,flag,username,timestamp
|
| 2 |
+
"{""video"": {""path"": ""gradio_cached_examples/34/component 0/d2ac1c9664e80f60d50f/temporal.mp4"", ""url"": ""/file=/tmp/gradio/4bc133becbc469de8da700250f7f7df1103c6f56/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 23:33:26.912888
|