chingshuai commited on
Commit
f6152b4
·
1 Parent(s): 848f72a
gradio_app.py CHANGED
@@ -1,4 +1,3 @@
1
- # we should use gradio==5.38.2
2
  import argparse
3
  import codecs as cs
4
  import json
@@ -10,47 +9,33 @@ import textwrap
10
  from typing import List, Optional, Tuple, Union
11
 
12
  import gradio as gr
13
- import torch
14
- from huggingface_hub import snapshot_download
15
-
16
- def try_to_download_model():
17
- repo_id = "tencent/HY-Motion-1.0"
18
- target_folder = "HY-Motion-1.0-Lite"
19
- print(f">>> start download ", repo_id, target_folder)
20
- local_dir = snapshot_download(
21
- repo_id=repo_id,
22
- allow_patterns=f"{target_folder}/*",
23
- local_dir="./downloaded_models"
24
- )
25
- final_model_path = os.path.join(local_dir, target_folder)
26
- print(f">>> Final model path: {final_model_path}")
27
- return final_model_path
28
-
29
-
30
  # Import spaces for Hugging Face Zero GPU support
31
- try:
32
- import spaces
33
- SPACES_AVAILABLE = True
34
- except ImportError:
35
- SPACES_AVAILABLE = False
36
- # Create a dummy decorator when spaces is not available
37
- class spaces:
38
- @staticmethod
39
- def GPU(func=None, duration=None):
40
- def decorator(fn):
41
- return fn
42
- if func is not None:
43
- return func
44
- return decorator
45
 
46
- from hymotion.utils.t2m_runtime import T2MRuntime
 
 
 
47
 
48
- NUM_WORKERS = torch.cuda.device_count() if torch.cuda.is_available() else 1
 
 
 
49
 
50
- # Global runtime instance for Zero GPU lazy loading
51
- _global_runtime = None
52
- _global_args = None
 
 
 
 
 
 
53
 
 
54
 
55
  def _init_runtime_if_needed():
56
  """Initialize runtime lazily for Zero GPU support."""
@@ -81,267 +66,18 @@ def _init_runtime_if_needed():
81
  ckpt_name=ckpt,
82
  skip_text=skip_text,
83
  device_ids=None,
84
- prompt_engineering_host=args.prompt_engineering_host,
85
  skip_model_loading=skip_model_loading,
 
 
 
86
  )
87
  return _global_runtime
88
 
89
-
90
- @spaces.GPU(duration=120)
91
- def generate_motion_on_gpu(
92
- text: str,
93
- seeds_csv: str,
94
- motion_duration: float,
95
- cfg_scale: float,
96
- output_format: str,
97
- original_text: str,
98
- output_dir: str,
99
- ) -> Tuple[str, List[str]]:
100
- """
101
- GPU-decorated function for motion generation.
102
- This function will request GPU allocation on Hugging Face Zero GPU.
103
- """
104
- runtime = _init_runtime_if_needed()
105
-
106
- html_content, fbx_files, _ = runtime.generate_motion(
107
- text=text,
108
- seeds_csv=seeds_csv,
109
- duration=motion_duration,
110
- cfg_scale=cfg_scale,
111
- output_format=output_format,
112
- original_text=original_text,
113
- output_dir=output_dir,
114
- )
115
- return html_content, fbx_files
116
-
117
-
118
  # define data sources
119
  DATA_SOURCES = {
120
  "example_prompts": "examples/example_prompts/example_subset.json",
121
  }
122
 
123
- # create interface
124
- APP_CSS = """
125
- :root{
126
- --primary-start:#667eea; --primary-end:#764ba2;
127
- --secondary-start:#4facfe; --secondary-end:#00f2fe;
128
- --accent-start:#f093fb; --accent-end:#f5576c;
129
- --page-bg:linear-gradient(135deg,#f5f7fa 0%,#c3cfe2 100%);
130
- --card-bg:linear-gradient(135deg,#ffffff 0%,#f8f9fa 100%);
131
- --radius:12px;
132
- --iframe-bg:#ffffff;
133
- }
134
-
135
- /* Dark mode variables */
136
- [data-theme="dark"], .dark {
137
- --page-bg:linear-gradient(135deg,#1a1a1a 0%,#2d3748 100%);
138
- --card-bg:linear-gradient(135deg,#2d3748 0%,#374151 100%);
139
- --text-primary:#f7fafc;
140
- --text-secondary:#e2e8f0;
141
- --border-color:#4a5568;
142
- --input-bg:#374151;
143
- --input-border:#4a5568;
144
- --iframe-bg:#1a1a2e;
145
- }
146
-
147
- /* Page and card */
148
- .gradio-container{
149
- background:var(--page-bg) !important;
150
- min-height:100vh !important;
151
- color:var(--text-primary, #333) !important;
152
- }
153
-
154
- .main-header{
155
- background:transparent !important; border:none !important; box-shadow:none !important;
156
- padding:0 !important; margin:10px 0 16px !important;
157
- text-align:center !important;
158
- }
159
-
160
- .main-header h1, .main-header p, .main-header li {
161
- color:var(--text-primary, #333) !important;
162
- }
163
-
164
- .left-panel,.right-panel{
165
- background:var(--card-bg) !important;
166
- border:1px solid var(--border-color, #e9ecef) !important;
167
- border-radius:15px !important;
168
- box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
169
- padding:24px !important;
170
- }
171
-
172
- .gradio-accordion{
173
- border:1px solid var(--border-color, #e1e5e9) !important;
174
- border-radius:var(--radius) !important;
175
- margin:12px 0 !important; background:transparent !important;
176
- }
177
-
178
- .gradio-accordion summary{
179
- background:transparent !important;
180
- padding:14px 18px !important;
181
- font-weight:600 !important;
182
- color:var(--text-primary, #495057) !important;
183
- }
184
-
185
- .gradio-group{
186
- background:transparent !important; border:none !important;
187
- border-radius:8px !important; padding:12px 0 !important; margin:8px 0 !important;
188
- }
189
-
190
- /* Input class style - dark mode adaptation */
191
- .gradio-textbox input,.gradio-textbox textarea,.gradio-dropdown .wrap{
192
- border-radius:8px !important;
193
- border:2px solid var(--input-border, #e9ecef) !important;
194
- background:var(--input-bg, #fff) !important;
195
- color:var(--text-primary, #333) !important;
196
- transition:.2s all !important;
197
- }
198
-
199
- .gradio-textbox input:focus,.gradio-textbox textarea:focus,.gradio-dropdown .wrap:focus-within{
200
- border-color:var(--primary-start) !important;
201
- box-shadow:0 0 0 3px rgba(102,126,234,.1) !important;
202
- }
203
-
204
- .gradio-slider input[type="range"]{
205
- background:linear-gradient(to right,var(--primary-start),var(--primary-end)) !important;
206
- border-radius:10px !important;
207
- }
208
-
209
- .gradio-checkbox input[type="checkbox"]{
210
- border-radius:4px !important;
211
- border:2px solid var(--input-border, #e9ecef) !important;
212
- transition:.2s all !important;
213
- }
214
-
215
- .gradio-checkbox input[type="checkbox"]:checked{
216
- background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important;
217
- border-color:var(--primary-start) !important;
218
- }
219
-
220
- /* Label text color adaptation */
221
- .gradio-textbox label, .gradio-dropdown label, .gradio-slider label,
222
- .gradio-checkbox label, .gradio-html label {
223
- color:var(--text-primary, #333) !important;
224
- }
225
-
226
- .gradio-textbox .info, .gradio-dropdown .info, .gradio-slider .info,
227
- .gradio-checkbox .info {
228
- color:var(--text-secondary, #666) !important;
229
- }
230
-
231
- /* Status information - dark mode adaptation */
232
- .gradio-textbox[data-testid*="状态信息"] input{
233
- background:var(--input-bg, linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%)) !important;
234
- border:2px solid var(--input-border, #dee2e6) !important;
235
- color:var(--text-primary, #495057) !important;
236
- font-weight:500 !important;
237
- }
238
-
239
- /* Button base class and variant */
240
- .generate-button,.rewrite-button,.dice-button{
241
- border:none !important; color:#fff !important; font-weight:600 !important;
242
- border-radius:8px !important; transition:.3s all !important;
243
- box-shadow:0 4px 15px rgba(0,0,0,.12) !important;
244
- }
245
-
246
- .generate-button{ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; }
247
- .rewrite-button{ background:linear-gradient(45deg,var(--secondary-start),var(--secondary-end)) !important; }
248
- .dice-button{
249
- background:linear-gradient(45deg,var(--accent-start),var(--accent-end)) !important;
250
- height:40px !important;
251
- }
252
-
253
- .generate-button:hover,.rewrite-button:hover{ transform:translateY(-2px) !important; }
254
- .dice-button:hover{
255
- transform:scale(1.05) !important;
256
- box-shadow:0 4px 12px rgba(240,147,251,.28) !important;
257
- }
258
-
259
- .dice-container{
260
- display:flex !important;
261
- align-items:flex-end !important;
262
- justify-content:center !important;
263
- }
264
-
265
- /* Right panel clipping overflow, avoid double scrollbars */
266
- .right-panel{
267
- background:var(--card-bg) !important;
268
- border:1px solid var(--border-color, #e9ecef) !important;
269
- border-radius:15px !important;
270
- box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
271
- padding:24px !important; overflow:hidden !important;
272
- }
273
-
274
- /* Main content row - ensure equal heights */
275
- .main-row {
276
- display: flex !important;
277
- align-items: stretch !important;
278
- }
279
-
280
- /* Flask area - match left panel height */
281
- .flask-display{
282
- padding:0 !important; margin:0 !important; border:none !important;
283
- box-shadow:none !important; background:var(--iframe-bg) !important;
284
- border-radius:10px !important; position:relative !important;
285
- height:100% !important; min-height:750px !important;
286
- display:flex !important; flex-direction:column !important;
287
- }
288
-
289
- .flask-display iframe{
290
- width:100% !important; flex:1 !important; min-height:750px !important;
291
- border:none !important; border-radius:10px !important; display:block !important;
292
- background:var(--iframe-bg) !important;
293
- }
294
-
295
- /* Right panel should stretch to match left panel */
296
- .right-panel{
297
- background:var(--card-bg) !important;
298
- border:1px solid var(--border-color, #e9ecef) !important;
299
- border-radius:15px !important;
300
- box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
301
- padding:24px !important; overflow:hidden !important;
302
- display:flex !important; flex-direction:column !important;
303
- }
304
-
305
- /* Ensure dropdown menu is visible in dark mode */
306
- [data-theme="dark"] .gradio-dropdown .wrap,
307
- .dark .gradio-dropdown .wrap {
308
- background:var(--input-bg) !important;
309
- color:var(--text-primary) !important;
310
- }
311
-
312
- [data-theme="dark"] .gradio-dropdown .option,
313
- .dark .gradio-dropdown .option {
314
- background:var(--input-bg) !important;
315
- color:var(--text-primary) !important;
316
- }
317
-
318
- [data-theme="dark"] .gradio-dropdown .option:hover,
319
- .dark .gradio-dropdown .option:hover {
320
- background:var(--border-color) !important;
321
- }
322
-
323
- .footer{
324
- text-align:center !important;
325
- margin-top:20px !important;
326
- padding:10px !important;
327
- color:var(--text-secondary, #666) !important;
328
- }
329
- """
330
-
331
- HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground"
332
-
333
- FOOTER_MD = "*This is a Beta version, any issues or feedback are welcome!*"
334
-
335
- HTML_OUTPUT_PLACEHOLDER = """
336
- <div style='height: 750px; width: 100%; border-radius: 8px; border-color: #e5e7eb; border-style: solid; border-width: 1px; display: flex; justify-content: center; align-items: center;'>
337
- <div style='text-align: center; font-size: 16px; color: #6b7280;'>
338
- <p style="color: #8d8d8d;">Welcome to HY-Motion-1.0!</p>
339
- <p style="color: #8d8d8d;">No motion visualization here yet.</p>
340
- </div>
341
- </div>
342
- """
343
-
344
-
345
  def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
346
  """Load examples from txt file."""
347
 
@@ -393,20 +129,12 @@ def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12
393
 
394
 
395
  class T2MGradioUI:
396
- def __init__(self, runtime: T2MRuntime, args: argparse.Namespace):
397
- self.runtime = runtime
398
- self.args = args
399
-
400
- # Check if rewrite is available:
401
- # - prompt_engineering_host must be provided
402
- # - disable_rewrite must not be set
403
- print(f">>> args: {vars(args)}")
404
- self.rewrite_available = (
405
- args.prompt_engineering_host is not None
406
- and args.prompt_engineering_host.strip() != ""
407
- and not args.disable_rewrite
408
- )
409
-
410
  self.all_example_data = {}
411
  self._init_example_data()
412
 
@@ -449,7 +177,7 @@ class T2MGradioUI:
449
  else:
450
  print(f"\t>>> Using LLM to estimate duration/rewrite text...")
451
  try:
452
- predicted_duration, rewritten_text = self.runtime.rewrite_text_and_infer_time(text=text)
453
  except Exception as e:
454
  print(f"\t>>> Text rewriting/duration prediction failed: {e}")
455
  return (
@@ -473,7 +201,7 @@ class T2MGradioUI:
473
  cfg_scale: float,
474
  ) -> Tuple[str, List[str]]:
475
  # When rewrite is not available, use original_text directly
476
- if not self.rewrite_available:
477
  text_to_use = original_text.strip()
478
  if not text_to_use:
479
  return "Error: Input text is empty, please enter text first", []
@@ -484,31 +212,30 @@ class T2MGradioUI:
484
 
485
  try:
486
  # Use runtime from global if available (for Zero GPU), otherwise use self.runtime
487
- runtime = _global_runtime if _global_runtime is not None else self.runtime
488
- fbx_ok = getattr(runtime, "fbx_available", False)
489
  req_format = "fbx" if fbx_ok else "dict"
490
 
491
  # Use GPU-decorated function for Zero GPU support
492
- html_content, fbx_files = generate_motion_on_gpu(
493
  text=text_to_use,
494
  seeds_csv=seed_input,
495
- motion_duration=duration,
496
  cfg_scale=cfg_scale,
497
  output_format=req_format,
498
  original_text=original_text,
499
- output_dir=self.args.output_dir,
500
  )
501
  # Escape HTML content for srcdoc attribute
502
- escaped_html = html_content.replace('"', '&quot;')
503
  # Return iframe with srcdoc - directly embed HTML content
504
- iframe_html = f'''
505
  <iframe
506
  srcdoc="{escaped_html}"
507
  width="100%"
508
  height="750px"
509
  style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
510
  ></iframe>
511
- '''
512
  return iframe_html, fbx_files
513
  except Exception as e:
514
  print(f"\t>>> Motion generation failed: {e}")
@@ -549,9 +276,14 @@ class T2MGradioUI:
549
  # Left control panel
550
  with gr.Column(scale=2, elem_classes=["left-panel"]):
551
  # Input textbox
 
 
 
 
 
552
  self.text_input = gr.Textbox(
553
  label="📝 Input Text",
554
- placeholder="Enter text to generate motion, support Chinese and English text input.",
555
  )
556
  # Rewritten textbox
557
  self.rewritten_text = gr.Textbox(
@@ -572,7 +304,7 @@ class T2MGradioUI:
572
 
573
  # Execute buttons
574
  with gr.Row():
575
- if self.rewrite_available:
576
  self.rewrite_btn = gr.Button(
577
  "🔄 Rewrite Text",
578
  variant="secondary",
@@ -595,17 +327,14 @@ class T2MGradioUI:
595
  variant="primary",
596
  size="lg",
597
  elem_classes=["generate-button"],
598
- interactive=not self.rewrite_available, # Enable directly if rewrite not available
599
  )
600
 
601
- if not self.rewrite_available:
602
  gr.Markdown(
603
  "> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
604
  )
605
 
606
- # Advanced settings
607
- with gr.Accordion("🔧 Advanced Settings", open=False):
608
- self._build_advanced_settings()
609
 
610
  # Example selection dropdown
611
  self.example_dropdown = gr.Dropdown(
@@ -616,8 +345,12 @@ class T2MGradioUI:
616
  interactive=True,
617
  )
618
 
 
 
 
 
619
  # Status message depends on whether rewrite is available
620
- if self.rewrite_available:
621
  status_msg = "Please click the [🔄 Rewrite Text] button to rewrite the text first"
622
  else:
623
  status_msg = "Enter your text and click [🚀 Generate Motion] directly."
@@ -629,7 +362,7 @@ class T2MGradioUI:
629
 
630
  # FBX Download section
631
  with gr.Row(visible=False) as self.fbx_download_row:
632
- if getattr(self.runtime, "fbx_available", False):
633
  self.fbx_files = gr.File(
634
  label="📦 Download FBX Files",
635
  file_count="multiple",
@@ -641,7 +374,7 @@ class T2MGradioUI:
641
  # Right display area
642
  with gr.Column(scale=3):
643
  self.output_display = gr.HTML(
644
- value=HTML_OUTPUT_PLACEHOLDER,
645
  show_label=False,
646
  elem_classes=["flask-display"]
647
  )
@@ -655,7 +388,7 @@ class T2MGradioUI:
655
 
656
  def _build_advanced_settings(self):
657
  # Only show rewrite options if rewrite is available
658
- if self.rewrite_available:
659
  with gr.Group():
660
  gr.Markdown("### 🔄 Text Rewriting Options")
661
  with gr.Row():
@@ -730,7 +463,7 @@ class T2MGradioUI:
730
  )
731
 
732
  # Rewrite text logic (only bind when rewrite is available)
733
- if self.rewrite_available:
734
  self.rewrite_btn.click(fn=lambda: "Rewriting text, please wait...", outputs=[self.status_output]).then(
735
  self._prompt_engineering,
736
  inputs=[
@@ -750,7 +483,7 @@ class T2MGradioUI:
750
 
751
  # Generate motion logic
752
  self.generate_btn.click(
753
- fn=lambda: "Generating motion, please wait... (It takes some extra time to start the renderer for the first generation)",
754
  outputs=[self.status_output],
755
  ).then(
756
  self._generate_motion,
@@ -761,8 +494,7 @@ class T2MGradioUI:
761
  self.duration_slider,
762
  self.cfg_slider,
763
  ],
764
- outputs=[self.output_display, self.fbx_files],
765
- concurrency_limit=NUM_WORKERS,
766
  ).then(
767
  fn=lambda fbx_list: (
768
  (
@@ -777,7 +509,7 @@ class T2MGradioUI:
777
  )
778
 
779
  # Reset logic - different behavior based on rewrite availability
780
- if self.rewrite_available:
781
  self.text_input.change(
782
  fn=lambda: (
783
  gr.update(visible=False),
@@ -802,7 +534,7 @@ class T2MGradioUI:
802
  outputs=[self.rewritten_text, self.generate_btn, self.status_output],
803
  )
804
  # Only bind rewritten_text change when rewrite is available
805
- if self.rewrite_available:
806
  self.rewritten_text.change(
807
  fn=lambda text: (
808
  gr.update(interactive=bool(text.strip())),
@@ -819,16 +551,17 @@ class T2MGradioUI:
819
 
820
  def create_demo(final_model_path):
821
  """Create the Gradio demo with Zero GPU support."""
822
- global _global_runtime, _global_args
823
 
824
  class Args:
825
  model_path = final_model_path
826
  output_dir = "output/gradio"
 
 
827
  prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
828
- disable_rewrite = False
 
829
 
830
  args = Args()
831
- _global_args = args # Set global args for lazy loading
832
 
833
  # Check required files:
834
  cfg = osp.join(args.model_path, "config.yml")
@@ -841,55 +574,19 @@ def create_demo(final_model_path):
841
 
842
  # For Zero GPU: Don't load model at startup, use lazy loading
843
  # Create a minimal runtime for UI initialization (without model loading)
844
- if SPACES_AVAILABLE:
845
- print(">>> Hugging Face Spaces detected. Using Zero GPU lazy loading.")
846
- print(">>> Model will be loaded on first GPU request.")
847
-
848
- # Create a placeholder runtime with minimal initialization for UI
849
- class PlaceholderRuntime:
850
- def __init__(self):
851
- self.fbx_available = False
852
- self.prompt_engineering_host = args.prompt_engineering_host
853
-
854
- def rewrite_text_and_infer_time(self, text: str):
855
- # For prompt rewriting, we don't need GPU
856
- from hymotion.prompt_engineering.prompt_rewrite import PromptRewriter
857
- rewriter = PromptRewriter(host=self.prompt_engineering_host)
858
- return rewriter.rewrite_prompt_and_infer_time(text)
859
-
860
- runtime = PlaceholderRuntime()
861
- else:
862
- # Local development: load model immediately
863
- print(">>> Local environment detected. Loading model at startup.")
864
- skip_model_loading = False
865
- if not os.path.exists(ckpt):
866
- print(f">>> [WARNING] Checkpoint file not found: {ckpt}")
867
- print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.")
868
- skip_model_loading = True
869
-
870
- print(">>> Initializing T2MRuntime...")
871
- if "USE_HF_MODELS" not in os.environ:
872
- os.environ["USE_HF_MODELS"] = "1"
873
-
874
- skip_text = False
875
- runtime = T2MRuntime(
876
- config_path=cfg,
877
- ckpt_name=ckpt,
878
- skip_text=skip_text,
879
- device_ids=None,
880
- prompt_engineering_host=args.prompt_engineering_host,
881
- skip_model_loading=skip_model_loading,
882
- )
883
- _global_runtime = runtime # Set global runtime for GPU function
884
-
885
- ui = T2MGradioUI(runtime=runtime, args=args)
886
  demo = ui.build_ui()
887
  return demo
888
 
889
 
890
  # Create demo at module level for Hugging Face Spaces
891
- final_model_path = try_to_download_model()
892
- demo = create_demo(final_model_path)
893
 
894
  if __name__ == "__main__":
 
 
 
 
 
895
  demo.launch()
 
 
1
  import argparse
2
  import codecs as cs
3
  import json
 
9
  from typing import List, Optional, Tuple, Union
10
 
11
  import gradio as gr
12
+ from hymotion.utils.gradio_runtime import ModelInference
13
+ from hymotion.utils.gradio_utils import try_to_download_model, try_to_download_text_encoder
14
+ from hymotion.utils.gradio_css import get_placeholder_html, APP_CSS, HEADER_BASE_MD, FOOTER_MD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Import spaces for Hugging Face Zero GPU support
16
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Apply @spaces.GPU decorator to run_inference method
19
+ # This ensures GPU operations happen in isolated subprocess
20
+ # Model loading and inference will occur in GPU subprocess, not main process
21
+ original_run_inference = ModelInference.run_inference
22
 
23
+ @spaces.GPU(duration=120) # Request GPU for up to 120 seconds per inference
24
+ def gpu_run_inference(self, *args, **kwargs):
25
+ """
26
+ GPU-accelerated inference with Spaces decorator.
27
 
28
+ This function runs in a GPU subprocess where:
29
+ - Model is loaded and moved to GPU (safe)
30
+ - CUDA operations are allowed
31
+ - All CUDA tensors are moved to CPU before return (for pickle safety)
32
+ """
33
+ return original_run_inference(self, *args, **kwargs)
34
+
35
+ # Replace the original method with the GPU-decorated version
36
+ ModelInference.run_inference = gpu_run_inference
37
 
38
+ from hymotion.utils.t2m_runtime import T2MRuntime
39
 
40
  def _init_runtime_if_needed():
41
  """Initialize runtime lazily for Zero GPU support."""
 
66
  ckpt_name=ckpt,
67
  skip_text=skip_text,
68
  device_ids=None,
 
69
  skip_model_loading=skip_model_loading,
70
+ disable_prompt_engineering=args.disable_prompt_engineering,
71
+ prompt_engineering_host=args.prompt_engineering_host,
72
+ prompt_engineering_model_path=args.prompt_engineering_model_path,
73
  )
74
  return _global_runtime
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # define data sources
77
  DATA_SOURCES = {
78
  "example_prompts": "examples/example_prompts/example_subset.json",
79
  }
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
82
  """Load examples from txt file."""
83
 
 
129
 
130
 
131
  class T2MGradioUI:
132
+ def __init__(self, args):
133
+ self.output_dir = args.output_dir
134
+ print(f"[{self.__class__.__name__}] output_dir: {self.output_dir}")
135
+ self.model_inference = ModelInference(args.model_path, use_prompt_engineering=args.use_prompt_engineering, use_text_encoder=args.use_text_encoder)
136
+ # self.args = args
137
+ self.prompt_engineering_available = args.use_prompt_engineering
 
 
 
 
 
 
 
 
138
  self.all_example_data = {}
139
  self._init_example_data()
140
 
 
177
  else:
178
  print(f"\t>>> Using LLM to estimate duration/rewrite text...")
179
  try:
180
+ predicted_duration, rewritten_text = self.model_inference.rewrite_text_and_infer_time(text=text)
181
  except Exception as e:
182
  print(f"\t>>> Text rewriting/duration prediction failed: {e}")
183
  return (
 
201
  cfg_scale: float,
202
  ) -> Tuple[str, List[str]]:
203
  # When rewrite is not available, use original_text directly
204
+ if not self.prompt_engineering_available:
205
  text_to_use = original_text.strip()
206
  if not text_to_use:
207
  return "Error: Input text is empty, please enter text first", []
 
212
 
213
  try:
214
  # Use runtime from global if available (for Zero GPU), otherwise use self.runtime
215
+ fbx_ok = getattr(self.model_inference, "fbx_available", False)
 
216
  req_format = "fbx" if fbx_ok else "dict"
217
 
218
  # Use GPU-decorated function for Zero GPU support
219
+ html_content, fbx_files = self.model_inference.run_inference(
220
  text=text_to_use,
221
  seeds_csv=seed_input,
222
+ duration=duration,
223
  cfg_scale=cfg_scale,
224
  output_format=req_format,
225
  original_text=original_text,
226
+ output_dir=self.output_dir
227
  )
228
  # Escape HTML content for srcdoc attribute
229
+ escaped_html = html_content.replace('"', "&quot;")
230
  # Return iframe with srcdoc - directly embed HTML content
231
+ iframe_html = f"""
232
  <iframe
233
  srcdoc="{escaped_html}"
234
  width="100%"
235
  height="750px"
236
  style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
237
  ></iframe>
238
+ """
239
  return iframe_html, fbx_files
240
  except Exception as e:
241
  print(f"\t>>> Motion generation failed: {e}")
 
276
  # Left control panel
277
  with gr.Column(scale=2, elem_classes=["left-panel"]):
278
  # Input textbox
279
+ if self.prompt_engineering_available:
280
+ input_place_holder = "Enter text to generate motion, support Chinese and English text input."
281
+ else:
282
+ input_place_holder = "Enter text to generate motion, please use `A person ...` format to describe the motion"
283
+
284
  self.text_input = gr.Textbox(
285
  label="📝 Input Text",
286
+ placeholder=input_place_holder,
287
  )
288
  # Rewritten textbox
289
  self.rewritten_text = gr.Textbox(
 
304
 
305
  # Execute buttons
306
  with gr.Row():
307
+ if self.prompt_engineering_available:
308
  self.rewrite_btn = gr.Button(
309
  "🔄 Rewrite Text",
310
  variant="secondary",
 
327
  variant="primary",
328
  size="lg",
329
  elem_classes=["generate-button"],
330
+ interactive=not self.prompt_engineering_available, # Enable directly if rewrite not available
331
  )
332
 
333
+ if not self.prompt_engineering_available:
334
  gr.Markdown(
335
  "> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
336
  )
337
 
 
 
 
338
 
339
  # Example selection dropdown
340
  self.example_dropdown = gr.Dropdown(
 
345
  interactive=True,
346
  )
347
 
348
+ # Advanced settings
349
+ with gr.Accordion("🔧 Advanced Settings", open=False):
350
+ self._build_advanced_settings()
351
+
352
  # Status message depends on whether rewrite is available
353
+ if self.prompt_engineering_available:
354
  status_msg = "Please click the [🔄 Rewrite Text] button to rewrite the text first"
355
  else:
356
  status_msg = "Enter your text and click [🚀 Generate Motion] directly."
 
362
 
363
  # FBX Download section
364
  with gr.Row(visible=False) as self.fbx_download_row:
365
+ if getattr(self.model_inference, "fbx_available", False):
366
  self.fbx_files = gr.File(
367
  label="📦 Download FBX Files",
368
  file_count="multiple",
 
374
  # Right display area
375
  with gr.Column(scale=3):
376
  self.output_display = gr.HTML(
377
+ value=get_placeholder_html(),
378
  show_label=False,
379
  elem_classes=["flask-display"]
380
  )
 
388
 
389
  def _build_advanced_settings(self):
390
  # Only show rewrite options if rewrite is available
391
+ if self.prompt_engineering_available:
392
  with gr.Group():
393
  gr.Markdown("### 🔄 Text Rewriting Options")
394
  with gr.Row():
 
463
  )
464
 
465
  # Rewrite text logic (only bind when rewrite is available)
466
+ if self.prompt_engineering_available:
467
  self.rewrite_btn.click(fn=lambda: "Rewriting text, please wait...", outputs=[self.status_output]).then(
468
  self._prompt_engineering,
469
  inputs=[
 
483
 
484
  # Generate motion logic
485
  self.generate_btn.click(
486
+ fn=lambda: "Generating motion, please wait... (It takes some extra time for the first generation)",
487
  outputs=[self.status_output],
488
  ).then(
489
  self._generate_motion,
 
494
  self.duration_slider,
495
  self.cfg_slider,
496
  ],
497
+ outputs=[self.output_display, self.fbx_files]
 
498
  ).then(
499
  fn=lambda fbx_list: (
500
  (
 
509
  )
510
 
511
  # Reset logic - different behavior based on rewrite availability
512
+ if self.prompt_engineering_available:
513
  self.text_input.change(
514
  fn=lambda: (
515
  gr.update(visible=False),
 
534
  outputs=[self.rewritten_text, self.generate_btn, self.status_output],
535
  )
536
  # Only bind rewritten_text change when rewrite is available
537
+ if self.prompt_engineering_available:
538
  self.rewritten_text.change(
539
  fn=lambda text: (
540
  gr.update(interactive=bool(text.strip())),
 
551
 
552
  def create_demo(final_model_path):
553
  """Create the Gradio demo with Zero GPU support."""
 
554
 
555
  class Args:
556
  model_path = final_model_path
557
  output_dir = "output/gradio"
558
+ use_prompt_engineering = False
559
+ use_text_encoder = True
560
  prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
561
+ prompt_engineering_model_path = os.environ.get("PROMPT_MODEL_PATH", None)
562
+ disable_prompt_engineering = os.environ.get("DISABLE_PROMPT_ENGINEERING", False)
563
 
564
  args = Args()
 
565
 
566
  # Check required files:
567
  cfg = osp.join(args.model_path, "config.yml")
 
574
 
575
  # For Zero GPU: Don't load model at startup, use lazy loading
576
  # Create a minimal runtime for UI initialization (without model loading)
577
+ ui = T2MGradioUI(args=args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  demo = ui.build_ui()
579
  return demo
580
 
581
 
582
  # Create demo at module level for Hugging Face Spaces
583
+ # Pre-download text encoder models first (without loading)
584
+
585
 
586
  if __name__ == "__main__":
587
+ # Create demo at module level for Hugging Face Spaces
588
+ try_to_download_text_encoder()
589
+ # Then download the main model
590
+ final_model_path = try_to_download_model()
591
+ demo = create_demo(final_model_path)
592
  demo.launch()
hymotion/prompt_engineering/prompt_rewrite.py CHANGED
@@ -13,8 +13,10 @@ import uuid
13
  from dataclasses import dataclass
14
  from typing import Any, Dict, List, Literal, Optional, Tuple, Union
15
 
 
16
  from openai import OpenAI
17
  from requests import exceptions as req_exc
 
18
 
19
  from .model_constants import REWRITE_AND_INFER_TIME_PROMPT_FORMAT
20
 
@@ -242,18 +244,39 @@ class ResponseParser:
242
 
243
 
244
  class PromptRewriter:
245
- def __init__(self, host: Optional[str] = None, parser: Optional[ResponseParser] = None):
 
 
246
  self.parser = parser or ResponseParser()
247
  self.logger = logging.getLogger(__name__)
248
- self.api = OpenAIChatApi(
249
- ApiConfig(
250
- host=host,
251
- user="",
252
- apikey="EMPTY",
253
- model="Qwen3-30B-A3B-SFT",
254
- api_version="",
 
 
 
255
  )
256
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  def rewrite_prompt_and_infer_time(
259
  self,
@@ -261,17 +284,36 @@ class PromptRewriter:
261
  prompt_format: str = REWRITE_AND_INFER_TIME_PROMPT_FORMAT,
262
  retry_config: Optional[RetryConfig] = None,
263
  ) -> Tuple[float, str]:
264
- self.logger.info("Start rewriting prompt...")
265
- try:
266
- result, cost, elapsed = self.parser.call_data_eval_with_retry(
267
- self.api, prompt_format.format(text), retry_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
269
- self.logger.info(f"Rewriting completed - cost: {cost:.6f}, time: {elapsed:.2f}s")
270
- return round(float(result["duration"]) / 30.0, 2), result["short_caption"]
 
 
271
 
272
- except Exception as e:
273
- self.logger.error(f"Prompt rewriting failed: {e}")
274
- raise
 
 
 
275
 
276
 
277
  if __name__ == "__main__":
 
13
  from dataclasses import dataclass
14
  from typing import Any, Dict, List, Literal, Optional, Tuple, Union
15
 
16
+ import torch
17
  from openai import OpenAI
18
  from requests import exceptions as req_exc
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
 
21
  from .model_constants import REWRITE_AND_INFER_TIME_PROMPT_FORMAT
22
 
 
244
 
245
 
246
  class PromptRewriter:
247
+ def __init__(
248
+ self, host: Optional[str] = None, model_path: Optional[str] = None, parser: Optional[ResponseParser] = None, device="auto"
249
+ ):
250
  self.parser = parser or ResponseParser()
251
  self.logger = logging.getLogger(__name__)
252
+ self.host = host
253
+ if host:
254
+ self.api = OpenAIChatApi(
255
+ ApiConfig(
256
+ host=host,
257
+ user="",
258
+ apikey="EMPTY",
259
+ model="Qwen3-30B-A3B-SFT",
260
+ api_version="",
261
+ )
262
  )
263
+ else:
264
+ self.model_path = model_path or "Text2MotionPrompter/Text2MotionPrompter"
265
+ self.tokenizer = None
266
+ self.model = None
267
+ self._load_model(device)
268
+
269
+ def _load_model(self, device="auto"):
270
+ if self.model is None:
271
+ print(f">>> Loading prompter model from {self.model_path}")
272
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
273
+ self.model = AutoModelForCausalLM.from_pretrained(
274
+ self.model_path,
275
+ torch_dtype=torch.float16,
276
+ device_map=device,
277
+ load_in_4bit=True,
278
+ )
279
+ self.model.eval()
280
 
281
  def rewrite_prompt_and_infer_time(
282
  self,
 
284
  prompt_format: str = REWRITE_AND_INFER_TIME_PROMPT_FORMAT,
285
  retry_config: Optional[RetryConfig] = None,
286
  ) -> Tuple[float, str]:
287
+ if self.host:
288
+ self.logger.info("Start rewriting prompt...")
289
+ try:
290
+ result, cost, elapsed = self.parser.call_data_eval_with_retry(
291
+ self.api, prompt_format.format(text), retry_config
292
+ )
293
+ self.logger.info(f"Rewriting completed - cost: {cost:.6f}, time: {elapsed:.2f}s")
294
+ return round(float(result["duration"]) / 30.0, 2), result["short_caption"]
295
+
296
+ except Exception as e:
297
+ self.logger.error(f"Prompt rewriting failed: {e}")
298
+ raise
299
+ else:
300
+ messages = [{"role": "user", "content": prompt_format.format(text)}]
301
+ full_prompt = self.tokenizer.apply_chat_template(
302
+ messages,
303
+ tokenize=False,
304
+ add_generation_prompt=True,
305
  )
306
+ inputs = self.tokenizer([full_prompt], return_tensors="pt").to(self.model.device)
307
+ with torch.no_grad():
308
+ outputs = self.model.generate(**inputs, max_new_tokens=8192)
309
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :].tolist(), skip_special_tokens=True)
310
 
311
+ try:
312
+ json_str = re.search(r"\{.*\}", response, re.DOTALL).group()
313
+ result = json.loads(json_str)
314
+ return round(float(result["duration"]) / 30.0, 2), result["short_caption"]
315
+ except:
316
+ return 5.0, text
317
 
318
 
319
  if __name__ == "__main__":
hymotion/utils/gradio_css.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ # create interface
4
+ APP_CSS = """
5
+ :root{
6
+ --primary-start:#667eea; --primary-end:#764ba2;
7
+ --secondary-start:#4facfe; --secondary-end:#00f2fe;
8
+ --accent-start:#f093fb; --accent-end:#f5576c;
9
+ --page-bg:linear-gradient(135deg,#f5f7fa 0%,#c3cfe2 100%);
10
+ --card-bg:linear-gradient(135deg,#ffffff 0%,#f8f9fa 100%);
11
+ --radius:12px;
12
+ --iframe-bg:#ffffff;
13
+ }
14
+
15
+ /* Dark mode variables */
16
+ [data-theme="dark"], .dark {
17
+ --page-bg:linear-gradient(135deg,#1a1a1a 0%,#2d3748 100%);
18
+ --card-bg:linear-gradient(135deg,#2d3748 0%,#374151 100%);
19
+ --text-primary:#f7fafc;
20
+ --text-secondary:#e2e8f0;
21
+ --border-color:#4a5568;
22
+ --input-bg:#374151;
23
+ --input-border:#4a5568;
24
+ --iframe-bg:#1a1a2e;
25
+ }
26
+
27
+ /* Page and card */
28
+ .gradio-container{
29
+ background:var(--page-bg) !important;
30
+ min-height:100vh !important;
31
+ color:var(--text-primary, #333) !important;
32
+ }
33
+
34
+ .main-header{
35
+ background:transparent !important; border:none !important; box-shadow:none !important;
36
+ padding:0 !important; margin:10px 0 16px !important;
37
+ text-align:center !important;
38
+ }
39
+
40
+ .main-header h1, .main-header p, .main-header li {
41
+ color:var(--text-primary, #333) !important;
42
+ }
43
+
44
+ .left-panel,.right-panel{
45
+ background:var(--card-bg) !important;
46
+ border:1px solid var(--border-color, #e9ecef) !important;
47
+ border-radius:15px !important;
48
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
49
+ padding:24px !important;
50
+ }
51
+
52
+ .gradio-accordion{
53
+ border:1px solid var(--border-color, #e1e5e9) !important;
54
+ border-radius:var(--radius) !important;
55
+ margin:12px 0 !important; background:transparent !important;
56
+ }
57
+
58
+ .gradio-accordion summary{
59
+ background:transparent !important;
60
+ padding:14px 18px !important;
61
+ font-weight:600 !important;
62
+ color:var(--text-primary, #495057) !important;
63
+ }
64
+
65
+ .gradio-group{
66
+ background:transparent !important; border:none !important;
67
+ border-radius:8px !important; padding:12px 0 !important; margin:8px 0 !important;
68
+ }
69
+
70
+ /* Input class style - dark mode adaptation */
71
+ .gradio-textbox input,.gradio-textbox textarea,.gradio-dropdown .wrap{
72
+ border-radius:8px !important;
73
+ border:2px solid var(--input-border, #e9ecef) !important;
74
+ background:var(--input-bg, #fff) !important;
75
+ color:var(--text-primary, #333) !important;
76
+ transition:.2s all !important;
77
+ }
78
+
79
+ .gradio-textbox input:focus,.gradio-textbox textarea:focus,.gradio-dropdown .wrap:focus-within{
80
+ border-color:var(--primary-start) !important;
81
+ box-shadow:0 0 0 3px rgba(102,126,234,.1) !important;
82
+ }
83
+
84
+ .gradio-slider input[type="range"]{
85
+ background:linear-gradient(to right,var(--primary-start),var(--primary-end)) !important;
86
+ border-radius:10px !important;
87
+ }
88
+
89
+ .gradio-checkbox input[type="checkbox"]{
90
+ border-radius:4px !important;
91
+ border:2px solid var(--input-border, #e9ecef) !important;
92
+ transition:.2s all !important;
93
+ }
94
+
95
+ .gradio-checkbox input[type="checkbox"]:checked{
96
+ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important;
97
+ border-color:var(--primary-start) !important;
98
+ }
99
+
100
+ /* Label text color adaptation */
101
+ .gradio-textbox label, .gradio-dropdown label, .gradio-slider label,
102
+ .gradio-checkbox label, .gradio-html label {
103
+ color:var(--text-primary, #333) !important;
104
+ }
105
+
106
+ .gradio-textbox .info, .gradio-dropdown .info, .gradio-slider .info,
107
+ .gradio-checkbox .info {
108
+ color:var(--text-secondary, #666) !important;
109
+ }
110
+
111
+ /* Status information - dark mode adaptation */
112
+ .gradio-textbox[data-testid*="状态信息"] input{
113
+ background:var(--input-bg, linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%)) !important;
114
+ border:2px solid var(--input-border, #dee2e6) !important;
115
+ color:var(--text-primary, #495057) !important;
116
+ font-weight:500 !important;
117
+ }
118
+
119
+ /* Button base class and variant */
120
+ .generate-button,.rewrite-button,.dice-button{
121
+ border:none !important; color:#fff !important; font-weight:600 !important;
122
+ border-radius:8px !important; transition:.3s all !important;
123
+ box-shadow:0 4px 15px rgba(0,0,0,.12) !important;
124
+ }
125
+
126
+ .generate-button{ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; }
127
+ .rewrite-button{ background:linear-gradient(45deg,var(--secondary-start),var(--secondary-end)) !important; }
128
+ .dice-button{
129
+ background:linear-gradient(45deg,var(--accent-start),var(--accent-end)) !important;
130
+ height:40px !important;
131
+ }
132
+
133
+ .generate-button:hover,.rewrite-button:hover{ transform:translateY(-2px) !important; }
134
+ .dice-button:hover{
135
+ transform:scale(1.05) !important;
136
+ box-shadow:0 4px 12px rgba(240,147,251,.28) !important;
137
+ }
138
+
139
+ .dice-container{
140
+ display:flex !important;
141
+ align-items:flex-end !important;
142
+ justify-content:center !important;
143
+ }
144
+
145
+ /* Right panel clipping overflow, avoid double scrollbars */
146
+ .right-panel{
147
+ background:var(--card-bg) !important;
148
+ border:1px solid var(--border-color, #e9ecef) !important;
149
+ border-radius:15px !important;
150
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
151
+ padding:24px !important; overflow:hidden !important;
152
+ }
153
+
154
+ /* Main content row - ensure equal heights */
155
+ .main-row {
156
+ display: flex !important;
157
+ align-items: stretch !important;
158
+ }
159
+
160
+ /* Flask area - match left panel height */
161
+ .flask-display{
162
+ padding:0 !important; margin:0 !important; border:none !important;
163
+ box-shadow:none !important; background:var(--iframe-bg) !important;
164
+ border-radius:10px !important; position:relative !important;
165
+ height:100% !important; min-height:750px !important;
166
+ display:flex !important; flex-direction:column !important;
167
+ }
168
+
169
+ .flask-display iframe{
170
+ width:100% !important; flex:1 !important; min-height:750px !important;
171
+ border:none !important; border-radius:10px !important; display:block !important;
172
+ background:var(--iframe-bg) !important;
173
+ }
174
+
175
+ /* Right panel should stretch to match left panel */
176
+ .right-panel{
177
+ background:var(--card-bg) !important;
178
+ border:1px solid var(--border-color, #e9ecef) !important;
179
+ border-radius:15px !important;
180
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
181
+ padding:24px !important; overflow:hidden !important;
182
+ display:flex !important; flex-direction:column !important;
183
+ }
184
+
185
+ /* Ensure dropdown menu is visible in dark mode */
186
+ [data-theme="dark"] .gradio-dropdown .wrap,
187
+ .dark .gradio-dropdown .wrap {
188
+ background:var(--input-bg) !important;
189
+ color:var(--text-primary) !important;
190
+ }
191
+
192
+ [data-theme="dark"] .gradio-dropdown .option,
193
+ .dark .gradio-dropdown .option {
194
+ background:var(--input-bg) !important;
195
+ color:var(--text-primary) !important;
196
+ }
197
+
198
+ [data-theme="dark"] .gradio-dropdown .option:hover,
199
+ .dark .gradio-dropdown .option:hover {
200
+ background:var(--border-color) !important;
201
+ }
202
+
203
+ .footer{
204
+ text-align:center !important;
205
+ margin-top:20px !important;
206
+ padding:10px !important;
207
+ color:var(--text-secondary, #666) !important;
208
+ }
209
+ """
210
+
211
+ HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground\n### *Tencent Hunyuan 3D Digital Human Team*"
212
+
213
+ FOOTER_MD = "*This is a Beta version, any issues or feedback are welcome!*"
214
+
215
+ # Path to placeholder scene HTML template
216
+ PLACEHOLDER_SCENE_TEMPLATE = osp.join(osp.dirname(__file__), "..", "..", "scripts/gradio/templates/placeholder_scene.html")
217
+
218
+
219
+
220
+ def get_placeholder_html() -> str:
221
+ """
222
+ Load the placeholder scene HTML and wrap it in an iframe for display.
223
+ Returns an iframe HTML string with the embedded placeholder scene.
224
+ """
225
+ try:
226
+ with open(PLACEHOLDER_SCENE_TEMPLATE, "r", encoding="utf-8") as f:
227
+ html_content = f.read()
228
+ # Escape HTML content for srcdoc attribute
229
+ escaped_html = html_content.replace('"', '&quot;')
230
+ iframe_html = f'''
231
+ <iframe
232
+ srcdoc="{escaped_html}"
233
+ width="100%"
234
+ height="750px"
235
+ style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
236
+ ></iframe>
237
+ '''
238
+ return iframe_html
239
+ except Exception as e:
240
+ print(f">>> Failed to load placeholder scene HTML: {e}")
241
+ # Fallback to simple placeholder
242
+ return """
243
+ <div style='height: 750px; width: 100%; border-radius: 8px; border-color: #e5e7eb; border-style: solid; border-width: 1px; display: flex; justify-content: center; align-items: center; background: #424242;'>
244
+ <div style='text-align: center; font-size: 16px; color: #a0aec0;'>
245
+ <p>Welcome to HY-Motion-1.0!</p>
246
+ <p>Enter a text description and generate motion to see the 3D visualization here.</p>
247
+ </div>
248
+ </div>
249
+ """
250
+
hymotion/utils/gradio_runtime.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import time
4
+ import uuid
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import yaml
9
+
10
+ from ..prompt_engineering.prompt_rewrite import PromptRewriter
11
+ from .loaders import load_object
12
+ from .visualize_mesh_web import save_visualization_data, generate_static_html_content
13
+
14
+ try:
15
+ import fbx
16
+
17
+ FBX_AVAILABLE = True
18
+ print(">>> FBX module found.")
19
+ except ImportError:
20
+ FBX_AVAILABLE = False
21
+ print(">>> FBX module not found.")
22
+
23
+
24
+ def _now():
25
+ t = time.time()
26
+ ms = int((t - int(t)) * 1000)
27
+ return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
28
+
29
+ _MODEL_CACHE = None
30
+
31
+
32
+ class SimpleRuntime(torch.nn.Module):
33
+ def __init__(self, config_path, ckpt_name, load_prompt_engineering=False, load_text_encoder=False):
34
+ super().__init__()
35
+ self.load_prompt_engineering = load_prompt_engineering
36
+ self.load_text_encoder = load_text_encoder
37
+ # prompt engineering
38
+ if self.load_prompt_engineering:
39
+ print(f"[{self.__class__.__name__}] Loading prompt engineering...")
40
+ self.prompt_rewriter = PromptRewriter(
41
+ host=None, model_path=None, device="cpu"
42
+ )
43
+ else:
44
+ self.prompt_rewriter = None
45
+ # text encoder
46
+ if self.load_text_encoder:
47
+ print(f"[{self.__class__.__name__}] Loading text encoder...")
48
+ _text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
49
+ _text_encoder_cfg = {
50
+ "llm_type": "qwen3",
51
+ "max_length_llm": 128
52
+ }
53
+ text_encoder = load_object(_text_encoder_module, _text_encoder_cfg)
54
+ else:
55
+ text_encoder = None
56
+ # 2. load model
57
+ print(f"[{self.__class__.__name__}] Loading model...")
58
+ with open(config_path, "r") as f:
59
+ config = yaml.load(f, Loader=yaml.FullLoader)
60
+ pipeline = load_object(
61
+ config["train_pipeline"],
62
+ config["train_pipeline_args"],
63
+ network_module=config["network_module"],
64
+ network_module_args=config["network_module_args"],
65
+ )
66
+ print(f"[{self.__class__.__name__}] Loading ckpt: {ckpt_name}")
67
+ pipeline.load_in_demo(
68
+ os.path.join(os.path.dirname(config_path), ckpt_name),
69
+ "stats",
70
+ build_text_encoder=False,
71
+ allow_empty_ckpt=False,
72
+ )
73
+ pipeline.text_encoder = text_encoder
74
+ self.pipeline = pipeline
75
+ #
76
+ self.fbx_available = FBX_AVAILABLE
77
+ if self.fbx_available:
78
+ try:
79
+ from .smplh2woodfbx import SMPLH2WoodFBX
80
+
81
+ self.fbx_converter = SMPLH2WoodFBX()
82
+ except Exception as e:
83
+ print(f">>> Failed to initialize FBX converter: {e}")
84
+ self.fbx_available = False
85
+ self.fbx_converter = None
86
+ else:
87
+ self.fbx_converter = None
88
+ print(">>> FBX module not found. FBX export will be disabled.")
89
+
90
+
91
+ def _generate_html_content(
92
+ self,
93
+ timestamp: str,
94
+ file_path: str,
95
+ output_dir: Optional[str] = None,
96
+ ) -> str:
97
+ """
98
+ Generate static HTML content with embedded data for iframe srcdoc.
99
+ All JavaScript code is embedded directly in the HTML, no external static resources needed.
100
+
101
+ Args:
102
+ timestamp: Timestamp string for logging
103
+ file_path: Base filename (without extension)
104
+ output_dir: Directory where NPZ/meta files are stored
105
+
106
+ Returns:
107
+ HTML content string (to be used in iframe srcdoc)
108
+ """
109
+ print(f">>> Generating static HTML content, timestamp: {timestamp}")
110
+ gradio_dir = output_dir if output_dir is not None else "output/gradio"
111
+
112
+ try:
113
+ # Generate static HTML content with embedded data (all JS is embedded in template)
114
+ html_content = generate_static_html_content(
115
+ folder_name=gradio_dir,
116
+ file_name=file_path,
117
+ hide_captions=False,
118
+ )
119
+
120
+ print(f">>> Static HTML content generated for: {file_path}")
121
+ return html_content
122
+
123
+ except Exception as e:
124
+ print(f">>> Failed to generate static HTML content: {e}")
125
+ import traceback
126
+
127
+ traceback.print_exc()
128
+ # Return error HTML
129
+ return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
130
+
131
+
132
+ def _generate_fbx_files(
133
+ self,
134
+ visualization_data: dict,
135
+ output_dir: Optional[str] = None,
136
+ fbx_filename: Optional[str] = None,
137
+ ) -> List[str]:
138
+ assert "smpl_data" in visualization_data, "smpl_data not found in visualization_data"
139
+ fbx_files = []
140
+ if output_dir is None:
141
+ root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
142
+ output_dir = os.path.join(root_dir, "output", "gradio")
143
+
144
+ smpl_data_list = visualization_data["smpl_data"]
145
+
146
+ unique_id = str(uuid.uuid4())[:8]
147
+ text = visualization_data["text"]
148
+ timestamp = visualization_data["timestamp"]
149
+ for bb in range(len(smpl_data_list)):
150
+ smpl_data = smpl_data_list[bb]
151
+ if fbx_filename is None:
152
+ fbx_filename_bb = f"{timestamp}_{unique_id}_{bb:03d}.fbx"
153
+ else:
154
+ fbx_filename_bb = f"{fbx_filename}_{bb:03d}.fbx"
155
+ fbx_path = os.path.join(output_dir, fbx_filename_bb)
156
+ success = self.fbx_converter.convert_npz_to_fbx(smpl_data, fbx_path)
157
+ if success:
158
+ fbx_files.append(fbx_path)
159
+ print(f"\t>>> FBX file generated: {fbx_path}")
160
+ txt_path = fbx_path.replace(".fbx", ".txt")
161
+ with open(txt_path, "w", encoding="utf-8") as f:
162
+ f.write(text)
163
+ fbx_files.append(txt_path)
164
+
165
+ return fbx_files
166
+
167
+ def generate_motion(
168
+ self,
169
+ text: str,
170
+ seeds_csv: str,
171
+ duration: float,
172
+ cfg_scale: float,
173
+ output_format: str = "fbx",
174
+ output_dir: Optional[str] = None,
175
+ output_filename: Optional[str] = None,
176
+ original_text: Optional[str] = None,
177
+ use_special_game_feat: bool = False,
178
+ ) -> Tuple[Union[str, list[str]], dict]:
179
+ seeds = [int(s.strip()) for s in seeds_csv.split(",") if s.strip() != ""]
180
+
181
+ print(f"[{self.__class__.__name__}] Generating motion...")
182
+ print(f"[{self.__class__.__name__}] text: {text}")
183
+ if self.load_prompt_engineering:
184
+ duration, rewritten_text = self.prompt_rewriter.rewrite_prompt_and_infer_time(f"{text}")
185
+ else:
186
+ rewritten_text = text
187
+ duration = duration
188
+
189
+ pipeline = self.pipeline
190
+ pipeline.eval()
191
+
192
+ # When skip_text=True (debug mode), use blank text features
193
+ if not self.load_text_encoder:
194
+ print(">>> [Debug Mode] Using blank text features (skip_text=True)")
195
+ device = next(pipeline.parameters()).device
196
+ batch_size = len(seeds) if seeds else 1
197
+ # Create blank hidden_state_dict using null features
198
+ hidden_state_dict = {
199
+ "text_vec_raw": pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device),
200
+ "text_ctxt_raw": pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device),
201
+ "text_ctxt_raw_length": torch.tensor([1] * batch_size, device=device),
202
+ }
203
+ # Disable CFG in debug mode (use cfg_scale=1.0)
204
+ model_output = pipeline.generate(
205
+ rewritten_text,
206
+ seeds,
207
+ duration,
208
+ cfg_scale=1.0,
209
+ use_special_game_feat=False,
210
+ hidden_state_dict=hidden_state_dict,
211
+ )
212
+ else:
213
+ model_output = pipeline.generate(
214
+ rewritten_text, seeds, duration, cfg_scale=cfg_scale, use_special_game_feat=use_special_game_feat
215
+ )
216
+
217
+ ts = _now()
218
+ save_data, base_filename = save_visualization_data(
219
+ output=model_output,
220
+ text=text if original_text is None else original_text,
221
+ rewritten_text=rewritten_text,
222
+ timestamp=ts,
223
+ output_dir=output_dir,
224
+ output_filename=output_filename,
225
+ )
226
+
227
+ html_content = self._generate_html_content(
228
+ timestamp=ts,
229
+ file_path=base_filename,
230
+ output_dir=output_dir,
231
+ )
232
+
233
+ if output_format == "fbx" and not self.fbx_available:
234
+ print(">>> Warning: FBX export requested but FBX SDK is not available. Falling back to dict format.")
235
+ output_format = "dict"
236
+
237
+ if output_format == "fbx" and self.fbx_available:
238
+ fbx_files = self._generate_fbx_files(
239
+ visualization_data=save_data,
240
+ output_dir=output_dir,
241
+ fbx_filename=output_filename,
242
+ )
243
+ return html_content, fbx_files, model_output
244
+ elif output_format == "dict":
245
+ # Return HTML content and empty list for fbx_files when using dict format
246
+ return html_content, [], model_output
247
+ else:
248
+ raise ValueError(f">>> Invalid output format: {output_format}")
249
+
250
+ class ModelInference:
251
+ """
252
+ Handles model inference and data processing for Depth Anything 3.
253
+ """
254
+
255
+ def __init__(self, model_path, use_prompt_engineering, use_text_encoder):
256
+ """Initialize the model inference handler.
257
+
258
+ Note: Do not store model in instance variable to avoid
259
+ cross-process state issues with @spaces.GPU decorator.
260
+ """
261
+ # No instance variables - model cached in global variable
262
+ self.model_path = model_path
263
+ self.use_prompt_engineering = use_prompt_engineering
264
+ self.use_text_encoder = use_text_encoder
265
+ self.fbx_available = FBX_AVAILABLE
266
+
267
+ def initialize_model(self, device: str = "cuda"):
268
+ """
269
+ Initialize the DepthAnything3 model using global cache.
270
+
271
+ Optimization: Load model to CPU first, then move to GPU when needed.
272
+ This is faster than reloading from disk each time.
273
+
274
+ This uses a global variable which is safe because @spaces.GPU
275
+ runs in isolated subprocess, each with its own global namespace.
276
+ Args:
277
+ device: Device to run inference on (will move model to this device)
278
+
279
+ Returns:
280
+ Model instance ready for inference on specified device
281
+ """
282
+ global _MODEL_CACHE
283
+
284
+ if _MODEL_CACHE is None:
285
+ # First time loading in this subprocess
286
+ # Load to CPU first (faster than loading directly to GPU)
287
+ _MODEL_CACHE = SimpleRuntime(
288
+ config_path=os.path.join(self.model_path, "config.yml"),
289
+ ckpt_name="latest.ckpt",
290
+ load_prompt_engineering=self.use_prompt_engineering,
291
+ load_text_encoder=self.use_text_encoder
292
+ )
293
+ # Load to CPU first (faster, and allows reuse)
294
+ _MODEL_CACHE = _MODEL_CACHE.to("cpu")
295
+ _MODEL_CACHE.eval()
296
+ print("✅ Model loaded to CPU memory (cached in subprocess)")
297
+
298
+ # Move to target device for inference
299
+ if device != "cpu" and next(_MODEL_CACHE.parameters()).device.type != device:
300
+ print(f"🚀 Moving model from {next(_MODEL_CACHE.parameters()).device} to {device}...")
301
+ _MODEL_CACHE = _MODEL_CACHE.to(device)
302
+ print(f"✅ Model ready on {device}")
303
+ elif device == "cpu":
304
+ # Already on CPU or requested CPU
305
+ pass
306
+
307
+ return _MODEL_CACHE
308
+
309
+ def run_inference(
310
+ self, *args, **kwargs
311
+ ):
312
+ """
313
+ Run DepthAnything3 model inference on images.
314
+ Args:
315
+ target_dir: Directory containing images
316
+ apply_mask: Whether to apply mask for ambiguous depth classes
317
+ mask_edges: Whether to mask edges
318
+ filter_black_bg: Whether to filter black background
319
+ filter_white_bg: Whether to filter white background
320
+ process_res_method: Method for resizing input images
321
+ show_camera: Whether to show camera in 3D view
322
+ selected_first_frame: Selected first frame filename
323
+ save_percentage: Percentage of points to save (0-100)
324
+ infer_gs: Whether to infer 3D Gaussian Splatting
325
+ Returns:
326
+ Tuple of (prediction, processed_data)
327
+ """
328
+ # Device check
329
+ device = "cuda" if torch.cuda.is_available() else "cpu"
330
+ device = torch.device(device)
331
+
332
+ # Initialize model if needed - get model instance (not stored in self)
333
+ model = self.initialize_model(device)
334
+
335
+
336
+ with torch.no_grad():
337
+ print(f"[{self.__class__.__name__}] Running inference...")
338
+ html_content, fbx_files, model_output = model.generate_motion(*args, **kwargs)
339
+ # CRITICAL: Move all CUDA tensors to CPU before returning
340
+ # This prevents CUDA initialization in main process during unpickling
341
+ for k, val in model_output.items():
342
+ if isinstance(val, torch.Tensor):
343
+ model_output[k] = val.detach().cpu()
344
+ # # Clean up
345
+ torch.cuda.empty_cache()
346
+
347
+ return html_content, fbx_files
348
+
349
+ if __name__ == "__main__":
350
+ # python -m hymotion.utils.gradio_runtime
351
+ runtime = SimpleRuntime(config_path="assets/config_simplified.yml", ckpt_name="latest.ckpt", load_prompt_engineering=False, load_text_encoder=False)
352
+ print(runtime.pipeline)
hymotion/utils/gradio_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import snapshot_download
3
+
4
+ # 本地模型路径配置(如果已经下载,直接使用本地路径)
5
+ QWEN_LOCAL_PATH = "ckpts/Qwen3-8B"
6
+ CLIP_LOCAL_PATH = "ckpts/clip-vit-large-patch14"
7
+
8
+ def try_to_download_text_encoder():
9
+ """
10
+ Pre-download text encoder models (Qwen3-8B and CLIP) to local cache.
11
+ This ensures the models are cached locally before they are needed,
12
+ so later loading will not require downloading again.
13
+
14
+ If models already exist in local paths (ckpts/), skip downloading.
15
+ """
16
+ # Text encoder model IDs (same as in hymotion/network/text_encoders/text_encoder.py)
17
+ QWEN_REPO_ID = "Qwen/Qwen3-8B"
18
+ CLIP_REPO_ID = "openai/clip-vit-large-patch14"
19
+
20
+ token = os.environ.get("HF_TOKEN", None)
21
+ if token is None:
22
+ token = ""
23
+
24
+ # 检查 Qwen3-8B 是否已存在
25
+ if os.path.exists(QWEN_LOCAL_PATH) and os.path.isdir(QWEN_LOCAL_PATH):
26
+ print(f">>> Found local Qwen model at: {QWEN_LOCAL_PATH}, skipping download.")
27
+ else:
28
+ print(f">>> Pre-downloading text encoder: {QWEN_REPO_ID} to {QWEN_LOCAL_PATH}")
29
+ try:
30
+ snapshot_download(
31
+ repo_id=QWEN_REPO_ID,
32
+ local_dir=QWEN_LOCAL_PATH,
33
+ token=token,
34
+ )
35
+ print(f">>> Successfully pre-downloaded: {QWEN_REPO_ID}")
36
+ except Exception as e:
37
+ print(f">>> [WARNING] Failed to pre-download {QWEN_REPO_ID}: {e}")
38
+
39
+ # 检查 CLIP 是否已存在
40
+ if os.path.exists(CLIP_LOCAL_PATH) and os.path.isdir(CLIP_LOCAL_PATH):
41
+ print(f">>> Found local CLIP model at: {CLIP_LOCAL_PATH}, skipping download.")
42
+ else:
43
+ print(f">>> Pre-downloading text encoder: {CLIP_REPO_ID} to {CLIP_LOCAL_PATH}")
44
+ try:
45
+ snapshot_download(
46
+ repo_id=CLIP_REPO_ID,
47
+ local_dir=CLIP_LOCAL_PATH,
48
+ token=token,
49
+ )
50
+ print(f">>> Successfully pre-downloaded: {CLIP_REPO_ID}")
51
+ except Exception as e:
52
+ print(f">>> [WARNING] Failed to pre-download {CLIP_REPO_ID}: {e}")
53
+
54
+ print(">>> Text encoder pre-download complete.")
55
+
56
+
57
+ def try_to_download_model():
58
+ repo_id = "tencent/HY-Motion-1.0"
59
+ target_folder = "HY-Motion-1.0-Lite"
60
+ print(f">>> start download ", repo_id, target_folder)
61
+ token = os.environ.get("HF_TOKEN", None)
62
+ if token is None:
63
+ token = ""
64
+ local_dir = snapshot_download(
65
+ repo_id=repo_id,
66
+ allow_patterns=f"{target_folder}/*",
67
+ local_dir="./downloaded_models",
68
+ token=token
69
+ )
70
+ final_model_path = os.path.join(local_dir, target_folder)
71
+ print(f">>> Final model path: {final_model_path}")
72
+ return final_model_path
hymotion/utils/smplh2fbx.py DELETED
@@ -1,585 +0,0 @@
1
- import glob
2
- import os
3
- import shutil
4
- import sys
5
- import tempfile
6
-
7
- import fbx
8
- import numpy as np
9
- import torch
10
- from transforms3d.euler import mat2euler
11
-
12
- from .geometry import angle_axis_to_rotation_matrix, rot_mat2trans_mat, trans2trans_mat
13
-
14
- # yapf: disable
15
- SMPLH_JOINT2NUM = {
16
- "Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3,
17
- "L_Knee": 4, "R_Knee": 5, "Spine2": 6,
18
- "L_Ankle": 7, "R_Ankle": 8,
19
- "Spine3": 9,
20
- "L_Foot": 10, "R_Foot": 11,
21
- "Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15,
22
- "L_Shoulder": 16, "R_Shoulder": 17,
23
- "L_Elbow": 18, "R_Elbow": 19,
24
- "L_Wrist": 20, "R_Wrist": 21,
25
- # "Jaw": 22, "L_Eye": 23, "R_Eye": 24,
26
- "L_Index1": 22, "L_Index2": 23, "L_Index3": 24,
27
- "L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27,
28
- "L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30,
29
- "L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33,
30
- "L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36,
31
- "R_Index1": 37, "R_Index2": 38, "R_Index3": 39,
32
- "R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42,
33
- "R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45,
34
- "R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48,
35
- "R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51,
36
- }
37
- # yapf: enable
38
-
39
-
40
- def _parse_obj_file(obj_path):
41
- vertices = []
42
- uv_coords = []
43
- faces = []
44
- uv_faces = []
45
-
46
- with open(obj_path, "r") as f:
47
- for line in f:
48
- line = line.strip()
49
- if line.startswith("v "):
50
- parts = line.split()
51
- vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
52
- elif line.startswith("vt "):
53
- parts = line.split()
54
- uv_coords.append([float(parts[1]), float(parts[2])])
55
- elif line.startswith("f "):
56
- parts = line.split()
57
- face_vertices = []
58
- face_uvs = []
59
- for part in parts[1:]:
60
- indices = part.split("/")
61
- face_vertices.append(int(indices[0]) - 1)
62
- if len(indices) > 1 and indices[1]:
63
- face_uvs.append(int(indices[1]) - 1)
64
-
65
- if len(face_vertices) == 3:
66
- faces.append(face_vertices)
67
- if len(face_uvs) == 3:
68
- uv_faces.append(face_uvs)
69
-
70
- return np.array(vertices), np.array(uv_coords), np.array(faces), np.array(uv_faces)
71
-
72
-
73
- def _blend_shapes(betas: torch.Tensor, shape_disps: torch.Tensor) -> torch.Tensor:
74
- """Calculates the per vertex displacement due to the blend shapes.
75
-
76
- Parameters
77
- ----------
78
- betas : torch.tensor Bx(num_betas)
79
- Blend shape coefficients
80
- shape_disps: torch.tensor Vx3x(num_betas)
81
- Blend shapes
82
-
83
- Returns
84
- -------
85
- torch.tensor BxVx3
86
- The per-vertex displacement due to shape deformation
87
- """
88
-
89
- # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
90
- # i.e. Multiply each shape displacement by its corresponding beta and
91
- # then sum them.
92
- blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps])
93
- return blend_shape
94
-
95
-
96
- def _vertices2joints(J_regressor: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor:
97
- """Calculates the 3D joint locations from the vertices.
98
-
99
- Parameters
100
- ----------
101
- J_regressor : torch.tensor JxV
102
- The regressor array that is used to calculate the joints from the
103
- position of the vertices
104
- vertices : torch.tensor BxVx3
105
- The tensor of mesh vertices
106
-
107
- Returns
108
- -------
109
- torch.tensor BxJx3
110
- The location of the joints
111
- """
112
-
113
- return torch.einsum("bik,ji->bjk", [vertices, J_regressor])
114
-
115
-
116
- def _addSmplXMesh(fbxScene, v_posed, faces, uv_coords=None, uv_faces=None):
117
- # Obtain a reference to the scene's root node.
118
- rootNode = fbxScene.GetRootNode()
119
-
120
- # Create a new node in the scene.
121
- geometryNode = fbx.FbxNode.Create(fbxScene, "Geometry")
122
- rootNode.AddChild(geometryNode)
123
-
124
- # Create a new mesh node attribute in the scene, and
125
- # set it as the new node's attribute
126
- mesh = fbx.FbxMesh.Create(fbxScene, "body")
127
- geometryNode.SetNodeAttribute(mesh)
128
-
129
- # Define the new mesh's control points.
130
- # v_posed, faces = smplx['v_posed'], smplx['faces']
131
- v_posed = np.array(v_posed)
132
- faces = np.array(faces)
133
-
134
- minValue = np.min(v_posed)
135
- maxValue = np.max(v_posed)
136
- # print(f"min = {minValue}, max = {maxValue}")
137
- # print("min = {}, max = {}".format(minValue, maxValue))
138
-
139
- # m = axangle2mat((1, 0, 0), np.radians(180))
140
-
141
- mesh.InitControlPoints(v_posed.shape[0])
142
- for i in range(v_posed.shape[0]):
143
- v = v_posed[i, :]
144
- # v = np.matmul(m, v)
145
- vertex = fbx.FbxVector4(v[0], v[1], v[2])
146
- mesh.SetControlPointAt(vertex, i)
147
-
148
- for i in range(faces.shape[0]):
149
- mesh.BeginPolygon(i)
150
- mesh.AddPolygon(faces[i, 0])
151
- mesh.AddPolygon(faces[i, 1])
152
- mesh.AddPolygon(faces[i, 2])
153
- mesh.EndPolygon()
154
-
155
- if uv_coords is not None and uv_faces is not None:
156
- uv_layer = mesh.CreateElementUV("UVSet")
157
- uv_layer.SetMappingMode(fbx.FbxLayerElement.EMappingMode.eByPolygonVertex)
158
- uv_layer.SetReferenceMode(fbx.FbxLayerElement.EReferenceMode.eIndexToDirect)
159
-
160
- uv_array = uv_layer.GetDirectArray()
161
- for i in range(len(uv_coords)):
162
- uv_array.Add(fbx.FbxVector2(uv_coords[i][0], uv_coords[i][1]))
163
-
164
- uv_index_array = uv_layer.GetIndexArray()
165
- for i in range(len(uv_faces)):
166
- for j in range(3):
167
- uv_index_array.Add(uv_faces[i][j])
168
- return geometryNode
169
-
170
-
171
- def _addSmplXSkeleton(fbxManager, fbxScene, trans, joint2num, kintree_table):
172
- num2joint = ["" for key in joint2num]
173
- for key, value in joint2num.items():
174
- num2joint[value] = key
175
-
176
- # trans = np.array(trans)
177
-
178
- # Obtain a reference to the scene's root node.
179
- rootNode = fbxScene.GetRootNode()
180
-
181
- # Create a new node in the scene.
182
- referenceNode = fbx.FbxNode.Create(fbxScene, "Reference")
183
- rootNode.AddChild(referenceNode)
184
-
185
- # Create skeletons
186
- skeletonNodes = []
187
- for nth in range(len(kintree_table)):
188
- skeleton = fbx.FbxSkeleton.Create(fbxManager, "")
189
- skeleton.SetSkeletonType(fbx.FbxSkeleton.EType.eRoot if nth == -1 else fbx.FbxSkeleton.EType.eLimbNode)
190
-
191
- node = fbx.FbxNode.Create(fbxScene, num2joint[nth])
192
- node.SetNodeAttribute(skeleton)
193
-
194
- node.LclTranslation.Set(fbx.FbxDouble3(trans[nth, 0], trans[nth, 1], trans[nth, 2]))
195
-
196
- skeletonNodes.append(node)
197
-
198
- if kintree_table[nth] != -1:
199
- skeletonNodes[kintree_table[nth]].AddChild(node)
200
-
201
- referenceNode.AddChild(skeletonNodes[0])
202
- return referenceNode, skeletonNodes
203
-
204
-
205
- def _addSkiningWeight(fbxScene, lbs_weights, geometryNode, skeletonNodes):
206
- clusters = []
207
- for i in range(lbs_weights.shape[1]):
208
- cluster = fbx.FbxCluster.Create(fbxScene, "")
209
- cluster.SetLink(skeletonNodes[i])
210
- cluster.SetLinkMode(fbx.FbxCluster.ELinkMode.eTotalOne)
211
-
212
- for j in range(lbs_weights.shape[0]):
213
- weight = lbs_weights[j, i]
214
- if weight > 0:
215
- cluster.AddControlPointIndex(j, weight)
216
-
217
- clusters.append(cluster)
218
-
219
- # Now we have the Geometry and the skeleton correctly positioned,
220
- # set the transform and TransformLink matrix accordingly.
221
- matrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(geometryNode)
222
- for cluster in clusters:
223
- cluster.SetTransformMatrix(matrix)
224
-
225
- for i in range(len(skeletonNodes)):
226
- matrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(skeletonNodes[i])
227
- clusters[i].SetTransformLinkMatrix(matrix)
228
-
229
- # Add the clusters to the patch by creating a skin and adding those clusters to that skin.
230
- skin = fbx.FbxSkin.Create(fbxScene, "")
231
- for cluster in clusters:
232
- skin.AddCluster(cluster)
233
- geometryNode.GetNodeAttribute().AddDeformer(skin)
234
-
235
-
236
- def _storeBindPose(fbxScene, geometryNode):
237
- # In the bind pose, we must store all the link's global matrix at the
238
- # time of the bind.
239
- # Plus, we must store all the parent(s) global matrix of a link, even
240
- # if they are not themselves deforming any model.
241
-
242
- clusteredNodes = []
243
- if geometryNode and geometryNode.GetNodeAttribute():
244
- skinCount = 0
245
- clusterCount = 0
246
- attributeType = geometryNode.GetNodeAttribute().GetAttributeType()
247
- if attributeType in (
248
- fbx.FbxNodeAttribute.EType.eMesh,
249
- fbx.FbxNodeAttribute.EType.eNurbs,
250
- fbx.FbxNodeAttribute.EType.ePatch,
251
- ):
252
- skinCount = geometryNode.GetNodeAttribute().GetDeformerCount(fbx.FbxDeformer.EDeformerType.eSkin)
253
- for i in range(skinCount):
254
- skin = geometryNode.GetNodeAttribute().GetDeformer(i, fbx.FbxDeformer.EDeformerType.eSkin)
255
- clusterCount += skin.GetClusterCount()
256
-
257
- if clusterCount:
258
- for i in range(skinCount):
259
- skin = geometryNode.GetNodeAttribute().GetDeformer(i, fbx.FbxDeformer.EDeformerType.eSkin)
260
- clusterCount = skin.GetClusterCount()
261
- for j in range(clusterCount):
262
- link = skin.GetCluster(j).GetLink()
263
- _addNodeRecursively(clusteredNodes, link)
264
-
265
- # Add the geometry to the pose
266
- clusteredNodes += [geometryNode]
267
-
268
- # Now create a bind pose with the link list
269
- if len(clusteredNodes):
270
- # A pose must be named. Arbitrarily use the name of the geometry node.
271
- pose = fbx.FbxPose.Create(fbxScene, geometryNode.GetName())
272
- pose.SetIsBindPose(True)
273
-
274
- for node in clusteredNodes:
275
- bindMatrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(node)
276
- pose.Add(node, fbx.FbxMatrix(bindMatrix))
277
-
278
- fbxScene.AddPose(pose)
279
-
280
-
281
- def _addNodeRecursively(nodeArray, node):
282
- """Add the specified node to the node array.
283
-
284
- Also, add recursively all the parent node of the specified node to the array.
285
- """
286
- if node:
287
- _addNodeRecursively(nodeArray, node.GetParent())
288
- found = False
289
- if node in nodeArray:
290
- if node.GetName() == node.GetName():
291
- found = True
292
- if not found:
293
- nodeArray += [node]
294
-
295
-
296
- def _animateGlobalTransformsFromTransMat(animLayer, referenceNode, global_translation, frameDuration):
297
- _animateSingleChannel(animLayer, referenceNode.LclTranslation, "X", global_translation, frameDuration)
298
- _animateSingleChannel(animLayer, referenceNode.LclTranslation, "Y", global_translation, frameDuration)
299
- _animateSingleChannel(animLayer, referenceNode.LclTranslation, "Z", global_translation, frameDuration)
300
-
301
-
302
- def _animateSingleChannel(animLayer, component, name, values, frameDuration):
303
- ncomp = 0
304
-
305
- if name == "X":
306
- ncomp = 0
307
- elif name == "Y":
308
- ncomp = 1
309
- elif name == "Z":
310
- ncomp = 2
311
-
312
- time = fbx.FbxTime()
313
- curve = component.GetCurve(animLayer, name, True)
314
- curve.KeyModifyBegin()
315
- for nth in range(len(values)):
316
- time.SetSecondDouble(nth * frameDuration)
317
- keyIndex = curve.KeyAdd(time)[0]
318
- curve.KeySetValue(keyIndex, values[nth][ncomp])
319
- curve.KeySetInterpolation(
320
- keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant
321
- ) # NOTE: using eInterpolationCubic to do interpolation causes error.
322
- curve.KeyModifyEnd()
323
-
324
-
325
- def _animateRotationKeyFrames(animLayer, node, transforms_mat, frameDuration):
326
- rotations = []
327
- for nth in range(len(transforms_mat)):
328
- rotations.append(np.rad2deg(mat2euler(transforms_mat[nth][0:3, 0:3], axes="sxyz")))
329
-
330
- _animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration)
331
- _animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration)
332
- _animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration)
333
-
334
-
335
- def _animateTranslationKeyFrames(animLayer, node, transforms_mat, frameDuration):
336
- translations = []
337
- for nth in range(len(transforms_mat)):
338
- translations.append(transforms_mat[nth][0:3, 3])
339
-
340
- _animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration)
341
- _animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration)
342
- _animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration)
343
-
344
-
345
- def _animateScalingKeyFrames(animLayer, node, transforms_mat, frameDuration):
346
- scalings = []
347
- for nth in range(len(transforms_mat)):
348
- scalings.append(
349
- np.array(
350
- (
351
- transforms_mat[nth][0, 0],
352
- transforms_mat[nth][1, 1],
353
- transforms_mat[nth][2, 2],
354
- )
355
- )
356
- )
357
-
358
- _animateSingleChannel(animLayer, node.LclTranslation, "X", scalings, frameDuration)
359
- _animateSingleChannel(animLayer, node.LclTranslation, "Y", scalings, frameDuration)
360
- _animateSingleChannel(animLayer, node.LclTranslation, "Z", scalings, frameDuration)
361
-
362
-
363
- def _animateSkeleton(fbxScene, skeletonNodes, frames, frameRate, name="Take1"):
364
- frameDuration = 1.0 / frameRate
365
-
366
- if name != "Take1":
367
- subs = name.split("/")
368
- name = subs[-1][:-5]
369
-
370
- animStack = fbx.FbxAnimStack.Create(fbxScene, name)
371
- animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer")
372
- animStack.AddMember(animLayer)
373
- _animateGlobalTransformsFromTransMat(
374
- animLayer=animLayer,
375
- referenceNode=skeletonNodes[0],
376
- global_translation=frames[:, 0, :3, 3],
377
- frameDuration=frameDuration,
378
- )
379
-
380
- for nId in range(len(skeletonNodes)):
381
- _animateRotationKeyFrames(
382
- animLayer=animLayer,
383
- node=skeletonNodes[nId],
384
- transforms_mat=frames[:, nId],
385
- frameDuration=frameDuration,
386
- )
387
-
388
-
389
- def _saveScene(filename, fbxManager, fbxScene):
390
- exporter = fbx.FbxExporter.Create(fbxManager, "")
391
- isInitialized = exporter.Initialize(filename)
392
-
393
- if isInitialized is False:
394
- raise Exception(
395
- "Exporter failed to initialized. Error returned: {}".format(exporter.GetStatus().GetErrorString())
396
- )
397
-
398
- exporter.Export(fbxScene)
399
- exporter.Destroy()
400
-
401
-
402
- def _get_offsets_from_beta(beta, smplx_params, return_template_mesh=True):
403
- v_template = torch.FloatTensor(smplx_params["v_template"]).unsqueeze(0)
404
- shape_dirs = torch.FloatTensor(smplx_params["shapedirs"])
405
- J_regressor = torch.FloatTensor(smplx_params["J_regressor"])
406
-
407
- v_shaped = v_template + _blend_shapes(beta, shape_dirs)
408
- J = _vertices2joints(J_regressor, v_shaped).squeeze(0).numpy()
409
-
410
- parents = smplx_params["kintree_table"][()][0]
411
- parents[0] = -1
412
- Translates = J[()].copy()
413
- Translates[1:] -= J[parents[1:]]
414
- if not return_template_mesh:
415
- return Translates
416
- else:
417
- return Translates, v_shaped
418
-
419
-
420
- def _preprocess_smplx(smplx_params, source_anim_data, scale=1, debug=False):
421
- Translates, v_shaped = _get_offsets_from_beta(
422
- torch.FloatTensor(source_anim_data["betas"]),
423
- smplx_params,
424
- return_template_mesh=True,
425
- )
426
-
427
- parents = smplx_params["kintree_table"][()][0]
428
- parents[0] = -1
429
-
430
- poses = torch.FloatTensor(source_anim_data["poses"])
431
- source_LclRotation = angle_axis_to_rotation_matrix(poses).numpy()
432
- source_LclTranslation = np.tile(Translates, (source_LclRotation.shape[0], 1, 1))
433
- source_LclTranslation[:, 0] += source_anim_data["trans"]
434
-
435
- source_skeleton = {
436
- "parent": parents,
437
- "LclRotation": source_LclRotation,
438
- "LclTranslation": source_LclTranslation * scale,
439
- "Translate": Translates * scale,
440
- "v_shaped": v_shaped.squeeze(0).numpy() * scale,
441
- }
442
- return source_skeleton
443
-
444
-
445
- def _convert_npz_to_fbx(smplh_params, npz_data, save_fn, fps=30, uv_coords=None, uv_faces=None):
446
- kintree = smplh_params["kintree_table"][0]
447
- kintree[0] = -1
448
-
449
- source_anim_data = {
450
- "betas": npz_data["betas"],
451
- "poses": npz_data["poses"].reshape(npz_data["poses"].shape[0], -1, 3),
452
- "trans": npz_data["trans"],
453
- }
454
- source_skeleton = _preprocess_smplx(smplh_params, source_anim_data, scale=100)
455
- rot = rot_mat2trans_mat(source_skeleton["LclRotation"])
456
- trans = trans2trans_mat(source_skeleton["LclTranslation"])
457
- frame_data = np.einsum("Btnk,Btkm ->Btnm", trans, rot)
458
-
459
- fbxManager = fbx.FbxManager.Create()
460
- fbxScene = fbx.FbxScene.Create(fbxManager, "")
461
- timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps)
462
- fbxScene.GetGlobalSettings().SetTimeMode(timeMode)
463
-
464
- geometryNode = _addSmplXMesh(
465
- fbxScene,
466
- source_skeleton["v_shaped"],
467
- smplh_params["f"],
468
- uv_coords=uv_coords,
469
- uv_faces=uv_faces,
470
- )
471
- referenceNode, skeletonNodes = _addSmplXSkeleton(
472
- fbxManager,
473
- fbxScene=fbxScene,
474
- trans=source_skeleton["Translate"],
475
- joint2num=SMPLH_JOINT2NUM,
476
- kintree_table=kintree,
477
- )
478
-
479
- _addSkiningWeight(fbxScene, smplh_params["weights"], geometryNode, skeletonNodes)
480
- _storeBindPose(fbxScene, geometryNode)
481
- _animateSkeleton(
482
- fbxScene=fbxScene,
483
- skeletonNodes=skeletonNodes,
484
- frames=frame_data,
485
- frameRate=fps,
486
- )
487
-
488
- with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f:
489
- temp_file = tmp_f.name
490
-
491
- try:
492
- # Save to temporary location
493
- _saveScene(temp_file, fbxManager, fbxScene)
494
- # If successful, copy to final destination
495
- shutil.copy2(temp_file, save_fn)
496
- except Exception as e:
497
- print(f"Error saving FBX file: {e}")
498
- finally:
499
- # Remove temporary file
500
- if os.path.exists(temp_file):
501
- os.remove(temp_file)
502
-
503
- # CLEANUP
504
- fbxManager.Destroy()
505
- del fbxManager, fbxScene
506
-
507
-
508
- def _read_uv(obj_template):
509
- uv_coords = None
510
- uv_faces = None
511
- if obj_template and os.path.isfile(obj_template):
512
- try:
513
- print("Loading UV coordinates from OBJ template: {}".format(obj_template))
514
- obj_vertices, uv_coords, obj_faces, uv_faces = _parse_obj_file(obj_template)
515
- print("Loaded {} UV coordinates and {} UV faces".format(len(uv_coords), len(uv_faces)))
516
- except Exception as e:
517
- print("Warning: Failed to load UV coordinates from OBJ file: {}".format(e))
518
- uv_coords = None
519
- uv_faces = None
520
- return uv_coords, uv_faces
521
-
522
-
523
- class SMPLH2FBX:
524
- def __init__(
525
- self,
526
- obj_template="./assets/smpl_family_models/smplh/textures/male_smplh.obj",
527
- smplh_model_path="./assets/body_models/smplh/neutral/model.npz",
528
- ):
529
- print(f"[{self.__class__.__name__}] Load obj_template: {obj_template}")
530
- self.uv_coords, self.uv_faces = _read_uv(obj_template)
531
- print(f"[{self.__class__.__name__}] Load smplh_model_path: {smplh_model_path}")
532
- self.smplh_params = dict(np.load(smplh_model_path, allow_pickle=True))
533
-
534
- def convert_npz_to_fbx(self, npz_file, outname, fps=30):
535
- os.makedirs(os.path.dirname(outname), exist_ok=True)
536
- if isinstance(npz_file, str) and os.path.isfile(npz_file):
537
- npz_data = dict(np.load(npz_file, allow_pickle=True))
538
- else:
539
- npz_data = npz_file
540
- _convert_npz_to_fbx(
541
- self.smplh_params,
542
- npz_data,
543
- outname,
544
- uv_coords=self.uv_coords,
545
- uv_faces=self.uv_faces,
546
- )
547
- return os.path.exists(outname)
548
-
549
- def convert_params_to_fbx(self, params, outname):
550
- fps = params.get("mocap_framerate", 30)
551
- os.makedirs(os.path.dirname(outname), exist_ok=True)
552
- assert len(params["poses"].shape) == 3, f"poses shape should be (F, 52, 3), but got {params['poses'].shape}"
553
- assert len(params["betas"].shape) == 2, f"betas shape should be (1, 16), but got {params['betas'].shape}"
554
- assert len(params["trans"].shape) == 2, f"trans shape should be (1, 3), but got {params['trans'].shape}"
555
- _convert_npz_to_fbx(
556
- self.smplh_params,
557
- params,
558
- outname,
559
- fps=fps,
560
- uv_coords=self.uv_coords,
561
- uv_faces=self.uv_faces,
562
- )
563
- return os.path.exists(outname)
564
-
565
-
566
- if __name__ == "__main__":
567
- # python hymotion/utils/smplh2fbx.py
568
- import argparse
569
-
570
- parser = argparse.ArgumentParser()
571
- parser.add_argument("root", type=str)
572
- args = parser.parse_args()
573
-
574
- converter = SMPLH2FBX()
575
-
576
- if os.path.isdir(args.root):
577
- npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz")))
578
- else:
579
- if args.root.endswith(".npz"):
580
- npzfiles = [args.root]
581
- else:
582
- raise ValueError(f"Unknown file type: {args.root}")
583
-
584
- for npzfile in npzfiles:
585
- converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hymotion/utils/t2m_runtime.py CHANGED
@@ -46,14 +46,18 @@ class T2MRuntime:
46
  ckpt_name: str = "latest.ckpt",
47
  skip_text: bool = False,
48
  device_ids: Union[list[int], None] = None,
49
- prompt_engineering_host: Optional[str] = None,
50
  skip_model_loading: bool = False,
51
  force_cpu: bool = False,
 
 
 
52
  ):
53
  self.config_path = config_path
54
  self.ckpt_name = ckpt_name
55
  self.skip_text = skip_text
56
  self.prompt_engineering_host = prompt_engineering_host
 
 
57
  self.skip_model_loading = skip_model_loading
58
  self.local_ip = _get_local_ip()
59
 
@@ -71,7 +75,12 @@ class T2MRuntime:
71
  self._lock = threading.Lock()
72
  self._loaded = False
73
 
74
- self.prompt_rewriter = PromptRewriter(host=self.prompt_engineering_host)
 
 
 
 
 
75
  # Skip model loading if checkpoint not found
76
  if self.skip_model_loading:
77
  print(">>> [WARNING] Checkpoint not found, will use randomly initialized model weights")
@@ -92,7 +101,9 @@ class T2MRuntime:
92
 
93
  device_info = self.device_ids if self.device_ids else "cpu"
94
  if self.skip_model_loading:
95
- print(f">>> T2MRuntime initialized (using randomly initialized weights) in IP {self.local_ip}, devices={device_info}")
 
 
96
  else:
97
  print(f">>> T2MRuntime loaded in IP {self.local_ip}, devices={device_info}")
98
 
@@ -116,7 +127,10 @@ class T2MRuntime:
116
  )
117
  device = torch.device("cpu")
118
  pipeline.load_in_demo(
119
- self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text, allow_empty_ckpt=allow_empty_ckpt
 
 
 
120
  )
121
  pipeline.to(device)
122
  self.pipelines = [pipeline]
@@ -129,7 +143,12 @@ class T2MRuntime:
129
  network_module=config["network_module"],
130
  network_module_args=config["network_module_args"],
131
  )
132
- p.load_in_demo(self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text, allow_empty_ckpt=allow_empty_ckpt)
 
 
 
 
 
133
  p.to(torch.device(f"cuda:{gid}"))
134
  self.pipelines.append(p)
135
  self._gpu_load = [0] * len(self.pipelines)
@@ -360,6 +379,7 @@ class T2MRuntime:
360
  except Exception as e:
361
  print(f">>> Failed to generate static HTML content: {e}")
362
  import traceback
 
363
  traceback.print_exc()
364
  # Return error HTML
365
  return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
 
46
  ckpt_name: str = "latest.ckpt",
47
  skip_text: bool = False,
48
  device_ids: Union[list[int], None] = None,
 
49
  skip_model_loading: bool = False,
50
  force_cpu: bool = False,
51
+ disable_prompt_engineering: bool = False,
52
+ prompt_engineering_host: Optional[str] = None,
53
+ prompt_engineering_model_path: Optional[str] = None,
54
  ):
55
  self.config_path = config_path
56
  self.ckpt_name = ckpt_name
57
  self.skip_text = skip_text
58
  self.prompt_engineering_host = prompt_engineering_host
59
+ self.prompt_engineering_model_path = prompt_engineering_model_path
60
+ self.disable_prompt_engineering = disable_prompt_engineering
61
  self.skip_model_loading = skip_model_loading
62
  self.local_ip = _get_local_ip()
63
 
 
75
  self._lock = threading.Lock()
76
  self._loaded = False
77
 
78
+ if self.disable_prompt_engineering:
79
+ self.prompt_rewriter = None
80
+ else:
81
+ self.prompt_rewriter = PromptRewriter(
82
+ host=self.prompt_engineering_host, model_path=self.prompt_engineering_model_path
83
+ )
84
  # Skip model loading if checkpoint not found
85
  if self.skip_model_loading:
86
  print(">>> [WARNING] Checkpoint not found, will use randomly initialized model weights")
 
101
 
102
  device_info = self.device_ids if self.device_ids else "cpu"
103
  if self.skip_model_loading:
104
+ print(
105
+ f">>> T2MRuntime initialized (using randomly initialized weights) in IP {self.local_ip}, devices={device_info}"
106
+ )
107
  else:
108
  print(f">>> T2MRuntime loaded in IP {self.local_ip}, devices={device_info}")
109
 
 
127
  )
128
  device = torch.device("cpu")
129
  pipeline.load_in_demo(
130
+ self.ckpt_name,
131
+ os.path.dirname(self.ckpt_name),
132
+ build_text_encoder=not self.skip_text,
133
+ allow_empty_ckpt=allow_empty_ckpt,
134
  )
135
  pipeline.to(device)
136
  self.pipelines = [pipeline]
 
143
  network_module=config["network_module"],
144
  network_module_args=config["network_module_args"],
145
  )
146
+ p.load_in_demo(
147
+ self.ckpt_name,
148
+ os.path.dirname(self.ckpt_name),
149
+ build_text_encoder=not self.skip_text,
150
+ allow_empty_ckpt=allow_empty_ckpt,
151
+ )
152
  p.to(torch.device(f"cuda:{gid}"))
153
  self.pipelines.append(p)
154
  self._gpu_load = [0] * len(self.pipelines)
 
379
  except Exception as e:
380
  print(f">>> Failed to generate static HTML content: {e}")
381
  import traceback
382
+
383
  traceback.print_exc()
384
  # Return error HTML
385
  return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
scripts/gradio/templates/placeholder_scene.html ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Motion Visualization</title>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
8
+ <style>
9
+ html, body {
10
+ background: #424242 !important;
11
+ color: #e2e8f0;
12
+ margin: 0;
13
+ padding: 0;
14
+ width: 100%;
15
+ height: 100%;
16
+ overflow: hidden;
17
+ }
18
+ * {
19
+ margin: 0;
20
+ padding: 0;
21
+ box-sizing: border-box;
22
+ }
23
+ .fullscreen-container {
24
+ position: fixed;
25
+ top: 0;
26
+ left: 0;
27
+ width: 100vw;
28
+ height: 100vh;
29
+ background: #424242;
30
+ overflow: hidden;
31
+ }
32
+ #vis3d {
33
+ position: absolute;
34
+ top: 0;
35
+ left: 0;
36
+ width: 100%;
37
+ height: 100%;
38
+ background: #424242;
39
+ }
40
+ #vis3d canvas {
41
+ display: block;
42
+ width: 100% !important;
43
+ height: 100% !important;
44
+ }
45
+ .welcome-overlay {
46
+ position: absolute;
47
+ top: 50%;
48
+ left: 50%;
49
+ transform: translate(-50%, -50%);
50
+ background: rgba(0, 0, 0, 0.6);
51
+ backdrop-filter: blur(10px);
52
+ -webkit-backdrop-filter: blur(10px);
53
+ color: white;
54
+ padding: 30px 50px;
55
+ border-radius: 16px;
56
+ font-size: 16px;
57
+ z-index: 200;
58
+ text-align: center;
59
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3);
60
+ }
61
+ .welcome-overlay h2 {
62
+ font-size: 20px;
63
+ font-weight: 600;
64
+ margin-bottom: 12px;
65
+ color: #4a9eff;
66
+ }
67
+ .welcome-overlay p {
68
+ color: #a0aec0;
69
+ font-size: 14px;
70
+ line-height: 1.6;
71
+ }
72
+ .control-overlay {
73
+ position: absolute;
74
+ bottom: 30px;
75
+ left: 50%;
76
+ transform: translateX(-50%);
77
+ width: 80%;
78
+ max-width: 600px;
79
+ z-index: 100;
80
+ background: rgba(0, 0, 0, 0.4);
81
+ backdrop-filter: blur(8px);
82
+ -webkit-backdrop-filter: blur(8px);
83
+ padding: 15px 20px;
84
+ border-radius: 12px;
85
+ }
86
+ .control-row-minimal {
87
+ display: flex;
88
+ align-items: center;
89
+ gap: 20px;
90
+ }
91
+ .progress-container {
92
+ flex: 1;
93
+ }
94
+ .progress-slider-minimal {
95
+ width: 100%;
96
+ height: 8px;
97
+ border-radius: 4px;
98
+ background: rgba(255, 255, 255, 0.3);
99
+ outline: none;
100
+ cursor: not-allowed;
101
+ -webkit-appearance: none;
102
+ appearance: none;
103
+ opacity: 0.5;
104
+ }
105
+ .progress-slider-minimal::-webkit-slider-runnable-track {
106
+ width: 100%;
107
+ height: 8px;
108
+ border-radius: 4px;
109
+ background: rgba(255, 255, 255, 0.3);
110
+ }
111
+ .progress-slider-minimal::-webkit-slider-thumb {
112
+ -webkit-appearance: none;
113
+ appearance: none;
114
+ width: 20px;
115
+ height: 20px;
116
+ border-radius: 50%;
117
+ background: #4a9eff;
118
+ cursor: not-allowed;
119
+ border: 2px solid white;
120
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);
121
+ margin-top: -6px;
122
+ }
123
+ .progress-slider-minimal::-moz-range-track {
124
+ width: 100%;
125
+ height: 8px;
126
+ border-radius: 4px;
127
+ background: rgba(255, 255, 255, 0.3);
128
+ }
129
+ .progress-slider-minimal::-moz-range-thumb {
130
+ width: 20px;
131
+ height: 20px;
132
+ border-radius: 50%;
133
+ background: #4a9eff;
134
+ cursor: not-allowed;
135
+ border: 2px solid white;
136
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);
137
+ }
138
+ .frame-counter {
139
+ font-family: 'SF Mono', 'Consolas', monospace;
140
+ font-size: 14px;
141
+ font-weight: 500;
142
+ color: rgba(255, 255, 255, 0.5);
143
+ text-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
144
+ white-space: nowrap;
145
+ min-width: 80px;
146
+ text-align: right;
147
+ }
148
+ </style>
149
+ </head>
150
+ <body>
151
+ <div class="fullscreen-container">
152
+ <div id="vis3d"></div>
153
+ <div class="welcome-overlay">
154
+ <h2>Welcome to HY-Motion-1.0!</h2>
155
+ <p>Enter a text description and generate motion<br>to see the 3D visualization here.</p>
156
+ </div>
157
+ <div class="control-overlay">
158
+ <div class="control-row-minimal">
159
+ <div class="progress-container">
160
+ <input type="range" class="progress-slider-minimal" min="0" max="100" value="0" disabled>
161
+ </div>
162
+ <div class="frame-counter">
163
+ <span>0</span> / <span>0</span>
164
+ </div>
165
+ </div>
166
+ </div>
167
+ </div>
168
+
169
+ <script type="importmap">
170
+ {
171
+ "imports": {
172
+ "three": "https://cdn.jsdelivr.net/npm/[email protected]/build/three.module.js",
173
+ "three/addons/": "https://cdn.jsdelivr.net/npm/[email protected]/examples/jsm/"
174
+ }
175
+ }
176
+ </script>
177
+
178
+ <script type="module">
179
+ import * as THREE from 'three';
180
+ import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
181
+
182
+ function createBaseChessboard(
183
+ grid_size = 50,
184
+ divisions = 50,
185
+ white = "#ffffff",
186
+ black = "#3a3a3a",
187
+ texture_size = 1024
188
+ ) {
189
+ var adjusted_texture_size = Math.floor(texture_size / divisions) * divisions;
190
+ var canvas = document.createElement("canvas");
191
+ canvas.width = canvas.height = adjusted_texture_size;
192
+ var context = canvas.getContext("2d");
193
+ context.imageSmoothingEnabled = false;
194
+
195
+ var step = adjusted_texture_size / divisions;
196
+ for (var i = 0; i < divisions; i++) {
197
+ for (var j = 0; j < divisions; j++) {
198
+ context.fillStyle = (i + j) % 2 === 0 ? white : black;
199
+ context.fillRect(i * step, j * step, step, step);
200
+ }
201
+ }
202
+
203
+ var texture = new THREE.CanvasTexture(canvas);
204
+ texture.wrapS = THREE.RepeatWrapping;
205
+ texture.wrapT = THREE.RepeatWrapping;
206
+ texture.magFilter = THREE.NearestFilter;
207
+ texture.minFilter = THREE.NearestFilter;
208
+ texture.generateMipmaps = false;
209
+
210
+ var planeGeometry = new THREE.PlaneGeometry(grid_size, grid_size);
211
+
212
+ var planeMaterial = new THREE.MeshStandardMaterial({
213
+ map: texture,
214
+ side: THREE.DoubleSide,
215
+ transparent: true,
216
+ opacity: 0.85,
217
+ roughness: 0.9,
218
+ metalness: 0.1,
219
+ emissiveIntensity: 0.05,
220
+ });
221
+
222
+ var plane = new THREE.Mesh(planeGeometry, planeMaterial);
223
+ plane.receiveShadow = true;
224
+
225
+ return plane;
226
+ }
227
+
228
+ function getChessboardXZ() {
229
+ var plane = createBaseChessboard();
230
+ plane.rotation.x = -Math.PI / 2;
231
+ plane.name = 'ground';
232
+ plane.receiveShadow = true;
233
+ return plane;
234
+ }
235
+
236
+ let scene, camera, renderer, controls;
237
+
238
+ function init() {
239
+ const width = window.innerWidth;
240
+ const height = window.innerHeight;
241
+
242
+ scene = new THREE.Scene();
243
+ camera = new THREE.PerspectiveCamera(45, width / height, 0.1, 50);
244
+ renderer = new THREE.WebGLRenderer({ antialias: true, logarithmicDepthBuffer: true });
245
+
246
+ // Camera setup
247
+ camera.up.set(0, 1, 0);
248
+ camera.position.set(3, 2.5, 5);
249
+ camera.lookAt(new THREE.Vector3(0, 1, 0));
250
+
251
+ // Scene background and fog
252
+ scene.background = new THREE.Color(0x424242);
253
+ scene.fog = new THREE.FogExp2(0x424242, 0.06);
254
+
255
+ // Renderer setup
256
+ renderer.shadowMap.enabled = true;
257
+ renderer.shadowMap.type = THREE.PCFSoftShadowMap;
258
+ renderer.toneMapping = THREE.ACESFilmicToneMapping;
259
+ renderer.toneMappingExposure = 1.0;
260
+ renderer.outputColorSpace = THREE.SRGBColorSpace;
261
+ renderer.setPixelRatio(window.devicePixelRatio);
262
+ renderer.setSize(width, height);
263
+
264
+ // Lights
265
+ const hemisphereLight = new THREE.HemisphereLight(0xffffff, 0x444444, 1.2);
266
+ hemisphereLight.position.set(0, 2, 0);
267
+ scene.add(hemisphereLight);
268
+
269
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5);
270
+ directionalLight.position.set(3, 5, 4);
271
+ directionalLight.castShadow = true;
272
+ directionalLight.shadow.mapSize.width = 2048;
273
+ directionalLight.shadow.mapSize.height = 2048;
274
+ directionalLight.shadow.camera.near = 0.5;
275
+ directionalLight.shadow.camera.far = 50;
276
+ directionalLight.shadow.camera.left = -10;
277
+ directionalLight.shadow.camera.right = 10;
278
+ directionalLight.shadow.camera.top = 10;
279
+ directionalLight.shadow.camera.bottom = -10;
280
+ directionalLight.shadow.bias = -0.0001;
281
+ scene.add(directionalLight);
282
+
283
+ const fillLight = new THREE.DirectionalLight(0xaaccff, 0.5);
284
+ fillLight.position.set(-3, 3, -2);
285
+ scene.add(fillLight);
286
+
287
+ const rimLight = new THREE.DirectionalLight(0xffeedd, 0.4);
288
+ rimLight.position.set(0, 4, -5);
289
+ scene.add(rimLight);
290
+
291
+ // Ground
292
+ scene.add(getChessboardXZ());
293
+
294
+ // Add to DOM
295
+ var container = document.getElementById('vis3d');
296
+ container.appendChild(renderer.domElement);
297
+
298
+ // Controls
299
+ controls = new OrbitControls(camera, renderer.domElement);
300
+ controls.minDistance = 1;
301
+ controls.maxDistance = 15;
302
+ controls.enableDamping = true;
303
+ controls.dampingFactor = 0.05;
304
+ controls.target.set(0, 0.5, 0);
305
+ controls.update();
306
+
307
+ window.addEventListener('resize', onWindowResize);
308
+ animate();
309
+ }
310
+
311
+ function animate() {
312
+ requestAnimationFrame(animate);
313
+ if (controls && controls.enableDamping) {
314
+ controls.update();
315
+ }
316
+ renderer.render(scene, camera);
317
+ }
318
+
319
+ function onWindowResize() {
320
+ const width = window.innerWidth;
321
+ const height = window.innerHeight;
322
+ camera.aspect = width / height;
323
+ camera.updateProjectionMatrix();
324
+ renderer.setSize(width, height);
325
+ }
326
+
327
+ init();
328
+ </script>
329
+ </body>
330
+ </html>
331
+