chingshuai commited on
Commit
fa80dfd
·
1 Parent(s): c4d5f5a

merge gradio_app, runtime

Browse files
gradio_app.py CHANGED
@@ -11,7 +11,7 @@ from typing import List, Optional, Tuple, Union
11
  import gradio as gr
12
  from hymotion.utils.gradio_runtime import ModelInference
13
  from hymotion.utils.gradio_utils import try_to_download_model, try_to_download_text_encoder
14
- from hymotion.utils.gradio_css import get_placeholder_html, APP_CSS, HEADER_BASE_MD, FOOTER_MD
15
  # Import spaces for Hugging Face Zero GPU support
16
  import spaces
17
 
@@ -20,6 +20,155 @@ DATA_SOURCES = {
20
  "example_prompts": "examples/example_prompts/example_subset.json",
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
24
  """Load examples from txt file."""
25
 
@@ -69,19 +218,19 @@ def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12
69
 
70
  return examples
71
 
 
72
  @spaces.GPU(duration=120) # Request GPU for up to 120 seconds per inference
73
  def generate_motion_func(
74
  # text input
75
  original_text: str,
76
  rewritten_text: str,
77
- use_prompt_engineering: bool,
78
  # model input
79
  seed_input: str,
80
  motion_duration: float,
81
  cfg_scale: float,
82
- # output
83
- output_dir: str,
84
  ) -> Tuple[str, List[str]]:
 
 
85
  # When rewrite is not available, use original_text directly
86
  if use_prompt_engineering:
87
  text_to_use = rewritten_text.strip()
@@ -106,7 +255,7 @@ def generate_motion_func(
106
  cfg_scale=cfg_scale,
107
  output_format=req_format,
108
  original_text=original_text,
109
- output_dir=output_dir
110
  )
111
  print(f"Running inference...after gpu_inference_wrapper")
112
  # Escape HTML content for srcdoc attribute
@@ -128,12 +277,25 @@ def generate_motion_func(
128
  [],
129
  )
130
 
 
131
  class T2MGradioUI:
132
  def __init__(self, args):
133
  self.output_dir = args.output_dir
134
  print(f"[{self.__class__.__name__}] output_dir: {self.output_dir}")
135
  # self.args = args
136
  self.prompt_engineering_available = args.use_prompt_engineering
 
 
 
 
 
 
 
 
 
 
 
 
137
  self.all_example_data = {}
138
  self._init_example_data()
139
 
@@ -162,34 +324,29 @@ class T2MGradioUI:
162
  seeds = [random.randint(0, 999) for _ in range(4)]
163
  return ",".join(map(str, seeds))
164
 
165
- def _prompt_engineering(
166
- self, text: str, duration: float, enable_rewrite: bool = True, enable_duration_est: bool = True
167
- ):
168
  if not text.strip():
169
- return "", gr.update(interactive=False), gr.update()
170
 
171
- call_llm = enable_rewrite or enable_duration_est
172
- if not call_llm:
173
- print(f"\t>>> Using original duration and original text...")
174
- predicted_duration = duration
175
- rewritten_text = text
176
- else:
177
- print(f"\t>>> Using LLM to estimate duration/rewrite text...")
178
- try:
179
- predicted_duration, rewritten_text = model_inference.rewrite_text_and_infer_time(text=text)
180
- except Exception as e:
181
- print(f"\t>>> Text rewriting/duration prediction failed: {e}")
182
- return (
183
- f"❌ Text rewriting/duration prediction failed: {str(e)}",
184
- gr.update(interactive=False),
185
- gr.update(),
186
- )
187
- if not enable_rewrite:
188
- rewritten_text = text
189
- if not enable_duration_est:
190
- predicted_duration = duration
191
 
192
- return rewritten_text, gr.update(interactive=True), gr.update(value=predicted_duration)
 
 
 
 
 
193
 
194
  def _get_example_choices(self):
195
  """Get all example choices from all data sources"""
@@ -204,7 +361,10 @@ class T2MGradioUI:
204
  def _on_example_select(self, selected_example):
205
  """When selecting an example, the callback function"""
206
  if selected_example == "Custom Input":
207
- return "", self._generate_random_seeds(), gr.update()
 
 
 
208
  else:
209
  # find the corresponding example from all data sources
210
  for source_name in self.all_example_data:
@@ -212,30 +372,45 @@ class T2MGradioUI:
212
  for text, duration in example_data:
213
  display_text = f"{text[:50]}..." if len(text) > 50 else text
214
  if display_text == selected_example:
215
- return text, self._generate_random_seeds(), gr.update(value=duration)
216
- return "", self._generate_random_seeds(), gr.update()
 
 
 
 
 
 
 
217
 
218
  def build_ui(self):
219
  with gr.Blocks(css=APP_CSS) as demo:
220
  # Create State components for non-UI values that need to be passed to event handlers
221
  self.use_prompt_engineering_state = gr.State(self.prompt_engineering_available)
222
  self.output_dir_state = gr.State(self.output_dir)
223
-
224
  self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
225
 
226
  with gr.Row():
227
  # Left control panel
228
  with gr.Column(scale=2, elem_classes=["left-panel"]):
 
229
  # Input textbox
230
  if self.prompt_engineering_available:
231
- input_place_holder = "Enter text to generate motion, support Chinese and English text input."
232
  else:
233
- input_place_holder = "Enter text to generate motion, please use `A person ...` format to describe the motion"
234
 
235
  self.text_input = gr.Textbox(
236
  label="📝 Input Text",
237
  placeholder=input_place_holder,
 
 
 
238
  )
 
 
 
 
239
  # Rewritten textbox
240
  self.rewritten_text = gr.Textbox(
241
  label="✏️ Rewritten Text",
@@ -281,18 +456,13 @@ class T2MGradioUI:
281
  interactive=not self.prompt_engineering_available, # Enable directly if rewrite not available
282
  )
283
 
284
- if not self.prompt_engineering_available:
285
- gr.Markdown(
286
- "> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
287
- )
288
-
289
 
290
  # Example selection dropdown
291
  self.example_dropdown = gr.Dropdown(
292
  choices=self._get_example_choices(),
293
  value="Custom Input",
294
- label="📚 Test Examples",
295
- info="Select a preset example or input your own text above",
296
  interactive=True,
297
  )
298
 
@@ -309,6 +479,9 @@ class T2MGradioUI:
309
  self.status_output = gr.Textbox(
310
  label="📊 Status Information",
311
  value=status_msg,
 
 
 
312
  )
313
 
314
  # FBX Download section
@@ -325,11 +498,27 @@ class T2MGradioUI:
325
  # Right display area
326
  with gr.Column(scale=3):
327
  self.output_display = gr.HTML(
328
- value=get_placeholder_html(),
329
- show_label=False,
330
- elem_classes=["flask-display"]
331
  )
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  # Footer
334
  gr.Markdown(FOOTER_MD, elem_classes=["footer"])
335
 
@@ -338,79 +527,73 @@ class T2MGradioUI:
338
  return demo
339
 
340
  def _build_advanced_settings(self):
341
- # Only show rewrite options if rewrite is available
342
- if self.prompt_engineering_available:
343
- with gr.Group():
344
- gr.Markdown("### 🔄 Text Rewriting Options")
345
- with gr.Row():
346
- self.enable_rewrite = gr.Checkbox(
347
- label="Enable Text Rewriting",
348
- value=True,
349
- info="Automatically optimize text prompt to get better motion generation",
350
- )
351
-
352
- with gr.Group():
353
- gr.Markdown("### ⏱️ Duration Settings")
354
- self.enable_duration_est = gr.Checkbox(
355
- label="Enable Duration Estimation",
356
- value=True,
357
- info="Automatically estimate the duration of the motion",
358
- )
359
- else:
360
- # Create hidden placeholders with default values (disabled)
361
- self.enable_rewrite = gr.Checkbox(
362
- label="Enable Text Rewriting",
363
- value=False,
364
- visible=False,
365
  )
366
- self.enable_duration_est = gr.Checkbox(
367
- label="Enable Duration Estimation",
368
- value=False,
369
- visible=False,
 
 
370
  )
371
- with gr.Group():
372
- gr.Markdown("### ⚠️ Prompt Engineering Unavailable")
373
- gr.Markdown(
374
- "Text rewriting and duration estimation are not available. "
375
- "Your input text and duration will be used directly."
376
- )
 
377
 
378
- with gr.Group():
379
- gr.Markdown("### ⚙️ Generation Parameters")
380
- with gr.Row():
381
- with gr.Column(scale=3):
382
- self.seed_input = gr.Textbox(
383
- label="🎯 Random Seed List (comma separated)",
384
- value="0,1,2,3",
385
- placeholder="Enter comma separated seed list (e.g.: 0,1,2,3)",
386
- info="Random seeds control the diversity of generated motions",
387
- )
388
- with gr.Column(scale=1, min_width=60, elem_classes=["dice-container"]):
389
- self.dice_btn = gr.Button(
390
- "🎲 Lucky Button",
391
- variant="secondary",
392
- size="sm",
393
- elem_classes=["dice-button"],
394
- )
395
 
396
- self.cfg_slider = gr.Slider(
397
- minimum=1,
398
- maximum=10,
399
- value=5.0,
400
- step=0.1,
401
- label="⚙️ CFG Strength",
402
- info="Text fidelity: higher = more faithful to the prompt",
 
 
 
 
 
 
 
 
 
 
 
 
403
  )
404
 
405
  def _bind_events(self):
406
  # Generate random seeds
407
  self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
408
 
 
 
 
 
 
 
 
409
  # Bind example selection event
410
  self.example_dropdown.change(
411
  fn=self._on_example_select,
412
  inputs=[self.example_dropdown],
413
- outputs=[self.text_input, self.seed_input, self.duration_slider],
414
  )
415
 
416
  # Rewrite text logic (only bind when rewrite is available)
@@ -420,16 +603,11 @@ class T2MGradioUI:
420
  inputs=[
421
  self.text_input,
422
  self.duration_slider,
423
- self.enable_rewrite,
424
- self.enable_duration_est,
425
  ],
426
- outputs=[self.rewritten_text, self.generate_btn, self.duration_slider],
427
  ).then(
428
- fn=lambda: (
429
- gr.update(visible=True),
430
- "Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
431
- ),
432
- outputs=[self.rewritten_text, self.status_output],
433
  )
434
 
435
  # Generate motion logic
@@ -438,16 +616,8 @@ class T2MGradioUI:
438
  outputs=[self.status_output],
439
  ).then(
440
  generate_motion_func,
441
- inputs=[
442
- self.text_input,
443
- self.rewritten_text,
444
- self.use_prompt_engineering_state,
445
- self.seed_input,
446
- self.duration_slider,
447
- self.cfg_slider,
448
- self.output_dir_state,
449
- ],
450
- outputs=[self.output_display, self.fbx_files]
451
  ).then(
452
  fn=lambda fbx_list: (
453
  (
@@ -463,12 +633,22 @@ class T2MGradioUI:
463
 
464
  # Reset logic - different behavior based on rewrite availability
465
  if self.prompt_engineering_available:
 
 
 
 
 
466
  self.text_input.change(
467
- fn=lambda: (
468
- gr.update(visible=False),
469
- gr.update(interactive=False),
470
- "Please click the [🔄 Rewrite Text] button to rewrite the text first",
 
 
 
 
471
  ),
 
472
  outputs=[self.rewritten_text, self.generate_btn, self.status_output],
473
  )
474
  else:
@@ -508,11 +688,8 @@ def create_demo(final_model_path):
508
  class Args:
509
  model_path = final_model_path
510
  output_dir = "output/gradio"
511
- use_prompt_engineering = False
512
  use_text_encoder = True
513
- prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
514
- prompt_engineering_model_path = os.environ.get("PROMPT_MODEL_PATH", None)
515
- disable_prompt_engineering = os.environ.get("DISABLE_PROMPT_ENGINEERING", False)
516
 
517
  args = Args()
518
 
@@ -538,11 +715,21 @@ def create_demo(final_model_path):
538
 
539
  if __name__ == "__main__":
540
  # Create demo at module level for Hugging Face Spaces
 
 
 
 
 
 
541
  try_to_download_text_encoder()
542
  # Then download the main model
543
  final_model_path = try_to_download_model()
544
- model_inference = ModelInference(final_model_path,
545
  use_prompt_engineering=False, use_text_encoder=True)
546
  model_inference.initialize_model(device="cpu")
 
 
 
 
547
  demo = create_demo(final_model_path)
548
- demo.launch(server_name="0.0.0.0")
 
11
  import gradio as gr
12
  from hymotion.utils.gradio_runtime import ModelInference
13
  from hymotion.utils.gradio_utils import try_to_download_model, try_to_download_text_encoder
14
+ from hymotion.utils.gradio_css import get_placeholder_html, APP_CSS, HEADER_BASE_MD, FOOTER_MD, WITHOUT_PROMPT_ENGINEERING_WARNING
15
  # Import spaces for Hugging Face Zero GPU support
16
  import spaces
17
 
 
20
  "example_prompts": "examples/example_prompts/example_subset.json",
21
  }
22
 
23
+ # Pre-generated examples for gallery display (generated on first startup)
24
+ # Add/remove items to control the number of examples
25
+ EXAMPLE_GALLERY_LIST = [
26
+ {
27
+ "prompt": "A person jumps upward with both legs twice.",
28
+ "duration": 4.5,
29
+ "seeds": "792",
30
+ "cfg_scale": 5.0,
31
+ "filename": "jump_twice",
32
+ },
33
+ # Add more examples here as needed:
34
+ {
35
+ "prompt": "A person jumps on their right leg.",
36
+ "duration": 4.5,
37
+ "seeds": "941",
38
+ "cfg_scale": 5.0,
39
+ "filename": "jump_right_leg",
40
+ },
41
+ ]
42
+ EXAMPLE_GALLERY_OUTPUT_DIR = "examples/pregenerated"
43
+
44
+ def ensure_examples_generated(model_inference_obj) -> List[str]:
45
+ """
46
+ Ensure all example motions are generated on first startup.
47
+ Returns a list of successfully generated example filenames.
48
+ """
49
+ example_dir = EXAMPLE_GALLERY_OUTPUT_DIR
50
+ os.makedirs(example_dir, exist_ok=True)
51
+
52
+ generated_examples = []
53
+
54
+ for example in EXAMPLE_GALLERY_LIST:
55
+ example_filename = example["filename"]
56
+ meta_path = os.path.join(example_dir, f"{example_filename}_meta.json")
57
+
58
+ # Check if already generated
59
+ if os.path.exists(meta_path):
60
+ print(f">>> Example already exists: {meta_path}")
61
+ generated_examples.append(example_filename)
62
+ continue
63
+
64
+ # Generate the example
65
+ print(f">>> Generating example motion: {example['prompt']}")
66
+ try:
67
+ html_content, fbx_files = model_inference_obj.run_inference(
68
+ text=example["prompt"],
69
+ seeds_csv=example["seeds"],
70
+ motion_duration=example["duration"],
71
+ cfg_scale=example["cfg_scale"],
72
+ output_format="dict", # Don't generate FBX for example
73
+ original_text=example["prompt"],
74
+ output_dir=example_dir,
75
+ output_filename=example_filename,
76
+ )
77
+ print(f">>> Example '{example_filename}' generated successfully!")
78
+ generated_examples.append(example_filename)
79
+ except Exception as e:
80
+ print(f">>> Failed to generate example '{example_filename}': {e}")
81
+
82
+ return generated_examples
83
+
84
+
85
+ def load_example_gallery_html(example_index: int = 0) -> str:
86
+ """
87
+ Load a specific pre-generated example and return iframe HTML for display.
88
+ Args:
89
+ example_index: Index of the example in EXAMPLE_GALLERY_LIST
90
+ """
91
+ from hymotion.utils.visualize_mesh_web import generate_static_html_content
92
+
93
+ if example_index < 0 or example_index >= len(EXAMPLE_GALLERY_LIST):
94
+ return ""
95
+
96
+ example = EXAMPLE_GALLERY_LIST[example_index]
97
+ example_dir = EXAMPLE_GALLERY_OUTPUT_DIR
98
+ example_filename = example["filename"]
99
+ meta_path = os.path.join(example_dir, f"{example_filename}_meta.json")
100
+
101
+ if not os.path.exists(meta_path):
102
+ return f"""
103
+ <div style='height: 300px; display: flex; justify-content: center; align-items: center;
104
+ background: #2d3748; border-radius: 12px; color: #a0aec0;'>
105
+ <p>Example not generated yet. Please restart the app.</p>
106
+ </div>
107
+ """
108
+
109
+ try:
110
+ html_content = generate_static_html_content(
111
+ folder_name=example_dir,
112
+ file_name=example_filename,
113
+ hide_captions=False,
114
+ )
115
+ escaped_html = html_content.replace('"', "&quot;")
116
+ iframe_html = f"""
117
+ <iframe
118
+ srcdoc="{escaped_html}"
119
+ width="100%"
120
+ height="350px"
121
+ style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
122
+ ></iframe>
123
+ """
124
+ return iframe_html
125
+ except Exception as e:
126
+ print(f">>> Failed to load example gallery: {e}")
127
+ return ""
128
+
129
+
130
+ def get_example_gallery_grid_html() -> str:
131
+ """
132
+ Generate a grid layout HTML for all examples in the gallery.
133
+ """
134
+ if not EXAMPLE_GALLERY_LIST:
135
+ return "<p>No examples configured.</p>"
136
+
137
+ # Calculate grid columns based on number of examples
138
+ num_examples = len(EXAMPLE_GALLERY_LIST)
139
+ if num_examples == 1:
140
+ columns = 1
141
+ elif num_examples == 2:
142
+ columns = 2
143
+ elif num_examples <= 4:
144
+ columns = 2
145
+ else:
146
+ columns = 3
147
+
148
+ grid_items = []
149
+ for idx, example in enumerate(EXAMPLE_GALLERY_LIST):
150
+ iframe_html = load_example_gallery_html(idx)
151
+ prompt_short = example["prompt"][:60] + "..." if len(example["prompt"]) > 60 else example["prompt"]
152
+
153
+ grid_items.append(f"""
154
+ <div class="example-grid-item" style="background: var(--card-bg, #fff); border-radius: 12px;
155
+ padding: 12px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
156
+ <div style="font-size: 14px; font-weight: 600; color: var(--text-primary, #333);
157
+ margin-bottom: 8px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap;">
158
+ {prompt_short}
159
+ </div>
160
+ {iframe_html}
161
+ </div>
162
+ """)
163
+
164
+ grid_html = f"""
165
+ <div style="display: grid; grid-template-columns: repeat({columns}, 1fr); gap: 16px; padding: 8px;">
166
+ {"".join(grid_items)}
167
+ </div>
168
+ """
169
+ return grid_html
170
+
171
+
172
  def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
173
  """Load examples from txt file."""
174
 
 
218
 
219
  return examples
220
 
221
+
222
  @spaces.GPU(duration=120) # Request GPU for up to 120 seconds per inference
223
  def generate_motion_func(
224
  # text input
225
  original_text: str,
226
  rewritten_text: str,
 
227
  # model input
228
  seed_input: str,
229
  motion_duration: float,
230
  cfg_scale: float,
 
 
231
  ) -> Tuple[str, List[str]]:
232
+ use_prompt_engineering = USE_PROMPT_ENGINEERING
233
+ output_dir = "output/gradio"
234
  # When rewrite is not available, use original_text directly
235
  if use_prompt_engineering:
236
  text_to_use = rewritten_text.strip()
 
255
  cfg_scale=cfg_scale,
256
  output_format=req_format,
257
  original_text=original_text,
258
+ output_dir=output_dir,
259
  )
260
  print(f"Running inference...after gpu_inference_wrapper")
261
  # Escape HTML content for srcdoc attribute
 
277
  [],
278
  )
279
 
280
+
281
  class T2MGradioUI:
282
  def __init__(self, args):
283
  self.output_dir = args.output_dir
284
  print(f"[{self.__class__.__name__}] output_dir: {self.output_dir}")
285
  # self.args = args
286
  self.prompt_engineering_available = args.use_prompt_engineering
287
+ if self.prompt_engineering_available:
288
+ try:
289
+ from hymotion.prompt_engineering.client import PromptEngineeringClient
290
+ self.prompt_engineering_client = PromptEngineeringClient()
291
+ # Test the client with a simple prompt to verify it works
292
+ self.prompt_engineering_client.rewrite_prompt_and_infer_time("A person walks forward.", max_timeout=30)
293
+ print(f"[{self.__class__.__name__}] Prompt engineering client initialized successfully.")
294
+ except Exception as e:
295
+ print(f"[{self.__class__.__name__}] Prompt engineering client initialization failed: {e}")
296
+ self.prompt_engineering_available = False
297
+
298
+
299
  self.all_example_data = {}
300
  self._init_example_data()
301
 
 
324
  seeds = [random.randint(0, 999) for _ in range(4)]
325
  return ",".join(map(str, seeds))
326
 
327
+ def _prompt_engineering(self, text: str, duration: float):
 
 
328
  if not text.strip():
329
+ return "", gr.update(interactive=False), gr.update(), "⚠️ Please enter text first"
330
 
331
+ print(f"\t>>> Using LLM to estimate duration/rewrite text...")
332
+ try:
333
+ predicted_duration, rewritten_text = self.prompt_engineering_client.rewrite_prompt_and_infer_time(text=text)
334
+ except Exception as e:
335
+ print(f"\t>>> Text rewriting/duration prediction failed: {e}")
336
+ # On failure, use original text and enable generate button
337
+ return (
338
+ text, # Use original text as fallback
339
+ gr.update(interactive=True), # Enable generate button
340
+ gr.update(),
341
+ f"⚠️ Text rewriting failed: {str(e)}\n💡 Using your original input directly. You can click [🚀 Generate Motion] to continue.",
342
+ )
 
 
 
 
 
 
 
 
343
 
344
+ return (
345
+ rewritten_text,
346
+ gr.update(interactive=True),
347
+ gr.update(value=predicted_duration),
348
+ "✅ Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
349
+ )
350
 
351
  def _get_example_choices(self):
352
  """Get all example choices from all data sources"""
 
361
  def _on_example_select(self, selected_example):
362
  """When selecting an example, the callback function"""
363
  if selected_example == "Custom Input":
364
+ if self.prompt_engineering_available:
365
+ return "", self._generate_random_seeds(), gr.update(), gr.update(value="", visible=False), gr.update(interactive=False), "Please enter text or select an example"
366
+ else:
367
+ return "", self._generate_random_seeds(), gr.update(), gr.update(), gr.update(), gr.update()
368
  else:
369
  # find the corresponding example from all data sources
370
  for source_name in self.all_example_data:
 
372
  for text, duration in example_data:
373
  display_text = f"{text[:50]}..." if len(text) > 50 else text
374
  if display_text == selected_example:
375
+ if self.prompt_engineering_available:
376
+ # Set text directly to rewritten_text and enable generate button
377
+ return text, self._generate_random_seeds(), gr.update(value=duration), gr.update(value=text, visible=True), gr.update(interactive=True), "✅ Example selected! Click [🚀 Generate Motion] to start."
378
+ else:
379
+ return text, self._generate_random_seeds(), gr.update(value=duration), gr.update(), gr.update(), gr.update()
380
+ if self.prompt_engineering_available:
381
+ return "", self._generate_random_seeds(), gr.update(), gr.update(value="", visible=False), gr.update(interactive=False), "Please enter text or select an example"
382
+ else:
383
+ return "", self._generate_random_seeds(), gr.update(), gr.update(), gr.update(), gr.update()
384
 
385
  def build_ui(self):
386
  with gr.Blocks(css=APP_CSS) as demo:
387
  # Create State components for non-UI values that need to be passed to event handlers
388
  self.use_prompt_engineering_state = gr.State(self.prompt_engineering_available)
389
  self.output_dir_state = gr.State(self.output_dir)
390
+
391
  self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
392
 
393
  with gr.Row():
394
  # Left control panel
395
  with gr.Column(scale=2, elem_classes=["left-panel"]):
396
+
397
  # Input textbox
398
  if self.prompt_engineering_available:
399
+ input_place_holder = "Enter text to generate motion, support Chinese and English text input. Non-humanoid Characters, Multi-person Interactions and Environment & Camera are not supported. Click [ 📚 Example Prompts ] to see more examples."
400
  else:
401
+ input_place_holder = "Enter English text to generate motion, please use `A person ...` format to describe the motion, better less than 50 words. Non-humanoid Characters, Multi-person Interactions and Environment & Camera are not supported. Click [ 📚 Example Prompts ] to see more examples."
402
 
403
  self.text_input = gr.Textbox(
404
  label="📝 Input Text",
405
  placeholder=input_place_holder,
406
+ lines=3,
407
+ max_lines=10,
408
+ autoscroll=False,
409
  )
410
+ # if not self.prompt_engineering_available:
411
+ # gr.Markdown(
412
+ # "Click [📚 Example Prompts] to see more examples."
413
+ # )
414
  # Rewritten textbox
415
  self.rewritten_text = gr.Textbox(
416
  label="✏️ Rewritten Text",
 
456
  interactive=not self.prompt_engineering_available, # Enable directly if rewrite not available
457
  )
458
 
 
 
 
 
 
459
 
460
  # Example selection dropdown
461
  self.example_dropdown = gr.Dropdown(
462
  choices=self._get_example_choices(),
463
  value="Custom Input",
464
+ label="📚 Example Prompts",
465
+ # info="Select a preset example or input your own text above",
466
  interactive=True,
467
  )
468
 
 
479
  self.status_output = gr.Textbox(
480
  label="📊 Status Information",
481
  value=status_msg,
482
+ lines=1,
483
+ max_lines=10,
484
+ elem_classes=["status-textbox"],
485
  )
486
 
487
  # FBX Download section
 
498
  # Right display area
499
  with gr.Column(scale=3):
500
  self.output_display = gr.HTML(
501
+ value=get_placeholder_html(), show_label=False, elem_classes=["flask-display"]
 
 
502
  )
503
 
504
+ # Example Gallery Section
505
+ with gr.Accordion("🎬 Example Gallery", open=True):
506
+ self.example_gallery_display = gr.HTML(
507
+ value=get_example_gallery_grid_html(),
508
+ show_label=False,
509
+ elem_classes=["example-gallery-display"]
510
+ )
511
+ # Create use example buttons for each example
512
+ with gr.Row():
513
+ self.use_example_btns = []
514
+ for idx, example in enumerate(EXAMPLE_GALLERY_LIST):
515
+ btn = gr.Button(
516
+ f"📋 Use Example {idx + 1}",
517
+ variant="secondary",
518
+ size="sm",
519
+ )
520
+ self.use_example_btns.append((btn, idx))
521
+
522
  # Footer
523
  gr.Markdown(FOOTER_MD, elem_classes=["footer"])
524
 
 
527
  return demo
528
 
529
  def _build_advanced_settings(self):
530
+ with gr.Row():
531
+ self.seed_input = gr.Textbox(
532
+ label="🎯 Random Seeds",
533
+ value="0,1,2,3",
534
+ placeholder="e.g.: 0,1,2,3",
535
+ scale=3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  )
537
+ self.dice_btn = gr.Button(
538
+ "🎲",
539
+ variant="secondary",
540
+ size="sm",
541
+ scale=1,
542
+ min_width=50,
543
  )
544
+ self.cfg_slider = gr.Slider(
545
+ minimum=1,
546
+ maximum=10,
547
+ value=5.0,
548
+ step=0.1,
549
+ label="⚙️ CFG Strength",
550
+ )
551
 
552
+ def _on_use_example(self, example_idx: int):
553
+ """When clicking 'Use This Example' button, fill in the example prompt"""
554
+ if example_idx < 0 or example_idx >= len(EXAMPLE_GALLERY_LIST):
555
+ if self.prompt_engineering_available:
556
+ return ("", "0,1,2,3", gr.update(), gr.update(value="", visible=False), gr.update(interactive=False), "Please select a valid example")
557
+ else:
558
+ return ("", "0,1,2,3", gr.update(), gr.update(), gr.update(), gr.update())
 
 
 
 
 
 
 
 
 
 
559
 
560
+ example = EXAMPLE_GALLERY_LIST[example_idx]
561
+ if self.prompt_engineering_available:
562
+ # Set text directly to rewritten_text and enable generate button
563
+ return (
564
+ example["prompt"],
565
+ example["seeds"],
566
+ gr.update(value=example["duration"]),
567
+ gr.update(value=example["prompt"], visible=True),
568
+ gr.update(interactive=True),
569
+ "✅ Example selected! Click [🚀 Generate Motion] to start.",
570
+ )
571
+ else:
572
+ return (
573
+ example["prompt"],
574
+ example["seeds"],
575
+ gr.update(value=example["duration"]),
576
+ gr.update(),
577
+ gr.update(),
578
+ gr.update(),
579
  )
580
 
581
  def _bind_events(self):
582
  # Generate random seeds
583
  self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
584
 
585
+ # Use example buttons - bind each button to its example
586
+ for btn, idx in self.use_example_btns:
587
+ btn.click(
588
+ fn=lambda i=idx: self._on_use_example(i),
589
+ outputs=[self.text_input, self.seed_input, self.duration_slider, self.rewritten_text, self.generate_btn, self.status_output],
590
+ )
591
+
592
  # Bind example selection event
593
  self.example_dropdown.change(
594
  fn=self._on_example_select,
595
  inputs=[self.example_dropdown],
596
+ outputs=[self.text_input, self.seed_input, self.duration_slider, self.rewritten_text, self.generate_btn, self.status_output],
597
  )
598
 
599
  # Rewrite text logic (only bind when rewrite is available)
 
603
  inputs=[
604
  self.text_input,
605
  self.duration_slider,
 
 
606
  ],
607
+ outputs=[self.rewritten_text, self.generate_btn, self.duration_slider, self.status_output],
608
  ).then(
609
+ fn=lambda: gr.update(visible=True),
610
+ outputs=[self.rewritten_text],
 
 
 
611
  )
612
 
613
  # Generate motion logic
 
616
  outputs=[self.status_output],
617
  ).then(
618
  generate_motion_func,
619
+ inputs=[self.text_input, self.rewritten_text, self.seed_input, self.duration_slider, self.cfg_slider],
620
+ outputs=[self.output_display, self.fbx_files],
 
 
 
 
 
 
 
 
621
  ).then(
622
  fn=lambda fbx_list: (
623
  (
 
633
 
634
  # Reset logic - different behavior based on rewrite availability
635
  if self.prompt_engineering_available:
636
+ # When text_input changes:
637
+ # - If text_input == rewritten_text, it means the change was triggered by example selection,
638
+ # so we should NOT hide the rewritten_text (keep it visible and generate button enabled)
639
+ # - If text_input != rewritten_text, it means user manually edited the input,
640
+ # so we should hide the rewritten_text and require a new rewrite
641
  self.text_input.change(
642
+ fn=lambda text, rewritten: (
643
+ gr.update() if text.strip() == rewritten.strip() else gr.update(visible=False),
644
+ gr.update() if text.strip() == rewritten.strip() else gr.update(interactive=False),
645
+ (
646
+ "✅ Example selected! Click [🚀 Generate Motion] to start."
647
+ if text.strip() == rewritten.strip() and text.strip()
648
+ else "Please click the [🔄 Rewrite Text] button to rewrite the text first"
649
+ ),
650
  ),
651
+ inputs=[self.text_input, self.rewritten_text],
652
  outputs=[self.rewritten_text, self.generate_btn, self.status_output],
653
  )
654
  else:
 
688
  class Args:
689
  model_path = final_model_path
690
  output_dir = "output/gradio"
691
+ use_prompt_engineering = USE_PROMPT_ENGINEERING
692
  use_text_encoder = True
 
 
 
693
 
694
  args = Args()
695
 
 
715
 
716
  if __name__ == "__main__":
717
  # Create demo at module level for Hugging Face Spaces
718
+ import argparse
719
+ parser = argparse.ArgumentParser(description="HY-Motion-1.0 Gradio App")
720
+ parser.add_argument("--port", type=int, default=7860, help="Port to listen on")
721
+ args = parser.parse_args()
722
+
723
+ USE_PROMPT_ENGINEERING = True
724
  try_to_download_text_encoder()
725
  # Then download the main model
726
  final_model_path = try_to_download_model()
727
+ model_inference = ModelInference(final_model_path,
728
  use_prompt_engineering=False, use_text_encoder=True)
729
  model_inference.initialize_model(device="cpu")
730
+
731
+ # Generate examples on first startup (if not exists)
732
+ ensure_examples_generated(model_inference)
733
+
734
  demo = create_demo(final_model_path)
735
+ demo.launch(server_name="0.0.0.0", server_port=args.port)
hymotion/network/text_encoders/text_encoder.py CHANGED
@@ -99,7 +99,9 @@ class HYTextModel(nn.Module):
99
  padding_side="right",
100
  )
101
  self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
102
- LLM_ENCODER_LAYOUT[llm_type]["module_path"], low_cpu_mem_usage=True
 
 
103
  )
104
  self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
105
  self.ctxt_dim = self.llm_text_encoder.config.hidden_size
@@ -150,9 +152,9 @@ class HYTextModel(nn.Module):
150
  )
151
  )
152
  if self.llm_type == "qwen3":
153
- ctxt_raw = llm_outputs.hidden_states[-1]
154
  else:
155
- ctxt_raw = llm_outputs.last_hidden_state
156
 
157
  start = self.crop_start
158
  end = start + self._orig_max_length_llm
 
99
  padding_side="right",
100
  )
101
  self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
102
+ LLM_ENCODER_LAYOUT[llm_type]["module_path"],
103
+ low_cpu_mem_usage=True,
104
+ torch_dtype=torch.bfloat16,
105
  )
106
  self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
107
  self.ctxt_dim = self.llm_text_encoder.config.hidden_size
 
152
  )
153
  )
154
  if self.llm_type == "qwen3":
155
+ ctxt_raw = llm_outputs.hidden_states[-1].clone()
156
  else:
157
+ ctxt_raw = llm_outputs.last_hidden_state.clone()
158
 
159
  start = self.crop_start
160
  end = start + self._orig_max_length_llm
hymotion/pipeline/motion_diffusion.py CHANGED
@@ -176,7 +176,6 @@ class MotionGeneration(torch.nn.Module):
176
  def load_in_demo(
177
  self,
178
  ckpt_name: str,
179
- mean_std_name: Optional[str] = None,
180
  build_text_encoder: bool = True,
181
  allow_empty_ckpt: bool = False,
182
  ) -> None:
@@ -188,11 +187,6 @@ class MotionGeneration(torch.nn.Module):
188
  else:
189
  checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
190
  self.load_state_dict(checkpoint["model_state_dict"], strict=False)
191
- if mean_std_name is not None:
192
- assert os.path.exists(mean_std_name), f"{mean_std_name} not found"
193
- if not os.path.isfile(mean_std_name):
194
- mean_std_name = None
195
- self._load_mean_std(mean_std_name)
196
  self.motion_transformer.eval()
197
  if build_text_encoder and not self.uncondition_mode:
198
  self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
@@ -299,11 +293,11 @@ class MotionGeneration(torch.nn.Module):
299
  k3d = torch.zeros(B, L, nj, 3, device=device)
300
 
301
  return dict(
302
- latent_denorm=latent_denorm, # (B, L, 201)
303
- keypoints3d=k3d, # (B, L, J, 3)
304
- rot6d=rot6d_smooth, # (B, L, J, 6)
305
- transl=transl_smooth, # (B, L, 3)
306
- root_rotations_mat=root_rotmat_smooth, # (B, L, 3, 3)
307
  )
308
 
309
  @staticmethod
@@ -584,9 +578,8 @@ class MotionFlowMatching(MotionGeneration):
584
  )
585
  with torch.no_grad():
586
  trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
587
- sampled = trajectory[-1]
588
  assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
589
- sampled = sampled[:, :length, ...].clone()
590
 
591
  output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
592
 
 
176
  def load_in_demo(
177
  self,
178
  ckpt_name: str,
 
179
  build_text_encoder: bool = True,
180
  allow_empty_ckpt: bool = False,
181
  ) -> None:
 
187
  else:
188
  checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
189
  self.load_state_dict(checkpoint["model_state_dict"], strict=False)
 
 
 
 
 
190
  self.motion_transformer.eval()
191
  if build_text_encoder and not self.uncondition_mode:
192
  self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
 
293
  k3d = torch.zeros(B, L, nj, 3, device=device)
294
 
295
  return dict(
296
+ latent_denorm=latent_denorm.cpu().detach(), # (B, L, 201)
297
+ keypoints3d=k3d.cpu().detach(), # (B, L, J, 3)
298
+ rot6d=rot6d_smooth.cpu().detach(), # (B, L, J, 6)
299
+ transl=transl_smooth.cpu().detach(), # (B, L, 3)
300
+ root_rotations_mat=root_rotmat_smooth.cpu().detach(), # (B, L, 3, 3)
301
  )
302
 
303
  @staticmethod
 
578
  )
579
  with torch.no_grad():
580
  trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
581
+ sampled = trajectory[-1][:, :length, ...].clone()
582
  assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
 
583
 
584
  output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
585
 
hymotion/prompt_engineering/client.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from openai import OpenAI
5
+ import json
6
+
7
+ PROMPT = """
8
+ # Role
9
+ You are an expert in 3D motion analysis, animation timing, and choreography. Your task is to analyze textual action descriptions to estimate execution time and standardize the language for motion generation systems.
10
+
11
+ # Task
12
+ Analyze the user-provided [Input Action] and generate a structured JSON response containing a duration estimate and a refined caption.
13
+
14
+ # Instructions
15
+
16
+ ### 1. Duration Estimation (frame_count)
17
+ - Analyze the complexity, speed, and physical constraints of the described action.
18
+ - Estimate the time required to perform the action in a **smooth, natural, and realistic manner**.
19
+ - Calculate the total duration in frames based on a **30 fps** (frames per second) standard.
20
+ - Output strictly as an Integer.
21
+
22
+ ### 2. Caption Refinement (short_caption)
23
+ - Generate a refined, grammatically correct version of the input description in **English**.
24
+ - **Strict Constraints**:
25
+ - You must **PRESERVE** the original sequence of events (chronological order).
26
+ - You must **RETAIN** all original spatial modifiers (e.g., "left," "upward," "quickly").
27
+ - **DO NOT** add new sub-actions or hallucinate details not present in the input.
28
+ - **DO NOT** delete any specific movements.
29
+ - The goal is to improve clarity and flow while maintaining 100% semantic fidelity to the original request.
30
+
31
+ ### 3. Output Format
32
+ - Return **ONLY** a raw JSON object.
33
+ - Do not use Markdown formatting (i.e., do not use ```json ... ```).
34
+ - Ensure the JSON is valid and parsable.
35
+
36
+ # JSON Structure
37
+ {{
38
+ "duration": <Integer, frames at 30fps>,
39
+ "short_caption": "<String, the refined English description>"
40
+ }}
41
+
42
+ # Input
43
+ {}
44
+ """
45
+
46
+
47
+ class PromptEngineeringClient:
48
+ def __init__(self):
49
+ BASE_URL = os.environ.get("PROMPT_ENGINEERING_BASE_URL", "http://IP:PORT/v1")
50
+ API_KEY = os.environ.get("PROMPT_ENGINEERING_API_KEY", "EMPTY")
51
+ MODEL_NAME = os.environ.get("PROMPT_ENGINEERING_MODEL_NAME", "")
52
+ client = OpenAI(
53
+ api_key=API_KEY,
54
+ base_url=BASE_URL
55
+ )
56
+ self.model_name = MODEL_NAME
57
+ self.client = client
58
+
59
+ def rewrite_prompt_and_infer_time(self, text, max_timeout=30):
60
+ start_time = time.time()
61
+ while True:
62
+ end_time = time.time()
63
+ if end_time - start_time > max_timeout:
64
+ raise Exception("Prompt rewriting timeout")
65
+ try:
66
+ chat_response = self.client.chat.completions.create(
67
+ model=self.model_name,
68
+ messages=[
69
+ {"role": "system", "content": "You are a helpful assistant."},
70
+ {"role": "user", "content": PROMPT.format(text)},
71
+ ]
72
+ )
73
+ chat_response = json.loads(chat_response.choices[0].message.content.strip())
74
+ duration = chat_response["duration"]
75
+ short_caption = chat_response["short_caption"]
76
+ pred_duration = min(12, max(1, int(duration) / 30))
77
+ except Exception as e:
78
+ print(e)
79
+ continue
80
+ else:
81
+ break
82
+
83
+ return pred_duration, short_caption
84
+
85
+ if __name__ == "__main__":
86
+ # python -m hymotion.prompt_engineering.client
87
+ client = PromptEngineeringClient()
88
+ print(client.rewrite_prompt_and_infer_time("A person jumps upward with both legs twice."))
hymotion/utils/gradio_css.py CHANGED
@@ -116,6 +116,14 @@ APP_CSS = """
116
  font-weight:500 !important;
117
  }
118
 
 
 
 
 
 
 
 
 
119
  /* Button base class and variant */
120
  .generate-button,.rewrite-button,.dice-button{
121
  border:none !important; color:#fff !important; font-weight:600 !important;
@@ -206,6 +214,20 @@ APP_CSS = """
206
  padding:10px !important;
207
  color:var(--text-secondary, #666) !important;
208
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  """
210
 
211
  HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground\n### *Tencent Hunyuan 3D Digital Human Team*"
@@ -248,3 +270,10 @@ def get_placeholder_html() -> str:
248
  </div>
249
  """
250
 
 
 
 
 
 
 
 
 
116
  font-weight:500 !important;
117
  }
118
 
119
+ /* Status textbox - dynamic height based on content */
120
+ .status-textbox textarea{
121
+ height:auto !important;
122
+ min-height:2.5em !important;
123
+ resize:none !important;
124
+ overflow-y:hidden !important;
125
+ }
126
+
127
  /* Button base class and variant */
128
  .generate-button,.rewrite-button,.dice-button{
129
  border:none !important; color:#fff !important; font-weight:600 !important;
 
214
  padding:10px !important;
215
  color:var(--text-secondary, #666) !important;
216
  }
217
+
218
+ /* Example Gallery Styles */
219
+ .example-gallery-display{
220
+ padding:0 !important; margin:12px 0 !important; border:none !important;
221
+ box-shadow:none !important; background:var(--iframe-bg) !important;
222
+ border-radius:10px !important; position:relative !important;
223
+ min-height:500px !important;
224
+ }
225
+
226
+ .example-gallery-display iframe{
227
+ width:100% !important; min-height:500px !important;
228
+ border:none !important; border-radius:10px !important; display:block !important;
229
+ background:var(--iframe-bg) !important;
230
+ }
231
  """
232
 
233
  HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground\n### *Tencent Hunyuan 3D Digital Human Team*"
 
270
  </div>
271
  """
272
 
273
+
274
+ WITHOUT_PROMPT_ENGINEERING_WARNING = """
275
+ <div style='color: #ff0000; font-weight: bold;'>
276
+ <p>Prompt engineering is not available. You should use `A person ...` format to describe the motion and manually adjust the duration. Click [📚 Example Prompts] to see more examples.</p>
277
+ <p>Non-humanoid Characters, Multi-person Interactions and Environment & Camera are not supported.</p>
278
+ </div>
279
+ """
hymotion/utils/gradio_runtime.py CHANGED
@@ -26,6 +26,7 @@ def _now():
26
  ms = int((t - int(t)) * 1000)
27
  return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
28
 
 
29
  _MODEL_CACHE = None
30
 
31
 
@@ -37,19 +38,14 @@ class SimpleRuntime(torch.nn.Module):
37
  # prompt engineering
38
  if self.load_prompt_engineering:
39
  print(f"[{self.__class__.__name__}] Loading prompt engineering...")
40
- self.prompt_rewriter = PromptRewriter(
41
- host=None, model_path=None, device="cpu"
42
- )
43
  else:
44
  self.prompt_rewriter = None
45
  # text encoder
46
  if self.load_text_encoder:
47
  print(f"[{self.__class__.__name__}] Loading text encoder...")
48
  _text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
49
- _text_encoder_cfg = {
50
- "llm_type": "qwen3",
51
- "max_length_llm": 128
52
- }
53
  text_encoder = load_object(_text_encoder_module, _text_encoder_cfg)
54
  else:
55
  text_encoder = None
@@ -66,7 +62,6 @@ class SimpleRuntime(torch.nn.Module):
66
  print(f"[{self.__class__.__name__}] Loading ckpt: {ckpt_name}")
67
  pipeline.load_in_demo(
68
  os.path.join(os.path.dirname(config_path), ckpt_name),
69
- "stats",
70
  build_text_encoder=False,
71
  allow_empty_ckpt=False,
72
  )
@@ -87,7 +82,6 @@ class SimpleRuntime(torch.nn.Module):
87
  self.fbx_converter = None
88
  print(">>> FBX module not found. FBX export will be disabled.")
89
 
90
-
91
  def _generate_html_content(
92
  self,
93
  timestamp: str,
@@ -128,7 +122,6 @@ class SimpleRuntime(torch.nn.Module):
128
  # Return error HTML
129
  return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
130
 
131
-
132
  def _generate_fbx_files(
133
  self,
134
  visualization_data: dict,
@@ -247,6 +240,7 @@ class SimpleRuntime(torch.nn.Module):
247
  else:
248
  raise ValueError(f">>> Invalid output format: {output_format}")
249
 
 
250
  class ModelInference:
251
  """
252
  Handles model inference and data processing for Depth Anything 3.
@@ -288,7 +282,7 @@ class ModelInference:
288
  config_path=os.path.join(self.model_path, "config.yml"),
289
  ckpt_name="latest.ckpt",
290
  load_prompt_engineering=self.use_prompt_engineering,
291
- load_text_encoder=self.use_text_encoder
292
  )
293
  # Load to CPU first (faster, and allows reuse)
294
  _MODEL_CACHE = _MODEL_CACHE.to("cpu")
@@ -306,9 +300,7 @@ class ModelInference:
306
 
307
  return _MODEL_CACHE
308
 
309
- def run_inference(
310
- self, *args, **kwargs
311
- ):
312
  """
313
  Run DepthAnything3 model inference on images.
314
  Args:
@@ -333,7 +325,6 @@ class ModelInference:
333
  # Initialize model if needed - get model instance (not stored in self)
334
  model = self.initialize_model(device)
335
 
336
-
337
  with torch.no_grad():
338
  print(f"[{self.__class__.__name__}] Running inference with torch.no_grad")
339
  html_content, fbx_files, model_output = model.generate_motion(*args, **kwargs)
@@ -347,7 +338,13 @@ class ModelInference:
347
 
348
  return html_content, fbx_files
349
 
 
350
  if __name__ == "__main__":
351
  # python -m hymotion.utils.gradio_runtime
352
- runtime = SimpleRuntime(config_path="assets/config_simplified.yml", ckpt_name="latest.ckpt", load_prompt_engineering=False, load_text_encoder=False)
353
- print(runtime.pipeline)
 
 
 
 
 
 
26
  ms = int((t - int(t)) * 1000)
27
  return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
28
 
29
+
30
  _MODEL_CACHE = None
31
 
32
 
 
38
  # prompt engineering
39
  if self.load_prompt_engineering:
40
  print(f"[{self.__class__.__name__}] Loading prompt engineering...")
41
+ self.prompt_rewriter = PromptRewriter(host=None, model_path=None, device="cpu")
 
 
42
  else:
43
  self.prompt_rewriter = None
44
  # text encoder
45
  if self.load_text_encoder:
46
  print(f"[{self.__class__.__name__}] Loading text encoder...")
47
  _text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
48
+ _text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": 128}
 
 
 
49
  text_encoder = load_object(_text_encoder_module, _text_encoder_cfg)
50
  else:
51
  text_encoder = None
 
62
  print(f"[{self.__class__.__name__}] Loading ckpt: {ckpt_name}")
63
  pipeline.load_in_demo(
64
  os.path.join(os.path.dirname(config_path), ckpt_name),
 
65
  build_text_encoder=False,
66
  allow_empty_ckpt=False,
67
  )
 
82
  self.fbx_converter = None
83
  print(">>> FBX module not found. FBX export will be disabled.")
84
 
 
85
  def _generate_html_content(
86
  self,
87
  timestamp: str,
 
122
  # Return error HTML
123
  return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
124
 
 
125
  def _generate_fbx_files(
126
  self,
127
  visualization_data: dict,
 
240
  else:
241
  raise ValueError(f">>> Invalid output format: {output_format}")
242
 
243
+
244
  class ModelInference:
245
  """
246
  Handles model inference and data processing for Depth Anything 3.
 
282
  config_path=os.path.join(self.model_path, "config.yml"),
283
  ckpt_name="latest.ckpt",
284
  load_prompt_engineering=self.use_prompt_engineering,
285
+ load_text_encoder=self.use_text_encoder,
286
  )
287
  # Load to CPU first (faster, and allows reuse)
288
  _MODEL_CACHE = _MODEL_CACHE.to("cpu")
 
300
 
301
  return _MODEL_CACHE
302
 
303
+ def run_inference(self, *args, **kwargs):
 
 
304
  """
305
  Run DepthAnything3 model inference on images.
306
  Args:
 
325
  # Initialize model if needed - get model instance (not stored in self)
326
  model = self.initialize_model(device)
327
 
 
328
  with torch.no_grad():
329
  print(f"[{self.__class__.__name__}] Running inference with torch.no_grad")
330
  html_content, fbx_files, model_output = model.generate_motion(*args, **kwargs)
 
338
 
339
  return html_content, fbx_files
340
 
341
+
342
  if __name__ == "__main__":
343
  # python -m hymotion.utils.gradio_runtime
344
+ runtime = SimpleRuntime(
345
+ config_path="assets/config_simplified.yml",
346
+ ckpt_name="latest.ckpt",
347
+ load_prompt_engineering=False,
348
+ load_text_encoder=False,
349
+ )
350
+ print(runtime.pipeline)
hymotion/utils/t2m_runtime.py CHANGED
@@ -128,7 +128,6 @@ class T2MRuntime:
128
  device = torch.device("cpu")
129
  pipeline.load_in_demo(
130
  self.ckpt_name,
131
- os.path.dirname(self.ckpt_name),
132
  build_text_encoder=not self.skip_text,
133
  allow_empty_ckpt=allow_empty_ckpt,
134
  )
@@ -145,7 +144,6 @@ class T2MRuntime:
145
  )
146
  p.load_in_demo(
147
  self.ckpt_name,
148
- os.path.dirname(self.ckpt_name),
149
  build_text_encoder=not self.skip_text,
150
  allow_empty_ckpt=allow_empty_ckpt,
151
  )
@@ -238,6 +236,8 @@ class T2MRuntime:
238
  raise
239
  finally:
240
  self._release_pipeline(pi)
 
 
241
 
242
  def load_text_encoder(self) -> None:
243
  """
 
128
  device = torch.device("cpu")
129
  pipeline.load_in_demo(
130
  self.ckpt_name,
 
131
  build_text_encoder=not self.skip_text,
132
  allow_empty_ckpt=allow_empty_ckpt,
133
  )
 
144
  )
145
  p.load_in_demo(
146
  self.ckpt_name,
 
147
  build_text_encoder=not self.skip_text,
148
  allow_empty_ckpt=allow_empty_ckpt,
149
  )
 
236
  raise
237
  finally:
238
  self._release_pipeline(pi)
239
+ if torch.cuda.is_available():
240
+ torch.cuda.empty_cache()
241
 
242
  def load_text_encoder(self) -> None:
243
  """
requirements.txt CHANGED
@@ -3,11 +3,13 @@ huggingface_hub==0.30.0
3
 
4
  torch==2.5.1
5
  torchvision==0.20.1
 
6
  accelerate==0.30.1
7
  diffusers==0.26.3
8
  transformers==4.53.3
9
  einops==0.8.1
10
  safetensors==0.5.3
 
11
 
12
  numpy>=1.24.0,<2.0
13
  scipy>=1.10.0
@@ -20,5 +22,3 @@ requests==2.32.4
20
  openai==1.78.1
21
 
22
  fbxsdkpy==2020.1.post2
23
-
24
- torchdiffeq==0.2.5
 
3
 
4
  torch==2.5.1
5
  torchvision==0.20.1
6
+ torchdiffeq==0.2.5
7
  accelerate==0.30.1
8
  diffusers==0.26.3
9
  transformers==4.53.3
10
  einops==0.8.1
11
  safetensors==0.5.3
12
+ bitsandbytes==0.49.0
13
 
14
  numpy>=1.24.0,<2.0
15
  scipy>=1.10.0
 
22
  openai==1.78.1
23
 
24
  fbxsdkpy==2020.1.post2