multimodalart HF Staff commited on
Commit
b118cac
·
verified ·
1 Parent(s): 715bd60

Add generate tab

Browse files
Files changed (1) hide show
  1. app.py +74 -10
app.py CHANGED
@@ -15,6 +15,7 @@ import numpy as np
15
  from PIL import Image
16
  import random
17
  import gc
 
18
 
19
  # Import the optimization function from the separate file
20
  from optimization import optimize_pipeline_
@@ -66,9 +67,8 @@ for i in range(3):
66
  torch.cuda.synchronize()
67
  torch.cuda.empty_cache()
68
 
69
- # Calling the imported optimization function with a placeholder image for compilation tracing
70
  optimize_pipeline_(pipe,
71
- image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), # Use representative dims
72
  prompt='prompt',
73
  height=MIN_DIMENSION,
74
  width=MAX_DIMENSION,
@@ -78,6 +78,43 @@ print("All models loaded and optimized. Gradio app is ready.")
78
 
79
 
80
  # --- 2. Image Processing and Application Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def process_image_for_video(image: Image.Image) -> Image.Image:
83
  """
@@ -199,23 +236,37 @@ def generate_video(
199
  return video_path, current_seed
200
 
201
 
202
- # --- 3. Gradio User Interface --- (No changes needed here)
203
 
204
  css = '''
205
  .fillable{max-width: 1100px !important}
206
  .dark .progress-text {color: white}
 
 
 
 
 
 
 
 
 
207
  '''
208
  with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
209
  gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
210
  gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA")
211
 
212
- with gr.Row():
213
  with gr.Column():
214
- with gr.Group():
215
  with gr.Row():
216
  start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"])
217
- end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"])
218
-
 
 
 
 
 
219
  prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
220
 
221
  with gr.Accordion("Advanced Settings", open=False):
@@ -233,7 +284,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
233
  with gr.Column():
234
  output_video = gr.Video(label="Generated Video", autoplay=True)
235
 
236
- # Define the inputs list for the click event
237
  ui_inputs = [
238
  start_image,
239
  end_image,
@@ -246,7 +297,6 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
246
  seed_input,
247
  randomize_seed_checkbox
248
  ]
249
- # The seed_input is both an input and an output to reflect the randomly generated seed
250
  ui_outputs = [output_video, seed_input]
251
 
252
  generate_button.click(
@@ -255,6 +305,20 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
255
  outputs=ui_outputs
256
  )
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  gr.Examples(
259
  examples=[
260
  ["poli_tower.png", "tower_takes_off.png", "the man turns around"],
@@ -268,4 +332,4 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
268
  )
269
 
270
  if __name__ == "__main__":
271
- app.launch(share=True, show_error=True)
 
15
  from PIL import Image
16
  import random
17
  import gc
18
+ from gradio_client import Client, handle_file # Import for API call
19
 
20
  # Import the optimization function from the separate file
21
  from optimization import optimize_pipeline_
 
67
  torch.cuda.synchronize()
68
  torch.cuda.empty_cache()
69
 
 
70
  optimize_pipeline_(pipe,
71
+ image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
72
  prompt='prompt',
73
  height=MIN_DIMENSION,
74
  width=MAX_DIMENSION,
 
78
 
79
 
80
  # --- 2. Image Processing and Application Logic ---
81
+ def generate_end_frame(start_img, gen_prompt, progress=gr.Progress(track_tqdm=True)):
82
+ """Calls an external Gradio API to generate an image."""
83
+ if start_img is None:
84
+ raise gr.Error("Please provide a Start Frame first.")
85
+
86
+ hf_token = os.getenv("HF_TOKEN")
87
+ if not hf_token:
88
+ raise gr.Error("HF_TOKEN not found in environment variables. Please set it in your Space secrets.")
89
+
90
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
91
+ start_img.save(tmpfile.name)
92
+ tmp_path = tmpfile.name
93
+
94
+ progress(0.1, desc="Connecting to image generation API...")
95
+ client = Client("multimodalart/nano-banana")
96
+
97
+ progress(0.5, desc=f"Generating with prompt: '{gen_prompt}'...")
98
+ try:
99
+ result = client.predict(
100
+ prompt=gen_prompt,
101
+ images=[
102
+ {"image": handle_file(tmp_path)}
103
+ ],
104
+ manual_token=hf_token,
105
+ api_name="/unified_image_generator"
106
+ )
107
+ finally:
108
+ os.remove(tmp_path)
109
+
110
+ progress(1.0, desc="Done!")
111
+ print(result)
112
+ return result
113
+
114
+ def switch_to_upload_tab():
115
+ """Returns a gr.Tabs update to switch to the first tab."""
116
+ return gr.Tabs(selected="upload_tab")
117
+
118
 
119
  def process_image_for_video(image: Image.Image) -> Image.Image:
120
  """
 
236
  return video_path, current_seed
237
 
238
 
239
+ # --- 3. Gradio User Interface ---
240
 
241
  css = '''
242
  .fillable{max-width: 1100px !important}
243
  .dark .progress-text {color: white}
244
+ #general_items{margin-top: 2em}
245
+ #group_all{overflow:visible}
246
+ #group_all .styler{overflow:visible}
247
+ #group_tabs .tabitem{padding: 0}
248
+ .tab-wrapper{margin-top: -33px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;}
249
+ #component-9-button{width: 50%;justify-content: center}
250
+ #component-11-button{width: 50%;justify-content: center}
251
+ #or_item{text-align: center; padding-top: 1em; padding-bottom: 1em; font-size: 1.1em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
252
+ #fivesec{margin-top: 5em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
253
  '''
254
  with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
255
  gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
256
  gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA")
257
 
258
+ with gr.Row(elem_id="general_items"):
259
  with gr.Column():
260
+ with gr.Group(elem_id="group_all"):
261
  with gr.Row():
262
  start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"])
263
+ # Capture the Tabs component in a variable and assign IDs to tabs
264
+ with gr.Tabs(elem_id="group_tabs") as tabs:
265
+ with gr.TabItem("Upload", id="upload_tab"):
266
+ end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"])
267
+ with gr.TabItem("Generate", id="generate_tab"):
268
+ generate_5seconds = gr.Button("Generate scene 5 seconds in the future", elem_id="fivesec")
269
+ gr.Markdown("Generate a custom end-frame with an edit model like [Nano Banana](https://huggingface.co/spaces/multimodalart/nano-banana) or [Qwen Image Edit](https://huggingface.co/spaces/multimodalart/Qwen-Image-Edit-Fast)", elem_id="or_item")
270
  prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
271
 
272
  with gr.Accordion("Advanced Settings", open=False):
 
284
  with gr.Column():
285
  output_video = gr.Video(label="Generated Video", autoplay=True)
286
 
287
+ # Main video generation button
288
  ui_inputs = [
289
  start_image,
290
  end_image,
 
297
  seed_input,
298
  randomize_seed_checkbox
299
  ]
 
300
  ui_outputs = [output_video, seed_input]
301
 
302
  generate_button.click(
 
305
  outputs=ui_outputs
306
  )
307
 
308
+ generate_5seconds.click(
309
+ fn=switch_to_upload_tab,
310
+ inputs=None,
311
+ outputs=[tabs]
312
+ ).then(
313
+ fn=lambda img: generate_end_frame(img, "this image is a still frame from a movie. generate a new frame with what happens on this scene 5 seconds in the future"),
314
+ inputs=[start_image],
315
+ outputs=[end_image]
316
+ ).then(
317
+ fn=generate_video,
318
+ inputs=ui_inputs,
319
+ outputs=ui_outputs
320
+ )
321
+
322
  gr.Examples(
323
  examples=[
324
  ["poli_tower.png", "tower_takes_off.png", "the man turns around"],
 
332
  )
333
 
334
  if __name__ == "__main__":
335
+ app.launch(share=True)