try: import spaces GPU = spaces.GPU print("spaces GPU is available") except ImportError: def GPU(func): return func import os import subprocess try: import gsplat except ImportError: def install_cuda_toolkit(): # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run" CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run" CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) os.environ["CUDA_HOME"] = "/usr/local/cuda" os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( os.environ["CUDA_HOME"], "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], ) # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0+PTX" print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"]) subprocess.call('rm /usr/bin/gcc', shell=True) subprocess.call('rm /usr/bin/g++', shell=True) subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True) subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True) subprocess.call('gcc --version', shell=True) subprocess.call('g++ --version', shell=True) install_cuda_toolkit() os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0+PTX" os.environ["CUDA_HOME"] = "/usr/local/cuda" os.environ["PATH"] = "/usr/local/cuda/bin/:" + os.environ["PATH"] subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "9.0+PTX", "PATH": "/usr/local/cuda/bin/:" + os.environ["PATH"]}, shell=True) from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import gradio as gr import base64 import io from PIL import Image import torch import numpy as np import os import argparse import imageio import json import time import tempfile import shutil from huggingface_hub import hf_hub_download import einops import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import imageio from models import * from utils import * from transformers import T5TokenizerFast, UMT5EncoderModel from diffusers import FlowMatchEulerDiscreteScheduler os.environ["TOKENIZERS_PARALLELISM"] = "false" class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps return torch.argmin( (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item() class GenerationSystem(nn.Module): def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False): super().__init__() self.device = device self.offload_t5 = offload_t5 self.offload_vae = offload_vae self.latent_dim = 48 self.temporal_downsample_factor = 4 self.spatial_downsample_factor = 16 self.feat_dim = 1024 self.latent_patch_size = 2 self.denoising_steps = [0, 250, 500, 750] model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval() from models.autoencoder_kl_wan import WanCausalConv3d with torch.no_grad(): for name, module in self.vae.named_modules(): if isinstance(module, WanCausalConv3d): time_pad = module._padding[4] module.padding = (0, module._padding[2], module._padding[0]) module._padding = (0, 0, 0, 0, 0, 0) module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone()) self.vae.requires_grad_(False) self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device)) self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device)) self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer") self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu") self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False) self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim))) # self.transformer.rope.freqs_f[:] = self.transformer.rope.freqs_f[:1] weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1]) bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim) extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02 extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim) self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone()) self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone()) self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device) self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3) self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device)) self.transformer.disable_gradient_checkpointing() self.transformer.gradient_checkpointing = False self.add_feedback_for_transformer() if ckpt_path is not None: state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) self.transformer.load_state_dict(state_dict["transformer"]) self.recon_decoder.load_state_dict(state_dict["recon_decoder"]) print(f"Loaded {ckpt_path}.") from quant import FluxFp8GeMMProcessor FluxFp8GeMMProcessor(self.transformer) del self.vae.post_quant_conv, self.vae.decoder self.vae.to(self.device if not self.offload_vae else "cpu") self.vae.to(torch.bfloat16) self.transformer.to(self.device) def latent_scale_fn(self, x): return (x - self.latents_mean) / self.latents_std def latent_unscale_fn(self, x): return x * self.latents_std + self.latents_mean def add_feedback_for_transformer(self): self.use_feedback = True self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim))) def encode_text(self, texts): max_sequence_length = 512 text_inputs = self.tokenizer( texts, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, return_tensors="pt", ) if getattr(self, "offload_t5", False): text_input_ids = text_inputs.input_ids.to("cpu") mask = text_inputs.attention_mask.to("cpu") else: text_input_ids = text_inputs.input_ids.to(self.device) mask = text_inputs.attention_mask.to(self.device) seq_lens = mask.gt(0).sum(dim=1).long() if getattr(self, "offload_t5", False): with torch.no_grad(): text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device) else: text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)] text_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0 ) return text_embeds.float() def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True): out = self.transformer( hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1), timestep=t, encoder_hidden_states=text_embeds, return_dict=False, )[0] v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1) sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device) latents_pred_2d = noisy_latents - sigma * v_pred if need_3d_mode: scene_params = self.recon_decoder( einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2), einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2), cameras ).flatten(1, -2) images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white") latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode( einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float() ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype) return { '2d': latents_pred_2d, '3d': latents_pred_3d if need_3d_mode else None, 'rgb_3d': images_pred if need_3d_mode else None, 'scene': scene_params if need_3d_mode else None, 'feat': feats } @torch.no_grad() def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.vae.to(self.device) self.text_encoder.to(self.device if not self.offload_t5 else "cpu") self.transformer.to(self.device) self.recon_decoder.to(self.device) self.timesteps = self.timesteps.to(self.device) self.latents_mean = self.latents_mean.to(self.device) self.latents_std = self.latents_std.to(self.device) with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"): batch_size = 1 cameras = cameras.to(self.device).unsqueeze(0) if cameras.shape[1] != n_frame: render_cameras = cameras.clone() cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0) else: render_cameras = cameras cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None) render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None) text = "[Static] " + text text_embeds = self.encode_text([text]) # neg_text_embeds = self.encode_text([""]).repeat(batch_size, 1, 1) masks = torch.zeros(batch_size, n_frame, device=self.device) condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) if image is not None: image = image.to(self.device) latent = self.latent_scale_fn(self.vae.encode( image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float() ).latent_dist.sample().to(self.device)).squeeze(2) masks[:, image_index] = 1 condition_latents[:, :, image_index] = latent raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor) raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame) noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) noisy_latents = noise torch.cuda.empty_cache() if self.use_feedback: prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device) for i in range(len(self.denoising_steps)): t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device) t = self.timesteps[t_ids] if self.use_feedback: _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1) else: _condition_latents = condition_latents if i < len(self.denoising_steps) - 1: out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True) latents_pred = out["3d"] if self.use_feedback: prev_latents_pred = latents_pred prev_feats = out['feat'] noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise)) else: out = self.transformer( hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1), timestep=t, encoder_hidden_states=text_embeds, return_dict=False, )[0] v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1) sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device) latents_pred = noisy_latents - sigma * v_pred scene_params = self.recon_decoder( einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2), einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2), cameras ).flatten(1, -2) if video_output_path is not None: interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white") interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C') interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))] imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1) scene_params = scene_params[0] scene_params = scene_params.detach().cpu() return scene_params, ref_w2c, T_norm @GPU def process_generation_request(data, generation_system, cache_dir): """ Process the generation request with the same logic as Flask version """ try: image_prompt = data.get('image_prompt', None) text_prompt = data.get('text_prompt', "") cameras = data.get('cameras') resolution = data.get('resolution') image_index = data.get('image_index', 0) n_frame, image_height, image_width = resolution if not image_prompt and text_prompt == "": return {'error': 'No Prompts provided'} if image_prompt: # image_prompt可以是路径和base64 if os.path.exists(image_prompt): image_prompt = Image.open(image_prompt) else: # image_prompt 可能是 "data:image/png;base64,...." if ',' in image_prompt: image_prompt = image_prompt.split(',', 1)[1] try: image_bytes = base64.b64decode(image_prompt) image_prompt = Image.open(io.BytesIO(image_bytes)) except Exception as img_e: return {'error': f'Image decode error: {str(img_e)}'} image = image_prompt.convert('RGB') w, h = image.size # center crop if image_height / h > image_width / w: scale = image_height / h else: scale = image_width / w new_h = int(image_height / scale) new_w = int(image_width / scale) image = image.crop(((w - new_w) // 2, (h - new_h) // 2, new_w + (w - new_w) // 2, new_h + (h - new_h) // 2)).resize((image_width, image_height)) for camera in cameras: camera['fx'] = camera['fx'] * scale camera['fy'] = camera['fy'] * scale camera['cx'] = (camera['cx'] - (w - new_w) // 2) * scale camera['cy'] = (camera['cy'] - (h - new_h) // 2) * scale image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 * 2 - 1 else: image = None cameras = torch.stack([ torch.from_numpy(np.array([camera['quaternion'][0], camera['quaternion'][1], camera['quaternion'][2], camera['quaternion'][3], camera['position'][0], camera['position'][1], camera['position'][2], camera['fx'] / image_width, camera['fy'] / image_height, camera['cx'] / image_width, camera['cy'] / image_height], dtype=np.float32)) for camera in cameras ], dim=0) file_id = str(int(time.time() * 1000)) start_time = time.time() scene_params, ref_w2c, T_norm = generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=os.path.join(cache_dir, f'{file_id}.mp4')) end_time = time.time() print(f'生成时间: {end_time - start_time} 秒') with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f: json.dump(data, f) splat_path = os.path.join(cache_dir, f'{file_id}.ply') export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm) if not os.path.exists(splat_path): return {'error': f'{splat_path} not found'} file_size = os.path.getsize(splat_path) response_data = { 'success': True, 'file_id': file_id, 'file_path': splat_path, 'file_size': file_size, 'download_url': f'/download/{file_id}', 'generation_time': end_time - start_time, } return response_data except Exception as e: return {'error': f'Processing error: {str(e)}'} if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--port', type=int, default=7860) parser.add_argument("--ckpt", default=None) parser.add_argument("--cache_dir", type=str, default=None) parser.add_argument("--offload_t5", type=bool, default=False) parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks") args, _ = parser.parse_known_args() # Ensure model.ckpt exists, download if not present if args.ckpt is None: from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt") if not os.path.exists(ckpt_path): hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False) else: ckpt_path = args.ckpt if args.cache_dir is None or args.cache_dir == "": GRADIO_TEMP_DIR = tempfile.gettempdir() cache_dir = os.path.join(GRADIO_TEMP_DIR, "flashworld_gradio") else: cache_dir = args.cache_dir # Create cache directory os.makedirs(cache_dir, exist_ok=True) # Initialize GenerationSystem device = torch.device("cpu") generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device) # Create Gradio interface with gr.Blocks(title="FlashWorld Backend") as demo: gr.Markdown("# FlashWorld Generation Backend") gr.Markdown("This backend processes JSON requests for 3D scene generation.") with gr.Row(): with gr.Column(): json_input = gr.Textbox( label="JSON Input", placeholder="Enter JSON request here...", lines=10, value='{"image_prompt": null, "text_prompt": "A beautiful landscape", "cameras": [...], "resolution": [16, 480, 704], "image_index": 0}' ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(): json_output = gr.Textbox( label="JSON Output", lines=10, interactive=False ) # File download section gr.Markdown("## File Download") with gr.Row(): file_id_input = gr.Textbox( label="File ID", placeholder="Enter file ID to download..." ) download_btn = gr.Button("Download PLY File") download_output = gr.File(label="Downloaded File") def gradio_generate(json_input): """ Gradio interface function that processes JSON input and returns JSON output """ try: # Parse JSON input if isinstance(json_input, str): data = json.loads(json_input) else: data = json_input # Process the request result = process_generation_request(data, generation_system, cache_dir) # Return JSON response return json.dumps(result, indent=2) except Exception as e: error_response = {'error': f'JSON processing error: {str(e)}'} return json.dumps(error_response, indent=2) def download_file(file_id): """ Download generated PLY file """ file_path = os.path.join(cache_dir, f'{file_id}.ply') if not os.path.exists(file_path): return None return file_path # Event handlers generate_btn.click( fn=gradio_generate, inputs=[json_input], outputs=[json_output] ) download_btn.click( fn=download_file, inputs=[file_id_input], outputs=[download_output] ) # Example JSON format gr.Markdown(""" ## Example JSON Input Format: ```json { "image_prompt": null, "text_prompt": "A beautiful landscape with mountains and trees", "cameras": [ { "quaternion": [0, 0, 0, 1], "position": [0, 0, 5], "fx": 500, "fy": 500, "cx": 240, "cy": 240 }, { "quaternion": [0, 0, 0, 1], "position": [0, 0, 5], "fx": 500, "fy": 500, "cx": 240, "cy": 240 } ], "resolution": [16, 480, 704], "image_index": 0 } ``` """) from contextlib import asynccontextmanager @asynccontextmanager async def lifespan_ctx(app): app.state._cleanup_stop_event = asyncio.Event() app.state._cleanup_task = asyncio.create_task(periodic_cache_cleanup(app.state._cleanup_stop_event, cache_dir)) try: yield finally: if getattr(app.state, "_cleanup_stop_event", None): app.state._cleanup_stop_event.set() if getattr(app.state, "_cleanup_task", None): try: await app.state._cleanup_task except Exception: pass app = FastAPI(lifespan=lifespan_ctx) from starlette.responses import FileResponse @app.get("/app") async def read_index(): return FileResponse('index.html') app = gr.mount_gradio_app(app, demo, path="/") import uvicorn from fastapi.staticfiles import StaticFiles from fastapi import HTTPException import asyncio # 挂载静态文件目录,使其可以被访问。例如 /cache/ app.mount("/cache", StaticFiles(directory=cache_dir), name="cache") # 删除指定 file_id 的生成文件(以及相关的中间文件) @app.post("/delete/{file_id}") async def delete_generated_file(file_id: str): try: deleted = False # 关联的可能文件:.ply, .json, .mp4 for ext in (".ply", ".json", ".mp4"): p = os.path.join(cache_dir, f"{file_id}{ext}") if os.path.exists(p): try: os.remove(p) deleted = True except Exception: pass return {"success": True, "deleted": deleted} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # 定期清理创建/修改时间超过15分钟的文件 async def periodic_cache_cleanup(stop_event: asyncio.Event, directory: str, max_age_seconds: int = 15 * 60, interval_seconds: int = 300): while not stop_event.is_set(): try: now = time.time() for name in os.listdir(directory): path = os.path.join(directory, name) try: if os.path.isfile(path): mtime = os.path.getmtime(path) if (now - mtime) > max_age_seconds: try: os.remove(path) except Exception: pass except Exception: pass except Exception: pass try: await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds) except asyncio.TimeoutError: continue uvicorn.run(app, host="0.0.0.0", port=7860)