hanquansanren commited on
Commit
b54f31e
·
1 Parent(s): 9847dc4
Files changed (1) hide show
  1. app.py +18 -81
app.py CHANGED
@@ -43,7 +43,7 @@ reg_model_bilin = register_model2((512,512), 'bilinear')
43
  def coords_grid_tensor(perturbed_img_shape):
44
  im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
45
  coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
46
- coords = th.from_numpy(coords).float().permute(2,0,1).to(dist_util.dev()) # (2, 512, 512)
47
  return coords.unsqueeze(0) # [2, 512, 512]
48
 
49
  def run_sample_lr_dewarping(
@@ -84,14 +84,7 @@ def run_sample_lr_dewarping(
84
  sample = th.clamp(sample, min=-1, max=1)
85
  return sample
86
 
87
- def visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow=None):
88
- os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
89
- # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
90
- warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample])
91
- warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
92
- warped_src = Image.fromarray((warped_src).astype(np.uint8))
93
 
94
- return warped_src
95
 
96
  def visualize_dewarping_single(settings, sample, source_vis):
97
  # os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
@@ -105,77 +98,20 @@ def visualize_dewarping_single(settings, sample, source_vis):
105
 
106
 
107
 
108
-
109
-
110
- def prepare_data(settings, batch_preprocessing, SIZE, data):
111
- if 'source_image_ori' in data:
112
- source_vis = data['source_image_ori'] # B, C, 512, 512 torch.uint8 cpu
113
- else:
114
- source_vis = data['source_image']
115
- if 'target_image' in data:
116
- target_vis = data['target_image']
117
- else:
118
- target_vis = None
119
-
120
- _, _, H_ori, W_ori = source_vis.shape
121
-
122
- source = data['source_image'].to(dist_util.dev()) # [1, 3, 914, 1380] torch.float32
123
- if 'source_image_0' in data:
124
- source_0 = data['source_image_0'].to(dist_util.dev())
125
- else:
126
- source_0 = None
127
- if 'target_image' in data:
128
- target = data['target_image'] # [1, 3, 914, 1380] torch.float32
129
- else:
130
- target = None
131
- if 'flow_map' in data:
132
- batch_ori = data['flow_map'] # [1, 2, 914, 1380] torch.float32
133
- else:
134
- batch_ori = None
135
- if 'flow_map_inter' in data:
136
- batch_ori_inter = data['flow_map_inter'] # [1, 2, 914, 1380] torch.float32
137
- else:
138
- batch_ori_inter = None
139
- if target is not None:
140
- target = F.interpolate(target, size=512, mode='bilinear', align_corners=False) # [1, 3, 512, 512]
141
- target_256 = data['target_image_256'].to(dist_util.dev()) # [1, 3, 256, 256]
142
- else:
143
- target = None
144
- target_256 = None
145
-
146
- if settings.env.eval_dataset == 'hp-240':# false
147
- source_256 = source
148
- target_256 = target
149
-
150
- else: # true
151
- data['source_image_256'] = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area')
152
- source_256 = data['source_image_256'].to(dist_util.dev())
153
-
154
- if 'target_image_256' in data:
155
- target_256 = data['target_image_256']
156
- else:
157
- target_256 = None
158
- if 'correspondence_mask' in data:
159
- mask = data['correspondence_mask'] # torch.bool [1, 914, 1380]
160
- else:
161
- mask = torch.ones((1, 512, 512), dtype=torch.bool).to(dist_util.dev()) # None
162
-
163
- return data, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
164
-
165
  def prepare_data_single(input_image, input_image_ori):
166
  source_vis = input_image_ori
167
  target_vis = None
168
  _, _, H_ori, W_ori = source_vis.shape
169
- source = input_image.to(dist_util.dev()) # [1, 3, 914, 1380] torch.float32
170
  source_0 = None
171
  target = None
172
  batch_ori = None
173
  batch_ori_inter = None
174
  target = None
175
  target_256 = None
176
- source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to(dist_util.dev())
177
  target_256 = None
178
- mask = torch.ones((1, 512, 512), dtype=torch.bool).to(dist_util.dev()) # None
179
 
180
  return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
181
 
@@ -201,16 +137,16 @@ def run_single_docunet(input_image_ori):
201
 
202
  os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
203
  batch_preprocessing = None
204
- pyramid = VGGPyramid(train=False).to(dist_util.dev())
205
  SIZE = None
206
 
207
 
208
  radius = 4
209
  raw_corr = None
210
- source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
211
 
212
  if settings.env.time_variant == True:
213
- init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to(dist_util.dev())
214
  else:
215
  init_feat = None
216
 
@@ -220,7 +156,7 @@ def run_single_docunet(input_image_ori):
220
  if settings.env.use_init_flow:
221
  init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
222
  else:
223
- init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to(dist_util.dev())
224
 
225
  (
226
  data,
@@ -375,13 +311,13 @@ settings.severity = 0
375
  settings.corruption_number = 0
376
 
377
 
378
- dist_util.setup_dist()
379
  logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
380
  logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
381
  logger.log("Loading model and diffusion...")
382
 
383
  model, diffusion = create_model_and_diffusion(
384
- device=dist_util.dev(),
385
  train_mode=settings.env.train_mode, # stage 1
386
  tv=settings.env.time_variant,
387
  **args_to_dict(settings, model_and_diffusion_defaults().keys()),
@@ -393,29 +329,30 @@ pretrained_dewarp_model = GeoTr_Seg_Inf()
393
  settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
394
  reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
395
  # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
396
- pretrained_dewarp_model.to(dist_util.dev())
397
  pretrained_dewarp_model.eval()
398
 
399
  if settings.env.use_line_mask:
400
  pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
401
  pretrained_seg_model = Seg()
402
  settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
403
- line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
404
  pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
405
- pretrained_line_seg_model.to(dist_util.dev())
406
  pretrained_line_seg_model.eval()
407
 
408
  settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
409
- seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
410
  pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
411
- pretrained_seg_model.to(dist_util.dev())
412
  pretrained_seg_model.eval()
413
 
414
  settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
415
- model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
 
416
  logger.log(f"Model loaded with {settings.env.model_path}")
417
 
418
- model.to(dist_util.dev())
419
  model.eval()
420
 
421
 
 
43
  def coords_grid_tensor(perturbed_img_shape):
44
  im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
45
  coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
46
+ coords = th.from_numpy(coords).float().permute(2,0,1).to('cuda') # (2, 512, 512)
47
  return coords.unsqueeze(0) # [2, 512, 512]
48
 
49
  def run_sample_lr_dewarping(
 
84
  sample = th.clamp(sample, min=-1, max=1)
85
  return sample
86
 
 
 
 
 
 
 
87
 
 
88
 
89
  def visualize_dewarping_single(settings, sample, source_vis):
90
  # os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
 
98
 
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def prepare_data_single(input_image, input_image_ori):
102
  source_vis = input_image_ori
103
  target_vis = None
104
  _, _, H_ori, W_ori = source_vis.shape
105
+ source = input_image.to('cuda') # [1, 3, 914, 1380] torch.float32
106
  source_0 = None
107
  target = None
108
  batch_ori = None
109
  batch_ori_inter = None
110
  target = None
111
  target_256 = None
112
+ source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to('cuda')
113
  target_256 = None
114
+ mask = torch.ones((1, 512, 512), dtype=torch.bool).to('cuda') # None
115
 
116
  return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
117
 
 
137
 
138
  os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
139
  batch_preprocessing = None
140
+ pyramid = VGGPyramid(train=False).to('cuda')
141
  SIZE = None
142
 
143
 
144
  radius = 4
145
  raw_corr = None
146
+ source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to('cuda')
147
 
148
  if settings.env.time_variant == True:
149
+ init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to('cuda')
150
  else:
151
  init_feat = None
152
 
 
156
  if settings.env.use_init_flow:
157
  init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
158
  else:
159
+ init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to('cuda')
160
 
161
  (
162
  data,
 
311
  settings.corruption_number = 0
312
 
313
 
314
+ # dist_util.setup_dist()
315
  logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
316
  logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
317
  logger.log("Loading model and diffusion...")
318
 
319
  model, diffusion = create_model_and_diffusion(
320
+ device='cuda',
321
  train_mode=settings.env.train_mode, # stage 1
322
  tv=settings.env.time_variant,
323
  **args_to_dict(settings, model_and_diffusion_defaults().keys()),
 
329
  settings.env.seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg.pth", token=token)
330
  reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
331
  # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
332
+ pretrained_dewarp_model.to('cuda')
333
  pretrained_dewarp_model.eval()
334
 
335
  if settings.env.use_line_mask:
336
  pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
337
  pretrained_seg_model = Seg()
338
  settings.env.line_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="line_model2.pth", token=token)
339
+ line_model_ckpt = pretrained_line_seg_model.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
340
  pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
341
+ pretrained_line_seg_model.to('cuda')
342
  pretrained_line_seg_model.eval()
343
 
344
  settings.env.new_seg_model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="seg_model.pth", token=token)
345
+ seg_model_ckpt = pretrained_seg_model.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
346
  pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
347
+ pretrained_seg_model.to('cuda')
348
  pretrained_seg_model.eval()
349
 
350
  settings.env.model_path = hf_hub_download(repo_id="hanquansanren/DvD", filename="model1852000.pt", token=token)
351
+ # model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
352
+ model.cpu().load_state_dict(settings.env.model_path, map_location="cpu", strict=False)
353
  logger.log(f"Model loaded with {settings.env.model_path}")
354
 
355
+ model.to('cuda')
356
  model.eval()
357
 
358