Spaces:
Paused
Paused
| import sys | |
| import os | |
| os.system("git clone https://github.com/dunbar12138/pix2pix3D.git") | |
| sys.path.append("pix2pix3D") | |
| from typing import List, Optional, Tuple, Union | |
| import dnnlib | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from tqdm import tqdm | |
| import legacy | |
| from camera_utils import LookAtPoseSampler | |
| from huggingface_hub import hf_hub_download | |
| from matplotlib import pyplot as plt | |
| from pathlib import Path | |
| import gradio as gr | |
| from training.utils import color_mask, color_list | |
| import plotly.graph_objects as go | |
| from tqdm import tqdm | |
| import imageio | |
| import trimesh | |
| import mcubes | |
| import copy | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import dnnlib | |
| from torch_utils import misc | |
| from legacy import * | |
| import io | |
| os.environ["PYOPENGL_PLATFORM"] = "egl" | |
| def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64): | |
| # return numpy array of forwarded sigma value | |
| # bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5 | |
| bound = nerf.rendering_kwargs['box_warp'] * 0.5 | |
| X = torch.linspace(-bound, bound, resolution).split(block_resolution) | |
| sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32) | |
| for xi, xs in enumerate(X): | |
| for yi, ys in enumerate(X): | |
| for zi, zs in enumerate(X): | |
| xx, yy, zz = torch.meshgrid(xs, ys, zs) | |
| pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C | |
| block_shape = [1, len(xs), len(ys), len(zs)] | |
| out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const') | |
| feat_out, sigma_out = out['rgb'], out['sigma'] | |
| sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \ | |
| yi * block_resolution: yi * block_resolution + len(ys), \ | |
| zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy() | |
| # print(feat_out.shape) | |
| return sigma_np, bound | |
| def extract_geometry(nerf, styles, resolution, threshold): | |
| # print('threshold: {}'.format(threshold)) | |
| u, bound = get_sigma_field_np(nerf, styles, resolution) | |
| vertices, faces = mcubes.marching_cubes(u, threshold) | |
| # vertices, faces, normals, values = skimage.measure.marching_cubes( | |
| # u, level=10 | |
| # ) | |
| b_min_np = np.array([-bound, -bound, -bound]) | |
| b_max_np = np.array([ bound, bound, bound]) | |
| vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
| return vertices.astype('float32'), faces | |
| def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'): | |
| frames, frames_label = [], [] | |
| for frame_idx in tqdm(range(num_frames)): | |
| cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames), | |
| 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames), | |
| torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device) | |
| pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) | |
| with torch.no_grad(): | |
| # out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)}) | |
| out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) | |
| frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)) | |
| frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8)) | |
| return frames, frames_label | |
| def return_plot_go(mesh_trimesh): | |
| x=np.asarray(mesh_trimesh.vertices).T[0] | |
| y=np.asarray(mesh_trimesh.vertices).T[1] | |
| z=np.asarray(mesh_trimesh.vertices).T[2] | |
| i=np.asarray(mesh_trimesh.faces).T[0] | |
| j=np.asarray(mesh_trimesh.faces).T[1] | |
| k=np.asarray(mesh_trimesh.faces).T[2] | |
| fig = go.Figure(go.Mesh3d(x=x, y=y, z=z, | |
| i=i, j=j, k=k, | |
| vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) , | |
| lighting=dict(ambient=0.5, | |
| diffuse=1, | |
| fresnel=4, | |
| specular=0.5, | |
| roughness=0.05, | |
| facenormalsepsilon=0, | |
| vertexnormalsepsilon=0), | |
| lightposition=dict(x=100, | |
| y=100, | |
| z=1000))) | |
| return fig | |
| network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main") | |
| models={"seg2cat":network_cat | |
| } | |
| device='cuda' if torch.cuda.is_available() else 'cpu' | |
| outdir="./" | |
| class CPU_Unpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if module == 'torch.storage' and name == '_load_from_bytes': | |
| return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
| return super().find_class(module, name) | |
| def load_network_pkl_cpu(f, force_fp16=False): | |
| data = CPU_Unpickler(f).load() | |
| # Legacy TensorFlow pickle => convert. | |
| if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): | |
| tf_G, tf_D, tf_Gs = data | |
| G = convert_tf_generator(tf_G) | |
| D = convert_tf_discriminator(tf_D) | |
| G_ema = convert_tf_generator(tf_Gs) | |
| data = dict(G=G, D=D, G_ema=G_ema) | |
| # Add missing fields. | |
| if 'training_set_kwargs' not in data: | |
| data['training_set_kwargs'] = None | |
| if 'augment_pipe' not in data: | |
| data['augment_pipe'] = None | |
| # Validate contents. | |
| assert isinstance(data['G'], torch.nn.Module) | |
| assert isinstance(data['D'], torch.nn.Module) | |
| assert isinstance(data['G_ema'], torch.nn.Module) | |
| assert isinstance(data['training_set_kwargs'], (dict, type(None))) | |
| assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) | |
| # Force FP16. | |
| if force_fp16: | |
| for key in ['G', 'D', 'G_ema']: | |
| old = data[key] | |
| kwargs = copy.deepcopy(old.init_kwargs) | |
| fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) | |
| fp16_kwargs.num_fp16_res = 4 | |
| fp16_kwargs.conv_clamp = 256 | |
| if kwargs != old.init_kwargs: | |
| new = type(old)(**kwargs).eval().requires_grad_(False) | |
| misc.copy_params_and_buffers(old, new, require_all=True) | |
| data[key] = new | |
| return data | |
| color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] | |
| def colormap2labelmap(color_img): | |
| im_base = np.zeros((color_img.shape[0], color_img.shape[1])) | |
| for idx, color in enumerate(color_list): | |
| k1=((color_img == np.asarray(color))[:,:,0])*1 | |
| k2=((color_img == np.asarray(color))[:,:,1])*1 | |
| k3=((color_img == np.asarray(color))[:,:,2])*1 | |
| k=((k1*k2*k3)==1) | |
| im_base[k] = idx | |
| return im_base | |
| def checklabelmap(img): | |
| labels=np.unique(img) | |
| for idx,label in enumerate(labels): | |
| img[img==label]=idx | |
| return img | |
| def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames): | |
| network=models[cfg] | |
| if device=="cpu": | |
| with dnnlib.util.open_url(network) as f: | |
| G = load_network_pkl_cpu(f)['G_ema'].eval().to(device) | |
| else: | |
| with dnnlib.util.open_url(network) as f: | |
| G = legacy.load_network_pkl(f)['G_ema'].eval().to(device) | |
| if cfg == 'seg2cat' or cfg == 'seg2face': | |
| neural_rendering_resolution = 128 | |
| data_type = 'seg' | |
| # Initialize pose sampler. | |
| forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), | |
| radius=G.rendering_kwargs['avg_camera_radius'], device=device) | |
| focal_length = 4.2647 # shapenet has higher FOV | |
| intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) | |
| forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) | |
| elif cfg == 'edge2car': | |
| neural_rendering_resolution = 64 | |
| data_type= 'edge' | |
| else: | |
| print('Invalid cfg') | |
| save_dir = Path(outdir) | |
| if isinstance(input,str): | |
| input_label =np.asarray( PIL.Image.open(input)) | |
| else: | |
| input_label=np.asarray(input) | |
| input_label=colormap2labelmap(input_label) | |
| input_label=checklabelmap(input_label) | |
| input_label = np.asarray(input_label).astype(np.uint8) | |
| input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device) | |
| input_pose = forward_pose.to(device) | |
| # Generate videos | |
| z = torch.from_numpy(np.random.RandomState(int(random_seed)).randn(1, G.z_dim).astype('float32')).to(device) | |
| with torch.no_grad(): | |
| ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose}) | |
| out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) | |
| image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8) | |
| image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8) | |
| mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.)) | |
| verts_np = np.array(mesh_trimesh.vertices) | |
| colors = torch.zeros((verts_np.shape[0], 3), device=device) | |
| semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device) | |
| samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float() | |
| head = 0 | |
| max_batch = 10000000 | |
| with tqdm(total = verts_np.shape[0]) as pbar: | |
| with torch.no_grad(): | |
| while head < verts_np.shape[0]: | |
| torch.manual_seed(0) | |
| out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const') | |
| # sigma = out['sigma'] | |
| colors[head:head+max_batch, :] = out['rgb'][0,:,:3] | |
| seg = out['rgb'][0, :, 32:32+6] | |
| semantic_colors[head:head+max_batch, :] = seg | |
| # semantics[:, head:head+max_batch] = out['semantic'] | |
| head += max_batch | |
| pbar.update(max_batch) | |
| semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)] | |
| mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8) | |
| frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device) | |
| # Save the video | |
| video=os.path.join(save_dir ,f'{cfg}_color.mp4') | |
| video_label=os.path.join(save_dir,f'{cfg}_label.mp4') | |
| imageio.mimsave(video, frames, fps=fps) | |
| imageio.mimsave(video_label, frames_label, fps=fps), | |
| fig_mesh=return_plot_go(mesh_trimesh) | |
| return fig_mesh,image_color,image_seg,video,video_label | |
| title="3D-aware Conditional Image Synthesis" | |
| desc=f''' | |
| [Arxiv: "3D-aware Conditional Image Synthesis".](https://arxiv.org/abs/2302.08509) | |
| [Project Page.](https://www.cs.cmu.edu/~pix2pix3D/) | |
| [For the official implementation.](https://github.com/dunbar12138/pix2pix3D) | |
| ### Future Work based on interest | |
| - Adding new models for new type objects | |
| - New Customization | |
| It is running on {device} | |
| The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device. | |
| ''' | |
| demo_inputs=[ | |
| gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat"), | |
| gr.Image(type="filepath",shape=(512, 512),label="Mask"), | |
| gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1), | |
| gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32), | |
| gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128), | |
| gr.Slider( minimum=10, maximum=120,label='FPS',value=30), | |
| gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30), | |
| ] | |
| demo_outputs=[ | |
| gr.Plot(label="Generated Mesh"), | |
| gr.Image(type="pil",shape=(256,256),label="Generated Image"), | |
| gr.Image(type="pil",shape=(256,256),label="Generated LabelMap"), | |
| gr.Video(label="Generated Video ") , | |
| gr.Video(label="Generated Label Video ") | |
| ] | |
| examples = [ | |
| ["seg2cat", "img.png", 1, 32, 128, 30, 30], | |
| ["seg2cat", "img2.png", 1, 32, 128, 30, 30], | |
| ["seg2cat", "img3.png", 1, 32, 128, 30, 30], | |
| ] | |
| demo_app = gr.Interface( | |
| fn=get_all, | |
| inputs=demo_inputs, | |
| outputs=demo_outputs, | |
| cache_examples=True, | |
| title=title, | |
| theme="huggingface", | |
| description=desc, | |
| examples=examples, | |
| ) | |
| demo_app.launch(debug=True, enable_queue=True) | |