Spaces:
Sleeping
Sleeping
| """ | |
| app.py | |
| An interactive demo of text-guided shape generation. | |
| """ | |
| from pathlib import Path | |
| from typing import Literal | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from salad.utils.spaghetti_util import ( | |
| get_mesh_from_spaghetti, | |
| generate_zc_from_sj_gaus, | |
| load_mesher, | |
| load_spaghetti, | |
| ) | |
| import hydra | |
| from omegaconf import OmegaConf | |
| import torch | |
| from pytorch_lightning import seed_everything | |
| def load_model( | |
| model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"], | |
| device, | |
| ): | |
| checkpoint_dir = Path(__file__).parent / "checkpoints" | |
| c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml") | |
| model = hydra.utils.instantiate(c) | |
| ckpt = torch.load( | |
| checkpoint_dir / f"{model_class}/state_only.ckpt", | |
| map_location=device, | |
| ) | |
| model.load_state_dict(ckpt) | |
| model.eval() | |
| for p in model.parameters(): p.requires_grad_(False) | |
| model = model.to(device) | |
| return model | |
| def run_inference(prompt: str): | |
| """The entry point of the demo.""" | |
| device: torch.device = torch.device("cuda") | |
| """Device to run the demo on.""" | |
| seed: int = 63 | |
| """Random seed for reproducibility.""" | |
| # set random seed | |
| seed_everything(seed) | |
| # load SPAGHETTI and mesher | |
| spaghetti = load_spaghetti(device) | |
| mesher = load_mesher(device) | |
| # load SALAD | |
| lang_phase1_model = load_model("lang_phase1", device) | |
| lang_phase2_model = load_model("phase2", device) | |
| lang_phase1_model._build_dataset("val") | |
| # run phase 1 | |
| extrinsics = lang_phase1_model.sampling_gaussians([prompt]) | |
| # run phase 2 | |
| intrinsics = lang_phase2_model.sample(extrinsics) | |
| # generate mesh | |
| zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics) | |
| vertices, faces = get_mesh_from_spaghetti( | |
| spaghetti, | |
| mesher, | |
| zcs[0], | |
| res=256, | |
| ) | |
| # plot | |
| figure = go.Figure( | |
| data=[ | |
| go.Mesh3d( | |
| x=vertices[:, 0], # flip front-back | |
| y=-vertices[:, 2], | |
| z=vertices[:, 1], | |
| i=faces[:, 0], | |
| j=faces[:, 1], | |
| k=faces[:, 2], | |
| color="gray", | |
| ) | |
| ], | |
| layout=dict( | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| ) | |
| ), | |
| ) | |
| return figure | |
| if __name__ == "__main__": | |
| title = "SALAD: Text-Guided Shape Generation" | |
| description_text = ''' | |
| This demo features text-guided 3D shape generation from our work <a href="https://arxiv.org/abs/2303.12236">SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023</a>. | |
| Please refer to our <a href="https://salad3d.github.io/">project page</a> for details. | |
| ''' | |
| # create UI | |
| with gr.Blocks(title=title) as demo: | |
| # description of demo | |
| gr.Markdown(description_text) | |
| # inputs | |
| with gr.Row(): | |
| prompt_textbox = gr.Textbox(placeholder="Describe a chair.") | |
| with gr.Row(): | |
| run_button = gr.Button(value="Generate") | |
| clear_button = gr.ClearButton( | |
| value="Clear", | |
| components=[prompt_textbox], | |
| ) | |
| # display examples | |
| examples = gr.Examples( | |
| examples=[ | |
| "an office chair", | |
| "a chair with armrests", | |
| "a chair without armrests", | |
| ], | |
| inputs=[prompt_textbox], | |
| ) | |
| # outputs | |
| mesh_viewport = gr.Plot() | |
| # run inference | |
| run_button.click( | |
| run_inference, | |
| inputs=[prompt_textbox], | |
| outputs=[mesh_viewport], | |
| ) | |
| demo.queue(max_size=30) | |
| demo.launch() |