imlixinyang commited on
Commit
62432f1
·
1 Parent(s): 854d14d
Files changed (10) hide show
  1. app.py +31 -37
  2. app_gradio copy.py +682 -0
  3. app_gradio.py +139 -90
  4. index.html +293 -181
  5. models/render.py +4 -2
  6. packages.txt +3 -1
  7. pre-requirements.txt +2 -1
  8. quant.py +1 -2
  9. requirements.txt +2 -1
  10. utils.py +45 -19
app.py CHANGED
@@ -9,42 +9,36 @@ except ImportError:
9
  import os
10
  import subprocess
11
 
12
- # def install_cuda_toolkit():
13
- # # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
14
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
15
- # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
16
- # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
17
- # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
18
- # subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
19
-
20
- # os.environ["CUDA_HOME"] = "/usr/local/cuda"
21
- # os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
22
- # os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
23
- # os.environ["CUDA_HOME"],
24
- # "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
25
- # )
26
- # # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
27
- # os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
28
-
29
- # print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
30
-
31
- # subprocess.call('rm /usr/bin/gcc', shell=True)
32
- # subprocess.call('rm /usr/bin/g++', shell=True)
33
- # subprocess.call('rm /usr/local/cuda/bin/gcc', shell=True)
34
- # subprocess.call('rm /usr/local/cuda/bin/g++', shell=True)
35
-
36
- # subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
37
- # subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
38
-
39
- # subprocess.call('ln -s /usr/bin/gcc-11 /usr/local/cuda/bin/gcc', shell=True)
40
- # subprocess.call('ln -s /usr/bin/g++-11 /usr/local/cuda/bin/g++', shell=True)
41
-
42
- # subprocess.call('gcc --version', shell=True)
43
- # subprocess.call('g++ --version', shell=True)
44
-
45
- # install_cuda_toolkit()
46
 
47
- # subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6"}, shell=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  from flask import Flask, jsonify, request, send_file, render_template
50
  import base64
@@ -349,7 +343,7 @@ if __name__ == "__main__":
349
  parser = argparse.ArgumentParser()
350
  parser.add_argument('--port', type=int, default=7860)
351
  parser.add_argument("--ckpt", default=None)
352
- parser.add_argument("--gpu", type=int, default=2)
353
  parser.add_argument("--cache_dir", type=str, default="./tmpfiles")
354
  parser.add_argument("--offload_t5", type=bool, default=False)
355
  parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
@@ -380,7 +374,7 @@ if __name__ == "__main__":
380
  response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
381
  return response
382
 
383
- @spaces.GPU
384
  def generate_wrapper(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None):
385
  """生成函数的包装器,用于并发控制"""
386
  return generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path)
 
9
  import os
10
  import subprocess
11
 
12
+ try:
13
+ import gsplat
14
+ except ImportError:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def install_cuda_toolkit():
17
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
18
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
19
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
20
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
21
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
22
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
23
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
24
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
25
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
26
+ os.environ["CUDA_HOME"],
27
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
28
+ )
29
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
30
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
31
+ print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
32
+ subprocess.call('rm /usr/bin/gcc', shell=True)
33
+ subprocess.call('rm /usr/bin/g++', shell=True)
34
+ subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
35
+ subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
36
+ subprocess.call('gcc --version', shell=True)
37
+ subprocess.call('g++ --version', shell=True)
38
+
39
+ install_cuda_toolkit()
40
+
41
+ subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6", "PATH": "/usr/local/cuda/bin/:" + os.environ["PATH"]}, shell=True)
42
 
43
  from flask import Flask, jsonify, request, send_file, render_template
44
  import base64
 
343
  parser = argparse.ArgumentParser()
344
  parser.add_argument('--port', type=int, default=7860)
345
  parser.add_argument("--ckpt", default=None)
346
+ parser.add_argument("--gpu", type=int, default=0)
347
  parser.add_argument("--cache_dir", type=str, default="./tmpfiles")
348
  parser.add_argument("--offload_t5", type=bool, default=False)
349
  parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
 
374
  response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
375
  return response
376
 
377
+ @GPU
378
  def generate_wrapper(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=None):
379
  """生成函数的包装器,用于并发控制"""
380
  return generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path)
app_gradio copy.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ GPU = spaces.GPU
4
+ print("spaces GPU is available")
5
+ except ImportError:
6
+ def GPU(func):
7
+ return func
8
+
9
+ import os
10
+ import subprocess
11
+
12
+ try:
13
+ import gsplat
14
+ except ImportError:
15
+ def install_cuda_toolkit():
16
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
17
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
18
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
19
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
20
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
21
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
22
+
23
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
24
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
25
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
26
+ os.environ["CUDA_HOME"],
27
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
28
+ )
29
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
30
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0+PTX"
31
+
32
+ print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
33
+
34
+ subprocess.call('rm /usr/bin/gcc', shell=True)
35
+ subprocess.call('rm /usr/bin/g++', shell=True)
36
+
37
+ subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
38
+ subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
39
+
40
+ subprocess.call('gcc --version', shell=True)
41
+ subprocess.call('g++ --version', shell=True)
42
+
43
+ install_cuda_toolkit()
44
+
45
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0+PTX"
46
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
47
+ os.environ["PATH"] = "/usr/local/cuda/bin/:" + os.environ["PATH"]
48
+
49
+ subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712',
50
+ env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "9.0+PTX", "PATH": "/usr/local/cuda/bin/:" + os.environ["PATH"]}, shell=True)
51
+
52
+ from fastapi import FastAPI
53
+ from fastapi.staticfiles import StaticFiles
54
+ import gradio as gr
55
+ import base64
56
+ import io
57
+ from PIL import Image
58
+ import torch
59
+ import numpy as np
60
+ import os
61
+ import argparse
62
+ import imageio
63
+ import json
64
+ import time
65
+ import tempfile
66
+ import shutil
67
+
68
+ from huggingface_hub import hf_hub_download
69
+
70
+ import einops
71
+ import torch
72
+ import torch.nn as nn
73
+ import torch.nn.functional as F
74
+ import numpy as np
75
+
76
+ import imageio
77
+
78
+ from models import *
79
+ from utils import *
80
+
81
+ from transformers import T5TokenizerFast, UMT5EncoderModel
82
+
83
+ from diffusers import FlowMatchEulerDiscreteScheduler
84
+
85
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
86
+
87
+ class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
88
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
89
+ if schedule_timesteps is None:
90
+ schedule_timesteps = self.timesteps
91
+
92
+ return torch.argmin(
93
+ (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item()
94
+
95
+ class GenerationSystem(nn.Module):
96
+ def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False):
97
+ super().__init__()
98
+ self.device = device
99
+ self.offload_t5 = offload_t5
100
+ self.offload_vae = offload_vae
101
+
102
+ self.latent_dim = 48
103
+ self.temporal_downsample_factor = 4
104
+ self.spatial_downsample_factor = 16
105
+
106
+ self.feat_dim = 1024
107
+
108
+ self.latent_patch_size = 2
109
+
110
+ self.denoising_steps = [0, 250, 500, 750]
111
+
112
+ model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
113
+
114
+ self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval()
115
+
116
+ from models.autoencoder_kl_wan import WanCausalConv3d
117
+ with torch.no_grad():
118
+ for name, module in self.vae.named_modules():
119
+ if isinstance(module, WanCausalConv3d):
120
+ time_pad = module._padding[4]
121
+ module.padding = (0, module._padding[2], module._padding[0])
122
+ module._padding = (0, 0, 0, 0, 0, 0)
123
+ module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone())
124
+
125
+ self.vae.requires_grad_(False)
126
+
127
+ self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
128
+ self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
129
+
130
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer")
131
+
132
+ self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu")
133
+
134
+ self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False)
135
+
136
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim)))
137
+ # self.transformer.rope.freqs_f[:] = self.transformer.rope.freqs_f[:1]
138
+
139
+ weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1])
140
+ bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim)
141
+
142
+ extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02
143
+ extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim)
144
+
145
+ self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone())
146
+ self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone())
147
+
148
+ self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device)
149
+
150
+ self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3)
151
+
152
+ self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device))
153
+
154
+ self.transformer.disable_gradient_checkpointing()
155
+ self.transformer.gradient_checkpointing = False
156
+
157
+ self.add_feedback_for_transformer()
158
+
159
+ if ckpt_path is not None:
160
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
161
+ self.transformer.load_state_dict(state_dict["transformer"])
162
+ self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
163
+ print(f"Loaded {ckpt_path}.")
164
+
165
+ from quant import FluxFp8GeMMProcessor
166
+
167
+ FluxFp8GeMMProcessor(self.transformer)
168
+
169
+ del self.vae.post_quant_conv, self.vae.decoder
170
+ self.vae.to(self.device if not self.offload_vae else "cpu")
171
+ self.vae.to(torch.bfloat16)
172
+
173
+ self.transformer.to(self.device)
174
+
175
+ def latent_scale_fn(self, x):
176
+ return (x - self.latents_mean) / self.latents_std
177
+
178
+ def latent_unscale_fn(self, x):
179
+ return x * self.latents_std + self.latents_mean
180
+
181
+ def add_feedback_for_transformer(self):
182
+ self.use_feedback = True
183
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim)))
184
+
185
+ def encode_text(self, texts):
186
+ max_sequence_length = 512
187
+
188
+ text_inputs = self.tokenizer(
189
+ texts,
190
+ padding="max_length",
191
+ max_length=max_sequence_length,
192
+ truncation=True,
193
+ add_special_tokens=True,
194
+ return_attention_mask=True,
195
+ return_tensors="pt",
196
+ )
197
+ if getattr(self, "offload_t5", False):
198
+ text_input_ids = text_inputs.input_ids.to("cpu")
199
+ mask = text_inputs.attention_mask.to("cpu")
200
+ else:
201
+ text_input_ids = text_inputs.input_ids.to(self.device)
202
+ mask = text_inputs.attention_mask.to(self.device)
203
+ seq_lens = mask.gt(0).sum(dim=1).long()
204
+
205
+ if getattr(self, "offload_t5", False):
206
+ with torch.no_grad():
207
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device)
208
+ else:
209
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
210
+ text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)]
211
+ text_embeds = torch.stack(
212
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0
213
+ )
214
+ return text_embeds.float()
215
+
216
+ def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True):
217
+
218
+ out = self.transformer(
219
+ hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1),
220
+ timestep=t,
221
+ encoder_hidden_states=text_embeds,
222
+ return_dict=False,
223
+ )[0]
224
+
225
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
226
+
227
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
228
+ latents_pred_2d = noisy_latents - sigma * v_pred
229
+
230
+ if need_3d_mode:
231
+ scene_params = self.recon_decoder(
232
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
233
+ einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
234
+ cameras
235
+ ).flatten(1, -2)
236
+
237
+ images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
238
+
239
+ latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode(
240
+ einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
241
+ ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype)
242
+
243
+ return {
244
+ '2d': latents_pred_2d,
245
+ '3d': latents_pred_3d if need_3d_mode else None,
246
+ 'rgb_3d': images_pred if need_3d_mode else None,
247
+ 'scene': scene_params if need_3d_mode else None,
248
+ 'feat': feats
249
+ }
250
+
251
+ @torch.no_grad()
252
+ def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
253
+
254
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
255
+
256
+ self.vae.to(self.device)
257
+ self.text_encoder.to(self.device if not self.offload_t5 else "cpu")
258
+ self.transformer.to(self.device)
259
+ self.recon_decoder.to(self.device)
260
+ self.timesteps = self.timesteps.to(self.device)
261
+ self.latents_mean = self.latents_mean.to(self.device)
262
+ self.latents_std = self.latents_std.to(self.device)
263
+
264
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
265
+ batch_size = 1
266
+
267
+ cameras = cameras.to(self.device).unsqueeze(0)
268
+
269
+ if cameras.shape[1] != n_frame:
270
+ render_cameras = cameras.clone()
271
+ cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0)
272
+ else:
273
+ render_cameras = cameras
274
+
275
+ cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None)
276
+
277
+ render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None)
278
+
279
+ text = "[Static] " + text
280
+
281
+ text_embeds = self.encode_text([text])
282
+ # neg_text_embeds = self.encode_text([""]).repeat(batch_size, 1, 1)
283
+
284
+ masks = torch.zeros(batch_size, n_frame, device=self.device)
285
+
286
+ condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
287
+
288
+ if image is not None:
289
+ image = image.to(self.device)
290
+
291
+ latent = self.latent_scale_fn(self.vae.encode(
292
+ image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
293
+ ).latent_dist.sample().to(self.device)).squeeze(2)
294
+
295
+ masks[:, image_index] = 1
296
+ condition_latents[:, :, image_index] = latent
297
+
298
+ raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor)
299
+ raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame)
300
+
301
+ noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
302
+
303
+ noisy_latents = noise
304
+
305
+ torch.cuda.empty_cache()
306
+
307
+ if self.use_feedback:
308
+ prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
309
+
310
+ prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
311
+
312
+ for i in range(len(self.denoising_steps)):
313
+ t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device)
314
+
315
+ t = self.timesteps[t_ids]
316
+
317
+ if self.use_feedback:
318
+ _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1)
319
+ else:
320
+ _condition_latents = condition_latents
321
+
322
+ if i < len(self.denoising_steps) - 1:
323
+ out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True)
324
+
325
+ latents_pred = out["3d"]
326
+
327
+ if self.use_feedback:
328
+ prev_latents_pred = latents_pred
329
+ prev_feats = out['feat']
330
+
331
+ noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise))
332
+
333
+ else:
334
+ out = self.transformer(
335
+ hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1),
336
+ timestep=t,
337
+ encoder_hidden_states=text_embeds,
338
+ return_dict=False,
339
+ )[0]
340
+
341
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
342
+
343
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
344
+ latents_pred = noisy_latents - sigma * v_pred
345
+
346
+ scene_params = self.recon_decoder(
347
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
348
+ einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
349
+ cameras
350
+ ).flatten(1, -2)
351
+
352
+ if video_output_path is not None:
353
+ interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
354
+
355
+ interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C')
356
+
357
+ interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))]
358
+
359
+ imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1)
360
+
361
+ scene_params = scene_params[0]
362
+
363
+ scene_params = scene_params.detach().cpu()
364
+
365
+ return scene_params, ref_w2c, T_norm
366
+
367
+ @GPU
368
+ def process_generation_request(data, generation_system, cache_dir):
369
+ """
370
+ Process the generation request with the same logic as Flask version
371
+ """
372
+ try:
373
+ image_prompt = data.get('image_prompt', None)
374
+ text_prompt = data.get('text_prompt', "")
375
+ cameras = data.get('cameras')
376
+ resolution = data.get('resolution')
377
+ image_index = data.get('image_index', 0)
378
+
379
+ n_frame, image_height, image_width = resolution
380
+
381
+ if not image_prompt and text_prompt == "":
382
+ return {'error': 'No Prompts provided'}
383
+
384
+ if image_prompt:
385
+ # image_prompt可以是路径和base64
386
+ if os.path.exists(image_prompt):
387
+ image_prompt = Image.open(image_prompt)
388
+ else:
389
+ # image_prompt 可能是 "data:image/png;base64,...."
390
+ if ',' in image_prompt:
391
+ image_prompt = image_prompt.split(',', 1)[1]
392
+
393
+ try:
394
+ image_bytes = base64.b64decode(image_prompt)
395
+ image_prompt = Image.open(io.BytesIO(image_bytes))
396
+ except Exception as img_e:
397
+ return {'error': f'Image decode error: {str(img_e)}'}
398
+
399
+ image = image_prompt.convert('RGB')
400
+
401
+ w, h = image.size
402
+
403
+ # center crop
404
+ if image_height / h > image_width / w:
405
+ scale = image_height / h
406
+ else:
407
+ scale = image_width / w
408
+
409
+ new_h = int(image_height / scale)
410
+ new_w = int(image_width / scale)
411
+
412
+ image = image.crop(((w - new_w) // 2, (h - new_h) // 2,
413
+ new_w + (w - new_w) // 2, new_h + (h - new_h) // 2)).resize((image_width, image_height))
414
+
415
+ for camera in cameras:
416
+ camera['fx'] = camera['fx'] * scale
417
+ camera['fy'] = camera['fy'] * scale
418
+ camera['cx'] = (camera['cx'] - (w - new_w) // 2) * scale
419
+ camera['cy'] = (camera['cy'] - (h - new_h) // 2) * scale
420
+
421
+ image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 * 2 - 1
422
+ else:
423
+ image = None
424
+
425
+ cameras = torch.stack([
426
+ torch.from_numpy(np.array([camera['quaternion'][0], camera['quaternion'][1], camera['quaternion'][2], camera['quaternion'][3], camera['position'][0], camera['position'][1], camera['position'][2], camera['fx'] / image_width, camera['fy'] / image_height, camera['cx'] / image_width, camera['cy'] / image_height], dtype=np.float32))
427
+ for camera in cameras
428
+ ], dim=0)
429
+
430
+ file_id = str(int(time.time() * 1000))
431
+
432
+ start_time = time.time()
433
+ scene_params, ref_w2c, T_norm = generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=os.path.join(cache_dir, f'{file_id}.mp4'))
434
+ end_time = time.time()
435
+ print(f'生成时间: {end_time - start_time} 秒')
436
+
437
+ with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
438
+ json.dump(data, f)
439
+
440
+ splat_path = os.path.join(cache_dir, f'{file_id}.ply')
441
+
442
+ export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
443
+
444
+ if not os.path.exists(splat_path):
445
+ return {'error': f'{splat_path} not found'}
446
+
447
+ file_size = os.path.getsize(splat_path)
448
+
449
+ response_data = {
450
+ 'success': True,
451
+ 'file_id': file_id,
452
+ 'file_path': splat_path,
453
+ 'file_size': file_size,
454
+ 'download_url': f'/download/{file_id}',
455
+ 'generation_time': end_time - start_time,
456
+ }
457
+
458
+ return response_data
459
+
460
+ except Exception as e:
461
+ return {'error': f'Processing error: {str(e)}'}
462
+
463
+ if __name__ == "__main__":
464
+ parser = argparse.ArgumentParser()
465
+ parser.add_argument('--port', type=int, default=7860)
466
+ parser.add_argument("--ckpt", default=None)
467
+ parser.add_argument("--cache_dir", type=str, default=None)
468
+ parser.add_argument("--offload_t5", type=bool, default=False)
469
+ parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
470
+ args, _ = parser.parse_known_args()
471
+
472
+ # Ensure model.ckpt exists, download if not present
473
+ if args.ckpt is None:
474
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
475
+ ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt")
476
+ if not os.path.exists(ckpt_path):
477
+ hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False)
478
+ else:
479
+ ckpt_path = args.ckpt
480
+
481
+ if args.cache_dir is None or args.cache_dir == "":
482
+ GRADIO_TEMP_DIR = tempfile.gettempdir()
483
+ cache_dir = os.path.join(GRADIO_TEMP_DIR, "flashworld_gradio")
484
+ else:
485
+ cache_dir = args.cache_dir
486
+
487
+ # Create cache directory
488
+ os.makedirs(cache_dir, exist_ok=True)
489
+
490
+ # Initialize GenerationSystem
491
+ device = torch.device("cpu")
492
+ generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
493
+
494
+ # Create Gradio interface
495
+ with gr.Blocks(title="FlashWorld Backend") as demo:
496
+ gr.Markdown("# FlashWorld Generation Backend")
497
+ gr.Markdown("This backend processes JSON requests for 3D scene generation.")
498
+
499
+ with gr.Row():
500
+ with gr.Column():
501
+ json_input = gr.Textbox(
502
+ label="JSON Input",
503
+ placeholder="Enter JSON request here...",
504
+ lines=10,
505
+ value='{"image_prompt": null, "text_prompt": "A beautiful landscape", "cameras": [...], "resolution": [16, 480, 704], "image_index": 0}'
506
+ )
507
+
508
+ generate_btn = gr.Button("Generate", variant="primary")
509
+
510
+ with gr.Column():
511
+ json_output = gr.Textbox(
512
+ label="JSON Output",
513
+ lines=10,
514
+ interactive=False
515
+ )
516
+
517
+ # File download section
518
+ gr.Markdown("## File Download")
519
+ with gr.Row():
520
+ file_id_input = gr.Textbox(
521
+ label="File ID",
522
+ placeholder="Enter file ID to download..."
523
+ )
524
+ download_btn = gr.Button("Download PLY File")
525
+ download_output = gr.File(label="Downloaded File")
526
+
527
+
528
+ def gradio_generate(json_input):
529
+ """
530
+ Gradio interface function that processes JSON input and returns JSON output
531
+ """
532
+ try:
533
+ # Parse JSON input
534
+ if isinstance(json_input, str):
535
+ data = json.loads(json_input)
536
+ else:
537
+ data = json_input
538
+
539
+ # Process the request
540
+ result = process_generation_request(data, generation_system, cache_dir)
541
+
542
+ # Return JSON response
543
+ return json.dumps(result, indent=2)
544
+
545
+ except Exception as e:
546
+ error_response = {'error': f'JSON processing error: {str(e)}'}
547
+ return json.dumps(error_response, indent=2)
548
+
549
+ def download_file(file_id):
550
+ """
551
+ Download generated PLY file
552
+ """
553
+ file_path = os.path.join(cache_dir, f'{file_id}.ply')
554
+
555
+ if not os.path.exists(file_path):
556
+ return None
557
+
558
+ return file_path
559
+
560
+ # Event handlers
561
+ generate_btn.click(
562
+ fn=gradio_generate,
563
+ inputs=[json_input],
564
+ outputs=[json_output]
565
+ )
566
+
567
+ download_btn.click(
568
+ fn=download_file,
569
+ inputs=[file_id_input],
570
+ outputs=[download_output]
571
+ )
572
+
573
+ # Example JSON format
574
+ gr.Markdown("""
575
+ ## Example JSON Input Format:
576
+ ```json
577
+ {
578
+ "image_prompt": null,
579
+ "text_prompt": "A beautiful landscape with mountains and trees",
580
+ "cameras": [
581
+ {
582
+ "quaternion": [0, 0, 0, 1],
583
+ "position": [0, 0, 5],
584
+ "fx": 500,
585
+ "fy": 500,
586
+ "cx": 240,
587
+ "cy": 240
588
+ },
589
+ {
590
+ "quaternion": [0, 0, 0, 1],
591
+ "position": [0, 0, 5],
592
+ "fx": 500,
593
+ "fy": 500,
594
+ "cx": 240,
595
+ "cy": 240
596
+ }
597
+ ],
598
+ "resolution": [16, 480, 704],
599
+ "image_index": 0
600
+ }
601
+ ```
602
+ """)
603
+
604
+ from contextlib import asynccontextmanager
605
+
606
+ @asynccontextmanager
607
+ async def lifespan_ctx(app):
608
+ app.state._cleanup_stop_event = asyncio.Event()
609
+ app.state._cleanup_task = asyncio.create_task(periodic_cache_cleanup(app.state._cleanup_stop_event, cache_dir))
610
+ try:
611
+ yield
612
+ finally:
613
+ if getattr(app.state, "_cleanup_stop_event", None):
614
+ app.state._cleanup_stop_event.set()
615
+ if getattr(app.state, "_cleanup_task", None):
616
+ try:
617
+ await app.state._cleanup_task
618
+ except Exception:
619
+ pass
620
+
621
+ app = FastAPI(lifespan=lifespan_ctx)
622
+
623
+ from starlette.responses import FileResponse
624
+
625
+ @app.get("/app")
626
+ async def read_index():
627
+ return FileResponse('index.html')
628
+
629
+ app = gr.mount_gradio_app(app, demo, path="/")
630
+
631
+ import uvicorn
632
+
633
+ from fastapi.staticfiles import StaticFiles
634
+ from fastapi import HTTPException
635
+ import asyncio
636
+
637
+ # 挂载静态文件目录,使其可以被访问。例如 /cache/<filename>
638
+ app.mount("/cache", StaticFiles(directory=cache_dir), name="cache")
639
+
640
+ # 删除指定 file_id 的生成文件(以及相关的中间文件)
641
+ @app.post("/delete/{file_id}")
642
+ async def delete_generated_file(file_id: str):
643
+ try:
644
+ deleted = False
645
+ # 关联的可能文件:.ply, .json, .mp4
646
+ for ext in (".ply", ".json", ".mp4"):
647
+ p = os.path.join(cache_dir, f"{file_id}{ext}")
648
+ if os.path.exists(p):
649
+ try:
650
+ os.remove(p)
651
+ deleted = True
652
+ except Exception:
653
+ pass
654
+ return {"success": True, "deleted": deleted}
655
+ except Exception as e:
656
+ raise HTTPException(status_code=500, detail=str(e))
657
+
658
+ # 定期清理创建/修改时间超过15分钟的文件
659
+ async def periodic_cache_cleanup(stop_event: asyncio.Event, directory: str, max_age_seconds: int = 15 * 60, interval_seconds: int = 300):
660
+ while not stop_event.is_set():
661
+ try:
662
+ now = time.time()
663
+ for name in os.listdir(directory):
664
+ path = os.path.join(directory, name)
665
+ try:
666
+ if os.path.isfile(path):
667
+ mtime = os.path.getmtime(path)
668
+ if (now - mtime) > max_age_seconds:
669
+ try:
670
+ os.remove(path)
671
+ except Exception:
672
+ pass
673
+ except Exception:
674
+ pass
675
+ except Exception:
676
+ pass
677
+ try:
678
+ await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds)
679
+ except asyncio.TimeoutError:
680
+ continue
681
+
682
+ uvicorn.run(app, host="0.0.0.0", port=7860)
app_gradio.py CHANGED
@@ -9,43 +9,48 @@ except ImportError:
9
  import os
10
  import subprocess
11
 
12
- # def install_cuda_toolkit():
13
- # # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
14
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
15
- # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
16
- # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
17
- # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
18
- # subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
19
-
20
- # os.environ["CUDA_HOME"] = "/usr/local/cuda"
21
- # os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
22
- # os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
23
- # os.environ["CUDA_HOME"],
24
- # "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
25
- # )
26
- # # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
27
- # os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
 
 
 
28
 
29
- # print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
30
 
31
- # subprocess.call('rm /usr/bin/gcc', shell=True)
32
- # subprocess.call('rm /usr/bin/g++', shell=True)
33
- # subprocess.call('rm /usr/local/cuda/bin/gcc', shell=True)
34
- # subprocess.call('rm /usr/local/cuda/bin/g++', shell=True)
35
 
36
- # subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
37
- # subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
38
 
39
- # subprocess.call('ln -s /usr/bin/gcc-11 /usr/local/cuda/bin/gcc', shell=True)
40
- # subprocess.call('ln -s /usr/bin/g++-11 /usr/local/cuda/bin/g++', shell=True)
41
 
42
- # subprocess.call('gcc --version', shell=True)
43
- # subprocess.call('g++ --version', shell=True)
44
 
45
- # install_cuda_toolkit()
 
 
46
 
47
- # subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6"}, shell=True)
 
48
 
 
 
49
  import gradio as gr
50
  import base64
51
  import io
@@ -59,6 +64,7 @@ import json
59
  import time
60
  import tempfile
61
  import shutil
 
62
 
63
  from huggingface_hub import hf_hub_download
64
 
@@ -78,7 +84,6 @@ from transformers import T5TokenizerFast, UMT5EncoderModel
78
  from diffusers import FlowMatchEulerDiscreteScheduler
79
 
80
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
81
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
82
 
83
  class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
84
  def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -152,11 +157,11 @@ class GenerationSystem(nn.Module):
152
 
153
  self.add_feedback_for_transformer()
154
 
155
- # if ckpt_path is not None:
156
- # state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
157
- # self.transformer.load_state_dict(state_dict["transformer"])
158
- # self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
159
- # print(f"Loaded {ckpt_path}.")
160
 
161
  from quant import FluxFp8GeMMProcessor
162
 
@@ -164,6 +169,7 @@ class GenerationSystem(nn.Module):
164
 
165
  del self.vae.post_quant_conv, self.vae.decoder
166
  self.vae.to(self.device if not self.offload_vae else "cpu")
 
167
 
168
  self.transformer.to(self.device)
169
 
@@ -243,11 +249,12 @@ class GenerationSystem(nn.Module):
243
  'feat': feats
244
  }
245
 
246
- @GPU
247
  @torch.no_grad()
248
- @torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda")
249
- def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
 
250
 
 
251
  self.text_encoder.to(self.device if not self.offload_t5 else "cpu")
252
  self.transformer.to(self.device)
253
  self.recon_decoder.to(self.device)
@@ -255,7 +262,7 @@ class GenerationSystem(nn.Module):
255
  self.latents_mean = self.latents_mean.to(self.device)
256
  self.latents_std = self.latents_std.to(self.device)
257
 
258
- with torch.no_grad():
259
  batch_size = 1
260
 
261
  cameras = cameras.to(self.device).unsqueeze(0)
@@ -358,6 +365,7 @@ class GenerationSystem(nn.Module):
358
 
359
  return scene_params, ref_w2c, T_norm
360
 
 
361
  def process_generation_request(data, generation_system, cache_dir):
362
  """
363
  Process the generation request with the same logic as Flask version
@@ -430,9 +438,9 @@ def process_generation_request(data, generation_system, cache_dir):
430
  with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
431
  json.dump(data, f)
432
 
433
- splat_path = os.path.join(cache_dir, f'{file_id}.ply')
434
 
435
- export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
436
 
437
  if not os.path.exists(splat_path):
438
  return {'error': f'{splat_path} not found'}
@@ -453,43 +461,10 @@ def process_generation_request(data, generation_system, cache_dir):
453
  except Exception as e:
454
  return {'error': f'Processing error: {str(e)}'}
455
 
456
- def gradio_generate(json_input, generation_system, cache_dir):
457
- """
458
- Gradio interface function that processes JSON input and returns JSON output
459
- """
460
- try:
461
- # Parse JSON input
462
- if isinstance(json_input, str):
463
- data = json.loads(json_input)
464
- else:
465
- data = json_input
466
-
467
- # Process the request
468
- result = process_generation_request(data, generation_system, cache_dir)
469
-
470
- # Return JSON response
471
- return json.dumps(result, indent=2)
472
-
473
- except Exception as e:
474
- error_response = {'error': f'JSON processing error: {str(e)}'}
475
- return json.dumps(error_response, indent=2)
476
-
477
- def download_file(file_id, cache_dir):
478
- """
479
- Download generated PLY file
480
- """
481
- file_path = os.path.join(cache_dir, f'{file_id}.ply')
482
-
483
- if not os.path.exists(file_path):
484
- return None
485
-
486
- return file_path
487
-
488
  if __name__ == "__main__":
489
  parser = argparse.ArgumentParser()
490
  parser.add_argument('--port', type=int, default=7860)
491
  parser.add_argument("--ckpt", default=None)
492
- parser.add_argument("--gpu", type=int, default=0)
493
  parser.add_argument("--cache_dir", type=str, default=None)
494
  parser.add_argument("--offload_t5", type=bool, default=False)
495
  parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
@@ -514,15 +489,14 @@ if __name__ == "__main__":
514
  os.makedirs(cache_dir, exist_ok=True)
515
 
516
  # Initialize GenerationSystem
517
- device = f"cuda:{args.gpu}" # if torch.cuda.is_available() else "cpu"
518
  generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
519
 
520
  # Create Gradio interface
521
  with gr.Blocks(title="FlashWorld Backend") as demo:
522
- gr.Markdown("# FlashWorld Generation Backend")
523
- gr.Markdown("This backend processes JSON requests for 3D scene generation.")
524
 
525
- with gr.Row():
526
  with gr.Column():
527
  json_input = gr.Textbox(
528
  label="JSON Input",
@@ -541,27 +515,83 @@ if __name__ == "__main__":
541
  )
542
 
543
  # File download section
544
- gr.Markdown("## File Download")
545
- with gr.Row():
546
  file_id_input = gr.Textbox(
547
  label="File ID",
548
  placeholder="Enter file ID to download..."
549
  )
550
- download_btn = gr.Button("Download PLY File")
551
  download_output = gr.File(label="Downloaded File")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  # Event handlers
554
  generate_btn.click(
555
- fn=lambda json_input: gradio_generate(json_input, generation_system, cache_dir),
556
  inputs=[json_input],
557
  outputs=[json_output]
558
  )
559
 
560
  download_btn.click(
561
- fn=lambda file_id: download_file(file_id, cache_dir),
562
  inputs=[file_id_input],
563
  outputs=[download_output]
564
  )
 
 
 
 
 
565
 
566
  # Example JSON format
567
  gr.Markdown("""
@@ -592,10 +622,29 @@ if __name__ == "__main__":
592
  "image_index": 0
593
  }
594
  ```
595
- """)
596
-
597
- # Launch the interface
598
- demo.launch(
599
- ssr_mode=False,
600
- allowed_paths=[cache_dir]
601
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import os
10
  import subprocess
11
 
12
+ try:
13
+ import gsplat
14
+ except ImportError:
15
+ def install_cuda_toolkit():
16
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
17
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
18
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
19
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
20
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
21
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
22
+
23
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
24
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
25
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
26
+ os.environ["CUDA_HOME"],
27
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
28
+ )
29
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
30
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0+PTX"
31
 
32
+ print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
33
 
34
+ subprocess.call('rm /usr/bin/gcc', shell=True)
35
+ subprocess.call('rm /usr/bin/g++', shell=True)
 
 
36
 
37
+ subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
38
+ subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
39
 
40
+ subprocess.call('gcc --version', shell=True)
41
+ subprocess.call('g++ --version', shell=True)
42
 
43
+ install_cuda_toolkit()
 
44
 
45
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0+PTX"
46
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
47
+ os.environ["PATH"] = "/usr/local/cuda/bin/:" + os.environ["PATH"]
48
 
49
+ subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712',
50
+ env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "9.0+PTX", "PATH": "/usr/local/cuda/bin/:" + os.environ["PATH"]}, shell=True)
51
 
52
+ from fastapi import FastAPI
53
+ from fastapi.staticfiles import StaticFiles
54
  import gradio as gr
55
  import base64
56
  import io
 
64
  import time
65
  import tempfile
66
  import shutil
67
+ import threading
68
 
69
  from huggingface_hub import hf_hub_download
70
 
 
84
  from diffusers import FlowMatchEulerDiscreteScheduler
85
 
86
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
87
 
88
  class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
89
  def index_for_timestep(self, timestep, schedule_timesteps=None):
 
157
 
158
  self.add_feedback_for_transformer()
159
 
160
+ if ckpt_path is not None:
161
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
162
+ self.transformer.load_state_dict(state_dict["transformer"])
163
+ self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
164
+ print(f"Loaded {ckpt_path}.")
165
 
166
  from quant import FluxFp8GeMMProcessor
167
 
 
169
 
170
  del self.vae.post_quant_conv, self.vae.decoder
171
  self.vae.to(self.device if not self.offload_vae else "cpu")
172
+ self.vae.to(torch.bfloat16)
173
 
174
  self.transformer.to(self.device)
175
 
 
249
  'feat': feats
250
  }
251
 
 
252
  @torch.no_grad()
253
+ def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
254
+
255
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
256
 
257
+ self.vae.to(self.device)
258
  self.text_encoder.to(self.device if not self.offload_t5 else "cpu")
259
  self.transformer.to(self.device)
260
  self.recon_decoder.to(self.device)
 
262
  self.latents_mean = self.latents_mean.to(self.device)
263
  self.latents_std = self.latents_std.to(self.device)
264
 
265
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
266
  batch_size = 1
267
 
268
  cameras = cameras.to(self.device).unsqueeze(0)
 
365
 
366
  return scene_params, ref_w2c, T_norm
367
 
368
+ @GPU
369
  def process_generation_request(data, generation_system, cache_dir):
370
  """
371
  Process the generation request with the same logic as Flask version
 
438
  with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
439
  json.dump(data, f)
440
 
441
+ splat_path = os.path.join(cache_dir, f'{file_id}.spz')
442
 
443
+ export_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
444
 
445
  if not os.path.exists(splat_path):
446
  return {'error': f'{splat_path} not found'}
 
461
  except Exception as e:
462
  return {'error': f'Processing error: {str(e)}'}
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  if __name__ == "__main__":
465
  parser = argparse.ArgumentParser()
466
  parser.add_argument('--port', type=int, default=7860)
467
  parser.add_argument("--ckpt", default=None)
 
468
  parser.add_argument("--cache_dir", type=str, default=None)
469
  parser.add_argument("--offload_t5", type=bool, default=False)
470
  parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
 
489
  os.makedirs(cache_dir, exist_ok=True)
490
 
491
  # Initialize GenerationSystem
492
+ device = torch.device("cpu")
493
  generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
494
 
495
  # Create Gradio interface
496
  with gr.Blocks(title="FlashWorld Backend") as demo:
497
+ gr.Markdown("FlashWorld Generation Backend — API only. This service powers the FlashWorld Web Demo and is intended for programmatic/API access. The UI is intentionally hidden.")
 
498
 
499
+ with gr.Row(visible=False):
500
  with gr.Column():
501
  json_input = gr.Textbox(
502
  label="JSON Input",
 
515
  )
516
 
517
  # File download section
518
+ gr.Markdown("## File Download", visible=False)
519
+ with gr.Row(visible=False):
520
  file_id_input = gr.Textbox(
521
  label="File ID",
522
  placeholder="Enter file ID to download..."
523
  )
524
+ download_btn = gr.Button("Download SPZ File")
525
  download_output = gr.File(label="Downloaded File")
526
+
527
+
528
+ def gradio_generate(json_input):
529
+ """
530
+ Gradio interface function that processes JSON input and returns JSON output
531
+ """
532
+ try:
533
+ # Parse JSON input
534
+ if isinstance(json_input, str):
535
+ data = json.loads(json_input)
536
+ else:
537
+ data = json_input
538
+
539
+ # Process the request
540
+ result = process_generation_request(data, generation_system, cache_dir)
541
+
542
+ # Return JSON response
543
+ return json.dumps(result, indent=2)
544
+
545
+ except Exception as e:
546
+ error_response = {'error': f'JSON processing error: {str(e)}'}
547
+ return json.dumps(error_response, indent=2)
548
+
549
+ def download_file(file_id):
550
+ """
551
+ Download generated SPZ file
552
+ """
553
+ file_path = os.path.join(cache_dir, f'{file_id}.spz')
554
+
555
+ if not os.path.exists(file_path):
556
+ return None
557
+
558
+ return file_path
559
+
560
+ def gradio_delete(file_id):
561
+ """
562
+ Delete generated artifacts by file_id (.spz/.json/.mp4)
563
+ """
564
+ deleted = False
565
+ try:
566
+ for ext in (".spz", ".json", ".mp4"):
567
+ p = os.path.join(cache_dir, f"{file_id}{ext}")
568
+ if os.path.exists(p):
569
+ try:
570
+ os.remove(p)
571
+ deleted = True
572
+ except Exception:
573
+ pass
574
+ return {"success": True, "deleted": deleted}
575
+ except Exception as e:
576
+ return {"success": False, "error": str(e)}
577
 
578
  # Event handlers
579
  generate_btn.click(
580
+ fn=gradio_generate,
581
  inputs=[json_input],
582
  outputs=[json_output]
583
  )
584
 
585
  download_btn.click(
586
+ fn=download_file,
587
  inputs=[file_id_input],
588
  outputs=[download_output]
589
  )
590
+
591
+ # Hidden API hook for deletion to expose /gradio_api/call/gradio_delete
592
+ _hidden_delete_in = gr.Textbox(visible=False)
593
+ _hidden_delete_btn = gr.Button(visible=False)
594
+ _hidden_delete_btn.click(fn=gradio_delete, inputs=[_hidden_delete_in], outputs=[])
595
 
596
  # Example JSON format
597
  gr.Markdown("""
 
622
  "image_index": 0
623
  }
624
  ```
625
+ """, visible=False)
626
+
627
+ # Background periodic cleanup thread (no FastAPI app lifecycle)
628
+ def _cleanup_loop(directory: str, max_age_seconds: int = 15 * 60, interval_seconds: int = 300):
629
+ while True:
630
+ try:
631
+ now = time.time()
632
+ for name in os.listdir(directory):
633
+ path = os.path.join(directory, name)
634
+ try:
635
+ if os.path.isfile(path):
636
+ mtime = os.path.getmtime(path)
637
+ if (now - mtime) > max_age_seconds:
638
+ try:
639
+ os.remove(path)
640
+ except Exception:
641
+ pass
642
+ except Exception:
643
+ pass
644
+ except Exception:
645
+ pass
646
+ time.sleep(interval_seconds)
647
+
648
+ threading.Thread(target=_cleanup_loop, args=(cache_dir,), daemon=True).start()
649
+
650
+ demo.launch(allowed_paths=[cache_dir])
index.html CHANGED
@@ -67,7 +67,7 @@
67
  .content-container {
68
  display: flex;
69
  flex: 1;
70
- overflow: hidden;
71
  }
72
 
73
  .left-panel {
@@ -76,6 +76,7 @@
76
  border-right: 1px solid rgba(255, 255, 255, 0.1);
77
  padding: 20px;
78
  overflow-y: auto;
 
79
  flex-shrink: 0;
80
  }
81
 
@@ -86,6 +87,7 @@
86
  display: flex;
87
  justify-content: center;
88
  align-items: center;
 
89
  }
90
 
91
  .right-panel {
@@ -95,6 +97,7 @@
95
  padding: 20px;
96
  overflow-y: auto;
97
  flex-shrink: 0;
 
98
  }
99
 
100
  .guidance {
@@ -222,6 +225,7 @@
222
  font-size: 12px;
223
  cursor: default;
224
  user-select: none;
 
225
  }
226
  .info-tip .tooltip {
227
  display: none;
@@ -229,16 +233,17 @@
229
  left: 0;
230
  top: calc(100% + 8px); /* show below the icon */
231
  transform: none;
232
- background: rgba(0,0,0,0.9);
233
  color: #e5e7eb;
234
- border: 1px solid rgba(255,255,255,0.15);
235
  border-radius: 8px;
236
  padding: 10px 12px;
237
  font-size: 12px;
238
- width: 360px; /* wider tooltip */
239
  white-space: normal;
240
- z-index: 2000; /* above GUI and other elements */
241
- box-shadow: 0 4px 12px rgba(0,0,0,0.4);
 
242
  }
243
  .info-tip:hover .tooltip {
244
  display: block;
@@ -430,8 +435,8 @@
430
  <script type="importmap">
431
  {
432
  "imports": {
433
- "three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.174.0/three.module.js",
434
- "@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.6/spark.module.js",
435
  "lil-gui": "https://cdn.jsdelivr.net/npm/[email protected]/+esm"
436
  }
437
  }
@@ -469,6 +474,7 @@
469
  <div class="step">
470
  <h3>1. Configure</h3>
471
  <p>Set FOV and Resolution and Click "Fix Configurations"</p>
 
472
  </div>
473
 
474
 
@@ -640,6 +646,31 @@
640
  function updateStatus(message, cameraCount = null) {
641
  const cameraText = cameraCount !== null ? `Cameras: ${cameraCount}` : `Cameras: ${cameraParams.length}`;
642
  statusBar.textContent = `${message} | ${cameraText} | Status: ${fixGenerationFOV ? 'Ready to record' : 'Configure settings'}`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  }
644
 
645
  // Show/hide loading
@@ -685,7 +716,151 @@
685
  if (progressText) progressText.textContent = text;
686
  }
687
 
688
- // Gradio handles concurrency automatically, no need for queue polling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
690
  // Hide download progress
691
  function hideDownloadProgress() {
@@ -741,8 +916,9 @@
741
 
742
  // GUI Options - declare early
743
  const guiOptions = {
744
- // Gradio后端地址,默认为本页面ip:7860
745
  BackendAddress: `${window.location.protocol}//${window.location.hostname}:7860`,
 
746
  FOV: 60,
747
  LoadFromJson: () => {
748
  const jsonInput = document.querySelector("#json-input");
@@ -805,11 +981,6 @@
805
  generateCameraTrajectory(guiOptions.templateType);
806
  },
807
  saveTrajectoryToJson: () => {
808
- if (cameraParams.length === 0) {
809
- updateStatus('No cameras to save.', cameraParams.length);
810
- console.warn('No cameras to save');
811
- return;
812
- }
813
 
814
  // Build JSON payload compatible with loader
815
  const [nStr, hStr, wStr] = guiOptions.Resolution.split('x');
@@ -913,14 +1084,15 @@
913
  console.log('Interpolated cameras:', interpolatedCameras.length);
914
  updateStatus('Sending request to backend...', cameraParams.length);
915
 
916
- // Gradio后端:使用Gradio API
 
917
  const requestData = {
918
  image_prompt: inputImageBase64 ? inputImageBase64 : "",
919
  text_prompt: guiOptions.inputTextPrompt,
920
  image_index: 0,
921
  resolution: [
922
- parseInt(guiOptions.Resolution.split('x')[0]),
923
- parseInt(guiOptions.Resolution.split('x')[1]),
924
  parseInt(guiOptions.Resolution.split('x')[2])
925
  ],
926
  cameras: interpolatedCameras.map(cam => ({
@@ -937,191 +1109,125 @@
937
  }))
938
  };
939
 
940
- // 请求Gradio后端生成
941
- fetch(guiOptions.BackendAddress + '/gradio_api/call/gradio_generate', {
942
  method: 'POST',
943
  headers: { 'Content-Type': 'application/json' },
944
  mode: 'cors',
945
- body: JSON.stringify({
946
- data: [JSON.stringify(requestData)]
947
- })
948
  })
949
  .then(response => response.json())
950
  .then(data => {
951
- console.log('Gradio response:', data);
952
-
953
- // Gradio总是返回event_id,需要先获取生成结果
954
- if (data.event_id) {
955
- console.log('Got EVENT_ID from generation call:', data.event_id);
956
-
957
- // 使用EVENT_ID获取生成结果(SSE格式)
958
- return fetch(guiOptions.BackendAddress + `/gradio_api/call/gradio_generate/${data.event_id}`)
959
- .then(response => {
960
- if (!response.ok) {
961
- throw new Error(`HTTP error! status: ${response.status}`);
962
- }
963
- return response.text();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
  })
965
- .then(sseText => {
966
- console.log('SSE response:', sseText);
967
-
968
- // 解析SSE格式的响应
969
- const lines = sseText.split('\n');
970
- let eventType = null;
971
- let dataContent = null;
972
-
973
- for (const line of lines) {
974
- if (line.startsWith('event: ')) {
975
- eventType = line.substring(7);
976
- } else if (line.startsWith('data: ')) {
977
- dataContent = line.substring(6);
978
- }
979
- }
980
-
981
- console.log('Event type:', eventType, 'Data:', dataContent);
982
-
983
- if (eventType === 'complete' && dataContent) {
984
- // 解析JSON数据
985
- const resultData = JSON.parse(dataContent);
986
- console.log('Generation result:', resultData);
987
-
988
- // 解析生成结果
989
- if (resultData && resultData.length > 0) {
990
- const responseData = JSON.parse(resultData[0]);
991
- console.log('Gradio generation successful:', responseData);
992
-
993
- if (responseData.success && responseData.download_url) {
994
- console.log('Generation time:', responseData.generation_time, 'seconds');
995
- console.log('File size:', responseData.file_size, 'bytes');
996
-
997
- // 显示生成信息
998
- showGenerationInfo(responseData.generation_time, responseData.file_size);
999
- showDownloadProgress();
1000
- updateStatus('Downloading generated scene...', cameraParams.length);
1001
-
1002
- // 现在下载文件,也需要两步:先获取下载的EVENT_ID,再下载文件
1003
- return fetch(guiOptions.BackendAddress + '/gradio_api/call/download_file', {
1004
- method: 'POST',
1005
- headers: { 'Content-Type': 'application/json' },
1006
- body: JSON.stringify({
1007
- data: [responseData.file_id]
1008
- })
1009
- })
1010
- .then(response => response.json())
1011
- .then(downloadEventData => {
1012
- console.log('Download EVENT_ID:', downloadEventData.event_id);
1013
-
1014
- // 使用下载的EVENT_ID获取文件信息(SSE格式)
1015
- return fetch(guiOptions.BackendAddress + `/gradio_api/call/download_file/${downloadEventData.event_id}`)
1016
- .then(response => {
1017
- if (!response.ok) {
1018
- throw new Error(`HTTP error! status: ${response.status}`);
1019
- }
1020
- return response.text();
1021
- })
1022
- .then(sseText => {
1023
- console.log('Download SSE response:', sseText);
1024
-
1025
- // 解析SSE格式的响应
1026
- const lines = sseText.split('\n');
1027
- let eventType = null;
1028
- let dataContent = null;
1029
-
1030
- for (const line of lines) {
1031
- if (line.startsWith('event: ')) {
1032
- eventType = line.substring(7);
1033
- } else if (line.startsWith('data: ')) {
1034
- dataContent = line.substring(6);
1035
- }
1036
- }
1037
-
1038
- console.log('Download event type:', eventType, 'Data:', dataContent);
1039
-
1040
- if (eventType === 'complete' && dataContent) {
1041
- // 解析文件信息
1042
- const fileData = JSON.parse(dataContent);
1043
- console.log('File data:', fileData);
1044
-
1045
- if (fileData && fileData.length > 0 && fileData[0].url) {
1046
- const fileUrl = fileData[0].url;
1047
- console.log('File URL:', fileUrl);
1048
-
1049
- // 从返回的URL下载实际文件
1050
- return fetch(fileUrl)
1051
- .then(response => {
1052
- if (!response.ok) {
1053
- throw new Error(`HTTP error! status: ${response.status}`);
1054
- }
1055
-
1056
- const contentLength = response.headers.get('content-length');
1057
- const total = parseInt(contentLength, 10);
1058
- let loaded = 0;
1059
-
1060
- const reader = response.body.getReader();
1061
- const chunks = [];
1062
-
1063
- function pump() {
1064
- return reader.read().then(({ done, value }) => {
1065
- if (done) {
1066
- return new Blob(chunks);
1067
- }
1068
-
1069
- chunks.push(value);
1070
- loaded += value.length;
1071
-
1072
- if (total) {
1073
- const percentage = (loaded / total) * 100;
1074
- updateProgressBar(percentage);
1075
- }
1076
-
1077
- return pump();
1078
- });
1079
- }
1080
-
1081
- return pump().then(blob => {
1082
- const url = URL.createObjectURL(blob);
1083
- return { url };
1084
- });
1085
- });
1086
- } else {
1087
- throw new Error('Invalid file data format from Gradio');
1088
- }
1089
- } else {
1090
- throw new Error('Gradio download SSE response not complete or missing data');
1091
- }
1092
- });
1093
- });
1094
- } else {
1095
- throw new Error('Gradio generation failed: ' + (responseData.error || 'Unknown error'));
1096
  }
1097
- } else {
1098
- throw new Error('Invalid Gradio generation result format');
1099
- }
1100
- } else {
1101
- throw new Error('Gradio SSE response not complete or missing data');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
  }
 
 
 
 
1103
  });
1104
- } else {
1105
- throw new Error('Invalid Gradio response format - no event_id');
1106
- }
1107
  })
1108
  .then(data => {
1109
- if (data.url) {
1110
  updateStatus('Loading 3D scene...', cameraParams.length);
1111
-
1112
- // Remove the instruction splat when generation is complete
1113
  if (instructionSplat) {
1114
  scene.remove(instructionSplat);
1115
  console.log('Instruction splat removed');
1116
  }
1117
-
1118
  const GeneratedSplat = new SplatMesh({ url: data.url });
1119
  scene.add(GeneratedSplat);
1120
- currentGeneratedSplat = GeneratedSplat; // 保存新生成的场景引用
1121
  console.log('3D scene loaded successfully!');
1122
  updateStatus('Scene generated successfully!', cameraParams.length);
1123
  hideDownloadProgress();
1124
  showLoading(false);
 
 
 
 
 
 
 
1125
  }
1126
  })
1127
  .catch(error => {
@@ -1499,7 +1605,8 @@
1499
 
1500
  // Step 1: Configure Generation Settings
1501
  const step1Folder = gui.addFolder('1. Configure Settings');
1502
- step1Folder.add(guiOptions, "BackendAddress").name("Gradio Backend Address");
 
1503
 
1504
  // FOV和Resolution控制器,初始时启用
1505
  const fovController = step1Folder.add(guiOptions, "FOV", 0, 120, 1).name("FOV").onChange((value) => {
@@ -1546,6 +1653,9 @@
1546
  const loadTrajectoryController = trajectoryFolder.add(guiOptions, "LoadTrajectoryFromJson").name("Load from JSON");
1547
  const saveTrajectoryController = trajectoryFolder.add(guiOptions, "saveTrajectoryToJson").name("Save Trajectory");
1548
 
 
 
 
1549
  // 清理相机按钮
1550
  const clearAllCamerasController = trajectoryFolder.add(guiOptions, "clearAllCameras").name("Clear All Cameras");
1551
 
@@ -1612,6 +1722,7 @@
1612
 
1613
  // Store controllers globally so they can be accessed from guiOptions
1614
  window.fixGenerationFOVController = fixGenerationFOVController;
 
1615
 
1616
  // Step 3: Add Scene Prompts
1617
  const step3Folder = gui.addFolder('3. Add Scene Prompts');
@@ -2025,6 +2136,7 @@
2025
  if (loadTrajectoryOnly) {
2026
  updateStatus(`Trajectory loaded: ${cameras.length} cameras`, cameraParams.length);
2027
  } else {
 
2028
  }
2029
  } catch (error) {
2030
  console.error("JSON data processing error:", error);
 
67
  .content-container {
68
  display: flex;
69
  flex: 1;
70
+ overflow: visible; /* Allow tooltips to extend beyond container */
71
  }
72
 
73
  .left-panel {
 
76
  border-right: 1px solid rgba(255, 255, 255, 0.1);
77
  padding: 20px;
78
  overflow-y: auto;
79
+ overflow-x: visible; /* Allow tooltips to extend beyond panel */
80
  flex-shrink: 0;
81
  }
82
 
 
87
  display: flex;
88
  justify-content: center;
89
  align-items: center;
90
+ z-index: 1; /* Lower z-index to allow tooltips to appear above */
91
  }
92
 
93
  .right-panel {
 
97
  padding: 20px;
98
  overflow-y: auto;
99
  flex-shrink: 0;
100
+ z-index: 1; /* Lower z-index to allow tooltips to appear above */
101
  }
102
 
103
  .guidance {
 
225
  font-size: 12px;
226
  cursor: default;
227
  user-select: none;
228
+ z-index: 100000; /* Ensure the tip itself is above everything */
229
  }
230
  .info-tip .tooltip {
231
  display: none;
 
233
  left: 0;
234
  top: calc(100% + 8px); /* show below the icon */
235
  transform: none;
236
+ background: rgba(0,0,0,0.95);
237
  color: #e5e7eb;
238
+ border: 1px solid rgba(255,255,255,0.2);
239
  border-radius: 8px;
240
  padding: 10px 12px;
241
  font-size: 12px;
242
+ width: 480px;
243
  white-space: normal;
244
+ z-index: 999999; /* Even higher z-index to ensure it's above everything */
245
+ box-shadow: 0 8px 24px rgba(0,0,0,0.6);
246
+ text-align: left;
247
  }
248
  .info-tip:hover .tooltip {
249
  display: block;
 
435
  <script type="importmap">
436
  {
437
  "imports": {
438
+ "three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.178.0/three.module.js",
439
+ "@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.9/spark.module.js",
440
  "lil-gui": "https://cdn.jsdelivr.net/npm/[email protected]/+esm"
441
  }
442
  }
 
474
  <div class="step">
475
  <h3>1. Configure</h3>
476
  <p>Set FOV and Resolution and Click "Fix Configurations"</p>
477
+ <p><strong>Important: You also need to specify your Hugging Face Access Token with READ permission to use the online free ZeroGPU service.</strong></p>
478
  </div>
479
 
480
 
 
646
  function updateStatus(message, cameraCount = null) {
647
  const cameraText = cameraCount !== null ? `Cameras: ${cameraCount}` : `Cameras: ${cameraParams.length}`;
648
  statusBar.textContent = `${message} | ${cameraText} | Status: ${fixGenerationFOV ? 'Ready to record' : 'Configure settings'}`;
649
+
650
+ // Update save trajectory button state
651
+ updateSaveTrajectoryButton();
652
+ }
653
+
654
+ // Update save trajectory button state based on camera count
655
+ function updateSaveTrajectoryButton() {
656
+ if (window.saveTrajectoryController) {
657
+ if (cameraParams.length >= 2) {
658
+ window.saveTrajectoryController.enable();
659
+ } else {
660
+ window.saveTrajectoryController.disable();
661
+ }
662
+ }
663
+ }
664
+
665
+ // Auth-aware fetch helper that injects Authorization header when HF_TOKEN is set
666
+ function fetchWithAuth(url, options = {}) {
667
+ const mergedOptions = { ...options };
668
+ const headers = new Headers(options && options.headers ? options.headers : undefined);
669
+ if (guiOptions && guiOptions.HF_TOKEN && String(guiOptions.HF_TOKEN).trim().length > 0) {
670
+ headers.set('Authorization', `Bearer ${guiOptions.HF_TOKEN}`);
671
+ }
672
+ mergedOptions.headers = headers;
673
+ return fetch(url, mergedOptions);
674
  }
675
 
676
  // Show/hide loading
 
716
  if (progressText) progressText.textContent = text;
717
  }
718
 
719
+ // ==============
720
+ // Queue handling
721
+ // ==============
722
+ let queuePollTimer = null;
723
+ let currentTaskId = null;
724
+ let initialQueuePosition = null;
725
+ let latestGenerationTime = null;
726
+ let lastDownloadPct = 0;
727
+ let lastDownloadUpdateTs = 0;
728
+
729
+ function showQueueWaiting(position, runningCount, queuedCount) {
730
+ // Use only the progress bar to show queue progress (from initial position to 0)
731
+ showDownloadProgress();
732
+ if (initialQueuePosition === null) {
733
+ // Initialize from first seen position; ensure >= 1 so 0 -> 100%
734
+ const initPos = (typeof position === 'number') ? position : 0;
735
+ initialQueuePosition = Math.max(initPos, 1);
736
+ }
737
+ const percent = initialQueuePosition && initialQueuePosition > 0
738
+ ? Math.max(0, Math.min(100, ((initialQueuePosition - (position || 0)) / initialQueuePosition) * 100))
739
+ : 0;
740
+ updateProgressBar(percent);
741
+ const totalWaiting = (position || 0) + (queuedCount || 0);
742
+ if (position !== null && position !== undefined) {
743
+ const pctText = `${Math.round(percent)}%`;
744
+ if (totalWaiting > 0) {
745
+ setProgressLabel(`Queued ${position}/${totalWaiting} (${pctText})`);
746
+ } else {
747
+ setProgressLabel(`Queued ${position} (${pctText})`);
748
+ }
749
+ } else {
750
+ setProgressLabel('Queued');
751
+ }
752
+ }
753
+
754
+ async function pollTaskUntilReady(taskId) {
755
+ currentTaskId = taskId;
756
+ initialQueuePosition = null;
757
+ if (queuePollTimer) {
758
+ clearInterval(queuePollTimer);
759
+ queuePollTimer = null;
760
+ }
761
+ const queueStartTs = Date.now();
762
+
763
+ const pollOnce = async () => {
764
+ try {
765
+ const resp = await fetchWithAuth(`${guiOptions.BackendAddress}/task/${taskId}`);
766
+ if (!resp.ok) return;
767
+ const info = await resp.json();
768
+ if (!info || !info.success) return;
769
+
770
+ const pos = info.queue && typeof info.queue.position === 'number' ? info.queue.position : 0;
771
+ const running = info.queue ? info.queue.running_count : 0;
772
+ const queued = info.queue ? info.queue.queued_count : 0;
773
+ if (info.status === 'queued' || info.status === 'running') {
774
+ // Only progress bar; set stage label
775
+ if (info.status === 'queued') {
776
+ showQueueWaiting(pos, running, queued);
777
+ } else {
778
+ // Transitioned to running: finalize queue progress visually
779
+ updateProgressBar(100);
780
+ showDownloadProgress();
781
+ setProgressLabel('Generating...');
782
+ }
783
+ }
784
+
785
+ if (info.status === 'completed' && info.download_url) {
786
+ clearInterval(queuePollTimer);
787
+ queuePollTimer = null;
788
+ latestGenerationTime = typeof info.generation_time === 'number' ? info.generation_time : null;
789
+ // Proceed to download the generated file like the normal path
790
+ updateStatus('Downloading generated scene...', cameraParams.length);
791
+ const response = await fetchWithAuth(guiOptions.BackendAddress + info.download_url);
792
+ if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
793
+ const contentLength = response.headers.get('content-length');
794
+ const total = parseInt(contentLength || '0', 10);
795
+ // Show generation info immediately once we know it and total size from headers
796
+ showGenerationInfo(latestGenerationTime || 0, total);
797
+ let loaded = 0;
798
+ const reader = response.body.getReader();
799
+ const chunks = [];
800
+ updateProgressBar(0);
801
+ setProgressLabel('Downloading 0%');
802
+ lastDownloadPct = 0;
803
+ lastDownloadUpdateTs = 0;
804
+ while (true) {
805
+ const { done, value } = await reader.read();
806
+ if (done) break;
807
+ chunks.push(value);
808
+ loaded += value.length;
809
+ if (total) {
810
+ const pct = Math.min(100, (loaded / total) * 100);
811
+ const now = Date.now();
812
+ const rounded = Math.round(pct);
813
+ // Throttle and enforce monotonic increase
814
+ if (rounded > Math.round(lastDownloadPct) || (now - lastDownloadUpdateTs) > 200) {
815
+ lastDownloadPct = Math.max(lastDownloadPct, pct);
816
+ updateProgressBar(lastDownloadPct);
817
+ setProgressLabel(`Downloading ${Math.round(lastDownloadPct)}%`);
818
+ lastDownloadUpdateTs = now;
819
+ }
820
+ }
821
+ }
822
+
823
+ if (instructionSplat) {
824
+ scene.remove(instructionSplat);
825
+ console.log('Instruction splat removed');
826
+ instructionSplat = null;
827
+ }
828
+
829
+ const blob = new Blob(chunks);
830
+ const url = URL.createObjectURL(blob);
831
+ // Continue to load the splat
832
+ updateStatus('Loading generated scene...', cameraParams.length);
833
+
834
+ const GeneratedSplat = new SplatMesh({ url });
835
+ scene.add(GeneratedSplat);
836
+ currentGeneratedSplat = GeneratedSplat;
837
+ updateStatus('Scene generated successfully!', cameraParams.length);
838
+ // Show generation time and total file size (MB)
839
+ showGenerationInfo(latestGenerationTime || 0, total || blob.size);
840
+ // Notify backend to delete the server file after client has downloaded it
841
+ try {
842
+ if (info.file_id) {
843
+ const resp = await fetchWithAuth(`${guiOptions.BackendAddress}/delete/${info.file_id}`, { method: 'POST' });
844
+ if (!resp.ok) console.warn('Delete notify failed');
845
+ }
846
+ } catch (e) {
847
+ console.warn('Delete notify error', e);
848
+ }
849
+ hideDownloadProgress();
850
+ showLoading(false);
851
+ } else if (info.status === 'failed') {
852
+ clearInterval(queuePollTimer);
853
+ queuePollTimer = null;
854
+ throw new Error(info.error || 'Generation failed');
855
+ }
856
+ } catch (e) {
857
+ console.debug('Polling error:', e);
858
+ }
859
+ };
860
+
861
+ await pollOnce();
862
+ queuePollTimer = setInterval(pollOnce, 2000);
863
+ }
864
 
865
  // Hide download progress
866
  function hideDownloadProgress() {
 
916
 
917
  // GUI Options - declare early
918
  const guiOptions = {
919
+ // 后端地址,默认为本页面ip
920
  BackendAddress: `${window.location.protocol}//${window.location.hostname}:7860`,
921
+ HF_TOKEN: "",
922
  FOV: 60,
923
  LoadFromJson: () => {
924
  const jsonInput = document.querySelector("#json-input");
 
981
  generateCameraTrajectory(guiOptions.templateType);
982
  },
983
  saveTrajectoryToJson: () => {
 
 
 
 
 
984
 
985
  // Build JSON payload compatible with loader
986
  const [nStr, hStr, wStr] = guiOptions.Resolution.split('x');
 
1084
  console.log('Interpolated cameras:', interpolatedCameras.length);
1085
  updateStatus('Sending request to backend...', cameraParams.length);
1086
 
1087
+ // 调用 Gradio 后端:POST 到 /gradio_api/call/gradio_generate,然后通过 SSE 获取结果
1088
+ const requestUrl = guiOptions.BackendAddress + '/gradio_api/call/gradio_generate';
1089
  const requestData = {
1090
  image_prompt: inputImageBase64 ? inputImageBase64 : "",
1091
  text_prompt: guiOptions.inputTextPrompt,
1092
  image_index: 0,
1093
  resolution: [
1094
+ parseInt(guiOptions.Resolution.split('x')[0]),
1095
+ parseInt(guiOptions.Resolution.split('x')[1]),
1096
  parseInt(guiOptions.Resolution.split('x')[2])
1097
  ],
1098
  cameras: interpolatedCameras.map(cam => ({
 
1109
  }))
1110
  };
1111
 
1112
+ fetchWithAuth(requestUrl, {
 
1113
  method: 'POST',
1114
  headers: { 'Content-Type': 'application/json' },
1115
  mode: 'cors',
1116
+ body: JSON.stringify({ data: [JSON.stringify(requestData)] })
 
 
1117
  })
1118
  .then(response => response.json())
1119
  .then(data => {
1120
+ // Gradio 总是返回 event_id,需要使用 SSE 获取生成结果
1121
+ if (!data || !data.event_id) {
1122
+ throw new Error('Invalid Gradio response format - no event_id');
1123
+ }
1124
+ return fetchWithAuth(guiOptions.BackendAddress + `/gradio_api/call/gradio_generate/${data.event_id}`)
1125
+ .then(resp => {
1126
+ if (!resp.ok) throw new Error(`HTTP error! status: ${resp.status}`);
1127
+ return resp.text();
1128
+ })
1129
+ .then(sseText => {
1130
+ const lines = sseText.split('\n');
1131
+ let eventType = null;
1132
+ let dataContent = null;
1133
+ for (const line of lines) {
1134
+ if (line.startsWith('event: ')) eventType = line.substring(7);
1135
+ else if (line.startsWith('data: ')) dataContent = line.substring(6);
1136
+ }
1137
+ if (eventType !== 'complete' || !dataContent) {
1138
+ throw new Error('Gradio SSE response not complete or missing data');
1139
+ }
1140
+ const resultData = JSON.parse(dataContent);
1141
+ if (!resultData || resultData.length === 0) {
1142
+ throw new Error('Invalid Gradio generation result format');
1143
+ }
1144
+ const responseData = JSON.parse(resultData[0]);
1145
+ if (!responseData.success) {
1146
+ throw new Error('Gradio generation failed: ' + (responseData.error || 'Unknown error'));
1147
+ }
1148
+
1149
+ // 显示生成信息
1150
+ showGenerationInfo(responseData.generation_time, responseData.file_size);
1151
+ showDownloadProgress();
1152
+ updateStatus('Downloading generated scene...', cameraParams.length);
1153
+
1154
+ // ��载文件:调用 download_file 获取下载 event_id,然后通过 SSE 拿到 URL,再实际下载
1155
+ return fetchWithAuth(guiOptions.BackendAddress + '/gradio_api/call/download_file', {
1156
+ method: 'POST',
1157
+ headers: { 'Content-Type': 'application/json' },
1158
+ body: JSON.stringify({ data: [responseData.file_id] })
1159
  })
1160
+ .then(r => r.json())
1161
+ .then(downloadEvent => {
1162
+ return fetchWithAuth(guiOptions.BackendAddress + `/gradio_api/call/download_file/${downloadEvent.event_id}`)
1163
+ .then(r => {
1164
+ if (!r.ok) throw new Error(`HTTP error! status: ${r.status}`);
1165
+ return r.text();
1166
+ })
1167
+ .then(downloadSseText => {
1168
+ const lines = downloadSseText.split('\n');
1169
+ let eventType = null;
1170
+ let dataContent = null;
1171
+ for (const line of lines) {
1172
+ if (line.startsWith('event: ')) eventType = line.substring(7);
1173
+ else if (line.startsWith('data: ')) dataContent = line.substring(6);
1174
+ }
1175
+ if (eventType !== 'complete' || !dataContent) {
1176
+ throw new Error('Gradio download SSE response not complete or missing data');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1177
  }
1178
+ const fileData = JSON.parse(dataContent);
1179
+ if (!fileData || fileData.length === 0 || !fileData[0].url) {
1180
+ throw new Error('Invalid file data format from Gradio');
1181
+ }
1182
+ return fileData[0].url;
1183
+ });
1184
+ });
1185
+ })
1186
+ .then(fileUrl => {
1187
+ return fetchWithAuth(fileUrl).then(response => {
1188
+ if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
1189
+ const contentLength = response.headers.get('content-length');
1190
+ const total = parseInt(contentLength || '0', 10);
1191
+ let loaded = 0;
1192
+ const reader = response.body.getReader();
1193
+ const chunks = [];
1194
+ function pump() {
1195
+ return reader.read().then(({ done, value }) => {
1196
+ if (done) return new Blob(chunks);
1197
+ chunks.push(value);
1198
+ loaded += value.length;
1199
+ if (total) updateProgressBar((loaded / total) * 100);
1200
+ return pump();
1201
+ });
1202
  }
1203
+ return pump().then(blob => {
1204
+ const url = URL.createObjectURL(blob);
1205
+ return { url, __deleteAfterDownloadFileId: (typeof responseData !== 'undefined' ? responseData.file_id : null) };
1206
+ });
1207
  });
1208
+ });
 
 
1209
  })
1210
  .then(data => {
1211
+ if (data && data.url) {
1212
  updateStatus('Loading 3D scene...', cameraParams.length);
 
 
1213
  if (instructionSplat) {
1214
  scene.remove(instructionSplat);
1215
  console.log('Instruction splat removed');
1216
  }
 
1217
  const GeneratedSplat = new SplatMesh({ url: data.url });
1218
  scene.add(GeneratedSplat);
1219
+ currentGeneratedSplat = GeneratedSplat;
1220
  console.log('3D scene loaded successfully!');
1221
  updateStatus('Scene generated successfully!', cameraParams.length);
1222
  hideDownloadProgress();
1223
  showLoading(false);
1224
+
1225
+ // 通知后端删除文件(如果有 file_id)
1226
+ if (data.__deleteAfterDownloadFileId) {
1227
+ fetchWithAuth(guiOptions.BackendAddress + '/delete/' + data.__deleteAfterDownloadFileId, { method: 'POST' })
1228
+ .then(() => console.log('Delete notify sent'))
1229
+ .catch(err => console.warn('Delete notify failed', err));
1230
+ }
1231
  }
1232
  })
1233
  .catch(error => {
 
1605
 
1606
  // Step 1: Configure Generation Settings
1607
  const step1Folder = gui.addFolder('1. Configure Settings');
1608
+ step1Folder.add(guiOptions, "BackendAddress").name("Backend Address");
1609
+ step1Folder.add(guiOptions, "HF_TOKEN").name("HF Token");
1610
 
1611
  // FOV和Resolution控制器,初始时启用
1612
  const fovController = step1Folder.add(guiOptions, "FOV", 0, 120, 1).name("FOV").onChange((value) => {
 
1653
  const loadTrajectoryController = trajectoryFolder.add(guiOptions, "LoadTrajectoryFromJson").name("Load from JSON");
1654
  const saveTrajectoryController = trajectoryFolder.add(guiOptions, "saveTrajectoryToJson").name("Save Trajectory");
1655
 
1656
+ // 初始状态:禁用保存按钮(相机数量不够)
1657
+ saveTrajectoryController.disable();
1658
+
1659
  // 清理相机按钮
1660
  const clearAllCamerasController = trajectoryFolder.add(guiOptions, "clearAllCameras").name("Clear All Cameras");
1661
 
 
1722
 
1723
  // Store controllers globally so they can be accessed from guiOptions
1724
  window.fixGenerationFOVController = fixGenerationFOVController;
1725
+ window.saveTrajectoryController = saveTrajectoryController;
1726
 
1727
  // Step 3: Add Scene Prompts
1728
  const step3Folder = gui.addFolder('3. Add Scene Prompts');
 
2136
  if (loadTrajectoryOnly) {
2137
  updateStatus(`Trajectory loaded: ${cameras.length} cameras`, cameraParams.length);
2138
  } else {
2139
+ updateStatus(`JSON loaded: ${cameras.length} cameras`, cameraParams.length);
2140
  }
2141
  } catch (error) {
2142
  console.error("JSON data processing error:", error);
models/render.py CHANGED
@@ -6,8 +6,6 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
- from gsplat import rasterization
10
-
11
  # torch.backends.cuda.preferred_linalg_library(backend="magma")
12
 
13
  """"
@@ -17,6 +15,9 @@ class GaussianRendererWithCheckpoint(torch.autograd.Function):
17
  @staticmethod
18
  def render(xyz, feature, scale, rotation, opacity, test_c2w, test_intr,
19
  W, H, sh_degree, near_plane, far_plane, backgrounds):
 
 
 
20
  test_w2c = test_c2w.float().inverse().unsqueeze(0) # (1, 4, 4)
21
  test_intr_i = torch.zeros(3, 3).to(test_intr.device)
22
  test_intr_i[0, 0] = test_intr[0]
@@ -29,6 +30,7 @@ class GaussianRendererWithCheckpoint(torch.autograd.Function):
29
  test_w2c, test_intr_i, W, H, sh_degree=sh_degree,
30
  near_plane=near_plane, far_plane=far_plane,
31
  render_mode="RGB+D",
 
32
  backgrounds=backgrounds[None],
33
  rasterize_mode='classic') # (1, H, W, 4)
34
  # rendering[..., 3:] = rendering[..., 3:] + far_plane * (1 - alpha)
 
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
 
 
9
  # torch.backends.cuda.preferred_linalg_library(backend="magma")
10
 
11
  """"
 
15
  @staticmethod
16
  def render(xyz, feature, scale, rotation, opacity, test_c2w, test_intr,
17
  W, H, sh_degree, near_plane, far_plane, backgrounds):
18
+
19
+ from gsplat import rasterization
20
+
21
  test_w2c = test_c2w.float().inverse().unsqueeze(0) # (1, 4, 4)
22
  test_intr_i = torch.zeros(3, 3).to(test_intr.device)
23
  test_intr_i[0, 0] = test_intr[0]
 
30
  test_w2c, test_intr_i, W, H, sh_degree=sh_degree,
31
  near_plane=near_plane, far_plane=far_plane,
32
  render_mode="RGB+D",
33
+ tile_size=16,
34
  backgrounds=backgrounds[None],
35
  rasterize_mode='classic') # (1, H, W, 4)
36
  # rendering[..., 3:] = rendering[..., 3:] + far_plane * (1 - alpha)
packages.txt CHANGED
@@ -1,2 +1,4 @@
1
  libglm-dev
2
- ffmpeg
 
 
 
1
  libglm-dev
2
+ ffmpeg
3
+ gcc-11
4
+ g++-11
pre-requirements.txt CHANGED
@@ -14,4 +14,5 @@ ftfy==6.3.1
14
  flask==3.1.2
15
  gradio==5.49.1
16
  gsplat==1.5.2
17
- accelerate==1.10.1
 
 
14
  flask==3.1.2
15
  gradio==5.49.1
16
  gsplat==1.5.2
17
+ accelerate==1.10.1
18
+ nanobind=2.9.2
quant.py CHANGED
@@ -138,7 +138,7 @@ class FP8DynamicLinear(torch.nn.Module):
138
  super().__init__()
139
  self.weight = torch.nn.Parameter(weight, requires_grad=False)
140
  self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
141
- self.bias = bias
142
  self.native_fp8_support = native_fp8_support
143
  self.dtype = dtype
144
 
@@ -186,7 +186,6 @@ def FluxFp8GeMMProcessor(model: torch.nn.Module):
186
  weight_scale=weight_scale,
187
  bias=bias,
188
  native_fp8_support=native_fp8_support,
189
- dtype=linear.weight.dtype
190
  )
191
  replace_module(model, name, quant_linear)
192
  del linear.weight
 
138
  super().__init__()
139
  self.weight = torch.nn.Parameter(weight, requires_grad=False)
140
  self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
141
+ self.bias = torch.nn.Parameter(bias.to(dtype), requires_grad=False)
142
  self.native_fp8_support = native_fp8_support
143
  self.dtype = dtype
144
 
 
186
  weight_scale=weight_scale,
187
  bias=bias,
188
  native_fp8_support=native_fp8_support,
 
189
  )
190
  replace_module(model, name, quant_linear)
191
  del linear.weight
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
 
 
1
+ git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
2
+ gti+https://github.com/nianticlabs/spz.git@a4fc69e7948c7152e807e6501d73ddc9c149ce37
utils.py CHANGED
@@ -125,7 +125,7 @@ class TimestepEmbedding(nn.Module):
125
  else:
126
  return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0)
127
 
128
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
129
  def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
130
  """
131
  Create sinusoidal timestep embeddings.
@@ -341,7 +341,7 @@ def matrix_to_square(mat):
341
  elif l==4:
342
  return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2)
343
 
344
- def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None):
345
 
346
  sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1))
347
 
@@ -380,28 +380,54 @@ def export_ply_for_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=Non
380
  scales = scales.detach() #.cpu().numpy()
381
  rotations = rotations.detach() #.cpu().numpy()
382
 
383
- l = ['x', 'y', 'z']
384
- # All channels except the 3 DC
385
- for i in range(f_dc.shape[1]):
386
- l.append('f_dc_{}'.format(i))
387
- l.append('opacity')
388
- for i in range(scales.shape[1]):
389
- l.append('scale_{}'.format(i))
390
- for i in range(rotations.shape[1]):
391
- l.append('rot_{}'.format(i))
392
 
393
- dtype_full = [(attribute, 'f4') for attribute in l]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- # 最优化方案:使用numpy的recarray直接创建
396
- attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy()
397
 
398
- # 使用recarray直接创建,避免循环和类型转换
399
- elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l))
400
- el = PlyElement.describe(elements, 'vertex')
401
 
402
- print(path)
403
 
404
- PlyData([el]).write(path)
405
 
406
  # plydata = PlyData([el])
407
 
 
125
  else:
126
  return timestep_embedding(t, self.dim, self.max_period, self.time_factor) * self.weight.unsqueeze(0)
127
 
128
+ # @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
129
  def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
130
  """
131
  Create sinusoidal timestep embeddings.
 
341
  elif l==4:
342
  return torch.cat([mat, torch.tensor([0,0,0,1]).repeat(mat.shape[0],mat.shape[1],1,1).to(mat.device)],dim=2)
343
 
344
+ def export_gaussians(path, gaussians, opacity_threshold=0.00, T_norm=None):
345
 
346
  sh_degree = int(math.sqrt((gaussians.shape[-1] - sum([3, 1, 3, 4])) / 3 - 1))
347
 
 
380
  scales = scales.detach() #.cpu().numpy()
381
  rotations = rotations.detach() #.cpu().numpy()
382
 
383
+ """spz
384
+ Data Layout
385
+ The Python bindings maintain the same data layout as the C++ library:
 
 
 
 
 
 
386
 
387
+ Positions: [x1, y1, z1, x2, y2, z2, ...]
388
+ Scales: [sx1, sy1, sz1, sx2, sy2, sz2, ...] (log-scale)
389
+ Rotations: [x1, y1, z1, w1, x2, y2, z2, w2, ...] (quaternions)
390
+ Alphas: [a1, a2, a3, ...] (before sigmoid activation)
391
+ Colors: [r1, g1, b1, r2, g2, b2, ...] (base RGB)
392
+ Spherical Harmonics: Coefficient-major order, e.g., for degree 1: [sh1n1_r, sh1n1_g, sh1n1_b, sh10_r, sh10_g, sh10_b, sh1p1_r, sh1p1_g, sh1p1_b, ...]
393
+ """
394
+
395
+ import spz
396
+
397
+ cloud = spz.GaussianCloud()
398
+ cloud.sh_degree = sh_degree
399
+
400
+ cloud.positions = xyzs.flatten().cpu().numpy()
401
+ cloud.scales = scales.flatten().cpu().numpy()
402
+ cloud.rotations = rotations[:, [3, 0, 1, 2]].flatten().cpu().numpy()
403
+ cloud.alphas = opacities.flatten().cpu().numpy()
404
+ cloud.colors = f_dc[..., :3].flatten().cpu().numpy()
405
+ cloud.sh = f_dc[..., 3:].flatten().cpu().numpy()
406
+
407
+ spz.save_spz(cloud, spz.PackOptions(), path)
408
+
409
+ # l = ['x', 'y', 'z']
410
+ # # All channels except the 3 DC
411
+ # for i in range(f_dc.shape[1]):
412
+ # l.append('f_dc_{}'.format(i))
413
+ # l.append('opacity')
414
+ # for i in range(scales.shape[1]):
415
+ # l.append('scale_{}'.format(i))
416
+ # for i in range(rotations.shape[1]):
417
+ # l.append('rot_{}'.format(i))
418
+
419
+ # dtype_full = [(attribute, 'f4') for attribute in l]
420
 
421
+ # # 最优化方案:使用numpy的recarray直接创建
422
+ # attributes = torch.cat((xyzs, f_dc, opacities, scales, rotations), dim=1).cpu().numpy()
423
 
424
+ # # 使用recarray直接创建,避免循环和类型转换
425
+ # elements = np.rec.fromarrays([attributes[:, i] for i in range(attributes.shape[1])], names=l, formats=['f4'] * len(l))
426
+ # el = PlyElement.describe(elements, 'vertex')
427
 
428
+ # print(path)
429
 
430
+ # PlyData([el]).write(path)
431
 
432
  # plydata = PlyData([el])
433