shaocong commited on
Commit
d041699
·
1 Parent(s): 8adf7c4

save gpu useness

Browse files
Files changed (3) hide show
  1. app.py +61 -219
  2. app_old.py +756 -0
  3. dkt/pipelines/{wan_video_new.py → pipeline.py} +320 -6
app.py CHANGED
@@ -1,10 +1,7 @@
1
 
2
  import os
3
-
4
-
5
-
6
  import gradio as gr
7
- # gr.set_config(schema_inference=False)
8
 
9
  import numpy as np
10
  import torch
@@ -12,46 +9,8 @@ from PIL import Image
12
  from loguru import logger
13
  from tqdm import tqdm
14
  from tools.common_utils import save_video
15
- from dkt.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
16
-
17
- # try:
18
- # import gradio_client.utils as _gc_utils
19
- # if hasattr(_gc_utils, "get_type"):
20
- # _orig_get_type = _gc_utils.get_type
21
- # def _get_type_safe(schema):
22
- # if not isinstance(schema, dict):
23
- # return "Any"
24
- # return _orig_get_type(schema)
25
- # _gc_utils.get_type = _get_type_safe
26
- # except Exception:
27
- # pass
28
-
29
- # # Additional guard: handle boolean JSON Schemas and parsing errors
30
- # try:
31
- # import gradio_client.utils as _gc_utils
32
- # # Wrap the internal _json_schema_to_python_type if present
33
- # if hasattr(_gc_utils, "_json_schema_to_python_type"):
34
- # _orig_internal = _gc_utils._json_schema_to_python_type
35
- # def _json_schema_to_python_type_safe(schema, defs=None):
36
- # if isinstance(schema, bool):
37
- # return "Any"
38
- # try:
39
- # return _orig_internal(schema, defs)
40
- # except Exception:
41
- # return "Any"
42
- # _gc_utils._json_schema_to_python_type = _json_schema_to_python_type_safe
43
-
44
- # # Also wrap the public json_schema_to_python_type to be extra defensive
45
- # if hasattr(_gc_utils, "json_schema_to_python_type"):
46
- # _orig_public = _gc_utils.json_schema_to_python_type
47
- # def json_schema_to_python_type_safe(schema):
48
- # try:
49
- # return _orig_public(schema)
50
- # except Exception:
51
- # return "Any"
52
- # _gc_utils.json_schema_to_python_type = json_schema_to_python_type_safe
53
- # except Exception:
54
- # pass
55
 
56
  import cv2
57
  import copy
@@ -59,7 +18,7 @@ import trimesh
59
 
60
  from os.path import join
61
  from tools.depth2pcd import depth2pcd
62
- from moge.model.v2 import MoGeModel
63
 
64
 
65
  from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map
@@ -70,12 +29,18 @@ import tempfile
70
  import spaces
71
 
72
 
73
- PIPE_1_3B = None
74
- MOGE_MODULE = None
75
  #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors
76
  PROMPT = 'depth'
77
  NEGATIVE_PROMPT = ''
78
 
 
 
 
 
 
 
 
79
  example_inputs = [
80
 
81
  ["examples/1.mp4", "1.3B", 5, 3],
@@ -116,15 +81,6 @@ example_inputs = [
116
  # ["examples/b68045aa2128ab63d9c7518f8d62eafe.mp4", "1.3B", 5, 3],
117
  ]
118
 
119
-
120
-
121
-
122
-
123
- height = 480
124
- width = 832
125
- window_size = 21
126
-
127
-
128
 
129
  def resize_frame(frame, height, width):
130
  frame = np.array(frame)
@@ -287,173 +243,74 @@ def get_model(model_size):
287
 
288
 
289
 
290
- @spaces.GPU(duration=120)
291
- @torch.inference_mode()
 
292
  def process_video(
293
  video_file,
294
  model_size,
295
  num_inference_steps,
296
  overlap
297
  ):
298
- print('process_video called')
299
- try:
300
- pipe = get_model(model_size)
301
- if pipe is None:
302
- return None, f"Model {model_size} not initialized. Please restart the application."
303
-
304
- tmp_video_path = video_file
305
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
306
-
307
- # 使用临时目录存储所有文件
308
- cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_')
309
-
310
-
311
- original_filename = f"input_{timestamp}.mp4"
312
- dst_path = os.path.join(cur_save_dir, original_filename)
313
- shutil.copy2(tmp_video_path, dst_path)
314
- origin_frames, input_fps = extract_frames_from_video_file(tmp_video_path)
315
-
316
- if not origin_frames:
317
- return None, "Failed to extract frames from video"
318
-
319
- logger.info(f"Extracted {len(origin_frames)} frames from video")
320
 
321
-
322
- original_width, original_height = origin_frames[0].size
323
- ROTATE = False
324
- if original_width < original_height:
325
- ROTATE = True
326
- origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames]
327
- tmp = original_width
328
- original_width = original_height
329
- original_height = tmp
330
 
331
-
332
-
333
- global height
334
- global width
335
- global window_size
336
 
337
- frames = [resize_frame(frame, height, width) for frame in origin_frames]
338
- frame_length = len(frames)
339
- if (frame_length - 1) % 4 != 0:
340
- new_len = ((frame_length - 1) // 4 + 1) * 4 + 1
341
- frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)]
 
 
342
 
343
-
344
- control_video = frames
345
- video, vae_outs = pipe(
346
- prompt=PROMPT,
347
- negative_prompt=NEGATIVE_PROMPT,
348
- control_video=control_video,
349
- height=height,
350
- width=width,
351
- num_frames=len(control_video),
352
- seed=1,
353
- tiled=False,
354
- num_inference_steps=num_inference_steps,
355
- sliding_window_size=window_size,
356
- sliding_window_stride=window_size - overlap,
357
- cfg_scale=1.0,
358
- )
359
-
360
- #* moge process
361
- torch.cuda.empty_cache()
362
- processed_video = video[:frame_length]
363
-
364
-
365
- processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video]
366
- if ROTATE:
367
- processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video]
368
- origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames]
369
-
370
 
371
- output_filename = f"output_{timestamp}.mp4"
372
- output_path = os.path.join(cur_save_dir, output_filename)
373
- color_predictions = []
374
- if PROMPT == 'depth':
375
- prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video]
376
- prediced_depth_map_np = np.stack(prediced_depth_map_np)
377
- prediced_depth_map_np = prediced_depth_map_np/ 255.0
378
- __min = prediced_depth_map_np.min()
379
- __max = prediced_depth_map_np.max()
380
- prediced_depth_map_np = (prediced_depth_map_np - __min) / (__max - __min)
381
- color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np]
382
- else:
383
- color_predictions = processed_video
384
- save_video(color_predictions, output_path, fps=input_fps, quality=5)
385
 
386
 
387
 
388
- # todo, inference MoGe only once
389
- frame_num = len(origin_frames)
390
- resize_W,resize_H = origin_frames[0].size
391
-
392
- vis_pc_num = 4
393
- indices = np.linspace(0, frame_num-1, vis_pc_num)
394
- indices = np.round(indices).astype(np.int32)
395
- pc_save_dir = os.path.join(cur_save_dir, 'pointclouds')
396
- os.makedirs(pc_save_dir, exist_ok=True)
397
-
398
- glb_files = []
399
- moge_device = MOGE_MODULE.device if MOGE_MODULE is not None else torch.device("cuda:0")
400
- for idx in tqdm(indices):
401
- orgin_rgb_frame = origin_frames[idx]
402
- predicted_depth = processed_video[idx]
403
-
404
- # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
405
- input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
406
- input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
407
-
408
- output = MOGE_MODULE.infer(input_image)
409
- #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
410
- moge_intrinsics = output['intrinsics'].cpu().numpy()
411
- moge_mask = output['mask'].cpu().numpy()
412
- moge_depth = output['depth'].cpu().numpy()
413
 
414
- predicted_depth = np.array(predicted_depth)
415
- predicted_depth = predicted_depth.mean(-1) / 255.0
416
-
417
- metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask)
418
-
419
- moge_intrinsics[0, 0] *= resize_W
420
- moge_intrinsics[1, 1] *= resize_H
421
- moge_intrinsics[0, 2] *= resize_W
422
- moge_intrinsics[1, 2] *= resize_H
423
-
424
- # pcd = depth2pcd(metric_depth, moge_intrinsics, color=cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB), input_mask=moge_mask, ret_pcd=True)
425
- pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=True)
426
-
427
- # pcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) * np.array([1, -1, -1], dtype=np.float32))
428
-
429
- apply_filter = True
430
- if apply_filter:
431
- cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=3.0)
432
- pcd = pcd.select_by_index(ind)
433
 
434
- #* save pcd: o3d.io.write_point_cloud(f'{pc_save_dir}/{timestamp}_{idx:02d}.ply', pcd)
435
- points = np.asarray(pcd.points)
436
- colors = np.asarray(pcd.colors) if pcd.has_colors() else None
 
437
 
 
438
 
439
- # ==== 新增:上下翻转点云 ====
440
- points[:, 2] = -points[:, 2]
441
- points[:, 0] = -points[:, 0]
442
- # =========================
 
 
443
 
 
 
 
 
444
 
445
- glb_filename = os.path.join(pc_save_dir, f'{timestamp}_{idx:02d}.glb')
446
- success = create_simple_glb_from_pointcloud(points, colors, glb_filename)
447
- if not success:
448
- logger.warning(f"Failed to save GLB file: {glb_filename}")
449
 
450
- glb_files.append(glb_filename)
451
 
452
- return output_path, glb_files
453
-
454
- except Exception as e:
455
- logger.error(f"Error processing video: {str(e)}")
456
- return None, f"Error: {str(e)}"
457
 
458
 
459
 
@@ -669,20 +526,16 @@ with gr.Blocks(css=css, title="DKT", head=head_html) as demo:
669
  )
670
 
671
  def on_submit(video_file, model_size, num_inference_steps, overlap):
672
- print('on_submit is calling')
673
  logger.info('on_submit is calling')
674
-
675
  if video_file is None:
676
  return None, None, None, None, None, None, "Please upload a video file"
677
 
678
  try:
679
-
680
  output_path, glb_files = process_video(
681
  video_file, model_size, num_inference_steps, overlap
682
  )
683
 
684
 
685
-
686
  if output_path is None:
687
  return None, None, None, None, None, None, glb_files
688
 
@@ -732,23 +585,12 @@ with gr.Blocks(css=css, title="DKT", head=head_html) as demo:
732
  if __name__ == '__main__':
733
 
734
  #* main code, model and moge model initialization
735
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
736
- logger.info(f"device = {device}")
737
- print(f"device = {device}")
738
 
739
- load_model_1_3b(device=device)
740
- load_moge_model(device=device)
741
- # torch.cuda.empty_cache()
742
- logger.info('model init done!')
743
- print('model init done!')
744
 
745
- demo.queue().launch(share = True)
746
-
747
- # demo.queue(
748
- # api_open=False,
749
- # ).launch()
750
 
 
 
751
 
752
- # server_name="0.0.0.0", server_port=7860
753
 
754
 
 
1
 
2
  import os
 
 
 
3
  import gradio as gr
4
+
5
 
6
  import numpy as np
7
  import torch
 
9
  from loguru import logger
10
  from tqdm import tqdm
11
  from tools.common_utils import save_video
12
+ from dkt.pipelines.pipeline import DKTPipeline, ModelConfig
13
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  import cv2
16
  import copy
 
18
 
19
  from os.path import join
20
  from tools.depth2pcd import depth2pcd
21
+ # from moge.model.v2 import MoGeModel
22
 
23
 
24
  from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map
 
29
  import spaces
30
 
31
 
32
+
 
33
  #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors
34
  PROMPT = 'depth'
35
  NEGATIVE_PROMPT = ''
36
 
37
+ height = 480
38
+ width = 832
39
+ window_size = 21
40
+
41
+
42
+ DKT_PIPELINE = DKTPipeline()
43
+
44
  example_inputs = [
45
 
46
  ["examples/1.mp4", "1.3B", 5, 3],
 
81
  # ["examples/b68045aa2128ab63d9c7518f8d62eafe.mp4", "1.3B", 5, 3],
82
  ]
83
 
 
 
 
 
 
 
 
 
 
84
 
85
  def resize_frame(frame, height, width):
86
  frame = np.array(frame)
 
243
 
244
 
245
 
246
+
247
+
248
+
249
  def process_video(
250
  video_file,
251
  model_size,
252
  num_inference_steps,
253
  overlap
254
  ):
255
+ global height
256
+ global width
257
+ global window_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
260
+ cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_')
261
+
 
 
 
 
 
 
262
 
 
 
 
 
 
263
 
264
+ prediction_result = DKT_PIPELINE(video_file, prompt=PROMPT, \
265
+ negative_prompt=NEGATIVE_PROMPT,\
266
+ height=height,width=width,num_inference_steps=num_inference_steps,\
267
+ overlap=overlap, return_rgb=True)
268
+
269
+
270
+
271
 
272
+ frame_length = len(prediction_result['rgb_frames'])
273
+ vis_pc_num = 4
274
+ indices = np.linspace(0, frame_length-1, vis_pc_num)
275
+ indices = np.round(indices).astype(np.int32)
276
+
277
+
278
+ pcds = DKT_PIPELINE.prediction2pc_v2(prediction_result['depth_map'], prediction_result['rgb_frames'], indices, return_pcd=True)
279
+ glb_files = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ for idx, pcd in enumerate(pcds):
282
+ points = np.asarray(pcd.points)
283
+ colors = np.asarray(pcd.colors) if pcd.has_colors() else None
 
 
 
 
 
 
 
 
 
 
 
284
 
285
 
286
 
287
+ points[:, 2] = -points[:, 2]
288
+ points[:, 0] = -points[:, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
+ glb_filename = os.path.join(cur_save_dir, f'{timestamp}_{idx:02d}.glb')
292
+ success = create_simple_glb_from_pointcloud(points, colors, glb_filename)
293
+ if not success:
294
+ logger.warning(f"Failed to save GLB file: {glb_filename}")
295
 
296
+ glb_files.append(glb_filename)
297
 
298
+
299
+
300
+
301
+ #* save depth predictions video
302
+ output_filename = f"output_{timestamp}.mp4"
303
+ output_path = os.path.join(cur_save_dir, output_filename)
304
 
305
+
306
+ cap = cv2.VideoCapture(video_file)
307
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
308
+ cap.release()
309
 
310
+ save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8)
311
+ return output_path, glb_files
 
 
312
 
 
313
 
 
 
 
 
 
314
 
315
 
316
 
 
526
  )
527
 
528
  def on_submit(video_file, model_size, num_inference_steps, overlap):
 
529
  logger.info('on_submit is calling')
 
530
  if video_file is None:
531
  return None, None, None, None, None, None, "Please upload a video file"
532
 
533
  try:
 
534
  output_path, glb_files = process_video(
535
  video_file, model_size, num_inference_steps, overlap
536
  )
537
 
538
 
 
539
  if output_path is None:
540
  return None, None, None, None, None, None, glb_files
541
 
 
585
  if __name__ == '__main__':
586
 
587
  #* main code, model and moge model initialization
 
 
 
588
 
 
 
 
 
 
589
 
 
 
 
 
 
590
 
591
+
592
+ demo.queue().launch(share = True)
593
 
594
+
595
 
596
 
app_old.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+
5
+
6
+ import gradio as gr
7
+ # gr.set_config(schema_inference=False)
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+ from loguru import logger
13
+ from tqdm import tqdm
14
+ from tools.common_utils import save_video
15
+ from dkt.pipelines.pipeline import WanVideoPipeline, ModelConfig
16
+
17
+ # try:
18
+ # import gradio_client.utils as _gc_utils
19
+ # if hasattr(_gc_utils, "get_type"):
20
+ # _orig_get_type = _gc_utils.get_type
21
+ # def _get_type_safe(schema):
22
+ # if not isinstance(schema, dict):
23
+ # return "Any"
24
+ # return _orig_get_type(schema)
25
+ # _gc_utils.get_type = _get_type_safe
26
+ # except Exception:
27
+ # pass
28
+
29
+ # # Additional guard: handle boolean JSON Schemas and parsing errors
30
+ # try:
31
+ # import gradio_client.utils as _gc_utils
32
+ # # Wrap the internal _json_schema_to_python_type if present
33
+ # if hasattr(_gc_utils, "_json_schema_to_python_type"):
34
+ # _orig_internal = _gc_utils._json_schema_to_python_type
35
+ # def _json_schema_to_python_type_safe(schema, defs=None):
36
+ # if isinstance(schema, bool):
37
+ # return "Any"
38
+ # try:
39
+ # return _orig_internal(schema, defs)
40
+ # except Exception:
41
+ # return "Any"
42
+ # _gc_utils._json_schema_to_python_type = _json_schema_to_python_type_safe
43
+
44
+ # # Also wrap the public json_schema_to_python_type to be extra defensive
45
+ # if hasattr(_gc_utils, "json_schema_to_python_type"):
46
+ # _orig_public = _gc_utils.json_schema_to_python_type
47
+ # def json_schema_to_python_type_safe(schema):
48
+ # try:
49
+ # return _orig_public(schema)
50
+ # except Exception:
51
+ # return "Any"
52
+ # _gc_utils.json_schema_to_python_type = json_schema_to_python_type_safe
53
+ # except Exception:
54
+ # pass
55
+
56
+ import cv2
57
+ import copy
58
+ import trimesh
59
+
60
+ from os.path import join
61
+ from tools.depth2pcd import depth2pcd
62
+ from moge.model.v2 import MoGeModel
63
+
64
+
65
+ from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map
66
+ import glob
67
+ import datetime
68
+ import shutil
69
+ import tempfile
70
+ import spaces
71
+
72
+
73
+ PIPE_1_3B = None
74
+ MOGE_MODULE = None
75
+ #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors
76
+ PROMPT = 'depth'
77
+ NEGATIVE_PROMPT = ''
78
+
79
+ example_inputs = [
80
+
81
+ ["examples/1.mp4", "1.3B", 5, 3],
82
+ ["examples/33.mp4", "1.3B", 5, 3],
83
+
84
+
85
+
86
+ ["examples/7.mp4", "1.3B", 5, 3],
87
+ ["examples/8.mp4", "1.3B", 5, 3],
88
+ ["examples/9.mp4", "1.3B", 5, 3],
89
+
90
+ # ["examples/178db6e89ab682bfc612a3290fec58dd.mp4", "1.3B", 5, 3],
91
+ ["examples/36.mp4", "1.3B", 5, 3],
92
+ ["examples/39.mp4", "1.3B", 5, 3],
93
+
94
+ # ["examples/b1f1fa44f414d7731cd7d77751093c44.mp4", "1.3B", 5, 3],
95
+
96
+ ["examples/10.mp4", "1.3B", 5, 3],
97
+ ["examples/30.mp4", "1.3B", 5, 3],
98
+ ["examples/3.mp4", "1.3B", 5, 3],
99
+
100
+ ["examples/32.mp4", "1.3B", 5, 3],
101
+
102
+ ["examples/35.mp4", "1.3B", 5, 3],
103
+
104
+ ["examples/40.mp4", "1.3B", 5, 3],
105
+ ["examples/2.mp4", "1.3B", 5, 3],
106
+
107
+ # ["examples/31.mp4", "1.3B", 5, 3],
108
+ # ["examples/DJI_20250912164311_0007_D.mp4", "1.3B", 5, 3],
109
+ # ["examples/DJI_20250912163642_0003_D.mp4", "1.3B", 5, 3],
110
+
111
+ # ["examples/5.mp4", "1.3B", 5, 3],
112
+
113
+ # ["examples/1b0daeb776471c7389b36cee53049417.mp4", "1.3B", 5, 3],
114
+ # ["examples/8a6dfb8cfe80634f4f77ae9aa830d075.mp4", "1.3B", 5, 3],
115
+ # ["examples/69230f105ad8740e08d743a8ee11c651.mp4", "1.3B", 5, 3],
116
+ # ["examples/b68045aa2128ab63d9c7518f8d62eafe.mp4", "1.3B", 5, 3],
117
+ ]
118
+
119
+
120
+
121
+
122
+
123
+ height = 480
124
+ width = 832
125
+ window_size = 21
126
+
127
+
128
+
129
+ def resize_frame(frame, height, width):
130
+ frame = np.array(frame)
131
+ frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
132
+ frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True)
133
+ frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy()
134
+ frame = Image.fromarray(frame)
135
+ return frame
136
+
137
+
138
+
139
+ def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene:
140
+ pts_3d = point_map[valid_mask] * np.array([-1, -1, 1])
141
+ pts_rgb = frame[valid_mask]
142
+
143
+ # Initialize a 3D scene
144
+ scene_3d = trimesh.Scene()
145
+
146
+ # Add point cloud data to the scene
147
+ point_cloud_data = trimesh.PointCloud(
148
+ vertices=pts_3d, colors=pts_rgb
149
+ )
150
+
151
+ scene_3d.add_geometry(point_cloud_data)
152
+ return scene_3d
153
+
154
+
155
+
156
+ def create_simple_glb_from_pointcloud(points, colors, glb_filename):
157
+ try:
158
+ if len(points) == 0:
159
+ logger.warning(f"No valid points to create GLB for {glb_filename}")
160
+ return False
161
+
162
+ if colors is not None:
163
+ # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]")
164
+ pts_rgb = colors
165
+ else:
166
+ logger.info("No colors provided, adding default white colors")
167
+ pts_rgb = np.ones((len(points), 3))
168
+
169
+ valid_mask = np.ones(len(points), dtype=bool)
170
+
171
+ scene_3d = pmap_to_glb(points, valid_mask, pts_rgb)
172
+
173
+ scene_3d.export(glb_filename)
174
+ # logger.info(f"Saved GLB file using trimesh: {glb_filename}")
175
+
176
+ return True
177
+
178
+ except Exception as e:
179
+ logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}")
180
+ return False
181
+
182
+
183
+
184
+
185
+
186
+ def extract_frames_from_video_file(video_path):
187
+ try:
188
+ cap = cv2.VideoCapture(video_path)
189
+ frames = []
190
+
191
+ fps = cap.get(cv2.CAP_PROP_FPS)
192
+ if fps <= 0:
193
+ fps = 15.0
194
+
195
+ while True:
196
+ ret, frame = cap.read()
197
+ if not ret:
198
+ break
199
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
200
+ frame_rgb = Image.fromarray(frame_rgb)
201
+ frames.append(frame_rgb)
202
+
203
+ cap.release()
204
+ return frames, fps
205
+ except Exception as e:
206
+ logger.error(f"Error extracting frames from {video_path}: {str(e)}")
207
+ return [], 15.0
208
+
209
+
210
+
211
+ def load_moge_model(device="cuda:0"):
212
+ global MOGE_MODULE
213
+ if MOGE_MODULE is not None:
214
+ return MOGE_MODULE
215
+ logger.info(f"Loading MoGe model on {device}...")
216
+ MOGE_MODULE = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal').to(device)
217
+ return MOGE_MODULE
218
+
219
+
220
+ def load_model_1_3b(device="cuda:0"):
221
+ global PIPE_1_3B
222
+
223
+ if PIPE_1_3B is not None:
224
+ return PIPE_1_3B
225
+
226
+ logger.info(f"Loading 1.3B model on {device}...")
227
+
228
+ pipe = WanVideoPipeline.from_pretrained(
229
+ torch_dtype=torch.bfloat16,
230
+ device=device,
231
+ model_configs=[
232
+ ModelConfig(
233
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
234
+ origin_file_pattern="diffusion_pytorch_model*.safetensors",
235
+ offload_device="cpu",
236
+ ),
237
+ ModelConfig(
238
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
239
+ origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
240
+ offload_device="cpu",
241
+ ),
242
+ ModelConfig(
243
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
244
+ origin_file_pattern="Wan2.1_VAE.pth",
245
+ offload_device="cpu",
246
+ ),
247
+ ModelConfig(
248
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
249
+ origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
250
+ offload_device="cpu",
251
+ ),
252
+ ],
253
+ training_strategy="origin",
254
+ )
255
+
256
+
257
+ lora_config = ModelConfig(
258
+ model_id="Daniellesry/DKT-Depth-1-3B",
259
+ origin_file_pattern="dkt-1-3B.safetensors",
260
+ offload_device="cpu",
261
+ )
262
+
263
+ lora_config.download_if_necessary(use_usp=False)
264
+
265
+ pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work?
266
+ pipe.enable_vram_management()
267
+
268
+
269
+ PIPE_1_3B = pipe
270
+
271
+ return pipe
272
+
273
+
274
+
275
+
276
+
277
+ def get_model(model_size):
278
+ if model_size == "1.3B":
279
+ assert PIPE_1_3B is not None, "1.3B model not initialized"
280
+ return PIPE_1_3B
281
+ else:
282
+ raise ValueError(f"Unsupported model size: {model_size}")
283
+
284
+
285
+
286
+
287
+
288
+
289
+
290
+
291
+
292
+ @spaces.GPU(duration=120)
293
+ @torch.inference_mode()
294
+ def process_video(
295
+ video_file,
296
+ model_size,
297
+ num_inference_steps,
298
+ overlap
299
+ ):
300
+
301
+ pipe = get_model(model_size)
302
+
303
+ if pipe is None:
304
+ return None, f"Model {model_size} not initialized. Please restart the application."
305
+
306
+ tmp_video_path = video_file
307
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
308
+
309
+
310
+ cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_')
311
+ origin_frames, input_fps = extract_frames_from_video_file(tmp_video_path)
312
+
313
+ if not origin_frames:
314
+ return None, "Failed to extract frames from video"
315
+
316
+ logger.info(f"Extracted {len(origin_frames)} frames from video")
317
+
318
+ original_width, original_height = origin_frames[0].size
319
+ ROTATE = False
320
+ if original_width < original_height:
321
+ ROTATE = True
322
+ origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames]
323
+ tmp = original_width
324
+ original_width = original_height
325
+ original_height = tmp
326
+
327
+
328
+
329
+ global height
330
+ global width
331
+ global window_size
332
+
333
+ frames = [resize_frame(frame, height, width) for frame in origin_frames]
334
+ frame_length = len(frames)
335
+ if (frame_length - 1) % 4 != 0:
336
+ new_len = ((frame_length - 1) // 4 + 1) * 4 + 1
337
+ frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)]
338
+
339
+
340
+ control_video = frames
341
+ video, vae_outs = pipe(
342
+ prompt=PROMPT,
343
+ negative_prompt=NEGATIVE_PROMPT,
344
+ control_video=control_video,
345
+ height=height,
346
+ width=width,
347
+ num_frames=len(control_video),
348
+ seed=1,
349
+ tiled=False,
350
+ num_inference_steps=num_inference_steps,
351
+ sliding_window_size=window_size,
352
+ sliding_window_stride=window_size - overlap,
353
+ cfg_scale=1.0,
354
+ )
355
+
356
+ #* moge process
357
+ torch.cuda.empty_cache()
358
+ processed_video = video[:frame_length]
359
+
360
+
361
+ processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video]
362
+ if ROTATE:
363
+ processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video]
364
+ origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames]
365
+
366
+
367
+ color_predictions = []
368
+ if PROMPT == 'depth':
369
+ prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video]
370
+ prediced_depth_map_np = np.stack(prediced_depth_map_np)
371
+ prediced_depth_map_np = prediced_depth_map_np/ 255.0
372
+ __min = prediced_depth_map_np.min()
373
+ __max = prediced_depth_map_np.max()
374
+ prediced_depth_map_np = (prediced_depth_map_np - __min) / (__max - __min)
375
+ color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np]
376
+ else:
377
+ color_predictions = processed_video
378
+
379
+
380
+
381
+
382
+
383
+ #* required parameters for MoGe
384
+
385
+ # todo, inference MoGe only once
386
+
387
+ resize_W,resize_H = origin_frames[0].size
388
+
389
+ vis_pc_num = 4
390
+ indices = np.linspace(0, frame_length-1, vis_pc_num)
391
+ indices = np.round(indices).astype(np.int32)
392
+ pc_save_dir = os.path.join(cur_save_dir, 'pointclouds')
393
+ os.makedirs(pc_save_dir, exist_ok=True)
394
+
395
+ glb_files = []
396
+ moge_device = MOGE_MODULE.device if MOGE_MODULE is not None else torch.device("cuda:0")
397
+
398
+ for idx in tqdm(indices):
399
+ orgin_rgb_frame = origin_frames[idx]
400
+ predicted_depth = processed_video[idx]
401
+
402
+ # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
403
+ input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
404
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
405
+
406
+ output = MOGE_MODULE.infer(input_image)
407
+ #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
408
+ moge_intrinsics = output['intrinsics'].cpu().numpy()
409
+ moge_mask = output['mask'].cpu().numpy()
410
+ moge_depth = output['depth'].cpu().numpy()
411
+
412
+ predicted_depth = np.array(predicted_depth)
413
+ predicted_depth = predicted_depth.mean(-1) / 255.0
414
+
415
+ metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask)
416
+
417
+ moge_intrinsics[0, 0] *= resize_W
418
+ moge_intrinsics[1, 1] *= resize_H
419
+ moge_intrinsics[0, 2] *= resize_W
420
+ moge_intrinsics[1, 2] *= resize_H
421
+
422
+ # pcd = depth2pcd(metric_depth, moge_intrinsics, color=cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB), input_mask=moge_mask, ret_pcd=True)
423
+ pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=True)
424
+
425
+ # pcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) * np.array([1, -1, -1], dtype=np.float32))
426
+
427
+ apply_filter = True
428
+ if apply_filter:
429
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=3.0)
430
+ pcd = pcd.select_by_index(ind)
431
+
432
+ #* save pcd: o3d.io.write_point_cloud(f'{pc_save_dir}/{timestamp}_{idx:02d}.ply', pcd)
433
+ points = np.asarray(pcd.points)
434
+ colors = np.asarray(pcd.colors) if pcd.has_colors() else None
435
+
436
+
437
+ # ==== 新增:上下翻转点云 ====
438
+ points[:, 2] = -points[:, 2]
439
+ points[:, 0] = -points[:, 0]
440
+ # =========================
441
+
442
+
443
+ glb_filename = os.path.join(pc_save_dir, f'{timestamp}_{idx:02d}.glb')
444
+ success = create_simple_glb_from_pointcloud(points, colors, glb_filename)
445
+ if not success:
446
+ logger.warning(f"Failed to save GLB file: {glb_filename}")
447
+
448
+ glb_files.append(glb_filename)
449
+
450
+
451
+
452
+ #* save depth predictions video
453
+ output_filename = f"output_{timestamp}.mp4"
454
+ output_path = os.path.join(cur_save_dir, output_filename)
455
+ save_video(color_predictions, output_path, fps=input_fps, quality=5)
456
+ return output_path, glb_files
457
+
458
+
459
+
460
+
461
+
462
+
463
+ #* gradio creation and initialization
464
+
465
+
466
+ css = """
467
+ #video-display-container {
468
+ max-height: 100vh;
469
+ }
470
+ #video-display-input {
471
+ max-height: 80vh;
472
+ }
473
+ #video-display-output {
474
+ max-height: 80vh;
475
+ }
476
+ #download {
477
+ height: 62px;
478
+ }
479
+ .title {
480
+ text-align: center;
481
+ }
482
+ .description {
483
+ text-align: center;
484
+ }
485
+ .gradio-examples {
486
+ max-height: 400px;
487
+ overflow-y: auto;
488
+ }
489
+ .gradio-examples .examples-container {
490
+ display: grid;
491
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
492
+ gap: 10px;
493
+ padding: 10px;
494
+ }
495
+ .gradio-container .gradio-examples .pagination,
496
+ .gradio-container .gradio-examples .pagination button,
497
+ div[data-testid="examples"] .pagination,
498
+ div[data-testid="examples"] .pagination button {
499
+ font-size: 28px !important;
500
+ font-weight: bold !important;
501
+ padding: 15px 20px !important;
502
+ min-width: 60px !important;
503
+ height: 60px !important;
504
+ border-radius: 10px !important;
505
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
506
+ color: white !important;
507
+ border: none !important;
508
+ cursor: pointer !important;
509
+ margin: 8px !important;
510
+ display: inline-block !important;
511
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
512
+ transition: all 0.3s ease !important;
513
+ }
514
+
515
+ div[data-testid="examples"] .pagination button:not(.active),
516
+ .gradio-container .gradio-examples .pagination button:not(.active) {
517
+ font-size: 32px !important;
518
+ font-weight: bold !important;
519
+ padding: 15px 20px !important;
520
+ min-width: 60px !important;
521
+ height: 60px !important;
522
+ background: linear-gradient(135deg, #8a9cf0 0%, #9a6bb2 100%) !important;
523
+ opacity: 0.8 !important;
524
+ }
525
+
526
+ div[data-testid="examples"] .pagination button:hover,
527
+ .gradio-container .gradio-examples .pagination button:hover {
528
+ background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%) !important;
529
+ transform: translateY(-2px) !important;
530
+ box-shadow: 0 6px 12px rgba(0,0,0,0.3) !important;
531
+ opacity: 1 !important;
532
+ }
533
+
534
+ div[data-testid="examples"] .pagination button.active,
535
+ .gradio-container .gradio-examples .pagination button.active {
536
+ background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%) !important;
537
+ box-shadow: 0 4px 8px rgba(17,153,142,0.4) !important;
538
+ opacity: 1 !important;
539
+ }
540
+
541
+ button[class*="pagination"],
542
+ button[class*="page"] {
543
+ font-size: 28px !important;
544
+ font-weight: bold !important;
545
+ padding: 15px 20px !important;
546
+ min-width: 60px !important;
547
+ height: 60px !important;
548
+ border-radius: 10px !important;
549
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
550
+ color: white !important;
551
+ border: none !important;
552
+ cursor: pointer !important;
553
+ margin: 8px !important;
554
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
555
+ transition: all 0.3s ease !important;
556
+ }
557
+ """
558
+
559
+
560
+
561
+ head_html = """
562
+ <link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E">
563
+ <link rel="shortcut icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E">
564
+ <link rel="icon" type="image/png" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E">
565
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
566
+ """
567
+
568
+
569
+
570
+ # description = """Official demo for **DKT **."""
571
+
572
+ # with gr.Blocks(css=css, title="DKT - Diffusion Knows Transparency", favicon_path="favicon.ico") as demo:
573
+
574
+
575
+ with gr.Blocks(css=css, title="DKT", head=head_html) as demo:
576
+ # gr.Markdown(title, elem_classes=["title"])
577
+ """
578
+
579
+ <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
580
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
581
+ </a>
582
+ <a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
583
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
584
+ </a>
585
+ <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
586
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
587
+ </a>
588
+
589
+
590
+ """
591
+
592
+ gr.Markdown(
593
+ """
594
+ # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation
595
+ <p align="center">
596
+ <a title="Github" href="https://github.com/Daniellli/DKT" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
597
+ <img src="https://img.shields.io/github/stars/Daniellli/DKT?style=social" alt="badge-github-stars">
598
+ </a>
599
+ """
600
+ )
601
+ # gr.Markdown(description, elem_classes=["description"])
602
+ # gr.Markdown("### Video Processing Demo", elem_classes=["description"])
603
+
604
+ with gr.Row():
605
+ with gr.Column():
606
+ input_video = gr.Video(label="Input Video", elem_id='video-display-input')
607
+
608
+ model_size = gr.Radio(
609
+ # choices=["1.3B", "14B"],
610
+ choices=["1.3B"],
611
+ value="1.3B",
612
+ label="Model Size"
613
+ )
614
+
615
+
616
+ with gr.Accordion("Advanced Parameters", open=False):
617
+ num_inference_steps = gr.Slider(
618
+ minimum=1, maximum=50, value=5, step=1,
619
+ label="Number of Inference Steps"
620
+ )
621
+ overlap = gr.Slider(
622
+ minimum=1, maximum=20, value=3, step=1,
623
+ label="Overlap"
624
+ )
625
+
626
+ submit = gr.Button(value="Compute Depth", variant="primary")
627
+
628
+ with gr.Column():
629
+ output_video = gr.Video(
630
+ label="Depth Outputs",
631
+ elem_id='video-display-output',
632
+ autoplay=True
633
+ )
634
+ vis_video = gr.Video(
635
+ label="Visualization Video",
636
+ visible=False,
637
+ autoplay=True
638
+ )
639
+
640
+ with gr.Row():
641
+ gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"])
642
+
643
+ with gr.Row(equal_height=True):
644
+ with gr.Column(scale=1):
645
+ output_point_map0 = gr.Model3D(
646
+ label="Point Cloud Key Frame 1",
647
+ clear_color=[1.0, 1.0, 1.0, 1.0],
648
+ interactive=False,
649
+ )
650
+ with gr.Column(scale=1):
651
+ output_point_map1 = gr.Model3D(
652
+ label="Point Cloud Key Frame 2",
653
+ clear_color=[1.0, 1.0, 1.0, 1.0],
654
+ interactive=False
655
+ )
656
+
657
+
658
+ with gr.Row(equal_height=True):
659
+
660
+ with gr.Column(scale=1):
661
+ output_point_map2 = gr.Model3D(
662
+ label="Point Cloud Key Frame 3",
663
+ clear_color=[1.0, 1.0, 1.0, 1.0],
664
+ interactive=False
665
+ )
666
+ with gr.Column(scale=1):
667
+ output_point_map3 = gr.Model3D(
668
+ label="Point Cloud Key Frame 4",
669
+ clear_color=[1.0, 1.0, 1.0, 1.0],
670
+ interactive=False
671
+ )
672
+
673
+ def on_submit(video_file, model_size, num_inference_steps, overlap):
674
+ print('on_submit is calling')
675
+ logger.info('on_submit is calling')
676
+
677
+ if video_file is None:
678
+ return None, None, None, None, None, None, "Please upload a video file"
679
+
680
+ try:
681
+
682
+ output_path, glb_files = process_video(
683
+ video_file, model_size, num_inference_steps, overlap
684
+ )
685
+
686
+
687
+
688
+ if output_path is None:
689
+ return None, None, None, None, None, None, glb_files
690
+
691
+ model3d_outputs = [None] * 4
692
+ if glb_files:
693
+ for i, glb_file in enumerate(glb_files[:4]):
694
+ if os.path.exists(glb_file):
695
+ model3d_outputs[i] = glb_file
696
+
697
+
698
+
699
+ return output_path, None, *model3d_outputs
700
+
701
+ except Exception as e:
702
+ logger.error(e)
703
+ return None, None, None, None, None, None
704
+
705
+
706
+ submit.click(
707
+ on_submit,
708
+ inputs=[
709
+ input_video, model_size, num_inference_steps, overlap
710
+ ],
711
+ outputs=[
712
+ output_video, vis_video, output_point_map0, output_point_map1, output_point_map2, output_point_map3
713
+ ]
714
+ )
715
+
716
+
717
+
718
+ logger.info(f'there are {len(example_inputs)} demo files')
719
+ print(f'there are {len(example_inputs)} demo files')
720
+
721
+ examples = gr.Examples(
722
+ examples=example_inputs,
723
+ inputs=[input_video, model_size, num_inference_steps, overlap],
724
+ outputs=[
725
+ output_video, vis_video,
726
+ output_point_map0, output_point_map1, output_point_map2, output_point_map3
727
+ ],
728
+ fn=on_submit,
729
+ examples_per_page=12,
730
+ cache_examples=False
731
+ )
732
+
733
+
734
+ if __name__ == '__main__':
735
+
736
+ #* main code, model and moge model initialization
737
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
738
+ logger.info(f"device = {device}")
739
+ print(f"device = {device}")
740
+
741
+ load_model_1_3b(device=device)
742
+ load_moge_model(device=device)
743
+ # torch.cuda.empty_cache()
744
+ logger.info('model init done!')
745
+ print('model init done!')
746
+
747
+ demo.queue().launch(share = True)
748
+
749
+ # demo.queue(
750
+ # api_open=False,
751
+ # ).launch()
752
+
753
+
754
+ # server_name="0.0.0.0", server_port=7860
755
+
756
+
dkt/pipelines/{wan_video_new.py → pipeline.py} RENAMED
@@ -29,7 +29,7 @@ from ..lora import GeneralLoRALoader
29
 
30
  from loguru import logger
31
 
32
-
33
 
34
  class BasePipeline(torch.nn.Module):
35
 
@@ -222,7 +222,7 @@ class ModelConfig:
222
  allow_patterns=allow_file_pattern,
223
  ignore_patterns=downloaded_files if downloaded_files else None
224
  )
225
-
226
  # Let rank 1, 2, ... wait for rank 0
227
  if use_usp:
228
  import torch.distributed as dist
@@ -716,6 +716,323 @@ class WanVideoPipeline(BasePipeline):
716
 
717
 
718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
 
721
 
@@ -1480,7 +1797,4 @@ def model_fn_wan_video(
1480
 
1481
  #* unpatchify, from [1, ( (F-1)/4 * H/16 * W/16), 64] to [1, 16, (F-1)/4, H/8, W/8]
1482
  x = dit.unpatchify(x, (f, h, w))
1483
- return x
1484
-
1485
-
1486
-
 
29
 
30
  from loguru import logger
31
 
32
+ import spaces
33
 
34
  class BasePipeline(torch.nn.Module):
35
 
 
222
  allow_patterns=allow_file_pattern,
223
  ignore_patterns=downloaded_files if downloaded_files else None
224
  )
225
+
226
  # Let rank 1, 2, ... wait for rank 0
227
  if use_usp:
228
  import torch.distributed as dist
 
716
 
717
 
718
 
719
+ def extract_frames_from_video_file(video_path):
720
+ try:
721
+ cap = cv2.VideoCapture(video_path)
722
+ frames = []
723
+
724
+ fps = cap.get(cv2.CAP_PROP_FPS)
725
+ if fps <= 0:
726
+ fps = 15.0
727
+
728
+ while True:
729
+ ret, frame = cap.read()
730
+ if not ret:
731
+ break
732
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
733
+ frame_rgb = Image.fromarray(frame_rgb)
734
+ frames.append(frame_rgb)
735
+
736
+ cap.release()
737
+ return frames, fps
738
+ except Exception as e:
739
+ logger.error(f"Error extracting frames from {video_path}: {str(e)}")
740
+ return [], 15.0
741
+
742
+
743
+ def resize_frame(frame, height, width):
744
+ frame = np.array(frame)
745
+ frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
746
+ frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True)
747
+ frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy()
748
+ frame = Image.fromarray(frame)
749
+ return frame
750
+
751
+
752
+
753
+ from moge.model.v2 import MoGeModel
754
+ from tools.eval_utils import transfer_pred_disp2depth, transfer_pred_disp2depth_v2, colorize_depth_map
755
+ from tools.depth2pcd import depth2pcd
756
+ import cv2, copy
757
+
758
+ class DKTPipeline:
759
+ def __init__(self, ):
760
+
761
+ self.main_pipe = self.init_model()
762
+
763
+ self.moge_pipe = self.load_moge_model()
764
+
765
+
766
+
767
+
768
+
769
+ def init_model(self ):
770
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
771
+
772
+ pipe = WanVideoPipeline.from_pretrained(
773
+ torch_dtype=torch.bfloat16,
774
+ device=device,
775
+ model_configs=[
776
+ ModelConfig(
777
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
778
+ origin_file_pattern="diffusion_pytorch_model*.safetensors",
779
+ offload_device="cpu",
780
+ ),
781
+ ModelConfig(
782
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
783
+ origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
784
+ offload_device="cpu",
785
+ ),
786
+ ModelConfig(
787
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
788
+ origin_file_pattern="Wan2.1_VAE.pth",
789
+ offload_device="cpu",
790
+ ),
791
+ ModelConfig(
792
+ model_id="PAI/Wan2.1-Fun-1.3B-Control",
793
+ origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
794
+ offload_device="cpu",
795
+ ),
796
+ ],
797
+ training_strategy="origin",
798
+ )
799
+
800
+
801
+ lora_config = ModelConfig(
802
+ model_id="Daniellesry/DKT-Depth-1-3B",
803
+ origin_file_pattern="dkt-1-3B.safetensors",
804
+ offload_device="cpu",
805
+ )
806
+ lora_config.download_if_necessary(use_usp=False)
807
+
808
+ pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work?
809
+ pipe.enable_vram_management()
810
+ return pipe
811
+
812
+ def load_moge_model(self):
813
+ device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
814
+
815
+ cached_model_path = 'checkpoints/moge_ckpt/moge-2-vitl-normal/model.pt'
816
+ if os.path.exists(cached_model_path):
817
+ logger.info(f"Found cached model at {cached_model_path}, loading from cache...")
818
+ moge_pipe = MoGeModel.from_pretrained(cached_model_path).to(device)
819
+ else:
820
+ logger.info(f"Cache not found at {cached_model_path}, downloading from HuggingFace...")
821
+ os.makedirs(os.path.dirname(cached_model_path), exist_ok=True)
822
+ moge_pipe = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal', cache_dir=os.path.dirname(cached_model_path)).to(device)
823
+
824
+ return moge_pipe
825
+
826
+
827
+
828
+ @spaces.GPU(duration=120)
829
+ @torch.inference_mode()
830
+ def __call__(self, video_file, prompt='depth', \
831
+ negative_prompt='', height=480, width=832, \
832
+ num_inference_steps=5, window_size=21, \
833
+ overlap=3, vis_pc = False, return_rgb = False):
834
+
835
+
836
+ origin_frames, input_fps = extract_frames_from_video_file(video_file)
837
+
838
+ frame_length = len(origin_frames)
839
+
840
+ original_width, original_height = origin_frames[0].size
841
+
842
+ ROTATE = False
843
+ if original_width < original_height:#* ensure the width is the longer side
844
+ ROTATE = True
845
+ origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames]
846
+ tmp = original_width
847
+ original_width = original_height
848
+ original_height = tmp
849
+
850
+
851
+ frames = [resize_frame(frame, height, width) for frame in origin_frames]
852
+
853
+
854
+ if (frame_length - 1) % 4 != 0:
855
+ new_len = ((frame_length - 1) // 4 + 1) * 4 + 1
856
+ frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)]
857
+
858
+
859
+
860
+
861
+ video, vae_outs = self.main_pipe(
862
+ prompt=prompt,
863
+ negative_prompt=negative_prompt,
864
+ control_video=frames,
865
+ height=height,
866
+ width=width,
867
+ num_frames=len(frames),
868
+ seed=1,
869
+ tiled=False,
870
+ num_inference_steps=num_inference_steps,
871
+ sliding_window_size=window_size,
872
+ sliding_window_stride=window_size - overlap,
873
+ cfg_scale=1.0,
874
+ )
875
+ torch.cuda.empty_cache()
876
+
877
+ processed_video = video[:frame_length]
878
+ processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video]
879
+
880
+ if ROTATE:
881
+ processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video]
882
+ origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames]
883
+
884
+ color_predictions = []
885
+ if prompt == 'depth':
886
+ prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video]
887
+ prediced_depth_map_np = np.stack(prediced_depth_map_np)
888
+ prediced_depth_map_np = prediced_depth_map_np / 255.0
889
+
890
+ __min = prediced_depth_map_np.min()
891
+ __max = prediced_depth_map_np.max()
892
+
893
+ prediced_depth_map_np_normalized = (prediced_depth_map_np - __min) / (__max - __min)
894
+ color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np_normalized]
895
+ else:
896
+ color_predictions = processed_video
897
+
898
+ return_dict = {}
899
+
900
+ return_dict['depth_map'] = prediced_depth_map_np
901
+ return_dict['colored_depth_map'] = color_predictions
902
+
903
+
904
+
905
+ if vis_pc and prompt == 'depth':
906
+ vis_pc_num = 4
907
+ indices = np.linspace(0, frame_length-1, vis_pc_num)
908
+ indices = np.round(indices).astype(np.int32)
909
+ return_dict['point_clouds'] = self.prediction2pc(prediced_depth_map_np, origin_frames, indices)
910
+
911
+ if return_rgb:
912
+ return_dict['rgb_frames'] = origin_frames
913
+ return return_dict
914
+
915
+
916
+
917
+
918
+
919
+ def prediction2pc(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
920
+ resize_W,resize_H = RGB_frames[0].size
921
+ pcds = []
922
+ moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0")
923
+
924
+ for idx in tqdm(indices):
925
+ orgin_rgb_frame = RGB_frames[idx]
926
+ predicted_depth = prediction_depth_map[idx]
927
+
928
+ # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
929
+ input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
930
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
931
+ output = self.moge_pipe.infer(input_image)
932
+
933
+ #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
934
+ moge_intrinsics = output['intrinsics'].cpu().numpy()
935
+ moge_mask = output['mask'].cpu().numpy()
936
+ moge_depth = output['depth'].cpu().numpy()
937
+
938
+
939
+ metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask)
940
+
941
+ moge_intrinsics[0, 0] *= resize_W
942
+ moge_intrinsics[1, 1] *= resize_H
943
+ moge_intrinsics[0, 2] *= resize_W
944
+ moge_intrinsics[1, 2] *= resize_H
945
+
946
+ pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
947
+
948
+ if return_pcd:
949
+ #* [15,50], [2,3]
950
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
951
+ pcd = pcd.select_by_index(ind)
952
+ #todo downsample
953
+
954
+ pcds.append(pcd)
955
+
956
+ return pcds
957
+
958
+
959
+
960
+
961
+ @spaces.GPU()
962
+ @torch.inference_mode()
963
+ def moge_infer(self, input_image):
964
+ return self.moge_pipe.infer(input_image)
965
+
966
+
967
+
968
+ def prediction2pc_v2(self, prediction_depth_map, RGB_frames, indices, return_pcd = True,nb_neighbors = 20, std_ratio = 3.0):
969
+ """
970
+ call MoGe once
971
+ """
972
+ resize_W,resize_H = RGB_frames[0].size
973
+ pcds = []
974
+ moge_device = self.moge_pipe.device if self.moge_pipe is not None else torch.device("cuda:0")
975
+
976
+ for iidx, idx in enumerate(tqdm(indices)):
977
+
978
+ orgin_rgb_frame = RGB_frames[idx]
979
+ predicted_depth = prediction_depth_map[idx]
980
+ input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array
981
+
982
+
983
+ if iidx == 0:
984
+ # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
985
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
986
+ output = self.moge_infer(input_image)
987
+
988
+ #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])"
989
+ moge_intrinsics = output['intrinsics'].cpu().numpy()
990
+ moge_mask = output['mask'].cpu().numpy()
991
+ moge_depth = output['depth'].cpu().numpy()
992
+
993
+ metric_depth, scale, shift = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask, return_scale_shift=True)
994
+
995
+ moge_intrinsics[0, 0] *= resize_W
996
+ moge_intrinsics[1, 1] *= resize_H
997
+ moge_intrinsics[0, 2] *= resize_W
998
+ moge_intrinsics[1, 2] *= resize_H
999
+ else:
1000
+ metric_depth = transfer_pred_disp2depth_v2(predicted_depth, scale, shift)
1001
+
1002
+
1003
+ pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=return_pcd)
1004
+
1005
+ if return_pcd:
1006
+ #* [15,50], [2,3]
1007
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
1008
+ pcd = pcd.select_by_index(ind)
1009
+ #todo downsample
1010
+
1011
+ pcds.append(pcd)
1012
+
1013
+ return pcds
1014
+
1015
+
1016
+
1017
+
1018
+
1019
+
1020
+
1021
+
1022
+
1023
+
1024
+
1025
+
1026
+
1027
+
1028
+
1029
+
1030
+
1031
+
1032
+
1033
+
1034
+
1035
+
1036
 
1037
 
1038
 
 
1797
 
1798
  #* unpatchify, from [1, ( (F-1)/4 * H/16 * W/16), 64] to [1, 16, (F-1)/4, H/8, W/8]
1799
  x = dit.unpatchify(x, (f, h, w))
1800
+ return x