|
|
import os |
|
|
import threading |
|
|
import time |
|
|
import uuid |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
from ..prompt_engineering.prompt_rewrite import PromptRewriter |
|
|
from .loaders import load_object |
|
|
from .visualize_mesh_web import save_visualization_data, generate_static_html_content |
|
|
|
|
|
try: |
|
|
import fbx |
|
|
|
|
|
FBX_AVAILABLE = True |
|
|
print(">>> FBX module found.") |
|
|
except ImportError: |
|
|
FBX_AVAILABLE = False |
|
|
print(">>> FBX module not found.") |
|
|
|
|
|
|
|
|
def _now(): |
|
|
t = time.time() |
|
|
ms = int((t - int(t)) * 1000) |
|
|
return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}" |
|
|
|
|
|
|
|
|
_MODEL_CACHE = None |
|
|
|
|
|
|
|
|
class SimpleRuntime(torch.nn.Module): |
|
|
def __init__(self, config_path, ckpt_name, load_prompt_engineering=False, load_text_encoder=False): |
|
|
super().__init__() |
|
|
self.load_prompt_engineering = load_prompt_engineering |
|
|
self.load_text_encoder = load_text_encoder |
|
|
|
|
|
if self.load_prompt_engineering: |
|
|
print(f"[{self.__class__.__name__}] Loading prompt engineering...") |
|
|
self.prompt_rewriter = PromptRewriter(host=None, model_path=None, device="cpu") |
|
|
else: |
|
|
self.prompt_rewriter = None |
|
|
|
|
|
if self.load_text_encoder: |
|
|
print(f"[{self.__class__.__name__}] Loading text encoder...") |
|
|
_text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel" |
|
|
_text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": 128} |
|
|
text_encoder = load_object(_text_encoder_module, _text_encoder_cfg) |
|
|
else: |
|
|
text_encoder = None |
|
|
|
|
|
print(f"[{self.__class__.__name__}] Loading model...") |
|
|
with open(config_path, "r") as f: |
|
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
|
pipeline = load_object( |
|
|
config["train_pipeline"], |
|
|
config["train_pipeline_args"], |
|
|
network_module=config["network_module"], |
|
|
network_module_args=config["network_module_args"], |
|
|
) |
|
|
print(f"[{self.__class__.__name__}] Loading ckpt: {ckpt_name}") |
|
|
pipeline.load_in_demo( |
|
|
os.path.join(os.path.dirname(config_path), ckpt_name), |
|
|
build_text_encoder=False, |
|
|
allow_empty_ckpt=False, |
|
|
) |
|
|
pipeline.text_encoder = text_encoder |
|
|
self.pipeline = pipeline |
|
|
|
|
|
self.fbx_available = FBX_AVAILABLE |
|
|
if self.fbx_available: |
|
|
try: |
|
|
from .smplh2woodfbx import SMPLH2WoodFBX |
|
|
|
|
|
self.fbx_converter = SMPLH2WoodFBX() |
|
|
except Exception as e: |
|
|
print(f">>> Failed to initialize FBX converter: {e}") |
|
|
self.fbx_available = False |
|
|
self.fbx_converter = None |
|
|
else: |
|
|
self.fbx_converter = None |
|
|
print(">>> FBX module not found. FBX export will be disabled.") |
|
|
|
|
|
def _generate_html_content( |
|
|
self, |
|
|
timestamp: str, |
|
|
file_path: str, |
|
|
output_dir: Optional[str] = None, |
|
|
) -> str: |
|
|
""" |
|
|
Generate static HTML content with embedded data for iframe srcdoc. |
|
|
All JavaScript code is embedded directly in the HTML, no external static resources needed. |
|
|
|
|
|
Args: |
|
|
timestamp: Timestamp string for logging |
|
|
file_path: Base filename (without extension) |
|
|
output_dir: Directory where NPZ/meta files are stored |
|
|
|
|
|
Returns: |
|
|
HTML content string (to be used in iframe srcdoc) |
|
|
""" |
|
|
print(f">>> Generating static HTML content, timestamp: {timestamp}") |
|
|
gradio_dir = output_dir if output_dir is not None else "output/gradio" |
|
|
|
|
|
try: |
|
|
|
|
|
html_content = generate_static_html_content( |
|
|
folder_name=gradio_dir, |
|
|
file_name=file_path, |
|
|
hide_captions=False, |
|
|
) |
|
|
|
|
|
print(f">>> Static HTML content generated for: {file_path}") |
|
|
return html_content |
|
|
|
|
|
except Exception as e: |
|
|
print(f">>> Failed to generate static HTML content: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>" |
|
|
|
|
|
def _generate_fbx_files( |
|
|
self, |
|
|
visualization_data: dict, |
|
|
output_dir: Optional[str] = None, |
|
|
fbx_filename: Optional[str] = None, |
|
|
) -> List[str]: |
|
|
assert "smpl_data" in visualization_data, "smpl_data not found in visualization_data" |
|
|
fbx_files = [] |
|
|
if output_dir is None: |
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
|
|
output_dir = os.path.join(root_dir, "output", "gradio") |
|
|
|
|
|
smpl_data_list = visualization_data["smpl_data"] |
|
|
|
|
|
unique_id = str(uuid.uuid4())[:8] |
|
|
text = visualization_data["text"] |
|
|
timestamp = visualization_data["timestamp"] |
|
|
for bb in range(len(smpl_data_list)): |
|
|
smpl_data = smpl_data_list[bb] |
|
|
if fbx_filename is None: |
|
|
fbx_filename_bb = f"{timestamp}_{unique_id}_{bb:03d}.fbx" |
|
|
else: |
|
|
fbx_filename_bb = f"{fbx_filename}_{bb:03d}.fbx" |
|
|
fbx_path = os.path.join(output_dir, fbx_filename_bb) |
|
|
success = self.fbx_converter.convert_npz_to_fbx(smpl_data, fbx_path) |
|
|
if success: |
|
|
fbx_files.append(fbx_path) |
|
|
print(f"\t>>> FBX file generated: {fbx_path}") |
|
|
txt_path = fbx_path.replace(".fbx", ".txt") |
|
|
with open(txt_path, "w", encoding="utf-8") as f: |
|
|
f.write(text) |
|
|
fbx_files.append(txt_path) |
|
|
|
|
|
return fbx_files |
|
|
|
|
|
def generate_motion( |
|
|
self, |
|
|
text: str, |
|
|
seeds_csv: str, |
|
|
motion_duration: float, |
|
|
cfg_scale: float, |
|
|
output_format: str = "fbx", |
|
|
output_dir: Optional[str] = None, |
|
|
output_filename: Optional[str] = None, |
|
|
original_text: Optional[str] = None, |
|
|
use_special_game_feat: bool = False, |
|
|
) -> Tuple[Union[str, list[str]], dict]: |
|
|
seeds = [int(s.strip()) for s in seeds_csv.split(",") if s.strip() != ""] |
|
|
|
|
|
print(f"[{self.__class__.__name__}] Generating motion...") |
|
|
print(f"[{self.__class__.__name__}] text: {text}") |
|
|
if self.load_prompt_engineering: |
|
|
duration, rewritten_text = self.prompt_rewriter.rewrite_prompt_and_infer_time(f"{text}") |
|
|
else: |
|
|
rewritten_text = text |
|
|
duration = motion_duration |
|
|
|
|
|
pipeline = self.pipeline |
|
|
pipeline.eval() |
|
|
|
|
|
|
|
|
if not self.load_text_encoder: |
|
|
print(">>> [Debug Mode] Using blank text features (skip_text=True)") |
|
|
device = next(pipeline.parameters()).device |
|
|
batch_size = len(seeds) if seeds else 1 |
|
|
|
|
|
hidden_state_dict = { |
|
|
"text_vec_raw": pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device), |
|
|
"text_ctxt_raw": pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device), |
|
|
"text_ctxt_raw_length": torch.tensor([1] * batch_size, device=device), |
|
|
} |
|
|
|
|
|
model_output = pipeline.generate( |
|
|
rewritten_text, |
|
|
seeds, |
|
|
duration, |
|
|
cfg_scale=1.0, |
|
|
use_special_game_feat=False, |
|
|
hidden_state_dict=hidden_state_dict, |
|
|
) |
|
|
else: |
|
|
model_output = pipeline.generate( |
|
|
rewritten_text, seeds, duration, cfg_scale=cfg_scale, use_special_game_feat=use_special_game_feat |
|
|
) |
|
|
|
|
|
ts = _now() |
|
|
save_data, base_filename = save_visualization_data( |
|
|
output=model_output, |
|
|
text=text if original_text is None else original_text, |
|
|
rewritten_text=rewritten_text, |
|
|
timestamp=ts, |
|
|
output_dir=output_dir, |
|
|
output_filename=output_filename, |
|
|
) |
|
|
|
|
|
html_content = self._generate_html_content( |
|
|
timestamp=ts, |
|
|
file_path=base_filename, |
|
|
output_dir=output_dir, |
|
|
) |
|
|
|
|
|
if output_format == "fbx" and not self.fbx_available: |
|
|
print(">>> Warning: FBX export requested but FBX SDK is not available. Falling back to dict format.") |
|
|
output_format = "dict" |
|
|
|
|
|
if output_format == "fbx" and self.fbx_available: |
|
|
fbx_files = self._generate_fbx_files( |
|
|
visualization_data=save_data, |
|
|
output_dir=output_dir, |
|
|
fbx_filename=output_filename, |
|
|
) |
|
|
return html_content, fbx_files, model_output |
|
|
elif output_format == "dict": |
|
|
|
|
|
return html_content, [], model_output |
|
|
else: |
|
|
raise ValueError(f">>> Invalid output format: {output_format}") |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
""" |
|
|
Handles model inference and data processing for Depth Anything 3. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path, use_prompt_engineering, use_text_encoder): |
|
|
"""Initialize the model inference handler. |
|
|
|
|
|
Note: Do not store model in instance variable to avoid |
|
|
cross-process state issues with @spaces.GPU decorator. |
|
|
""" |
|
|
|
|
|
self.model_path = model_path |
|
|
self.use_prompt_engineering = use_prompt_engineering |
|
|
self.use_text_encoder = use_text_encoder |
|
|
self.fbx_available = FBX_AVAILABLE |
|
|
|
|
|
def initialize_model(self, device: str = "cuda"): |
|
|
""" |
|
|
Initialize the DepthAnything3 model using global cache. |
|
|
|
|
|
Optimization: Load model to CPU first, then move to GPU when needed. |
|
|
This is faster than reloading from disk each time. |
|
|
|
|
|
This uses a global variable which is safe because @spaces.GPU |
|
|
runs in isolated subprocess, each with its own global namespace. |
|
|
Args: |
|
|
device: Device to run inference on (will move model to this device) |
|
|
|
|
|
Returns: |
|
|
Model instance ready for inference on specified device |
|
|
""" |
|
|
global _MODEL_CACHE |
|
|
|
|
|
if _MODEL_CACHE is None: |
|
|
|
|
|
|
|
|
_MODEL_CACHE = SimpleRuntime( |
|
|
config_path=os.path.join(self.model_path, "config.yml"), |
|
|
ckpt_name="latest.ckpt", |
|
|
load_prompt_engineering=self.use_prompt_engineering, |
|
|
load_text_encoder=self.use_text_encoder, |
|
|
) |
|
|
|
|
|
_MODEL_CACHE = _MODEL_CACHE.to("cpu") |
|
|
_MODEL_CACHE.eval() |
|
|
print("โ
Model loaded to CPU memory (cached in subprocess)") |
|
|
|
|
|
|
|
|
if device != "cpu" and next(_MODEL_CACHE.parameters()).device.type != device: |
|
|
print(f"๐ Moving model from {next(_MODEL_CACHE.parameters()).device} to {device}...") |
|
|
_MODEL_CACHE = _MODEL_CACHE.to(device) |
|
|
print(f"โ
Model ready on {device}") |
|
|
elif device == "cpu": |
|
|
|
|
|
pass |
|
|
|
|
|
return _MODEL_CACHE |
|
|
|
|
|
def run_inference(self, *args, device: str = None, **kwargs): |
|
|
""" |
|
|
Run model inference for motion generation. |
|
|
Args: |
|
|
device: Device to run inference on. If None, auto-detect. |
|
|
Use "cpu" to force CPU inference (e.g., when not in @spaces.GPU context). |
|
|
Returns: |
|
|
Tuple of (html_content, fbx_files) |
|
|
""" |
|
|
print(f"[{self.__class__.__name__}] Running inference...") |
|
|
|
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
device = torch.device(device) |
|
|
print(f"[{self.__class__.__name__}] Using device: {device}") |
|
|
|
|
|
|
|
|
model = self.initialize_model(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
print(f"[{self.__class__.__name__}] Running inference with torch.no_grad") |
|
|
html_content, fbx_files, model_output = model.generate_motion(*args, **kwargs) |
|
|
|
|
|
|
|
|
for k, val in model_output.items(): |
|
|
if isinstance(val, torch.Tensor): |
|
|
model_output[k] = val.detach().cpu() |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return html_content, fbx_files |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
runtime = SimpleRuntime( |
|
|
config_path="assets/config_simplified.yml", |
|
|
ckpt_name="latest.ckpt", |
|
|
load_prompt_engineering=False, |
|
|
load_text_encoder=False, |
|
|
) |
|
|
print(runtime.pipeline) |
|
|
|