Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b54f31e
1
Parent(s):
9847dc4
app.py
CHANGED
|
@@ -43,7 +43,7 @@ reg_model_bilin = register_model2((512,512), 'bilinear')
|
|
| 43 |
def coords_grid_tensor(perturbed_img_shape):
|
| 44 |
im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
|
| 45 |
coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
|
| 46 |
-
coords = th.from_numpy(coords).float().permute(2,0,1).to(
|
| 47 |
return coords.unsqueeze(0) # [2, 512, 512]
|
| 48 |
|
| 49 |
def run_sample_lr_dewarping(
|
|
@@ -84,14 +84,7 @@ def run_sample_lr_dewarping(
|
|
| 84 |
sample = th.clamp(sample, min=-1, max=1)
|
| 85 |
return sample
|
| 86 |
|
| 87 |
-
def visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow=None):
|
| 88 |
-
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
|
| 89 |
-
# warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
|
| 90 |
-
warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample])
|
| 91 |
-
warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
|
| 92 |
-
warped_src = Image.fromarray((warped_src).astype(np.uint8))
|
| 93 |
|
| 94 |
-
return warped_src
|
| 95 |
|
| 96 |
def visualize_dewarping_single(settings, sample, source_vis):
|
| 97 |
# os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
|
|
@@ -105,77 +98,20 @@ def visualize_dewarping_single(settings, sample, source_vis):
|
|
| 105 |
|
| 106 |
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def prepare_data(settings, batch_preprocessing, SIZE, data):
|
| 111 |
-
if 'source_image_ori' in data:
|
| 112 |
-
source_vis = data['source_image_ori'] # B, C, 512, 512 torch.uint8 cpu
|
| 113 |
-
else:
|
| 114 |
-
source_vis = data['source_image']
|
| 115 |
-
if 'target_image' in data:
|
| 116 |
-
target_vis = data['target_image']
|
| 117 |
-
else:
|
| 118 |
-
target_vis = None
|
| 119 |
-
|
| 120 |
-
_, _, H_ori, W_ori = source_vis.shape
|
| 121 |
-
|
| 122 |
-
source = data['source_image'].to(dist_util.dev()) # [1, 3, 914, 1380] torch.float32
|
| 123 |
-
if 'source_image_0' in data:
|
| 124 |
-
source_0 = data['source_image_0'].to(dist_util.dev())
|
| 125 |
-
else:
|
| 126 |
-
source_0 = None
|
| 127 |
-
if 'target_image' in data:
|
| 128 |
-
target = data['target_image'] # [1, 3, 914, 1380] torch.float32
|
| 129 |
-
else:
|
| 130 |
-
target = None
|
| 131 |
-
if 'flow_map' in data:
|
| 132 |
-
batch_ori = data['flow_map'] # [1, 2, 914, 1380] torch.float32
|
| 133 |
-
else:
|
| 134 |
-
batch_ori = None
|
| 135 |
-
if 'flow_map_inter' in data:
|
| 136 |
-
batch_ori_inter = data['flow_map_inter'] # [1, 2, 914, 1380] torch.float32
|
| 137 |
-
else:
|
| 138 |
-
batch_ori_inter = None
|
| 139 |
-
if target is not None:
|
| 140 |
-
target = F.interpolate(target, size=512, mode='bilinear', align_corners=False) # [1, 3, 512, 512]
|
| 141 |
-
target_256 = data['target_image_256'].to(dist_util.dev()) # [1, 3, 256, 256]
|
| 142 |
-
else:
|
| 143 |
-
target = None
|
| 144 |
-
target_256 = None
|
| 145 |
-
|
| 146 |
-
if settings.env.eval_dataset == 'hp-240':# false
|
| 147 |
-
source_256 = source
|
| 148 |
-
target_256 = target
|
| 149 |
-
|
| 150 |
-
else: # true
|
| 151 |
-
data['source_image_256'] = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area')
|
| 152 |
-
source_256 = data['source_image_256'].to(dist_util.dev())
|
| 153 |
-
|
| 154 |
-
if 'target_image_256' in data:
|
| 155 |
-
target_256 = data['target_image_256']
|
| 156 |
-
else:
|
| 157 |
-
target_256 = None
|
| 158 |
-
if 'correspondence_mask' in data:
|
| 159 |
-
mask = data['correspondence_mask'] # torch.bool [1, 914, 1380]
|
| 160 |
-
else:
|
| 161 |
-
mask = torch.ones((1, 512, 512), dtype=torch.bool).to(dist_util.dev()) # None
|
| 162 |
-
|
| 163 |
-
return data, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
|
| 164 |
-
|
| 165 |
def prepare_data_single(input_image, input_image_ori):
|
| 166 |
source_vis = input_image_ori
|
| 167 |
target_vis = None
|
| 168 |
_, _, H_ori, W_ori = source_vis.shape
|
| 169 |
-
source = input_image.to(
|
| 170 |
source_0 = None
|
| 171 |
target = None
|
| 172 |
batch_ori = None
|
| 173 |
batch_ori_inter = None
|
| 174 |
target = None
|
| 175 |
target_256 = None
|
| 176 |
-
source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to(
|
| 177 |
target_256 = None
|
| 178 |
-
mask = torch.ones((1, 512, 512), dtype=torch.bool).to(
|
| 179 |
|
| 180 |
return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
|
| 181 |
|
|
@@ -201,16 +137,16 @@ def run_single_docunet(input_image_ori):
|
|
| 201 |
|
| 202 |
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
|
| 203 |
batch_preprocessing = None
|
| 204 |
-
pyramid = VGGPyramid(train=False).to(
|
| 205 |
SIZE = None
|
| 206 |
|
| 207 |
|
| 208 |
radius = 4
|
| 209 |
raw_corr = None
|
| 210 |
-
source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to(
|
| 211 |
|
| 212 |
if settings.env.time_variant == True:
|
| 213 |
-
init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to(
|
| 214 |
else:
|
| 215 |
init_feat = None
|
| 216 |
|
|
@@ -220,7 +156,7 @@ def run_single_docunet(input_image_ori):
|
|
| 220 |
if settings.env.use_init_flow:
|
| 221 |
init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
|
| 222 |
else:
|
| 223 |
-
init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to(
|
| 224 |
|
| 225 |
(
|
| 226 |
data,
|
|
@@ -375,13 +311,13 @@ settings.severity = 0
|
|
| 375 |
settings.corruption_number = 0
|
| 376 |
|
| 377 |
|
| 378 |
-
dist_util.setup_dist()
|
| 379 |
logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
|
| 380 |
logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
|
| 381 |
logger.log("Loading model and diffusion...")
|
| 382 |
|
| 383 |
model, diffusion = create_model_and_diffusion(
|
| 384 |
-
device=
|
| 385 |
train_mode=settings.env.train_mode, # stage 1
|
| 386 |
tv=settings.env.time_variant,
|
| 387 |
**args_to_dict(settings, model_and_diffusion_defaults().keys()),
|
|
@@ -393,29 +329,30 @@ pretrained_dewarp_model = GeoTr_Seg_Inf()
|
|
| 393 |
settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
|
| 394 |
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
|
| 395 |
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
|
| 396 |
-
pretrained_dewarp_model.to(
|
| 397 |
pretrained_dewarp_model.eval()
|
| 398 |
|
| 399 |
if settings.env.use_line_mask:
|
| 400 |
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
|
| 401 |
pretrained_seg_model = Seg()
|
| 402 |
settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
|
| 403 |
-
line_model_ckpt =
|
| 404 |
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
|
| 405 |
-
pretrained_line_seg_model.to(
|
| 406 |
pretrained_line_seg_model.eval()
|
| 407 |
|
| 408 |
settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
|
| 409 |
-
seg_model_ckpt =
|
| 410 |
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
|
| 411 |
-
pretrained_seg_model.to(
|
| 412 |
pretrained_seg_model.eval()
|
| 413 |
|
| 414 |
settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
|
| 415 |
-
model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
|
|
|
|
| 416 |
logger.log(f"Model loaded with {settings.env.model_path}")
|
| 417 |
|
| 418 |
-
model.to(
|
| 419 |
model.eval()
|
| 420 |
|
| 421 |
|
|
|
|
| 43 |
def coords_grid_tensor(perturbed_img_shape):
|
| 44 |
im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
|
| 45 |
coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
|
| 46 |
+
coords = th.from_numpy(coords).float().permute(2,0,1).to('cuda') # (2, 512, 512)
|
| 47 |
return coords.unsqueeze(0) # [2, 512, 512]
|
| 48 |
|
| 49 |
def run_sample_lr_dewarping(
|
|
|
|
| 84 |
sample = th.clamp(sample, min=-1, max=1)
|
| 85 |
return sample
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
|
|
|
| 88 |
|
| 89 |
def visualize_dewarping_single(settings, sample, source_vis):
|
| 90 |
# os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def prepare_data_single(input_image, input_image_ori):
|
| 102 |
source_vis = input_image_ori
|
| 103 |
target_vis = None
|
| 104 |
_, _, H_ori, W_ori = source_vis.shape
|
| 105 |
+
source = input_image.to('cuda') # [1, 3, 914, 1380] torch.float32
|
| 106 |
source_0 = None
|
| 107 |
target = None
|
| 108 |
batch_ori = None
|
| 109 |
batch_ori_inter = None
|
| 110 |
target = None
|
| 111 |
target_256 = None
|
| 112 |
+
source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to('cuda')
|
| 113 |
target_256 = None
|
| 114 |
+
mask = torch.ones((1, 512, 512), dtype=torch.bool).to('cuda') # None
|
| 115 |
|
| 116 |
return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
|
| 117 |
|
|
|
|
| 137 |
|
| 138 |
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
|
| 139 |
batch_preprocessing = None
|
| 140 |
+
pyramid = VGGPyramid(train=False).to('cuda')
|
| 141 |
SIZE = None
|
| 142 |
|
| 143 |
|
| 144 |
radius = 4
|
| 145 |
raw_corr = None
|
| 146 |
+
source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to('cuda')
|
| 147 |
|
| 148 |
if settings.env.time_variant == True:
|
| 149 |
+
init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to('cuda')
|
| 150 |
else:
|
| 151 |
init_feat = None
|
| 152 |
|
|
|
|
| 156 |
if settings.env.use_init_flow:
|
| 157 |
init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
|
| 158 |
else:
|
| 159 |
+
init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to('cuda')
|
| 160 |
|
| 161 |
(
|
| 162 |
data,
|
|
|
|
| 311 |
settings.corruption_number = 0
|
| 312 |
|
| 313 |
|
| 314 |
+
# dist_util.setup_dist()
|
| 315 |
logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
|
| 316 |
logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
|
| 317 |
logger.log("Loading model and diffusion...")
|
| 318 |
|
| 319 |
model, diffusion = create_model_and_diffusion(
|
| 320 |
+
device='cuda',
|
| 321 |
train_mode=settings.env.train_mode, # stage 1
|
| 322 |
tv=settings.env.time_variant,
|
| 323 |
**args_to_dict(settings, model_and_diffusion_defaults().keys()),
|
|
|
|
| 329 |
settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
|
| 330 |
reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
|
| 331 |
# reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
|
| 332 |
+
pretrained_dewarp_model.to('cuda')
|
| 333 |
pretrained_dewarp_model.eval()
|
| 334 |
|
| 335 |
if settings.env.use_line_mask:
|
| 336 |
pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
|
| 337 |
pretrained_seg_model = Seg()
|
| 338 |
settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
|
| 339 |
+
line_model_ckpt = pretrained_line_seg_model.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
|
| 340 |
pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
|
| 341 |
+
pretrained_line_seg_model.to('cuda')
|
| 342 |
pretrained_line_seg_model.eval()
|
| 343 |
|
| 344 |
settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
|
| 345 |
+
seg_model_ckpt = pretrained_seg_model.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
|
| 346 |
pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
|
| 347 |
+
pretrained_seg_model.to('cuda')
|
| 348 |
pretrained_seg_model.eval()
|
| 349 |
|
| 350 |
settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
|
| 351 |
+
# model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
|
| 352 |
+
model.cpu().load_state_dict(settings.env.model_path, map_location="cpu", strict=False)
|
| 353 |
logger.log(f"Model loaded with {settings.env.model_path}")
|
| 354 |
|
| 355 |
+
model.to('cuda')
|
| 356 |
model.eval()
|
| 357 |
|
| 358 |
|