import argparse import random from datetime import date from shutil import copyfile import cv2 as cv import numpy as np from spaces import GPU import torch import torch.backends.cudnn import admin.settings as ws_settings import os # os.environ["CUDA_VISIBLE_DEVICES"] = "7" # os.environ["OPENAI_LOGDIR"] = "./logs" # os.environ["MPI_DISABLED"] = "1" # os.environ.getattribute("HF_TOKEN") token = os.getenv("HF_TOKEN", None) import torch import torch.distributed as dist import torchvision.transforms as transforms from torch.utils.data import DataLoader import datasets from utils_data.image_transforms import ArrayToTensor from train_settings.dvd.improved_diffusion import dist_util, logger from train_settings.dvd.improved_diffusion.script_util import args_to_dict, create_model_and_diffusion,model_and_diffusion_defaults from train_settings.models.geotr.geotr_core import GeoTr_Seg_Inf, reload_segmodel, reload_model, Seg from train_settings.models.geotr.unet_model import UNet from PIL import Image from tqdm import tqdm import torch.nn.functional as F import torch as th from train_settings.dvd.improved_diffusion.gaussian_diffusion import GaussianDiffusion from train_settings.dvd.feature_backbones.VGG_features import VGGPyramid from train_settings.dvd.eval_utils import extract_raw_features_single,extract_raw_features_single2 from datasets.utils.warping import register_model2 import gradio as gr from huggingface_hub import hf_hub_download example_img_list = [] EXAMPLES = [ ["https://huggingface.co/hanquansanren/DvD/resolve/main/examples/25_1.png"], ["https://huggingface.co/hanquansanren/DvD/resolve/main/examples/3_2.png"], ] for name in ['3_2 copy.png', '25_1 copy.png']: local_path = [hf_hub_download( repo_id="hanquansanren/DvD", filename=f"examples/{name}", repo_type="model" )] example_img_list.append(local_path) reg_model_bilin = register_model2((512,512), 'bilinear') def coords_grid_tensor(perturbed_img_shape): im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])] coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先 coords = th.from_numpy(coords).float().permute(2,0,1).to('cuda') # (2, 512, 512) return coords.unsqueeze(0) # [2, 512, 512] def run_sample_lr_dewarping( settings, logger, diffusion, model, radius, source, feature_size, raw_corr, init_flow, c20, source_64, pyramid, doc_mask, seg_map_all=None, textline_map=None, init_feat=None ): model_kwsettings = {'init_flow': init_flow, 'src_feat': c20, 'src_64':None, 'y512':source, 'tmode':settings.env.train_mode, 'mask_cat': doc_mask, 'init_feat': init_feat, 'iter': settings.env.iter} # 'trg_feat': trg_feat # [1, 81, 64, 64] [1, 2, 64, 64] [1, 64, 64, 64] if settings.env.use_gt_mask == False: model_kwsettings['mask_y512'] = seg_map_all # [b, 384, 64, 64] if settings.env.use_line_mask == True: model_kwsettings['line_msk'] = textline_map # image_size_h, image_size_w = feature_size, feature_size logger.info(f"\nStarting sampling") sample, _ = diffusion.ddim_sample_loop( model, (1, 2, image_size_h, image_size_w), # 1,2,64,64 noise=None, clip_denoised=settings.env.clip_denoised, # false model_kwargs=model_kwsettings, eta=0.0, progress=True, denoised_fn=None, sampling_kwargs={'src_img': source}, # 'trg_img': target logger=logger, n_batch=settings.env.n_batch, time_variant = settings.env.time_variant, pyramid=pyramid ) sample = th.clamp(sample, min=-1, max=1) return sample def visualize_dewarping_single(settings, sample, source_vis): # os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981] warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample]) warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3) warped_src = Image.fromarray((warped_src).astype(np.uint8)) return warped_src def prepare_data_single(input_image, input_image_ori): source_vis = input_image_ori target_vis = None _, _, H_ori, W_ori = source_vis.shape source = input_image.to('cuda') # [1, 3, 914, 1380] torch.float32 source_0 = None target = None batch_ori = None batch_ori_inter = None target = None target_256 = None source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to('cuda') target_256 = None mask = torch.ones((1, 512, 512), dtype=torch.bool).to('cuda') # None return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0 @GPU def run_single_docunet(input_image_ori): input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3] # resize to 512x512 input_image_resized = cv.resize(input_image_ori, (512, 512)) # [512, 512, 3] # transpose to [3, 512, 512] input_image_ori = np.transpose(input_image_ori, (2, 0, 1)) # [3, 512, 512] input_image = np.transpose(input_image_resized, (2, 0, 1)) # [3, 512, 512] input_image = input_image / 255 input_image_ori = torch.tensor(input_image_ori).unsqueeze(0) # [1, 3, 512, 512] input_image = torch.tensor(input_image).unsqueeze(0).float() # [1, 3, 512, 512] os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True) batch_preprocessing = None pyramid = VGGPyramid(train=False).to('cuda') SIZE = None radius = 4 raw_corr = None source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to('cuda') if settings.env.time_variant == True: init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to('cuda') else: init_feat = None with torch.inference_mode(): ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1 ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288] if settings.env.use_init_flow: init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64] else: init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to('cuda') ( data, H_ori, # 512 W_ori, # 512 source, # [1, 3, 512, 512] 0-1 target, # None batch_ori, # None batch_ori_inter, # None source_256,# [1, 3, 256, 256] 0-1 target_256, # None source_vis, # [1, 3, H, W] cpu仅用于可视化 target_vis, # None mask, # [1, 512, 512] 全白 source_0 ) = prepare_data_single(input_image, input_image_ori) with torch.no_grad(): if settings.env.use_gt_mask == False: # ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256 mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288) hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False) hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False) hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False) hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False) hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False) hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False) seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64] # tv_save_image(mskx,"vis_hp/debug_vis/mskx.png") if settings.env.use_line_mask: textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256] textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) # [3, 64, 64, 64] else: seg_map_all = None textline_map = None if settings.env.train_VGG: c20 = None feature_size = 64 else: feature_size = 64 if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross': with th.no_grad(): c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64] # 平均互相关,VGG最浅层特征的下采样(512*512->64*64) else: with th.no_grad(): c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64] # 平均互相关,VGG最浅层特征的下采样(512*512->64*64) source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True) logger.info(f"Starting sampling with VGG Features") sample = run_sample_lr_dewarping( settings, logger, diffusion, model, radius, # 4 source, # [B, 3, 512, 512] 0~1 feature_size, # 64 raw_corr, # None init_flow, # [B, 2, 64, 64] -1~1 c20, # # [B, 64, 64, 64] source_64, # None pyramid, mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x seg_map_all, textline_map, init_feat ) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果 if settings.env.use_sr_net == False: sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场 # sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori # sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True) # sample = ( ((sample + base.to(sample.device)) )*2 - 1 ) sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99 ref_flow = None if ref_flow is not None: ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场 # ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori # ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1 # init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) else: raise ValueError("Invalid value") output = visualize_dewarping_single(settings, sample, source_vis) return output parser = argparse.ArgumentParser(description='Run a sampling scripts in train_settings.') parser.add_argument('--train_module', type=str, default='dvd', help='Name of module in the "train_settings/" folder.') parser.add_argument('--train_name', type=str, default='val_TDiff', help='Name of the train settings file.') parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed') parser.add_argument('--name', type=str, default="gradio", help='Name of the experiment') parser.add_argument('--corruption', action='store_true') # 默认为false,触发则为true args = parser.parse_args() args.seed = random.randint(0, 3000000) args.seed = torch.initial_seed() & (2 ** 32 - 1) print('Seed is {}'.format(args.seed)) random.seed(int(args.seed)) np.random.seed(args.seed) cudnn_benchmark=args.cudnn_benchmark seed=args.seed corruption=args.corruption name=args.name # This is needed to avoid strange crashes related to opencv cv.setNumThreads(0) torch.backends.cudnn.benchmark = cudnn_benchmark # dd/mm/YY today = date.today() d1 = today.strftime("%d/%m/%Y") print('Sampling: {} {}\nDate: {}'.format(args.train_module, args.train_name, d1)) settings = ws_settings.Settings() settings.module_name = args.train_module settings.script_name = args.train_name settings.project_path = 'train_settings/{}/{}'.format(args.train_module, args.train_name) # 'train_settings/DiffMatch/val_DiffMatch' settings.seed = seed settings.name = name save_dir = os.path.join(settings.env.workspace_dir, settings.project_path) # 'checkpoints+train_settings/DiffMatch/val_DiffMatch' if not os.path.exists(save_dir): os.makedirs(save_dir) copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py')) settings.severity = 0 settings.corruption_number = 0 # dist_util.setup_dist() logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}") logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}") logger.log("Loading model and diffusion...") model, diffusion = create_model_and_diffusion( device='cuda', train_mode=settings.env.train_mode, # stage 1 tv=settings.env.time_variant, **args_to_dict(settings, model_and_diffusion_defaults().keys()), ) setattr(diffusion, "settings", settings) pretrained_dewarp_model = GeoTr_Seg_Inf() settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token) reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path) # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path) pretrained_dewarp_model.to('cuda') pretrained_dewarp_model.eval() if settings.env.use_line_mask: pretrained_line_seg_model = UNet(n_channels=3, n_classes=1) pretrained_seg_model = Seg() settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token) # line_model_ckpt = pretrained_line_seg_model.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model'] line_model_ckpt = torch.load(settings.env.line_seg_model_path, map_location='cpu')['model'] pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True) pretrained_line_seg_model.to('cuda') pretrained_line_seg_model.eval() settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token) # seg_model_ckpt = pretrained_seg_model.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model'] seg_model_ckpt = torch.load(settings.env.new_seg_model_path, map_location='cpu')['model'] pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True) pretrained_seg_model.to('cuda') pretrained_seg_model.eval() settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token) # model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False) model_ckpt = torch.load(settings.env.model_path, map_location='cpu') model.cpu().load_state_dict(model_ckpt, strict=False) logger.log(f"Model loaded with {settings.env.model_path}") model.to('cuda') model.eval() if __name__ == '__main__': # demo = gr.Interface( # fn=run_single_docunet, # inputs=[ # gr.Image(type="pil", label="Input Image"), # ], # outputs=[ # gr.Image(type="numpy", label="Output Image"), # ], # title="Document Image Dewarping", # description="This is a demo for SIGGRAPH Asia 2025 paper 'DvD: Unleashing a Generative Paradigm for Document Dewarping via Coordinates-based Diffusion Model' ", # examples=EXAMPLES # ) with gr.Blocks() as demo: gr.Markdown("## Document Image Dewarping Demo") with gr.Row(): input_image = gr.Image(type="pil", label="Input Image") output_image = gr.Image(type="numpy", label="Output Image") # 加载 Examples 到输入框 gr.Examples( examples=EXAMPLES, inputs=[input_image], label="Click an example to load into Input Image" ) # 按钮运行函数 run_btn = gr.Button("Run") run_btn.click(fn=run_single_docunet, inputs=[input_image], outputs=[output_image]) # demo.launch(share=True, debug=True, server_name="10.7.88.77") demo.launch(ssr_mode=False)